engine.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. package engine
  2. import (
  3. "fmt"
  4. "github.com/cznic/kv"
  5. "github.com/huichen/murmur"
  6. "github.com/huichen/sego"
  7. "github.com/huichen/wukong/core"
  8. "github.com/huichen/wukong/types"
  9. "github.com/huichen/wukong/utils"
  10. "log"
  11. "os"
  12. "runtime"
  13. "sort"
  14. "strconv"
  15. "sync/atomic"
  16. "time"
  17. )
  18. const (
  19. NumNanosecondsInAMillisecond = 1000000
  20. )
  21. type Engine struct {
  22. // 计数器,用来统计有多少文档被索引等信息
  23. numDocumentsIndexed uint64
  24. numIndexingRequests uint64
  25. numTokenIndexAdded uint64
  26. numDocumentsStored uint64
  27. // 记录初始化参数
  28. initOptions types.EngineInitOptions
  29. initialized bool
  30. indexers []core.Indexer
  31. rankers []core.Ranker
  32. segmenter sego.Segmenter
  33. stopTokens StopTokens
  34. dbs []*kv.DB
  35. // 建立索引器使用的通信通道
  36. segmenterChannel chan segmenterRequest
  37. indexerAddDocumentChannels []chan indexerAddDocumentRequest
  38. rankerAddScoringFieldsChannels []chan rankerAddScoringFieldsRequest
  39. // 建立排序器使用的通信通道
  40. indexerLookupChannels []chan indexerLookupRequest
  41. rankerRankChannels []chan rankerRankRequest
  42. rankerRemoveScoringFieldsChannels []chan rankerRemoveScoringFieldsRequest
  43. // 建立持久存储使用的通信通道
  44. persistentStorageIndexDocumentChannel chan persistentStorageIndexDocumentRequest
  45. persistentStorageInitChannel chan bool
  46. }
  47. func (engine *Engine) Init(options types.EngineInitOptions) {
  48. // 将线程数设置为CPU数
  49. runtime.GOMAXPROCS(runtime.NumCPU())
  50. // 初始化初始参数
  51. if engine.initialized {
  52. log.Fatal("请勿重复初始化引擎")
  53. }
  54. options.Init()
  55. engine.initOptions = options
  56. engine.initialized = true
  57. // 载入分词器词典
  58. engine.segmenter.LoadDictionary(options.SegmenterDictionaries)
  59. // 初始化停用词
  60. engine.stopTokens.Init(options.StopTokenFile)
  61. // 初始化索引器和排序器
  62. for shard := 0; shard < options.NumShards; shard++ {
  63. engine.indexers = append(engine.indexers, core.Indexer{})
  64. engine.indexers[shard].Init(*options.IndexerInitOptions)
  65. engine.rankers = append(engine.rankers, core.Ranker{})
  66. engine.rankers[shard].Init()
  67. }
  68. // 初始化分词器通道
  69. engine.segmenterChannel = make(
  70. chan segmenterRequest, options.NumSegmenterThreads)
  71. // 初始化索引器通道
  72. engine.indexerAddDocumentChannels = make(
  73. []chan indexerAddDocumentRequest, options.NumShards)
  74. engine.indexerLookupChannels = make(
  75. []chan indexerLookupRequest, options.NumShards)
  76. for shard := 0; shard < options.NumShards; shard++ {
  77. engine.indexerAddDocumentChannels[shard] = make(
  78. chan indexerAddDocumentRequest,
  79. options.IndexerBufferLength)
  80. engine.indexerLookupChannels[shard] = make(
  81. chan indexerLookupRequest,
  82. options.IndexerBufferLength)
  83. }
  84. // 初始化排序器通道
  85. engine.rankerAddScoringFieldsChannels = make(
  86. []chan rankerAddScoringFieldsRequest, options.NumShards)
  87. engine.rankerRankChannels = make(
  88. []chan rankerRankRequest, options.NumShards)
  89. engine.rankerRemoveScoringFieldsChannels = make(
  90. []chan rankerRemoveScoringFieldsRequest, options.NumShards)
  91. for shard := 0; shard < options.NumShards; shard++ {
  92. engine.rankerAddScoringFieldsChannels[shard] = make(
  93. chan rankerAddScoringFieldsRequest,
  94. options.RankerBufferLength)
  95. engine.rankerRankChannels[shard] = make(
  96. chan rankerRankRequest,
  97. options.RankerBufferLength)
  98. engine.rankerRemoveScoringFieldsChannels[shard] = make(
  99. chan rankerRemoveScoringFieldsRequest,
  100. options.RankerBufferLength)
  101. }
  102. // 初始化持久化存储通道
  103. if engine.initOptions.UsePersistentStorage {
  104. engine.persistentStorageIndexDocumentChannel = make(
  105. chan persistentStorageIndexDocumentRequest,
  106. engine.initOptions.PersistentStorageShards)
  107. engine.persistentStorageInitChannel = make(
  108. chan bool, engine.initOptions.PersistentStorageShards)
  109. }
  110. // 启动分词器
  111. for iThread := 0; iThread < options.NumSegmenterThreads; iThread++ {
  112. go engine.segmenterWorker()
  113. }
  114. // 启动索引器和排序器
  115. for shard := 0; shard < options.NumShards; shard++ {
  116. go engine.indexerAddDocumentWorker(shard)
  117. go engine.rankerAddScoringFieldsWorker(shard)
  118. go engine.rankerRemoveScoringFieldsWorker(shard)
  119. for i := 0; i < options.NumIndexerThreadsPerShard; i++ {
  120. go engine.indexerLookupWorker(shard)
  121. }
  122. for i := 0; i < options.NumRankerThreadsPerShard; i++ {
  123. go engine.rankerRankWorker(shard)
  124. }
  125. }
  126. // 启动持久化存储工作协程
  127. if engine.initOptions.UsePersistentStorage {
  128. err := os.MkdirAll(engine.initOptions.PersistentStorageFolder, 0700)
  129. if err != nil {
  130. log.Fatal("无法创建目录", engine.initOptions.PersistentStorageFolder)
  131. }
  132. // 打开或者创建数据库
  133. engine.dbs = make([]*kv.DB, engine.initOptions.PersistentStorageShards)
  134. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  135. dbPath := engine.initOptions.PersistentStorageFolder + "/persist." + strconv.Itoa(shard) + "-of-" + strconv.Itoa(engine.initOptions.PersistentStorageShards)
  136. db, err := utils.OpenOrCreateKv(dbPath, &kv.Options{})
  137. if db == nil || err != nil {
  138. log.Fatal("无法打开数据库", dbPath, ": ", err)
  139. }
  140. engine.dbs[shard] = db
  141. }
  142. // 从数据库中恢复
  143. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  144. go engine.persistentStorageInitWorker(shard)
  145. }
  146. // 等待恢复完成
  147. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  148. <-engine.persistentStorageInitChannel
  149. }
  150. for {
  151. runtime.Gosched()
  152. if engine.numIndexingRequests == engine.numDocumentsIndexed {
  153. break
  154. }
  155. }
  156. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  157. go engine.persistentStorageIndexDocumentWorker(shard)
  158. }
  159. }
  160. atomic.AddUint64(&engine.numDocumentsStored, engine.numIndexingRequests)
  161. }
  162. // 将文档加入索引
  163. //
  164. // 输入参数:
  165. // docId 标识文档编号,必须唯一
  166. // data 见DocumentIndexData注释
  167. //
  168. // 注意:
  169. // 1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
  170. // 2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
  171. // 如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
  172. func (engine *Engine) IndexDocument(docId uint64, data types.DocumentIndexData) {
  173. engine.internalIndexDocument(docId, data)
  174. if engine.initOptions.UsePersistentStorage {
  175. engine.persistentStorageIndexDocumentChannel <- persistentStorageIndexDocumentRequest{docId: docId, data: data}
  176. }
  177. }
  178. func (engine *Engine) internalIndexDocument(docId uint64, data types.DocumentIndexData) {
  179. if !engine.initialized {
  180. log.Fatal("必须先初始化引擎")
  181. }
  182. atomic.AddUint64(&engine.numIndexingRequests, 1)
  183. hash := murmur.Murmur3([]byte(fmt.Sprint("%d%s", docId, data.Content)))
  184. engine.segmenterChannel <- segmenterRequest{
  185. docId: docId, hash: hash, data: data}
  186. }
  187. // 将文档从索引中删除
  188. //
  189. // 输入参数:
  190. // docId 标识文档编号,必须唯一
  191. //
  192. // 注意:这个函数仅从排序器中删除文档的自定义评分字段,索引器不会发生变化。所以
  193. // 你的自定义评分字段必须能够区别评分字段为nil的情况,并将其从排序结果中删除。
  194. func (engine *Engine) RemoveDocument(docId uint64) {
  195. if !engine.initialized {
  196. log.Fatal("必须先初始化引擎")
  197. }
  198. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  199. engine.rankerRemoveScoringFieldsChannels[shard] <- rankerRemoveScoringFieldsRequest{docId: docId}
  200. }
  201. if engine.initOptions.UsePersistentStorage {
  202. // 从数据库中删除
  203. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  204. go engine.persistentStorageRemoveDocumentWorker(docId, shard)
  205. }
  206. }
  207. }
  208. // 阻塞等待直到所有索引添加完毕
  209. func (engine *Engine) FlushIndex() {
  210. for {
  211. runtime.Gosched()
  212. if engine.numIndexingRequests == engine.numDocumentsIndexed &&
  213. (!engine.initOptions.UsePersistentStorage ||
  214. engine.numIndexingRequests == engine.numDocumentsStored) {
  215. return
  216. }
  217. }
  218. }
  219. // 查找满足搜索条件的文档,此函数线程安全
  220. func (engine *Engine) Search(request types.SearchRequest) (output types.SearchResponse) {
  221. if !engine.initialized {
  222. log.Fatal("必须先初始化引擎")
  223. }
  224. var rankOptions types.RankOptions
  225. if request.RankOptions == nil {
  226. rankOptions = *engine.initOptions.DefaultRankOptions
  227. } else {
  228. rankOptions = *request.RankOptions
  229. }
  230. if rankOptions.ScoringCriteria == nil {
  231. rankOptions.ScoringCriteria = engine.initOptions.DefaultRankOptions.ScoringCriteria
  232. }
  233. // 收集关键词
  234. tokens := []string{}
  235. if request.Text != "" {
  236. querySegments := engine.segmenter.Segment([]byte(request.Text))
  237. for _, s := range querySegments {
  238. token := s.Token().Text()
  239. if !engine.stopTokens.IsStopToken(token) {
  240. tokens = append(tokens, s.Token().Text())
  241. }
  242. }
  243. } else {
  244. for _, t := range request.Tokens {
  245. tokens = append(tokens, t)
  246. }
  247. }
  248. // 建立排序器返回的通信通道
  249. rankerReturnChannel := make(
  250. chan rankerReturnRequest, engine.initOptions.NumShards)
  251. // 生成查找请求
  252. lookupRequest := indexerLookupRequest{
  253. tokens: tokens,
  254. labels: request.Labels,
  255. docIds: request.DocIds,
  256. options: rankOptions,
  257. rankerReturnChannel: rankerReturnChannel}
  258. // 向索引器发送查找请求
  259. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  260. engine.indexerLookupChannels[shard] <- lookupRequest
  261. }
  262. // 从通信通道读取排序器的输出
  263. rankOutput := types.ScoredDocuments{}
  264. timeout := request.Timeout
  265. isTimeout := false
  266. if timeout <= 0 {
  267. // 不设置超时
  268. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  269. rankerOutput := <-rankerReturnChannel
  270. for _, doc := range rankerOutput.docs {
  271. rankOutput = append(rankOutput, doc)
  272. }
  273. }
  274. } else {
  275. // 设置超时
  276. deadline := time.Now().Add(time.Nanosecond * time.Duration(NumNanosecondsInAMillisecond*request.Timeout))
  277. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  278. select {
  279. case rankerOutput := <-rankerReturnChannel:
  280. for _, doc := range rankerOutput.docs {
  281. rankOutput = append(rankOutput, doc)
  282. }
  283. case <-time.After(deadline.Sub(time.Now())):
  284. isTimeout = true
  285. break
  286. }
  287. }
  288. }
  289. // 再排序
  290. if rankOptions.ReverseOrder {
  291. sort.Sort(sort.Reverse(rankOutput))
  292. } else {
  293. sort.Sort(rankOutput)
  294. }
  295. // 准备输出
  296. output.Tokens = tokens
  297. var start, end int
  298. if rankOptions.MaxOutputs == 0 {
  299. start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
  300. end = len(rankOutput)
  301. } else {
  302. start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
  303. end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
  304. }
  305. output.Docs = rankOutput[start:end]
  306. output.Timeout = isTimeout
  307. return
  308. }
  309. // 关闭引擎
  310. func (engine *Engine) Close() {
  311. engine.FlushIndex()
  312. if engine.initOptions.UsePersistentStorage {
  313. for _, db := range engine.dbs {
  314. db.Close()
  315. }
  316. }
  317. }
  318. // 从文本hash得到要分配到的shard
  319. func (engine *Engine) getShard(hash uint32) int {
  320. return int(hash - hash/uint32(engine.initOptions.NumShards)*uint32(engine.initOptions.NumShards))
  321. }