search_server.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. // 一个微博搜索的例子。
  2. package main
  3. import (
  4. "bufio"
  5. "encoding/gob"
  6. "encoding/json"
  7. "flag"
  8. "github.com/huichen/wukong/engine"
  9. "github.com/huichen/wukong/types"
  10. "io"
  11. "log"
  12. "net/http"
  13. "os"
  14. "os/signal"
  15. "reflect"
  16. "strconv"
  17. "strings"
  18. )
  19. const (
  20. SecondsInADay = 86400
  21. MaxTokenProximity = 2
  22. )
  23. var (
  24. searcher = engine.Engine{}
  25. wbs = map[uint64]Weibo{}
  26. )
  27. type Weibo struct {
  28. Id uint64 `json:"id"`
  29. Timestamp uint64 `json:"timestamp"`
  30. UserName string `json:"user_name"`
  31. RepostsCount uint64 `json:"reposts_count"`
  32. Text string `json:"text"`
  33. }
  34. /*******************************************************************************
  35. 索引
  36. *******************************************************************************/
  37. func indexWeibo() {
  38. // 读入微博数据
  39. file, err := os.Open("../../testdata/weibo_data.txt")
  40. if err != nil {
  41. log.Fatal(err)
  42. }
  43. defer file.Close()
  44. scanner := bufio.NewScanner(file)
  45. for scanner.Scan() {
  46. data := strings.Split(scanner.Text(), "||||")
  47. if len(data) != 10 {
  48. continue
  49. }
  50. wb := Weibo{}
  51. wb.Id, _ = strconv.ParseUint(data[0], 10, 64)
  52. wb.Timestamp, _ = strconv.ParseUint(data[1], 10, 64)
  53. wb.UserName = data[3]
  54. wb.RepostsCount, _ = strconv.ParseUint(data[4], 10, 64)
  55. wb.Text = data[9]
  56. wbs[wb.Id] = wb
  57. }
  58. log.Print("添加索引")
  59. for docId, weibo := range wbs {
  60. searcher.IndexDocument(docId, types.DocumentIndexData{
  61. Content: weibo.Text,
  62. Fields: WeiboScoringFields{
  63. Timestamp: weibo.Timestamp,
  64. RepostsCount: weibo.RepostsCount,
  65. },
  66. })
  67. }
  68. searcher.FlushIndex()
  69. log.Printf("索引了%d条微博\n", len(wbs))
  70. }
  71. /*******************************************************************************
  72. 评分
  73. *******************************************************************************/
  74. type WeiboScoringFields struct {
  75. Timestamp uint64
  76. RepostsCount uint64
  77. }
  78. type WeiboScoringCriteria struct {
  79. }
  80. func (criteria WeiboScoringCriteria) Score(
  81. doc types.IndexedDocument, fields interface{}) []float32 {
  82. if reflect.TypeOf(fields) != reflect.TypeOf(WeiboScoringFields{}) {
  83. return []float32{}
  84. }
  85. wsf := fields.(WeiboScoringFields)
  86. output := make([]float32, 3)
  87. if doc.TokenProximity > MaxTokenProximity {
  88. output[0] = 1.0 / float32(doc.TokenProximity)
  89. } else {
  90. output[0] = 1.0
  91. }
  92. output[1] = float32(wsf.Timestamp / (SecondsInADay * 3))
  93. output[2] = float32(doc.BM25 * (1 + float32(wsf.RepostsCount)/10000))
  94. return output
  95. }
  96. /*******************************************************************************
  97. JSON-RPC
  98. *******************************************************************************/
  99. type JsonResponse struct {
  100. Docs []*Weibo `json:"docs"`
  101. }
  102. func JsonRpcServer(w http.ResponseWriter, req *http.Request) {
  103. query := req.URL.Query().Get("query")
  104. output := searcher.Search(types.SearchRequest{
  105. Text: query,
  106. RankOptions: &types.RankOptions{
  107. ScoringCriteria: &WeiboScoringCriteria{},
  108. OutputOffset: 0,
  109. MaxOutputs: 100,
  110. },
  111. })
  112. // 整理为输出格式
  113. docs := []*Weibo{}
  114. for _, doc := range output.Docs {
  115. wb := wbs[doc.DocId]
  116. for _, t := range output.Tokens {
  117. wb.Text = strings.Replace(wb.Text, t, "<font color=red>"+t+"</font>", -1)
  118. }
  119. docs = append(docs, &wb)
  120. }
  121. response, _ := json.Marshal(&JsonResponse{Docs: docs})
  122. w.Header().Set("Content-Type", "application/json")
  123. io.WriteString(w, string(response))
  124. }
  125. /*******************************************************************************
  126. 主函数
  127. *******************************************************************************/
  128. func main() {
  129. // 解析命令行参数
  130. flag.Parse()
  131. // 初始化
  132. gob.Register(WeiboScoringFields{})
  133. searcher.Init(types.EngineInitOptions{
  134. SegmenterDictionaries: "../../data/dictionary.txt",
  135. StopTokenFile: "../../data/stop_tokens.txt",
  136. IndexerInitOptions: &types.IndexerInitOptions{
  137. IndexType: types.LocationsIndex,
  138. },
  139. UsePersistentStorage: true,
  140. PersistentStorageFolder: "db",
  141. })
  142. wbs = make(map[uint64]Weibo)
  143. // 索引
  144. go indexWeibo()
  145. // 捕获ctrl-c
  146. c := make(chan os.Signal, 1)
  147. signal.Notify(c, os.Interrupt)
  148. go func(){
  149. for _ = range c {
  150. log.Print("捕获Ctrl-c,退出服务器")
  151. searcher.Close()
  152. os.Exit(0)
  153. }
  154. }()
  155. http.HandleFunc("/json", JsonRpcServer)
  156. http.Handle("/", http.FileServer(http.Dir("static")))
  157. log.Print("服务器启动")
  158. http.ListenAndServe("localhost:8080", nil)
  159. }