engine.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. package engine
  2. import (
  3. "fmt"
  4. "github.com/huichen/murmur"
  5. "github.com/huichen/sego"
  6. "github.com/huichen/wukong/core"
  7. "github.com/huichen/wukong/storage"
  8. "github.com/huichen/wukong/types"
  9. "github.com/huichen/wukong/utils"
  10. "log"
  11. "os"
  12. "runtime"
  13. "sort"
  14. "strconv"
  15. "sync/atomic"
  16. "time"
  17. )
  18. const (
  19. NumNanosecondsInAMillisecond = 1000000
  20. PersistentStorageFilePrefix = "wukong"
  21. )
  22. type Engine struct {
  23. // 计数器,用来统计有多少文档被索引等信息
  24. numDocumentsIndexed uint64
  25. numIndexingRequests uint64
  26. numTokenIndexAdded uint64
  27. numDocumentsStored uint64
  28. // 记录初始化参数
  29. initOptions types.EngineInitOptions
  30. initialized bool
  31. indexers []core.Indexer
  32. rankers []core.Ranker
  33. segmenter sego.Segmenter
  34. stopTokens StopTokens
  35. dbs []storage.Storage
  36. // 建立索引器使用的通信通道
  37. segmenterChannel chan segmenterRequest
  38. indexerAddDocumentChannels []chan indexerAddDocumentRequest
  39. indexerRemoveDocChannels []chan indexerRemoveDocRequest
  40. rankerAddDocChannels []chan rankerAddDocRequest
  41. // 建立排序器使用的通信通道
  42. indexerLookupChannels []chan indexerLookupRequest
  43. rankerRankChannels []chan rankerRankRequest
  44. rankerRemoveDocChannels []chan rankerRemoveDocRequest
  45. // 建立持久存储使用的通信通道
  46. persistentStorageIndexDocumentChannels []chan persistentStorageIndexDocumentRequest
  47. persistentStorageInitChannel chan bool
  48. }
  49. func (engine *Engine) Init(options types.EngineInitOptions) {
  50. // 将线程数设置为CPU数
  51. runtime.GOMAXPROCS(runtime.NumCPU())
  52. // 初始化初始参数
  53. if engine.initialized {
  54. log.Fatal("请勿重复初始化引擎")
  55. }
  56. options.Init()
  57. engine.initOptions = options
  58. engine.initialized = true
  59. if !options.NotUsingSegmenter {
  60. // 载入分词器词典
  61. engine.segmenter.LoadDictionary(options.SegmenterDictionaries)
  62. // 初始化停用词
  63. engine.stopTokens.Init(options.StopTokenFile)
  64. }
  65. // 初始化索引器和排序器
  66. for shard := 0; shard < options.NumShards; shard++ {
  67. engine.indexers = append(engine.indexers, core.Indexer{})
  68. engine.indexers[shard].Init(*options.IndexerInitOptions)
  69. engine.rankers = append(engine.rankers, core.Ranker{})
  70. engine.rankers[shard].Init()
  71. }
  72. // 初始化分词器通道
  73. engine.segmenterChannel = make(
  74. chan segmenterRequest, options.NumSegmenterThreads)
  75. // 初始化索引器通道
  76. engine.indexerAddDocumentChannels = make(
  77. []chan indexerAddDocumentRequest, options.NumShards)
  78. engine.indexerRemoveDocChannels = make(
  79. []chan indexerRemoveDocRequest, options.NumShards)
  80. engine.indexerLookupChannels = make(
  81. []chan indexerLookupRequest, options.NumShards)
  82. for shard := 0; shard < options.NumShards; shard++ {
  83. engine.indexerAddDocumentChannels[shard] = make(
  84. chan indexerAddDocumentRequest,
  85. options.IndexerBufferLength)
  86. engine.indexerRemoveDocChannels[shard] = make(
  87. chan indexerRemoveDocRequest,
  88. options.IndexerBufferLength)
  89. engine.indexerLookupChannels[shard] = make(
  90. chan indexerLookupRequest,
  91. options.IndexerBufferLength)
  92. }
  93. // 初始化排序器通道
  94. engine.rankerAddDocChannels = make(
  95. []chan rankerAddDocRequest, options.NumShards)
  96. engine.rankerRankChannels = make(
  97. []chan rankerRankRequest, options.NumShards)
  98. engine.rankerRemoveDocChannels = make(
  99. []chan rankerRemoveDocRequest, options.NumShards)
  100. for shard := 0; shard < options.NumShards; shard++ {
  101. engine.rankerAddDocChannels[shard] = make(
  102. chan rankerAddDocRequest,
  103. options.RankerBufferLength)
  104. engine.rankerRankChannels[shard] = make(
  105. chan rankerRankRequest,
  106. options.RankerBufferLength)
  107. engine.rankerRemoveDocChannels[shard] = make(
  108. chan rankerRemoveDocRequest,
  109. options.RankerBufferLength)
  110. }
  111. // 初始化持久化存储通道
  112. if engine.initOptions.UsePersistentStorage {
  113. engine.persistentStorageIndexDocumentChannels =
  114. make([]chan persistentStorageIndexDocumentRequest,
  115. engine.initOptions.PersistentStorageShards)
  116. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  117. engine.persistentStorageIndexDocumentChannels[shard] = make(
  118. chan persistentStorageIndexDocumentRequest)
  119. }
  120. engine.persistentStorageInitChannel = make(
  121. chan bool, engine.initOptions.PersistentStorageShards)
  122. }
  123. // 启动分词器
  124. for iThread := 0; iThread < options.NumSegmenterThreads; iThread++ {
  125. go engine.segmenterWorker()
  126. }
  127. // 启动索引器和排序器
  128. for shard := 0; shard < options.NumShards; shard++ {
  129. go engine.indexerAddDocumentWorker(shard)
  130. go engine.indexerRemoveDocWorker(shard)
  131. go engine.rankerAddDocWorker(shard)
  132. go engine.rankerRemoveDocWorker(shard)
  133. for i := 0; i < options.NumIndexerThreadsPerShard; i++ {
  134. go engine.indexerLookupWorker(shard)
  135. }
  136. for i := 0; i < options.NumRankerThreadsPerShard; i++ {
  137. go engine.rankerRankWorker(shard)
  138. }
  139. }
  140. // 启动持久化存储工作协程
  141. if engine.initOptions.UsePersistentStorage {
  142. err := os.MkdirAll(engine.initOptions.PersistentStorageFolder, 0700)
  143. if err != nil {
  144. log.Fatal("无法创建目录", engine.initOptions.PersistentStorageFolder)
  145. }
  146. // 打开或者创建数据库
  147. engine.dbs = make([]storage.Storage, engine.initOptions.PersistentStorageShards)
  148. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  149. dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
  150. db, err := storage.OpenStorage(dbPath)
  151. if db == nil || err != nil {
  152. log.Fatal("无法打开数据库", dbPath, ": ", err)
  153. }
  154. engine.dbs[shard] = db
  155. }
  156. // 从数据库中恢复
  157. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  158. go engine.persistentStorageInitWorker(shard)
  159. }
  160. // 等待恢复完成
  161. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  162. <-engine.persistentStorageInitChannel
  163. }
  164. for {
  165. runtime.Gosched()
  166. if engine.numIndexingRequests == engine.numDocumentsIndexed {
  167. break
  168. }
  169. }
  170. // 关闭并重新打开数据库
  171. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  172. engine.dbs[shard].Close()
  173. dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
  174. db, err := storage.OpenStorage(dbPath)
  175. if db == nil || err != nil {
  176. log.Fatal("无法打开数据库", dbPath, ": ", err)
  177. }
  178. engine.dbs[shard] = db
  179. }
  180. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  181. go engine.persistentStorageIndexDocumentWorker(shard)
  182. }
  183. }
  184. atomic.AddUint64(&engine.numDocumentsStored, engine.numIndexingRequests)
  185. }
  186. // 将文档加入索引
  187. //
  188. // 输入参数:
  189. // docId 标识文档编号,必须唯一
  190. // data 见DocumentIndexData注释
  191. //
  192. // 注意:
  193. // 1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
  194. // 2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
  195. // 如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
  196. func (engine *Engine) IndexDocument(docId uint64, data types.DocumentIndexData) {
  197. engine.internalIndexDocument(docId, data)
  198. hash := murmur.Murmur3([]byte(fmt.Sprint("%d", docId))) % uint32(engine.initOptions.PersistentStorageShards)
  199. if engine.initOptions.UsePersistentStorage {
  200. engine.persistentStorageIndexDocumentChannels[hash] <- persistentStorageIndexDocumentRequest{docId: docId, data: data}
  201. }
  202. }
  203. func (engine *Engine) internalIndexDocument(docId uint64, data types.DocumentIndexData) {
  204. if !engine.initialized {
  205. log.Fatal("必须先初始化引擎")
  206. }
  207. atomic.AddUint64(&engine.numIndexingRequests, 1)
  208. hash := murmur.Murmur3([]byte(fmt.Sprint("%d%s", docId, data.Content)))
  209. engine.segmenterChannel <- segmenterRequest{
  210. docId: docId, hash: hash, data: data}
  211. }
  212. // 将文档从索引中删除
  213. //
  214. // 输入参数:
  215. // docId 标识文档编号,必须唯一
  216. //
  217. // 注意:这个函数仅从排序器中删除文档,索引器不会发生变化。
  218. func (engine *Engine) RemoveDocument(docId uint64) {
  219. if !engine.initialized {
  220. log.Fatal("必须先初始化引擎")
  221. }
  222. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  223. engine.indexerRemoveDocChannels[shard] <- indexerRemoveDocRequest{docId: docId}
  224. engine.rankerRemoveDocChannels[shard] <- rankerRemoveDocRequest{docId: docId}
  225. }
  226. if engine.initOptions.UsePersistentStorage {
  227. // 从数据库中删除
  228. hash := murmur.Murmur3([]byte(fmt.Sprint("%d", docId))) % uint32(engine.initOptions.PersistentStorageShards)
  229. go engine.persistentStorageRemoveDocumentWorker(docId, hash)
  230. }
  231. }
  232. // 阻塞等待直到所有索引添加完毕
  233. func (engine *Engine) FlushIndex() {
  234. for {
  235. runtime.Gosched()
  236. if engine.numIndexingRequests == engine.numDocumentsIndexed &&
  237. (!engine.initOptions.UsePersistentStorage ||
  238. engine.numIndexingRequests == engine.numDocumentsStored) {
  239. return
  240. }
  241. }
  242. }
  243. // 查找满足搜索条件的文档,此函数线程安全
  244. func (engine *Engine) Search(request types.SearchRequest) (output types.SearchResponse) {
  245. if !engine.initialized {
  246. log.Fatal("必须先初始化引擎")
  247. }
  248. var rankOptions types.RankOptions
  249. if request.RankOptions == nil {
  250. rankOptions = *engine.initOptions.DefaultRankOptions
  251. } else {
  252. rankOptions = *request.RankOptions
  253. }
  254. if rankOptions.ScoringCriteria == nil {
  255. rankOptions.ScoringCriteria = engine.initOptions.DefaultRankOptions.ScoringCriteria
  256. }
  257. // 收集关键词
  258. tokens := []string{}
  259. if request.Text != "" {
  260. querySegments := engine.segmenter.Segment([]byte(request.Text))
  261. for _, s := range querySegments {
  262. token := s.Token().Text()
  263. if !engine.stopTokens.IsStopToken(token) {
  264. tokens = append(tokens, s.Token().Text())
  265. }
  266. }
  267. } else {
  268. for _, t := range request.Tokens {
  269. tokens = append(tokens, t)
  270. }
  271. }
  272. // 建立排序器返回的通信通道
  273. rankerReturnChannel := make(
  274. chan rankerReturnRequest, engine.initOptions.NumShards)
  275. // 生成查找请求
  276. lookupRequest := indexerLookupRequest{
  277. countDocsOnly: request.CountDocsOnly,
  278. tokens: tokens,
  279. labels: request.Labels,
  280. docIds: request.DocIds,
  281. options: rankOptions,
  282. rankerReturnChannel: rankerReturnChannel,
  283. orderless: request.Orderless,
  284. }
  285. // 向索引器发送查找请求
  286. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  287. engine.indexerLookupChannels[shard] <- lookupRequest
  288. }
  289. // 从通信通道读取排序器的输出
  290. numDocs := 0
  291. rankOutput := types.ScoredDocuments{}
  292. timeout := request.Timeout
  293. isTimeout := false
  294. if timeout <= 0 {
  295. // 不设置超时
  296. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  297. rankerOutput := <-rankerReturnChannel
  298. if !request.CountDocsOnly {
  299. for _, doc := range rankerOutput.docs {
  300. rankOutput = append(rankOutput, doc)
  301. }
  302. }
  303. numDocs += rankerOutput.numDocs
  304. }
  305. } else {
  306. // 设置超时
  307. deadline := time.Now().Add(time.Nanosecond * time.Duration(NumNanosecondsInAMillisecond*request.Timeout))
  308. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  309. select {
  310. case rankerOutput := <-rankerReturnChannel:
  311. if !request.CountDocsOnly {
  312. for _, doc := range rankerOutput.docs {
  313. rankOutput = append(rankOutput, doc)
  314. }
  315. }
  316. numDocs += rankerOutput.numDocs
  317. case <-time.After(deadline.Sub(time.Now())):
  318. isTimeout = true
  319. break
  320. }
  321. }
  322. }
  323. // 再排序
  324. if !request.CountDocsOnly && !request.Orderless {
  325. if rankOptions.ReverseOrder {
  326. sort.Sort(sort.Reverse(rankOutput))
  327. } else {
  328. sort.Sort(rankOutput)
  329. }
  330. }
  331. // 准备输出
  332. output.Tokens = tokens
  333. // 仅当CountDocsOnly为false时才充填output.Docs
  334. if !request.CountDocsOnly {
  335. if request.Orderless {
  336. // 无序状态无需对Offset截断
  337. output.Docs = rankOutput
  338. } else {
  339. var start, end int
  340. if rankOptions.MaxOutputs == 0 {
  341. start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
  342. end = len(rankOutput)
  343. } else {
  344. start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
  345. end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
  346. }
  347. output.Docs = rankOutput[start:end]
  348. }
  349. }
  350. output.NumDocs = numDocs
  351. output.Timeout = isTimeout
  352. return
  353. }
  354. // 关闭引擎
  355. func (engine *Engine) Close() {
  356. engine.FlushIndex()
  357. if engine.initOptions.UsePersistentStorage {
  358. for _, db := range engine.dbs {
  359. db.Close()
  360. }
  361. }
  362. }
  363. // 从文本hash得到要分配到的shard
  364. func (engine *Engine) getShard(hash uint32) int {
  365. return int(hash - hash/uint32(engine.initOptions.NumShards)*uint32(engine.initOptions.NumShards))
  366. }