engine_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. package engine
  2. import (
  3. "encoding/gob"
  4. "github.com/huichen/wukong/types"
  5. "github.com/huichen/wukong/utils"
  6. "os"
  7. "reflect"
  8. "testing"
  9. )
  10. type ScoringFields struct {
  11. A, B, C float32
  12. }
  13. func AddDocs(engine *Engine) {
  14. docId := uint64(1)
  15. engine.IndexDocument(docId, types.DocumentIndexData{
  16. Content: "中国有十三亿人口人口",
  17. Fields: ScoringFields{1, 2, 3},
  18. }, false)
  19. docId++
  20. engine.IndexDocument(docId, types.DocumentIndexData{
  21. Content: "中国人口",
  22. Fields: nil,
  23. }, false)
  24. docId++
  25. engine.IndexDocument(docId, types.DocumentIndexData{
  26. Content: "有人口",
  27. Fields: ScoringFields{2, 3, 1},
  28. }, false)
  29. docId++
  30. engine.IndexDocument(docId, types.DocumentIndexData{
  31. Content: "有十三亿人口",
  32. Fields: ScoringFields{2, 3, 3},
  33. }, false)
  34. docId++
  35. engine.IndexDocument(docId, types.DocumentIndexData{
  36. Content: "中国十三亿人口",
  37. Fields: ScoringFields{0, 9, 1},
  38. }, false)
  39. engine.FlushIndex()
  40. }
  41. func addDocsWithLabels(engine *Engine) {
  42. docId := uint64(1)
  43. engine.IndexDocument(docId, types.DocumentIndexData{
  44. Content: "此次百度收购将成中国互联网最大并购",
  45. Labels: []string{"百度", "中国"},
  46. }, false)
  47. docId++
  48. engine.IndexDocument(docId, types.DocumentIndexData{
  49. Content: "百度宣布拟全资收购91无线业务",
  50. Labels: []string{"百度"},
  51. }, false)
  52. docId++
  53. engine.IndexDocument(docId, types.DocumentIndexData{
  54. Content: "百度是中国最大的搜索引擎",
  55. Labels: []string{"百度"},
  56. }, false)
  57. docId++
  58. engine.IndexDocument(docId, types.DocumentIndexData{
  59. Content: "百度在研制无人汽车",
  60. Labels: []string{"百度"},
  61. }, false)
  62. docId++
  63. engine.IndexDocument(docId, types.DocumentIndexData{
  64. Content: "BAT是中国互联网三巨头",
  65. Labels: []string{"百度"},
  66. }, false)
  67. engine.FlushIndex()
  68. }
  69. type RankByTokenProximity struct {
  70. }
  71. func (rule RankByTokenProximity) Score(
  72. doc types.IndexedDocument, fields interface{}) []float32 {
  73. if doc.TokenProximity < 0 {
  74. return []float32{}
  75. }
  76. return []float32{1.0 / (float32(doc.TokenProximity) + 1)}
  77. }
  78. func TestEngineIndexDocument(t *testing.T) {
  79. var engine Engine
  80. engine.Init(types.EngineInitOptions{
  81. SegmenterDictionaries: "../testdata/test_dict.txt",
  82. DefaultRankOptions: &types.RankOptions{
  83. OutputOffset: 0,
  84. MaxOutputs: 10,
  85. ScoringCriteria: &RankByTokenProximity{},
  86. },
  87. IndexerInitOptions: &types.IndexerInitOptions{
  88. IndexType: types.LocationsIndex,
  89. },
  90. })
  91. AddDocs(&engine)
  92. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  93. utils.Expect(t, "2", len(outputs.Tokens))
  94. utils.Expect(t, "中国", outputs.Tokens[0])
  95. utils.Expect(t, "人口", outputs.Tokens[1])
  96. utils.Expect(t, "3", len(outputs.Docs))
  97. utils.Expect(t, "2", outputs.Docs[0].DocId)
  98. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  99. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  100. utils.Expect(t, "5", outputs.Docs[1].DocId)
  101. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  102. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  103. utils.Expect(t, "1", outputs.Docs[2].DocId)
  104. utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
  105. utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
  106. }
  107. func TestReverseOrder(t *testing.T) {
  108. var engine Engine
  109. engine.Init(types.EngineInitOptions{
  110. SegmenterDictionaries: "../testdata/test_dict.txt",
  111. DefaultRankOptions: &types.RankOptions{
  112. ReverseOrder: true,
  113. OutputOffset: 0,
  114. MaxOutputs: 10,
  115. ScoringCriteria: &RankByTokenProximity{},
  116. },
  117. IndexerInitOptions: &types.IndexerInitOptions{
  118. IndexType: types.LocationsIndex,
  119. },
  120. })
  121. AddDocs(&engine)
  122. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  123. utils.Expect(t, "3", len(outputs.Docs))
  124. utils.Expect(t, "1", outputs.Docs[0].DocId)
  125. utils.Expect(t, "5", outputs.Docs[1].DocId)
  126. utils.Expect(t, "2", outputs.Docs[2].DocId)
  127. }
  128. func TestOffsetAndMaxOutputs(t *testing.T) {
  129. var engine Engine
  130. engine.Init(types.EngineInitOptions{
  131. SegmenterDictionaries: "../testdata/test_dict.txt",
  132. DefaultRankOptions: &types.RankOptions{
  133. ReverseOrder: true,
  134. OutputOffset: 1,
  135. MaxOutputs: 3,
  136. ScoringCriteria: &RankByTokenProximity{},
  137. },
  138. IndexerInitOptions: &types.IndexerInitOptions{
  139. IndexType: types.LocationsIndex,
  140. },
  141. })
  142. AddDocs(&engine)
  143. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  144. utils.Expect(t, "2", len(outputs.Docs))
  145. utils.Expect(t, "5", outputs.Docs[0].DocId)
  146. utils.Expect(t, "2", outputs.Docs[1].DocId)
  147. }
  148. type TestScoringCriteria struct {
  149. }
  150. func (criteria TestScoringCriteria) Score(
  151. doc types.IndexedDocument, fields interface{}) []float32 {
  152. if reflect.TypeOf(fields) != reflect.TypeOf(ScoringFields{}) {
  153. return []float32{}
  154. }
  155. fs := fields.(ScoringFields)
  156. return []float32{float32(doc.TokenProximity)*fs.A + fs.B*fs.C}
  157. }
  158. func TestSearchWithCriteria(t *testing.T) {
  159. var engine Engine
  160. engine.Init(types.EngineInitOptions{
  161. SegmenterDictionaries: "../testdata/test_dict.txt",
  162. DefaultRankOptions: &types.RankOptions{
  163. ScoringCriteria: TestScoringCriteria{},
  164. },
  165. IndexerInitOptions: &types.IndexerInitOptions{
  166. IndexType: types.LocationsIndex,
  167. },
  168. })
  169. AddDocs(&engine)
  170. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  171. utils.Expect(t, "2", len(outputs.Docs))
  172. utils.Expect(t, "1", outputs.Docs[0].DocId)
  173. utils.Expect(t, "18000", int(outputs.Docs[0].Scores[0]*1000))
  174. utils.Expect(t, "5", outputs.Docs[1].DocId)
  175. utils.Expect(t, "9000", int(outputs.Docs[1].Scores[0]*1000))
  176. }
  177. func TestCompactIndex(t *testing.T) {
  178. var engine Engine
  179. engine.Init(types.EngineInitOptions{
  180. SegmenterDictionaries: "../testdata/test_dict.txt",
  181. DefaultRankOptions: &types.RankOptions{
  182. ScoringCriteria: TestScoringCriteria{},
  183. },
  184. })
  185. AddDocs(&engine)
  186. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  187. utils.Expect(t, "2", len(outputs.Docs))
  188. utils.Expect(t, "5", outputs.Docs[0].DocId)
  189. utils.Expect(t, "9000", int(outputs.Docs[0].Scores[0]*1000))
  190. utils.Expect(t, "1", outputs.Docs[1].DocId)
  191. utils.Expect(t, "6000", int(outputs.Docs[1].Scores[0]*1000))
  192. }
  193. type BM25ScoringCriteria struct {
  194. }
  195. func (criteria BM25ScoringCriteria) Score(
  196. doc types.IndexedDocument, fields interface{}) []float32 {
  197. if reflect.TypeOf(fields) != reflect.TypeOf(ScoringFields{}) {
  198. return []float32{}
  199. }
  200. return []float32{doc.BM25}
  201. }
  202. func TestFrequenciesIndex(t *testing.T) {
  203. var engine Engine
  204. engine.Init(types.EngineInitOptions{
  205. SegmenterDictionaries: "../testdata/test_dict.txt",
  206. DefaultRankOptions: &types.RankOptions{
  207. ScoringCriteria: BM25ScoringCriteria{},
  208. },
  209. IndexerInitOptions: &types.IndexerInitOptions{
  210. IndexType: types.FrequenciesIndex,
  211. },
  212. })
  213. AddDocs(&engine)
  214. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  215. utils.Expect(t, "2", len(outputs.Docs))
  216. utils.Expect(t, "5", outputs.Docs[0].DocId)
  217. utils.Expect(t, "2349", int(outputs.Docs[0].Scores[0]*1000))
  218. utils.Expect(t, "1", outputs.Docs[1].DocId)
  219. utils.Expect(t, "2320", int(outputs.Docs[1].Scores[0]*1000))
  220. }
  221. func TestRemoveDocument(t *testing.T) {
  222. var engine Engine
  223. engine.Init(types.EngineInitOptions{
  224. SegmenterDictionaries: "../testdata/test_dict.txt",
  225. DefaultRankOptions: &types.RankOptions{
  226. ScoringCriteria: TestScoringCriteria{},
  227. },
  228. })
  229. AddDocs(&engine)
  230. engine.RemoveDocument(5, false)
  231. engine.RemoveDocument(6, false)
  232. engine.FlushIndex()
  233. engine.IndexDocument(6, types.DocumentIndexData{
  234. Content: "中国人口有十三亿",
  235. Fields: ScoringFields{0, 9, 1},
  236. }, false)
  237. engine.FlushIndex()
  238. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  239. utils.Expect(t, "2", len(outputs.Docs))
  240. utils.Expect(t, "6", outputs.Docs[0].DocId)
  241. utils.Expect(t, "9000", int(outputs.Docs[0].Scores[0]*1000))
  242. utils.Expect(t, "1", outputs.Docs[1].DocId)
  243. utils.Expect(t, "6000", int(outputs.Docs[1].Scores[0]*1000))
  244. }
  245. func TestEngineIndexDocumentWithTokens(t *testing.T) {
  246. var engine Engine
  247. engine.Init(types.EngineInitOptions{
  248. SegmenterDictionaries: "../testdata/test_dict.txt",
  249. DefaultRankOptions: &types.RankOptions{
  250. OutputOffset: 0,
  251. MaxOutputs: 10,
  252. ScoringCriteria: &RankByTokenProximity{},
  253. },
  254. IndexerInitOptions: &types.IndexerInitOptions{
  255. IndexType: types.LocationsIndex,
  256. },
  257. })
  258. docId := uint64(1)
  259. engine.IndexDocument(docId, types.DocumentIndexData{
  260. Content: "",
  261. Tokens: []types.TokenData{
  262. {"中国", []int{0}},
  263. {"人口", []int{18, 24}},
  264. },
  265. Fields: ScoringFields{1, 2, 3},
  266. }, false)
  267. docId++
  268. engine.IndexDocument(docId, types.DocumentIndexData{
  269. Content: "",
  270. Tokens: []types.TokenData{
  271. {"中国", []int{0}},
  272. {"人口", []int{6}},
  273. },
  274. Fields: ScoringFields{1, 2, 3},
  275. }, false)
  276. docId++
  277. engine.IndexDocument(docId, types.DocumentIndexData{
  278. Content: "中国十三亿人口",
  279. Fields: ScoringFields{0, 9, 1},
  280. }, false)
  281. engine.FlushIndex()
  282. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  283. utils.Expect(t, "2", len(outputs.Tokens))
  284. utils.Expect(t, "中国", outputs.Tokens[0])
  285. utils.Expect(t, "人口", outputs.Tokens[1])
  286. utils.Expect(t, "3", len(outputs.Docs))
  287. utils.Expect(t, "2", outputs.Docs[0].DocId)
  288. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  289. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  290. utils.Expect(t, "3", outputs.Docs[1].DocId)
  291. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  292. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  293. utils.Expect(t, "1", outputs.Docs[2].DocId)
  294. utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
  295. utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
  296. }
  297. func TestEngineIndexDocumentWithContentAndLabels(t *testing.T) {
  298. var engine1, engine2 Engine
  299. engine1.Init(types.EngineInitOptions{
  300. SegmenterDictionaries: "../data/dictionary.txt",
  301. IndexerInitOptions: &types.IndexerInitOptions{
  302. IndexType: types.LocationsIndex,
  303. },
  304. })
  305. engine2.Init(types.EngineInitOptions{
  306. SegmenterDictionaries: "../data/dictionary.txt",
  307. IndexerInitOptions: &types.IndexerInitOptions{
  308. IndexType: types.DocIdsIndex,
  309. },
  310. })
  311. addDocsWithLabels(&engine1)
  312. addDocsWithLabels(&engine2)
  313. outputs1 := engine1.Search(types.SearchRequest{Text: "百度"})
  314. outputs2 := engine2.Search(types.SearchRequest{Text: "百度"})
  315. utils.Expect(t, "1", len(outputs1.Tokens))
  316. utils.Expect(t, "1", len(outputs2.Tokens))
  317. utils.Expect(t, "百度", outputs1.Tokens[0])
  318. utils.Expect(t, "百度", outputs2.Tokens[0])
  319. utils.Expect(t, "5", len(outputs1.Docs))
  320. utils.Expect(t, "5", len(outputs2.Docs))
  321. }
  322. func TestEngineIndexDocumentWithPersistentStorage(t *testing.T) {
  323. gob.Register(ScoringFields{})
  324. var engine Engine
  325. engine.Init(types.EngineInitOptions{
  326. SegmenterDictionaries: "../testdata/test_dict.txt",
  327. DefaultRankOptions: &types.RankOptions{
  328. OutputOffset: 0,
  329. MaxOutputs: 10,
  330. ScoringCriteria: &RankByTokenProximity{},
  331. },
  332. IndexerInitOptions: &types.IndexerInitOptions{
  333. IndexType: types.LocationsIndex,
  334. },
  335. UsePersistentStorage: true,
  336. PersistentStorageFolder: "wukong.persistent",
  337. PersistentStorageShards: 2,
  338. })
  339. AddDocs(&engine)
  340. engine.RemoveDocument(5, true)
  341. engine.Close()
  342. var engine1 Engine
  343. engine1.Init(types.EngineInitOptions{
  344. SegmenterDictionaries: "../testdata/test_dict.txt",
  345. DefaultRankOptions: &types.RankOptions{
  346. OutputOffset: 0,
  347. MaxOutputs: 10,
  348. ScoringCriteria: &RankByTokenProximity{},
  349. },
  350. IndexerInitOptions: &types.IndexerInitOptions{
  351. IndexType: types.LocationsIndex,
  352. },
  353. UsePersistentStorage: true,
  354. PersistentStorageFolder: "wukong.persistent",
  355. PersistentStorageShards: 2,
  356. })
  357. engine1.FlushIndex()
  358. outputs := engine1.Search(types.SearchRequest{Text: "中国人口"})
  359. utils.Expect(t, "2", len(outputs.Tokens))
  360. utils.Expect(t, "中国", outputs.Tokens[0])
  361. utils.Expect(t, "人口", outputs.Tokens[1])
  362. utils.Expect(t, "2", len(outputs.Docs))
  363. utils.Expect(t, "2", outputs.Docs[0].DocId)
  364. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  365. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  366. utils.Expect(t, "1", outputs.Docs[1].DocId)
  367. utils.Expect(t, "76", int(outputs.Docs[1].Scores[0]*1000))
  368. utils.Expect(t, "[0 18]", outputs.Docs[1].TokenSnippetLocations)
  369. engine1.Close()
  370. os.RemoveAll("wukong.persistent")
  371. }
  372. func TestCountDocsOnly(t *testing.T) {
  373. var engine Engine
  374. engine.Init(types.EngineInitOptions{
  375. SegmenterDictionaries: "../testdata/test_dict.txt",
  376. DefaultRankOptions: &types.RankOptions{
  377. ReverseOrder: true,
  378. OutputOffset: 0,
  379. MaxOutputs: 1,
  380. ScoringCriteria: &RankByTokenProximity{},
  381. },
  382. IndexerInitOptions: &types.IndexerInitOptions{
  383. IndexType: types.LocationsIndex,
  384. },
  385. })
  386. AddDocs(&engine)
  387. engine.RemoveDocument(5, false)
  388. engine.FlushIndex()
  389. outputs := engine.Search(types.SearchRequest{Text: "中国人口", CountDocsOnly: true})
  390. utils.Expect(t, "0", len(outputs.Docs))
  391. utils.Expect(t, "2", len(outputs.Tokens))
  392. utils.Expect(t, "2", outputs.NumDocs)
  393. }
  394. func TestSearchWithin(t *testing.T) {
  395. var engine Engine
  396. engine.Init(types.EngineInitOptions{
  397. SegmenterDictionaries: "../testdata/test_dict.txt",
  398. DefaultRankOptions: &types.RankOptions{
  399. ReverseOrder: true,
  400. OutputOffset: 0,
  401. MaxOutputs: 10,
  402. ScoringCriteria: &RankByTokenProximity{},
  403. },
  404. IndexerInitOptions: &types.IndexerInitOptions{
  405. IndexType: types.LocationsIndex,
  406. },
  407. })
  408. AddDocs(&engine)
  409. docIds := make(map[uint64]bool)
  410. docIds[5] = true
  411. docIds[1] = true
  412. outputs := engine.Search(types.SearchRequest{
  413. Text: "中国人口",
  414. DocIds: docIds,
  415. })
  416. utils.Expect(t, "2", len(outputs.Tokens))
  417. utils.Expect(t, "中国", outputs.Tokens[0])
  418. utils.Expect(t, "人口", outputs.Tokens[1])
  419. utils.Expect(t, "2", len(outputs.Docs))
  420. utils.Expect(t, "1", outputs.Docs[0].DocId)
  421. utils.Expect(t, "76", int(outputs.Docs[0].Scores[0]*1000))
  422. utils.Expect(t, "[0 18]", outputs.Docs[0].TokenSnippetLocations)
  423. utils.Expect(t, "5", outputs.Docs[1].DocId)
  424. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  425. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  426. }