custom_scoring_criteria.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. // 一个使用自定义评分规则搜索微博数据的例子
  2. //
  3. // 微博数据文件每行的格式是"<id>||||<timestamp>||||<uid>||||<reposts count>||||<text>"
  4. // <timestamp>, <reposts count>和<text>的文本长度做评分数据
  5. //
  6. // 自定义评分规则为:
  7. // 1. 首先排除关键词紧邻距离大于150个字节(五十个汉字)的微博
  8. // 2. 按照帖子距当前时间评分,精度为天,越晚的帖子评分越高
  9. // 3. 按照帖子BM25的整数部分排名
  10. // 4. 同一天的微博再按照转发数评分,转发越多的帖子评分越高
  11. // 5. 最后按照帖子长度评分,越长的帖子评分越高
  12. package main
  13. import (
  14. "bufio"
  15. "flag"
  16. "fmt"
  17. "github.com/huichen/wukong/engine"
  18. "github.com/huichen/wukong/types"
  19. "log"
  20. "os"
  21. "reflect"
  22. "strconv"
  23. "strings"
  24. )
  25. const (
  26. SecondsInADay = 86400
  27. MaxTokenProximity = 150
  28. )
  29. var (
  30. weibo_data = flag.String(
  31. "weibo_data",
  32. "../testdata/weibo_data.txt",
  33. "索引的微博帖子,每行当作一个文档")
  34. query = flag.String(
  35. "query",
  36. "chinajoy游戏",
  37. "待搜索的短语")
  38. dictionaries = flag.String(
  39. "dictionaries",
  40. "../data/dictionary.txt",
  41. "分词字典文件")
  42. stop_token_file = flag.String(
  43. "stop_token_file",
  44. "../data/stop_tokens.txt",
  45. "停用词文件")
  46. searcher = engine.Engine{}
  47. options = types.RankOptions{
  48. ScoringCriteria: WeiboScoringCriteria{},
  49. OutputOffset: 0,
  50. MaxOutputs: 100,
  51. }
  52. searchQueries = []string{}
  53. )
  54. // 微博评分字段
  55. type WeiboScoringFields struct {
  56. // 帖子的时间戳
  57. Timestamp uint32
  58. // 帖子的转发数
  59. RepostsCount uint32
  60. // 帖子的长度
  61. TextLength int
  62. }
  63. // 自定义的微博评分规则
  64. type WeiboScoringCriteria struct {
  65. }
  66. func (criteria WeiboScoringCriteria) Score(
  67. doc types.IndexedDocument, fields interface{}) []float32 {
  68. if doc.TokenProximity > MaxTokenProximity { // 评分第一步
  69. return []float32{}
  70. }
  71. if reflect.TypeOf(fields) != reflect.TypeOf(WeiboScoringFields{}) {
  72. return []float32{}
  73. }
  74. output := make([]float32, 4)
  75. wsf := fields.(WeiboScoringFields)
  76. output[0] = float32(wsf.Timestamp / SecondsInADay) // 评分第二步
  77. output[1] = float32(int(doc.BM25)) // 评分第三步
  78. output[2] = float32(wsf.RepostsCount) // 评分第四步
  79. output[3] = float32(wsf.TextLength) // 评分第五步
  80. return output
  81. }
  82. func main() {
  83. // 解析命令行参数
  84. flag.Parse()
  85. log.Printf("待搜索的短语为\"%s\"", *query)
  86. // 初始化
  87. searcher.Init(types.EngineInitOptions{
  88. SegmenterDictionaries: *dictionaries,
  89. StopTokenFile: *stop_token_file,
  90. IndexerInitOptions: &types.IndexerInitOptions{
  91. IndexType: types.LocationsIndex,
  92. },
  93. DefaultRankOptions: &options,
  94. })
  95. // 读入微博数据
  96. file, err := os.Open(*weibo_data)
  97. if err != nil {
  98. log.Fatal(err)
  99. }
  100. defer file.Close()
  101. log.Printf("读入文本 %s", *weibo_data)
  102. scanner := bufio.NewScanner(file)
  103. lines := []string{}
  104. fieldsSlice := []WeiboScoringFields{}
  105. for scanner.Scan() {
  106. data := strings.Split(scanner.Text(), "||||")
  107. if len(data) != 10 {
  108. continue
  109. }
  110. timestamp, _ := strconv.ParseUint(data[1], 10, 32)
  111. repostsCount, _ := strconv.ParseUint(data[4], 10, 32)
  112. text := data[9]
  113. if text != "" {
  114. lines = append(lines, text)
  115. fields := WeiboScoringFields{
  116. Timestamp: uint32(timestamp),
  117. RepostsCount: uint32(repostsCount),
  118. TextLength: len(text),
  119. }
  120. fieldsSlice = append(fieldsSlice, fields)
  121. }
  122. }
  123. log.Printf("读入%d条微博\n", len(lines))
  124. // 建立索引
  125. log.Print("建立索引")
  126. for i, text := range lines {
  127. searcher.IndexDocument(uint64(i),
  128. types.DocumentIndexData{Content: text, Fields: fieldsSlice[i]})
  129. }
  130. searcher.FlushIndex()
  131. log.Print("索引建立完毕")
  132. // 搜索
  133. log.Printf("开始查询")
  134. output := searcher.Search(types.SearchRequest{Text: *query})
  135. // 显示
  136. fmt.Println()
  137. for _, doc := range output.Docs {
  138. fmt.Printf("%v %s\n\n", doc.Scores, lines[doc.DocId])
  139. }
  140. log.Printf("查询完毕")
  141. }