ranker.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. package core
  2. import (
  3. "log"
  4. "sort"
  5. "sync"
  6. "github.com/huichen/wukong/types"
  7. "github.com/huichen/wukong/utils"
  8. )
  9. type Ranker struct {
  10. lock struct {
  11. sync.RWMutex
  12. fields map[string]interface{}
  13. docs map[string]bool
  14. }
  15. initialized bool
  16. }
  17. func (ranker *Ranker) Init() {
  18. if ranker.initialized == true {
  19. log.Fatal("排序器不能初始化两次")
  20. }
  21. ranker.initialized = true
  22. ranker.lock.fields = make(map[string]interface{})
  23. ranker.lock.docs = make(map[string]bool)
  24. }
  25. // 给某个文档添加评分字段
  26. func (ranker *Ranker) AddDoc(docId string, fields interface{}) {
  27. if ranker.initialized == false {
  28. log.Fatal("排序器尚未初始化")
  29. }
  30. ranker.lock.Lock()
  31. ranker.lock.fields[docId] = fields
  32. ranker.lock.docs[docId] = true
  33. ranker.lock.Unlock()
  34. }
  35. // 删除某个文档的评分字段
  36. func (ranker *Ranker) RemoveDoc(docId string) {
  37. if ranker.initialized == false {
  38. log.Fatal("排序器尚未初始化")
  39. }
  40. ranker.lock.Lock()
  41. delete(ranker.lock.fields, docId)
  42. delete(ranker.lock.docs, docId)
  43. ranker.lock.Unlock()
  44. }
  45. // 给文档评分并排序
  46. func (ranker *Ranker) Rank(
  47. docs []types.IndexedDocument, options types.RankOptions, countDocsOnly bool) (types.ScoredDocuments, int) {
  48. if ranker.initialized == false {
  49. log.Fatal("排序器尚未初始化")
  50. }
  51. // 对每个文档评分
  52. var outputDocs types.ScoredDocuments
  53. numDocs := 0
  54. for _, d := range docs {
  55. ranker.lock.RLock()
  56. // 判断doc是否存在
  57. if _, ok := ranker.lock.docs[d.DocId]; ok {
  58. fs := ranker.lock.fields[d.DocId]
  59. ranker.lock.RUnlock()
  60. // 计算评分并剔除没有分值的文档
  61. scores := options.ScoringCriteria.Score(d, fs)
  62. if len(scores) > 0 {
  63. if !countDocsOnly {
  64. outputDocs = append(outputDocs, types.ScoredDocument{
  65. DocId: d.DocId,
  66. Scores: scores,
  67. TokenSnippetLocations: d.TokenSnippetLocations,
  68. TokenLocations: d.TokenLocations})
  69. }
  70. numDocs++
  71. }
  72. } else {
  73. ranker.lock.RUnlock()
  74. }
  75. }
  76. // 排序
  77. if !countDocsOnly {
  78. if options.ReverseOrder {
  79. sort.Sort(sort.Reverse(outputDocs))
  80. } else {
  81. sort.Sort(outputDocs)
  82. }
  83. // 当用户要求只返回部分结果时返回部分结果
  84. var start, end int
  85. if options.MaxOutputs != 0 {
  86. start = utils.MinInt(options.OutputOffset, len(outputDocs))
  87. end = utils.MinInt(options.OutputOffset+options.MaxOutputs, len(outputDocs))
  88. } else {
  89. start = utils.MinInt(options.OutputOffset, len(outputDocs))
  90. end = len(outputDocs)
  91. }
  92. return outputDocs[start:end], numDocs
  93. }
  94. return outputDocs, numDocs
  95. }