custom_scoring_criteria.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. // 一个使用自定义评分规则搜索微博数据的例子
  2. //
  3. // 微博数据文件每行的格式是"<id>||||<timestamp>||||<uid>||||<reposts count>||||<text>"
  4. // <timestamp>, <reposts count>和<text>的文本长度做评分数据
  5. //
  6. // 自定义评分规则为:
  7. // 1. 首先排除关键词紧邻距离大于150个字节(五十个汉字)的微博
  8. // 2. 按照帖子距当前时间评分,精度为天,越晚的帖子评分越高
  9. // 3. 按照帖子BM25的整数部分排名
  10. // 4. 同一天的微博再按照转发数评分,转发越多的帖子评分越高
  11. // 5. 最后按照帖子长度评分,越长的帖子评分越高
  12. package main
  13. import (
  14. "bufio"
  15. "encoding/gob"
  16. "flag"
  17. "fmt"
  18. "log"
  19. "os"
  20. "reflect"
  21. "strconv"
  22. "strings"
  23. "github.com/huichen/wukong/engine"
  24. "github.com/huichen/wukong/types"
  25. )
  26. const (
  27. SecondsInADay = 86400
  28. MaxTokenProximity = 150
  29. )
  30. var (
  31. weibo_data = flag.String(
  32. "weibo_data",
  33. "../testdata/weibo_data.txt",
  34. "索引的微博帖子,每行当作一个文档")
  35. query = flag.String(
  36. "query",
  37. "chinajoy游戏",
  38. "待搜索的短语")
  39. dictionaries = flag.String(
  40. "dictionaries",
  41. "../data/dictionary.txt",
  42. "分词字典文件")
  43. stop_token_file = flag.String(
  44. "stop_token_file",
  45. "../data/stop_tokens.txt",
  46. "停用词文件")
  47. searcher = engine.Engine{}
  48. options = types.RankOptions{
  49. ScoringCriteria: WeiboScoringCriteria{},
  50. OutputOffset: 0,
  51. MaxOutputs: 100,
  52. }
  53. searchQueries = []string{}
  54. )
  55. // 微博评分字段
  56. type WeiboScoringFields struct {
  57. // 帖子的时间戳
  58. Timestamp uint32
  59. // 帖子的转发数
  60. RepostsCount uint32
  61. // 帖子的长度
  62. TextLength int
  63. }
  64. // 自定义的微博评分规则
  65. type WeiboScoringCriteria struct {
  66. }
  67. func (criteria WeiboScoringCriteria) Score(
  68. doc types.IndexedDocument, fields interface{}) []float32 {
  69. if doc.TokenProximity > MaxTokenProximity { // 评分第一步
  70. return []float32{}
  71. }
  72. if reflect.TypeOf(fields) != reflect.TypeOf(WeiboScoringFields{}) {
  73. return []float32{}
  74. }
  75. output := make([]float32, 4)
  76. wsf := fields.(WeiboScoringFields)
  77. output[0] = float32(wsf.Timestamp / SecondsInADay) // 评分第二步
  78. output[1] = float32(int(doc.BM25)) // 评分第三步
  79. output[2] = float32(wsf.RepostsCount) // 评分第四步
  80. output[3] = float32(wsf.TextLength) // 评分第五步
  81. return output
  82. }
  83. func main() {
  84. // 解析命令行参数
  85. flag.Parse()
  86. log.Printf("待搜索的短语为\"%s\"", *query)
  87. // 初始化
  88. gob.Register(WeiboScoringFields{})
  89. searcher.Init(types.EngineInitOptions{
  90. SegmenterDictionaries: *dictionaries,
  91. StopTokenFile: *stop_token_file,
  92. IndexerInitOptions: &types.IndexerInitOptions{
  93. IndexType: types.LocationsIndex,
  94. },
  95. DefaultRankOptions: &options,
  96. })
  97. defer searcher.Close()
  98. // 读入微博数据
  99. file, err := os.Open(*weibo_data)
  100. if err != nil {
  101. log.Fatal(err)
  102. }
  103. defer file.Close()
  104. log.Printf("读入文本 %s", *weibo_data)
  105. scanner := bufio.NewScanner(file)
  106. lines := []string{}
  107. fieldsSlice := []WeiboScoringFields{}
  108. for scanner.Scan() {
  109. data := strings.Split(scanner.Text(), "||||")
  110. if len(data) != 10 {
  111. continue
  112. }
  113. timestamp, _ := strconv.ParseUint(data[1], 10, 32)
  114. repostsCount, _ := strconv.ParseUint(data[4], 10, 32)
  115. text := data[9]
  116. if text != "" {
  117. lines = append(lines, text)
  118. fields := WeiboScoringFields{
  119. Timestamp: uint32(timestamp),
  120. RepostsCount: uint32(repostsCount),
  121. TextLength: len(text),
  122. }
  123. fieldsSlice = append(fieldsSlice, fields)
  124. }
  125. }
  126. log.Printf("读入%d条微博\n", len(lines))
  127. // 建立索引
  128. log.Print("建立索引")
  129. for i, text := range lines {
  130. searcher.IndexDocumentS(fmt.Sprintf("line:%d", i),
  131. types.DocumentIndexData{Content: text, Fields: fieldsSlice[i]}, false)
  132. }
  133. searcher.FlushIndex()
  134. log.Print("索引建立完毕")
  135. // 搜索
  136. log.Printf("开始查询")
  137. output := searcher.Search(types.SearchRequest{Text: *query})
  138. // 显示
  139. fmt.Println()
  140. for _, doc := range output.Docs {
  141. parts := strings.Split(doc.DocId, ":")
  142. index, _ := strconv.ParseInt(parts[1], 10, 64)
  143. fmt.Printf("%v score:%v %s\n\n", doc.DocId, doc.Scores, lines[int(index)])
  144. }
  145. log.Printf("查询完毕")
  146. }