ranker.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package core
  2. import (
  3. "github.com/huichen/wukong/types"
  4. "github.com/huichen/wukong/utils"
  5. "log"
  6. "sort"
  7. "sync"
  8. )
  9. type Ranker struct {
  10. lock struct {
  11. sync.RWMutex
  12. fields map[uint64]interface{}
  13. }
  14. initialized bool
  15. }
  16. func (ranker *Ranker) Init() {
  17. if ranker.initialized == true {
  18. log.Fatal("排序器不能初始化两次")
  19. }
  20. ranker.initialized = true
  21. ranker.lock.fields = make(map[uint64]interface{})
  22. }
  23. // 给某个文档添加评分字段
  24. func (ranker *Ranker) AddScoringFields(docId uint64, fields interface{}) {
  25. if ranker.initialized == false {
  26. log.Fatal("排序器尚未初始化")
  27. }
  28. ranker.lock.Lock()
  29. ranker.lock.fields[docId] = fields
  30. ranker.lock.Unlock()
  31. }
  32. // 删除某个文档的评分字段
  33. func (ranker *Ranker) RemoveScoringFields(docId uint64) {
  34. if ranker.initialized == false {
  35. log.Fatal("排序器尚未初始化")
  36. }
  37. ranker.lock.Lock()
  38. delete(ranker.lock.fields, docId)
  39. ranker.lock.Unlock()
  40. }
  41. // 给文档评分并排序
  42. func (ranker *Ranker) Rank(
  43. docs []types.IndexedDocument, options types.RankOptions) (outputDocs types.ScoredDocuments) {
  44. if ranker.initialized == false {
  45. log.Fatal("排序器尚未初始化")
  46. }
  47. // 对每个文档评分
  48. for _, d := range docs {
  49. ranker.lock.RLock()
  50. fs := ranker.lock.fields[d.DocId]
  51. ranker.lock.RUnlock()
  52. // 计算评分并剔除没有分值的文档
  53. scores := options.ScoringCriteria.Score(d, fs)
  54. if len(scores) > 0 {
  55. outputDocs = append(outputDocs, types.ScoredDocument{
  56. DocId: d.DocId,
  57. Scores: scores,
  58. TokenSnippetLocations: d.TokenSnippetLocations,
  59. TokenLocations: d.TokenLocations})
  60. }
  61. }
  62. // 排序
  63. if options.ReverseOrder {
  64. sort.Sort(sort.Reverse(outputDocs))
  65. } else {
  66. sort.Sort(outputDocs)
  67. }
  68. // 当用户要求只返回部分结果时返回部分结果
  69. var start, end int
  70. if options.MaxOutputs != 0 {
  71. start = utils.MinInt(options.OutputOffset, len(outputDocs))
  72. end = utils.MinInt(options.OutputOffset+options.MaxOutputs, len(outputDocs))
  73. } else {
  74. start = utils.MinInt(options.OutputOffset, len(outputDocs))
  75. end = len(outputDocs)
  76. }
  77. return outputDocs[start:end]
  78. }