engine.go 8.3 KB


  1. package engine
  2. import (
  3. "fmt"
  4. "github.com/huichen/murmur"
  5. "github.com/huichen/sego"
  6. "github.com/huichen/wukong/core"
  7. "github.com/huichen/wukong/types"
  8. "github.com/huichen/wukong/utils"
  9. "log"
  10. "runtime"
  11. "sort"
  12. "sync/atomic"
  13. "time"
  14. )
  15. const (
  16. NumNanosecondsInAMillisecond = 1000000
  17. )
  18. type Engine struct {
  19. // 记录初始化参数
  20. initOptions types.EngineInitOptions
  21. initialized bool
  22. indexers []core.Indexer
  23. rankers []core.Ranker
  24. segmenter sego.Segmenter
  25. stopTokens StopTokens
  26. // 建立索引器使用的通信通道
  27. segmenterChannel chan segmenterRequest
  28. indexerAddDocumentChannels []chan indexerAddDocumentRequest
  29. rankerAddScoringFieldsChannels []chan rankerAddScoringFieldsRequest
  30. // 建立排序器使用的通信通道
  31. indexerLookupChannels []chan indexerLookupRequest
  32. rankerRankChannels []chan rankerRankRequest
  33. rankerRemoveScoringFieldsChannels []chan rankerRemoveScoringFieldsRequest
  34. // 计数器,用来统计有多少文档被索引等信息
  35. numDocumentsIndexed uint64
  36. numIndexingRequests uint64
  37. numTokenIndexAdded uint64
  38. }
  39. func (engine *Engine) Init(options types.EngineInitOptions) {
  40. // 将线程数设置为CPU数
  41. runtime.GOMAXPROCS(runtime.NumCPU())
  42. // 初始化初始参数
  43. if engine.initialized {
  44. log.Fatal("请勿重复初始化引擎")
  45. }
  46. options.Init()
  47. engine.initOptions = options
  48. engine.initialized = true
  49. // 载入分词器词典
  50. engine.segmenter.LoadDictionary(options.SegmenterDictionaries)
  51. // 初始化停用词
  52. engine.stopTokens.Init(options.StopTokenFile)
  53. // 初始化索引器和排序器
  54. for shard := 0; shard < options.NumShards; shard++ {
  55. engine.indexers = append(engine.indexers, core.Indexer{})
  56. engine.indexers[shard].Init(*options.IndexerInitOptions)
  57. engine.rankers = append(engine.rankers, core.Ranker{})
  58. engine.rankers[shard].Init()
  59. }
  60. // 初始化分词器通道
  61. engine.segmenterChannel = make(
  62. chan segmenterRequest, options.NumSegmenterThreads)
  63. // 初始化索引器通道
  64. engine.indexerAddDocumentChannels = make(
  65. []chan indexerAddDocumentRequest, options.NumShards)
  66. engine.indexerLookupChannels = make(
  67. []chan indexerLookupRequest, options.NumShards)
  68. for shard := 0; shard < options.NumShards; shard++ {
  69. engine.indexerAddDocumentChannels[shard] = make(
  70. chan indexerAddDocumentRequest,
  71. options.IndexerBufferLength)
  72. engine.indexerLookupChannels[shard] = make(
  73. chan indexerLookupRequest,
  74. options.IndexerBufferLength)
  75. }
  76. // 初始化排序器通道
  77. engine.rankerAddScoringFieldsChannels = make(
  78. []chan rankerAddScoringFieldsRequest, options.NumShards)
  79. engine.rankerRankChannels = make(
  80. []chan rankerRankRequest, options.NumShards)
  81. engine.rankerRemoveScoringFieldsChannels = make(
  82. []chan rankerRemoveScoringFieldsRequest, options.NumShards)
  83. for shard := 0; shard < options.NumShards; shard++ {
  84. engine.rankerAddScoringFieldsChannels[shard] = make(
  85. chan rankerAddScoringFieldsRequest,
  86. options.RankerBufferLength)
  87. engine.rankerRankChannels[shard] = make(
  88. chan rankerRankRequest,
  89. options.RankerBufferLength)
  90. engine.rankerRemoveScoringFieldsChannels[shard] = make(
  91. chan rankerRemoveScoringFieldsRequest,
  92. options.RankerBufferLength)
  93. }
  94. // 启动分词器
  95. for iThread := 0; iThread < options.NumSegmenterThreads; iThread++ {
  96. go engine.segmenterWorker()
  97. }
  98. // 启动索引器和排序器
  99. for shard := 0; shard < options.NumShards; shard++ {
  100. go engine.indexerAddDocumentWorker(shard)
  101. go engine.rankerAddScoringFieldsWorker(shard)
  102. go engine.rankerRemoveScoringFieldsWorker(shard)
  103. for i := 0; i < options.NumIndexerThreadsPerShard; i++ {
  104. go engine.indexerLookupWorker(shard)
  105. }
  106. for i := 0; i < options.NumRankerThreadsPerShard; i++ {
  107. go engine.rankerRankWorker(shard)
  108. }
  109. }
  110. }
  111. // 将文档加入索引
  112. //
  113. // 输入参数:
  114. // docId 标识文档编号,必须唯一
  115. // data 见DocumentIndexData注释
  116. //
  117. // 注意:
  118. // 1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
  119. // 2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
  120. // 如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
  121. func (engine *Engine) IndexDocument(docId uint64, data types.DocumentIndexData) {
  122. if !engine.initialized {
  123. log.Fatal("必须先初始化引擎")
  124. }
  125. atomic.AddUint64(&engine.numIndexingRequests, 1)
  126. hash := murmur.Murmur3([]byte(fmt.Sprint("%d%s", docId, data.Content)))
  127. engine.segmenterChannel <- segmenterRequest{
  128. docId: docId, hash: hash, data: data}
  129. }
  130. // 将文档从索引中删除
  131. //
  132. // 输入参数:
  133. // docId 标识文档编号,必须唯一
  134. //
  135. // 注意:这个函数仅从排序器中删除文档的自定义评分字段,索引器不会发生变化。所以
  136. // 你的自定义评分字段必须能够区别评分字段为nil的情况,并将其从排序结果中删除。
  137. func (engine *Engine) RemoveDocument(docId uint64) {
  138. if !engine.initialized {
  139. log.Fatal("必须先初始化引擎")
  140. }
  141. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  142. engine.rankerRemoveScoringFieldsChannels[shard] <- rankerRemoveScoringFieldsRequest{docId: docId}
  143. }
  144. }
  145. // 阻塞等待直到所有索引添加完毕
  146. func (engine *Engine) FlushIndex() {
  147. for {
  148. runtime.Gosched()
  149. if engine.numIndexingRequests == engine.numDocumentsIndexed {
  150. return
  151. }
  152. }
  153. }
  154. // 查找满足搜索条件的文档,此函数线程安全
  155. func (engine *Engine) Search(request types.SearchRequest) (output types.SearchResponse) {
  156. if !engine.initialized {
  157. log.Fatal("必须先初始化引擎")
  158. }
  159. var rankOptions types.RankOptions
  160. if request.RankOptions == nil {
  161. rankOptions = *engine.initOptions.DefaultRankOptions
  162. } else {
  163. rankOptions = *request.RankOptions
  164. }
  165. if rankOptions.ScoringCriteria == nil {
  166. rankOptions.ScoringCriteria = engine.initOptions.DefaultRankOptions.ScoringCriteria
  167. }
  168. // 收集关键词
  169. tokens := []string{}
  170. if request.Text != "" {
  171. querySegments := engine.segmenter.Segment([]byte(request.Text))
  172. for _, s := range querySegments {
  173. token := s.Token().Text()
  174. if !engine.stopTokens.IsStopToken(token) {
  175. tokens = append(tokens, s.Token().Text())
  176. }
  177. }
  178. } else {
  179. for _, t := range request.Tokens {
  180. tokens = append(tokens, t)
  181. }
  182. }
  183. // 建立排序器返回的通信通道
  184. rankerReturnChannel := make(
  185. chan rankerReturnRequest, engine.initOptions.NumShards)
  186. // 生成查找请求
  187. lookupRequest := indexerLookupRequest{
  188. tokens: tokens,
  189. labels: request.Labels,
  190. docIds: request.DocIds,
  191. options: rankOptions,
  192. rankerReturnChannel: rankerReturnChannel}
  193. // 向索引器发送查找请求
  194. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  195. engine.indexerLookupChannels[shard] <- lookupRequest
  196. }
  197. // 从通信通道读取排序器的输出
  198. rankOutput := types.ScoredDocuments{}
  199. timeout := request.Timeout
  200. isTimeout := false
  201. if timeout <= 0 {
  202. // 不设置超时
  203. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  204. rankerOutput := <-rankerReturnChannel
  205. for _, doc := range rankerOutput.docs {
  206. rankOutput = append(rankOutput, doc)
  207. }
  208. }
  209. } else {
  210. // 设置超时
  211. deadline := time.Now().Add(time.Nanosecond * time.Duration(NumNanosecondsInAMillisecond*request.Timeout))
  212. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  213. select {
  214. case rankerOutput := <-rankerReturnChannel:
  215. for _, doc := range rankerOutput.docs {
  216. rankOutput = append(rankOutput, doc)
  217. }
  218. case <-time.After(deadline.Sub(time.Now())):
  219. isTimeout = true
  220. break
  221. }
  222. }
  223. }
  224. // 再排序
  225. if rankOptions.ReverseOrder {
  226. sort.Sort(sort.Reverse(rankOutput))
  227. } else {
  228. sort.Sort(rankOutput)
  229. }
  230. // 准备输出
  231. output.Tokens = tokens
  232. var start, end int
  233. if rankOptions.MaxOutputs == 0 {
  234. start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
  235. end = len(rankOutput)
  236. } else {
  237. start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
  238. end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
  239. }
  240. output.Docs = rankOutput[start:end]
  241. output.Timeout = isTimeout
  242. return
  243. }
  244. // 从文本hash得到要分配到的shard
  245. func (engine *Engine) getShard(hash uint32) int {
  246. return int(hash - hash/uint32(engine.initOptions.NumShards)*uint32(engine.initOptions.NumShards))
  247. }