engine.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. package engine
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. "runtime"
  7. "sort"
  8. "strconv"
  9. "sync/atomic"
  10. "time"
  11. "github.com/huichen/murmur"
  12. "github.com/huichen/sego"
  13. "github.com/huichen/wukong/core"
  14. "github.com/huichen/wukong/storage"
  15. "github.com/huichen/wukong/types"
  16. "github.com/huichen/wukong/utils"
  17. )
  18. const (
  19. NumNanosecondsInAMillisecond = 1000000
  20. PersistentStorageFilePrefix = "wukong"
  21. )
  22. type Engine struct {
  23. // 计数器,用来统计有多少文档被索引等信息
  24. numDocumentsIndexed uint64
  25. numDocumentsRemoved uint64
  26. numDocumentsForceUpdated uint64
  27. numIndexingRequests uint64
  28. numRemovingRequests uint64
  29. numForceUpdatingRequests uint64
  30. numTokenIndexAdded uint64
  31. numDocumentsStored uint64
  32. // 记录初始化参数
  33. initOptions types.EngineInitOptions
  34. initialized bool
  35. indexers []core.Indexer
  36. rankers []core.Ranker
  37. segmenter sego.Segmenter
  38. stopTokens StopTokens
  39. dbs []storage.Storage
  40. // 建立索引器使用的通信通道
  41. segmenterChannel chan segmenterRequest
  42. indexerAddDocChannels []chan indexerAddDocumentRequest
  43. indexerRemoveDocChannels []chan indexerRemoveDocRequest
  44. rankerAddDocChannels []chan rankerAddDocRequest
  45. // 建立排序器使用的通信通道
  46. indexerLookupChannels []chan indexerLookupRequest
  47. rankerRankChannels []chan rankerRankRequest
  48. rankerRemoveDocChannels []chan rankerRemoveDocRequest
  49. // 建立持久存储使用的通信通道
  50. persistentStorageIndexDocumentChannels []chan persistentStorageIndexDocumentRequest
  51. persistentStorageInitChannel chan bool
  52. }
  53. func (engine *Engine) Init(options types.EngineInitOptions) {
  54. // 将线程数设置为CPU数
  55. runtime.GOMAXPROCS(runtime.NumCPU())
  56. // 初始化初始参数
  57. if engine.initialized {
  58. log.Fatal("请勿重复初始化引擎")
  59. }
  60. options.Init()
  61. engine.initOptions = options
  62. engine.initialized = true
  63. if !options.NotUsingSegmenter {
  64. // 载入分词器词典
  65. engine.segmenter.LoadDictionary(options.SegmenterDictionaries)
  66. // 初始化停用词
  67. engine.stopTokens.Init(options.StopTokenFile)
  68. }
  69. // 初始化索引器和排序器
  70. for shard := 0; shard < options.NumShards; shard++ {
  71. engine.indexers = append(engine.indexers, core.Indexer{})
  72. engine.indexers[shard].Init(*options.IndexerInitOptions)
  73. engine.rankers = append(engine.rankers, core.Ranker{})
  74. engine.rankers[shard].Init()
  75. }
  76. // 初始化分词器通道
  77. engine.segmenterChannel = make(chan segmenterRequest, options.NumSegmenterThreads)
  78. // 初始化索引器通道
  79. engine.indexerAddDocChannels = make([]chan indexerAddDocumentRequest, options.NumShards)
  80. engine.indexerRemoveDocChannels = make([]chan indexerRemoveDocRequest, options.NumShards)
  81. engine.indexerLookupChannels = make([]chan indexerLookupRequest, options.NumShards)
  82. for shard := 0; shard < options.NumShards; shard++ {
  83. engine.indexerAddDocChannels[shard] = make(
  84. chan indexerAddDocumentRequest,
  85. options.IndexerBufferLength)
  86. engine.indexerRemoveDocChannels[shard] = make(
  87. chan indexerRemoveDocRequest,
  88. options.IndexerBufferLength)
  89. engine.indexerLookupChannels[shard] = make(
  90. chan indexerLookupRequest,
  91. options.IndexerBufferLength)
  92. }
  93. // 初始化排序器通道
  94. engine.rankerAddDocChannels = make([]chan rankerAddDocRequest, options.NumShards)
  95. engine.rankerRankChannels = make([]chan rankerRankRequest, options.NumShards)
  96. engine.rankerRemoveDocChannels = make([]chan rankerRemoveDocRequest, options.NumShards)
  97. for shard := 0; shard < options.NumShards; shard++ {
  98. engine.rankerAddDocChannels[shard] = make(
  99. chan rankerAddDocRequest,
  100. options.RankerBufferLength)
  101. engine.rankerRankChannels[shard] = make(
  102. chan rankerRankRequest,
  103. options.RankerBufferLength)
  104. engine.rankerRemoveDocChannels[shard] = make(
  105. chan rankerRemoveDocRequest,
  106. options.RankerBufferLength)
  107. }
  108. // 初始化持久化存储通道
  109. if engine.initOptions.UsePersistentStorage {
  110. engine.persistentStorageIndexDocumentChannels =
  111. make([]chan persistentStorageIndexDocumentRequest, engine.initOptions.PersistentStorageShards)
  112. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  113. engine.persistentStorageIndexDocumentChannels[shard] = make(chan persistentStorageIndexDocumentRequest)
  114. }
  115. engine.persistentStorageInitChannel = make(chan bool, engine.initOptions.PersistentStorageShards)
  116. }
  117. // 启动分词器
  118. for iThread := 0; iThread < options.NumSegmenterThreads; iThread++ {
  119. go engine.segmenterWorker()
  120. }
  121. // 启动索引器和排序器
  122. for shard := 0; shard < options.NumShards; shard++ {
  123. go engine.indexerAddDocumentWorker(shard)
  124. go engine.indexerRemoveDocWorker(shard)
  125. go engine.rankerAddDocWorker(shard)
  126. go engine.rankerRemoveDocWorker(shard)
  127. for i := 0; i < options.NumIndexerThreadsPerShard; i++ {
  128. go engine.indexerLookupWorker(shard)
  129. }
  130. for i := 0; i < options.NumRankerThreadsPerShard; i++ {
  131. go engine.rankerRankWorker(shard)
  132. }
  133. }
  134. // 启动持久化存储工作协程
  135. if engine.initOptions.UsePersistentStorage {
  136. err := os.MkdirAll(engine.initOptions.PersistentStorageFolder, 0700)
  137. if err != nil {
  138. log.Fatal("无法创建目录", engine.initOptions.PersistentStorageFolder)
  139. }
  140. // 打开或者创建数据库
  141. engine.dbs = make([]storage.Storage, engine.initOptions.PersistentStorageShards)
  142. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  143. dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
  144. db, err := storage.OpenStorage(dbPath, engine.initOptions.PersistentStorageEngine)
  145. if db == nil || err != nil {
  146. log.Fatal("无法打开数据库", dbPath, ": ", err)
  147. }
  148. engine.dbs[shard] = db
  149. }
  150. // 从数据库中恢复
  151. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  152. go engine.persistentStorageInitWorker(shard)
  153. }
  154. // 等待恢复完成
  155. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  156. <-engine.persistentStorageInitChannel
  157. }
  158. for {
  159. runtime.Gosched()
  160. if engine.numIndexingRequests == engine.numDocumentsIndexed {
  161. break
  162. }
  163. }
  164. // 关闭并重新打开数据库
  165. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  166. engine.dbs[shard].Close()
  167. dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
  168. db, err := storage.OpenStorage(dbPath, engine.initOptions.PersistentStorageEngine)
  169. if db == nil || err != nil {
  170. log.Fatal("无法打开数据库", dbPath, ": ", err)
  171. }
  172. engine.dbs[shard] = db
  173. }
  174. for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
  175. go engine.persistentStorageIndexDocumentWorker(shard)
  176. }
  177. }
  178. atomic.AddUint64(&engine.numDocumentsStored, engine.numIndexingRequests)
  179. }
  180. func uidToSidComp(docId uint64) string {
  181. id := strconv.FormatUint(docId, 10)
  182. if docId == 0 {
  183. id = ""
  184. }
  185. return id
  186. }
  187. // @deprecated Please use `IndexDocumentS` instead.
  188. // 将文档加入索引
  189. //
  190. // 输入参数:
  191. // docId 标识文档编号,必须唯一,docId == 0 表示非法文档(用于强制刷新索引),[1, +oo) 表示合法文档
  192. // data 见DocumentIndexData注释
  193. // forceUpdate 是否强制刷新 cache,如果设为 true,则尽快添加到索引,否则等待 cache 满之后一次全量添加
  194. //
  195. // 注意:
  196. // 1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
  197. // 2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
  198. // 如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
  199. func (engine *Engine) IndexDocument(docId uint64, data types.DocumentIndexData, forceUpdate bool) {
  200. id := uidToSidComp(docId)
  201. engine.IndexDocumentS(id, data, forceUpdate)
  202. }
  203. func (engine *Engine) IndexDocumentS(docId string, data types.DocumentIndexData, forceUpdate bool) {
  204. engine.internalIndexDocument(docId, data, forceUpdate)
  205. hash := murmur.Murmur3([]byte(docId)) % uint32(engine.initOptions.PersistentStorageShards)
  206. if engine.initOptions.UsePersistentStorage && docId != "" {
  207. engine.persistentStorageIndexDocumentChannels[hash] <- persistentStorageIndexDocumentRequest{docId: docId, data: data}
  208. }
  209. }
  210. func (engine *Engine) internalIndexDocument(docId string, data types.DocumentIndexData, forceUpdate bool) {
  211. if !engine.initialized {
  212. log.Fatal("必须先初始化引擎")
  213. }
  214. if docId != "" {
  215. atomic.AddUint64(&engine.numIndexingRequests, 1)
  216. }
  217. if forceUpdate {
  218. atomic.AddUint64(&engine.numForceUpdatingRequests, 1)
  219. }
  220. hash := murmur.Murmur3([]byte(fmt.Sprintf("%s%s", docId, data.Content)))
  221. engine.segmenterChannel <- segmenterRequest{
  222. docId: docId, hash: hash, data: data, forceUpdate: forceUpdate}
  223. }
  224. // @deprecated Please use `RemoveDocumentS` instead
  225. // 将文档从索引中删除
  226. //
  227. // 输入参数:
  228. // docId 标识文档编号,必须唯一,docId == 0 表示非法文档(用于强制刷新索引),[1, +oo) 表示合法文档
  229. // forceUpdate 是否强制刷新 cache,如果设为 true,则尽快删除索引,否则等待 cache 满之后一次全量删除
  230. //
  231. // 注意:
  232. // 1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
  233. // 2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
  234. // 如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
  235. func (engine *Engine) RemoveDocument(docId uint64, forceUpdate bool) {
  236. id := uidToSidComp(docId)
  237. engine.RemoveDocumentS(id, forceUpdate)
  238. }
  239. func (engine *Engine) RemoveDocumentS(docId string, forceUpdate bool) {
  240. if !engine.initialized {
  241. log.Fatal("必须先初始化引擎")
  242. }
  243. if docId != "" {
  244. atomic.AddUint64(&engine.numRemovingRequests, 1)
  245. }
  246. if forceUpdate {
  247. atomic.AddUint64(&engine.numForceUpdatingRequests, 1)
  248. }
  249. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  250. engine.indexerRemoveDocChannels[shard] <- indexerRemoveDocRequest{docId: docId, forceUpdate: forceUpdate}
  251. if docId == "" {
  252. continue
  253. }
  254. engine.rankerRemoveDocChannels[shard] <- rankerRemoveDocRequest{docId: docId}
  255. }
  256. if engine.initOptions.UsePersistentStorage && docId != "" {
  257. // 从数据库中删除
  258. hash := murmur.Murmur3([]byte(docId)) % uint32(engine.initOptions.PersistentStorageShards)
  259. go engine.persistentStorageRemoveDocumentWorker(docId, hash)
  260. }
  261. }
  262. // 查找满足搜索条件的文档,此函数线程安全
  263. func (engine *Engine) Search(request types.SearchRequest) (output types.SearchResponse) {
  264. if !engine.initialized {
  265. log.Fatal("必须先初始化引擎")
  266. }
  267. var rankOptions types.RankOptions
  268. if request.RankOptions == nil {
  269. rankOptions = *engine.initOptions.DefaultRankOptions
  270. } else {
  271. rankOptions = *request.RankOptions
  272. }
  273. if rankOptions.ScoringCriteria == nil {
  274. rankOptions.ScoringCriteria = engine.initOptions.DefaultRankOptions.ScoringCriteria
  275. }
  276. // 收集关键词
  277. tokens := []string{}
  278. if request.Text != "" {
  279. querySegments := engine.segmenter.Segment([]byte(request.Text))
  280. for _, s := range querySegments {
  281. token := s.Token().Text()
  282. if !engine.stopTokens.IsStopToken(token) {
  283. tokens = append(tokens, s.Token().Text())
  284. }
  285. }
  286. } else {
  287. for _, t := range request.Tokens {
  288. tokens = append(tokens, t)
  289. }
  290. }
  291. // 建立排序器返回的通信通道
  292. rankerReturnChannel := make(
  293. chan rankerReturnRequest, engine.initOptions.NumShards)
  294. // 生成查找请求
  295. lookupRequest := indexerLookupRequest{
  296. countDocsOnly: request.CountDocsOnly,
  297. tokens: tokens,
  298. labels: request.Labels,
  299. docIds: request.DocIds,
  300. options: rankOptions,
  301. rankerReturnChannel: rankerReturnChannel,
  302. orderless: request.Orderless,
  303. }
  304. // 向索引器发送查找请求
  305. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  306. engine.indexerLookupChannels[shard] <- lookupRequest
  307. }
  308. // 从通信通道读取排序器的输出
  309. numDocs := 0
  310. rankOutput := types.ScoredDocuments{}
  311. timeout := request.Timeout
  312. isTimeout := false
  313. if timeout <= 0 {
  314. // 不设置超时
  315. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  316. rankerOutput := <-rankerReturnChannel
  317. if !request.CountDocsOnly {
  318. for _, doc := range rankerOutput.docs {
  319. rankOutput = append(rankOutput, doc)
  320. }
  321. }
  322. numDocs += rankerOutput.numDocs
  323. }
  324. } else {
  325. // 设置超时
  326. deadline := time.Now().Add(time.Nanosecond * time.Duration(NumNanosecondsInAMillisecond*request.Timeout))
  327. for shard := 0; shard < engine.initOptions.NumShards; shard++ {
  328. select {
  329. case rankerOutput := <-rankerReturnChannel:
  330. if !request.CountDocsOnly {
  331. for _, doc := range rankerOutput.docs {
  332. rankOutput = append(rankOutput, doc)
  333. }
  334. }
  335. numDocs += rankerOutput.numDocs
  336. case <-time.After(deadline.Sub(time.Now())):
  337. isTimeout = true
  338. break
  339. }
  340. }
  341. }
  342. // 再排序
  343. if !request.CountDocsOnly && !request.Orderless {
  344. if rankOptions.ReverseOrder {
  345. sort.Sort(sort.Reverse(rankOutput))
  346. } else {
  347. sort.Sort(rankOutput)
  348. }
  349. }
  350. // 准备输出
  351. output.Tokens = tokens
  352. // 仅当CountDocsOnly为false时才充填output.Docs
  353. if !request.CountDocsOnly {
  354. if request.Orderless {
  355. // 无序状态无需对Offset截断
  356. output.Docs = rankOutput
  357. } else {
  358. var start, end int
  359. if rankOptions.MaxOutputs == 0 {
  360. start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
  361. end = len(rankOutput)
  362. } else {
  363. start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
  364. end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
  365. }
  366. output.Docs = rankOutput[start:end]
  367. }
  368. }
  369. output.NumDocs = numDocs
  370. output.Timeout = isTimeout
  371. return
  372. }
  373. // 阻塞等待直到所有索引添加完毕
  374. func (engine *Engine) FlushIndex() {
  375. for {
  376. runtime.Gosched()
  377. if engine.numIndexingRequests == engine.numDocumentsIndexed &&
  378. engine.numRemovingRequests*uint64(engine.initOptions.NumShards) == engine.numDocumentsRemoved &&
  379. (!engine.initOptions.UsePersistentStorage || engine.numIndexingRequests == engine.numDocumentsStored) {
  380. // 保证 CHANNEL 中 REQUESTS 全部被执行完
  381. break
  382. }
  383. }
  384. // 强制更新,保证其为最后的请求
  385. engine.IndexDocumentS("", types.DocumentIndexData{}, true)
  386. for {
  387. runtime.Gosched()
  388. if engine.numForceUpdatingRequests*uint64(engine.initOptions.NumShards) == engine.numDocumentsForceUpdated {
  389. return
  390. }
  391. }
  392. }
  393. // 关闭引擎
  394. func (engine *Engine) Close() {
  395. engine.FlushIndex()
  396. if engine.initOptions.UsePersistentStorage {
  397. for _, db := range engine.dbs {
  398. db.Close()
  399. }
  400. }
  401. }
  402. // 从文本hash得到要分配到的shard
  403. func (engine *Engine) getShard(hash uint32) int {
  404. return int(hash - hash/uint32(engine.initOptions.NumShards)*uint32(engine.initOptions.NumShards))
  405. }