engine_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. package engine
  2. import (
  3. "encoding/gob"
  4. "os"
  5. "reflect"
  6. "testing"
  7. "github.com/huichen/wukong/types"
  8. "github.com/huichen/wukong/utils"
  9. )
  10. type ScoringFields struct {
  11. A, B, C float32
  12. }
  13. func AddDocs(engine *Engine) {
  14. docId := uint64(1)
  15. engine.IndexDocument(docId, types.DocumentIndexData{
  16. Content: "中国有十三亿人口人口",
  17. Fields: ScoringFields{1, 2, 3},
  18. }, false)
  19. docId++
  20. engine.IndexDocument(docId, types.DocumentIndexData{
  21. Content: "中国人口",
  22. Fields: nil,
  23. }, false)
  24. docId++
  25. engine.IndexDocument(docId, types.DocumentIndexData{
  26. Content: "有人口",
  27. Fields: ScoringFields{2, 3, 1},
  28. }, false)
  29. docId++
  30. engine.IndexDocument(docId, types.DocumentIndexData{
  31. Content: "有十三亿人口",
  32. Fields: ScoringFields{2, 3, 3},
  33. }, false)
  34. docId++
  35. engine.IndexDocument(docId, types.DocumentIndexData{
  36. Content: "中国十三亿人口",
  37. Fields: ScoringFields{0, 9, 1},
  38. }, false)
  39. engine.FlushIndex()
  40. }
  41. func addDocsWithLabels(engine *Engine) {
  42. docId := uint64(1)
  43. engine.IndexDocument(docId, types.DocumentIndexData{
  44. Content: "此次百度收购将成中国互联网最大并购",
  45. Labels: []string{"百度", "中国"},
  46. }, false)
  47. docId++
  48. engine.IndexDocument(docId, types.DocumentIndexData{
  49. Content: "百度宣布拟全资收购91无线业务",
  50. Labels: []string{"百度"},
  51. }, false)
  52. docId++
  53. engine.IndexDocument(docId, types.DocumentIndexData{
  54. Content: "百度是中国最大的搜索引擎",
  55. Labels: []string{"百度"},
  56. }, false)
  57. docId++
  58. engine.IndexDocument(docId, types.DocumentIndexData{
  59. Content: "百度在研制无人汽车",
  60. Labels: []string{"百度"},
  61. }, false)
  62. docId++
  63. engine.IndexDocument(docId, types.DocumentIndexData{
  64. Content: "BAT是中国互联网三巨头",
  65. Labels: []string{"百度"},
  66. }, false)
  67. engine.FlushIndex()
  68. }
  69. type RankByTokenProximity struct {
  70. }
  71. func (rule RankByTokenProximity) Score(
  72. doc types.IndexedDocument, fields interface{}) []float32 {
  73. if doc.TokenProximity < 0 {
  74. return []float32{}
  75. }
  76. return []float32{1.0 / (float32(doc.TokenProximity) + 1)}
  77. }
  78. func TestEngineIndexDocument(t *testing.T) {
  79. var engine Engine
  80. engine.Init(types.EngineInitOptions{
  81. SegmenterDictionaries: "../testdata/test_dict.txt",
  82. DefaultRankOptions: &types.RankOptions{
  83. OutputOffset: 0,
  84. MaxOutputs: 10,
  85. ScoringCriteria: &RankByTokenProximity{},
  86. },
  87. IndexerInitOptions: &types.IndexerInitOptions{
  88. IndexType: types.LocationsIndex,
  89. },
  90. })
  91. AddDocs(&engine)
  92. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  93. utils.Expect(t, "2", len(outputs.Tokens))
  94. utils.Expect(t, "中国", outputs.Tokens[0])
  95. utils.Expect(t, "人口", outputs.Tokens[1])
  96. utils.Expect(t, "3", len(outputs.Docs))
  97. utils.Expect(t, "2", outputs.Docs[0].DocId)
  98. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  99. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  100. utils.Expect(t, "5", outputs.Docs[1].DocId)
  101. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  102. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  103. utils.Expect(t, "1", outputs.Docs[2].DocId)
  104. utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
  105. utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
  106. }
  107. func TestReverseOrder(t *testing.T) {
  108. var engine Engine
  109. engine.Init(types.EngineInitOptions{
  110. SegmenterDictionaries: "../testdata/test_dict.txt",
  111. DefaultRankOptions: &types.RankOptions{
  112. ReverseOrder: true,
  113. OutputOffset: 0,
  114. MaxOutputs: 10,
  115. ScoringCriteria: &RankByTokenProximity{},
  116. },
  117. IndexerInitOptions: &types.IndexerInitOptions{
  118. IndexType: types.LocationsIndex,
  119. },
  120. })
  121. AddDocs(&engine)
  122. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  123. utils.Expect(t, "3", len(outputs.Docs))
  124. utils.Expect(t, "1", outputs.Docs[0].DocId)
  125. utils.Expect(t, "5", outputs.Docs[1].DocId)
  126. utils.Expect(t, "2", outputs.Docs[2].DocId)
  127. }
  128. func TestOffsetAndMaxOutputs(t *testing.T) {
  129. var engine Engine
  130. engine.Init(types.EngineInitOptions{
  131. SegmenterDictionaries: "../testdata/test_dict.txt",
  132. DefaultRankOptions: &types.RankOptions{
  133. ReverseOrder: true,
  134. OutputOffset: 1,
  135. MaxOutputs: 3,
  136. ScoringCriteria: &RankByTokenProximity{},
  137. },
  138. IndexerInitOptions: &types.IndexerInitOptions{
  139. IndexType: types.LocationsIndex,
  140. },
  141. })
  142. AddDocs(&engine)
  143. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  144. utils.Expect(t, "2", len(outputs.Docs))
  145. utils.Expect(t, "5", outputs.Docs[0].DocId)
  146. utils.Expect(t, "2", outputs.Docs[1].DocId)
  147. }
  148. type TestScoringCriteria struct {
  149. }
  150. func (criteria TestScoringCriteria) Score(
  151. doc types.IndexedDocument, fields interface{}) []float32 {
  152. if reflect.TypeOf(fields) != reflect.TypeOf(ScoringFields{}) {
  153. return []float32{}
  154. }
  155. fs := fields.(ScoringFields)
  156. return []float32{float32(doc.TokenProximity)*fs.A + fs.B*fs.C}
  157. }
  158. func TestSearchWithCriteria(t *testing.T) {
  159. var engine Engine
  160. engine.Init(types.EngineInitOptions{
  161. SegmenterDictionaries: "../testdata/test_dict.txt",
  162. DefaultRankOptions: &types.RankOptions{
  163. ScoringCriteria: TestScoringCriteria{},
  164. },
  165. IndexerInitOptions: &types.IndexerInitOptions{
  166. IndexType: types.LocationsIndex,
  167. },
  168. })
  169. AddDocs(&engine)
  170. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  171. utils.Expect(t, "2", len(outputs.Docs))
  172. utils.Expect(t, "1", outputs.Docs[0].DocId)
  173. utils.Expect(t, "18000", int(outputs.Docs[0].Scores[0]*1000))
  174. utils.Expect(t, "5", outputs.Docs[1].DocId)
  175. utils.Expect(t, "9000", int(outputs.Docs[1].Scores[0]*1000))
  176. }
  177. func TestCompactIndex(t *testing.T) {
  178. var engine Engine
  179. engine.Init(types.EngineInitOptions{
  180. SegmenterDictionaries: "../testdata/test_dict.txt",
  181. DefaultRankOptions: &types.RankOptions{
  182. ScoringCriteria: TestScoringCriteria{},
  183. },
  184. })
  185. AddDocs(&engine)
  186. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  187. utils.Expect(t, "2", len(outputs.Docs))
  188. utils.Expect(t, "5", outputs.Docs[0].DocId)
  189. utils.Expect(t, "9000", int(outputs.Docs[0].Scores[0]*1000))
  190. utils.Expect(t, "1", outputs.Docs[1].DocId)
  191. utils.Expect(t, "6000", int(outputs.Docs[1].Scores[0]*1000))
  192. }
  193. type BM25ScoringCriteria struct {
  194. }
  195. func (criteria BM25ScoringCriteria) Score(doc types.IndexedDocument, fields interface{}) []float32 {
  196. if reflect.TypeOf(fields) != reflect.TypeOf(ScoringFields{}) {
  197. return []float32{}
  198. }
  199. return []float32{doc.BM25}
  200. }
  201. func TestFrequenciesIndex(t *testing.T) {
  202. var engine Engine
  203. engine.Init(types.EngineInitOptions{
  204. SegmenterDictionaries: "../testdata/test_dict.txt",
  205. DefaultRankOptions: &types.RankOptions{
  206. ScoringCriteria: BM25ScoringCriteria{},
  207. },
  208. IndexerInitOptions: &types.IndexerInitOptions{
  209. IndexType: types.FrequenciesIndex,
  210. },
  211. })
  212. AddDocs(&engine)
  213. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  214. utils.Expect(t, "2", len(outputs.Docs))
  215. utils.Expect(t, "5", outputs.Docs[0].DocId)
  216. utils.Expect(t, "2349", int(outputs.Docs[0].Scores[0]*1000))
  217. utils.Expect(t, "1", outputs.Docs[1].DocId)
  218. utils.Expect(t, "2320", int(outputs.Docs[1].Scores[0]*1000))
  219. }
  220. func TestRemoveDocument(t *testing.T) {
  221. var engine Engine
  222. engine.Init(types.EngineInitOptions{
  223. SegmenterDictionaries: "../testdata/test_dict.txt",
  224. DefaultRankOptions: &types.RankOptions{
  225. ScoringCriteria: TestScoringCriteria{},
  226. },
  227. })
  228. AddDocs(&engine)
  229. engine.RemoveDocument(5, false)
  230. engine.RemoveDocument(6, false)
  231. engine.FlushIndex()
  232. engine.IndexDocument(6, types.DocumentIndexData{
  233. Content: "中国人口有十三亿",
  234. Fields: ScoringFields{0, 9, 1},
  235. }, false)
  236. engine.FlushIndex()
  237. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  238. utils.Expect(t, "2", len(outputs.Docs))
  239. utils.Expect(t, "6", outputs.Docs[0].DocId)
  240. utils.Expect(t, "9000", int(outputs.Docs[0].Scores[0]*1000))
  241. utils.Expect(t, "1", outputs.Docs[1].DocId)
  242. utils.Expect(t, "6000", int(outputs.Docs[1].Scores[0]*1000))
  243. }
  244. func TestEngineIndexDocumentWithTokens(t *testing.T) {
  245. var engine Engine
  246. engine.Init(types.EngineInitOptions{
  247. SegmenterDictionaries: "../testdata/test_dict.txt",
  248. DefaultRankOptions: &types.RankOptions{
  249. OutputOffset: 0,
  250. MaxOutputs: 10,
  251. ScoringCriteria: &RankByTokenProximity{},
  252. },
  253. IndexerInitOptions: &types.IndexerInitOptions{
  254. IndexType: types.LocationsIndex,
  255. },
  256. })
  257. docId := uint64(1)
  258. engine.IndexDocument(docId, types.DocumentIndexData{
  259. Content: "",
  260. Tokens: []types.TokenData{
  261. {"中国", []int{0}},
  262. {"人口", []int{18, 24}},
  263. },
  264. Fields: ScoringFields{1, 2, 3},
  265. }, false)
  266. docId++
  267. engine.IndexDocument(docId, types.DocumentIndexData{
  268. Content: "",
  269. Tokens: []types.TokenData{
  270. {"中国", []int{0}},
  271. {"人口", []int{6}},
  272. },
  273. Fields: ScoringFields{1, 2, 3},
  274. }, false)
  275. docId++
  276. engine.IndexDocument(docId, types.DocumentIndexData{
  277. Content: "中国十三亿人口",
  278. Fields: ScoringFields{0, 9, 1},
  279. }, false)
  280. engine.FlushIndex()
  281. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  282. utils.Expect(t, "2", len(outputs.Tokens))
  283. utils.Expect(t, "中国", outputs.Tokens[0])
  284. utils.Expect(t, "人口", outputs.Tokens[1])
  285. utils.Expect(t, "3", len(outputs.Docs))
  286. utils.Expect(t, "2", outputs.Docs[0].DocId)
  287. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  288. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  289. utils.Expect(t, "3", outputs.Docs[1].DocId)
  290. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  291. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  292. utils.Expect(t, "1", outputs.Docs[2].DocId)
  293. utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
  294. utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
  295. }
  296. func TestEngineIndexDocumentWithContentAndLabels(t *testing.T) {
  297. var engine1, engine2 Engine
  298. engine1.Init(types.EngineInitOptions{
  299. SegmenterDictionaries: "../data/dictionary.txt",
  300. IndexerInitOptions: &types.IndexerInitOptions{
  301. IndexType: types.LocationsIndex,
  302. },
  303. })
  304. engine2.Init(types.EngineInitOptions{
  305. SegmenterDictionaries: "../data/dictionary.txt",
  306. IndexerInitOptions: &types.IndexerInitOptions{
  307. IndexType: types.DocIdsIndex,
  308. },
  309. })
  310. addDocsWithLabels(&engine1)
  311. addDocsWithLabels(&engine2)
  312. outputs1 := engine1.Search(types.SearchRequest{Text: "百度"})
  313. outputs2 := engine2.Search(types.SearchRequest{Text: "百度"})
  314. utils.Expect(t, "1", len(outputs1.Tokens))
  315. utils.Expect(t, "1", len(outputs2.Tokens))
  316. utils.Expect(t, "百度", outputs1.Tokens[0])
  317. utils.Expect(t, "百度", outputs2.Tokens[0])
  318. utils.Expect(t, "5", len(outputs1.Docs))
  319. utils.Expect(t, "5", len(outputs2.Docs))
  320. }
  321. func TestEngineIndexDocumentWithPersistentStorage(t *testing.T) {
  322. gob.Register(ScoringFields{})
  323. var engine Engine
  324. engine.Init(types.EngineInitOptions{
  325. SegmenterDictionaries: "../testdata/test_dict.txt",
  326. DefaultRankOptions: &types.RankOptions{
  327. OutputOffset: 0,
  328. MaxOutputs: 10,
  329. ScoringCriteria: &RankByTokenProximity{},
  330. },
  331. IndexerInitOptions: &types.IndexerInitOptions{
  332. IndexType: types.LocationsIndex,
  333. },
  334. UsePersistentStorage: true,
  335. PersistentStorageFolder: "wukong.persistent",
  336. PersistentStorageShards: 2,
  337. })
  338. AddDocs(&engine)
  339. engine.RemoveDocument(5, true)
  340. engine.Close()
  341. var engine1 Engine
  342. engine1.Init(types.EngineInitOptions{
  343. SegmenterDictionaries: "../testdata/test_dict.txt",
  344. DefaultRankOptions: &types.RankOptions{
  345. OutputOffset: 0,
  346. MaxOutputs: 10,
  347. ScoringCriteria: &RankByTokenProximity{},
  348. },
  349. IndexerInitOptions: &types.IndexerInitOptions{
  350. IndexType: types.LocationsIndex,
  351. },
  352. UsePersistentStorage: true,
  353. PersistentStorageFolder: "wukong.persistent",
  354. PersistentStorageShards: 2,
  355. })
  356. engine1.FlushIndex()
  357. outputs := engine1.Search(types.SearchRequest{Text: "中国人口"})
  358. utils.Expect(t, "2", len(outputs.Tokens))
  359. utils.Expect(t, "中国", outputs.Tokens[0])
  360. utils.Expect(t, "人口", outputs.Tokens[1])
  361. utils.Expect(t, "2", len(outputs.Docs))
  362. utils.Expect(t, "2", outputs.Docs[0].DocId)
  363. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  364. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  365. utils.Expect(t, "1", outputs.Docs[1].DocId)
  366. utils.Expect(t, "76", int(outputs.Docs[1].Scores[0]*1000))
  367. utils.Expect(t, "[0 18]", outputs.Docs[1].TokenSnippetLocations)
  368. engine1.Close()
  369. os.RemoveAll("wukong.persistent")
  370. }
  371. func TestCountDocsOnly(t *testing.T) {
  372. var engine Engine
  373. engine.Init(types.EngineInitOptions{
  374. SegmenterDictionaries: "../testdata/test_dict.txt",
  375. DefaultRankOptions: &types.RankOptions{
  376. ReverseOrder: true,
  377. OutputOffset: 0,
  378. MaxOutputs: 1,
  379. ScoringCriteria: &RankByTokenProximity{},
  380. },
  381. IndexerInitOptions: &types.IndexerInitOptions{
  382. IndexType: types.LocationsIndex,
  383. },
  384. })
  385. AddDocs(&engine)
  386. engine.RemoveDocument(5, false)
  387. engine.FlushIndex()
  388. outputs := engine.Search(types.SearchRequest{Text: "中国人口", CountDocsOnly: true})
  389. utils.Expect(t, "0", len(outputs.Docs))
  390. utils.Expect(t, "2", len(outputs.Tokens))
  391. utils.Expect(t, "2", outputs.NumDocs)
  392. }
  393. func TestSearchWithin(t *testing.T) {
  394. var engine Engine
  395. engine.Init(types.EngineInitOptions{
  396. SegmenterDictionaries: "../testdata/test_dict.txt",
  397. DefaultRankOptions: &types.RankOptions{
  398. ReverseOrder: true,
  399. OutputOffset: 0,
  400. MaxOutputs: 10,
  401. ScoringCriteria: &RankByTokenProximity{},
  402. },
  403. IndexerInitOptions: &types.IndexerInitOptions{
  404. IndexType: types.LocationsIndex,
  405. },
  406. })
  407. AddDocs(&engine)
  408. docIds := make(map[string]bool)
  409. docIds["5"] = true
  410. docIds["1"] = true
  411. outputs := engine.Search(types.SearchRequest{
  412. Text: "中国人口",
  413. DocIds: docIds,
  414. })
  415. utils.Expect(t, "2", len(outputs.Tokens))
  416. utils.Expect(t, "中国", outputs.Tokens[0])
  417. utils.Expect(t, "人口", outputs.Tokens[1])
  418. utils.Expect(t, "2", len(outputs.Docs))
  419. utils.Expect(t, "1", outputs.Docs[0].DocId)
  420. utils.Expect(t, "76", int(outputs.Docs[0].Scores[0]*1000))
  421. utils.Expect(t, "[0 18]", outputs.Docs[0].TokenSnippetLocations)
  422. utils.Expect(t, "5", outputs.Docs[1].DocId)
  423. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  424. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  425. }