engine_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. package engine
  2. import (
  3. "encoding/gob"
  4. "github.com/huichen/wukong/types"
  5. "github.com/huichen/wukong/utils"
  6. "os"
  7. "reflect"
  8. "testing"
  9. )
  10. type ScoringFields struct {
  11. A, B, C float32
  12. }
  13. func AddDocs(engine *Engine) {
  14. docId := uint64(1)
  15. // 因为需要保证文档全部被加入到索引中,所以 forceUpdate 全部设置成 true
  16. engine.IndexDocument(docId, types.DocumentIndexData{
  17. Content: "中国有十三亿人口人口",
  18. Fields: ScoringFields{1, 2, 3},
  19. }, true)
  20. docId++
  21. engine.IndexDocument(docId, types.DocumentIndexData{
  22. Content: "中国人口",
  23. Fields: nil,
  24. }, true)
  25. docId++
  26. engine.IndexDocument(docId, types.DocumentIndexData{
  27. Content: "有人口",
  28. Fields: ScoringFields{2, 3, 1},
  29. }, true)
  30. docId++
  31. engine.IndexDocument(docId, types.DocumentIndexData{
  32. Content: "有十三亿人口",
  33. Fields: ScoringFields{2, 3, 3},
  34. }, true)
  35. docId++
  36. engine.IndexDocument(docId, types.DocumentIndexData{
  37. Content: "中国十三亿人口",
  38. Fields: ScoringFields{0, 9, 1},
  39. }, true)
  40. engine.FlushIndex()
  41. }
  42. type RankByTokenProximity struct {
  43. }
  44. func (rule RankByTokenProximity) Score(
  45. doc types.IndexedDocument, fields interface{}) []float32 {
  46. if doc.TokenProximity < 0 {
  47. return []float32{}
  48. }
  49. return []float32{1.0 / (float32(doc.TokenProximity) + 1)}
  50. }
  51. func TestEngineIndexDocument(t *testing.T) {
  52. var engine Engine
  53. engine.Init(types.EngineInitOptions{
  54. SegmenterDictionaries: "../testdata/test_dict.txt",
  55. DefaultRankOptions: &types.RankOptions{
  56. OutputOffset: 0,
  57. MaxOutputs: 10,
  58. ScoringCriteria: &RankByTokenProximity{},
  59. },
  60. IndexerInitOptions: &types.IndexerInitOptions{
  61. IndexType: types.LocationsIndex,
  62. },
  63. })
  64. AddDocs(&engine)
  65. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  66. utils.Expect(t, "2", len(outputs.Tokens))
  67. utils.Expect(t, "中国", outputs.Tokens[0])
  68. utils.Expect(t, "人口", outputs.Tokens[1])
  69. utils.Expect(t, "3", len(outputs.Docs))
  70. utils.Expect(t, "2", outputs.Docs[0].DocId)
  71. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  72. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  73. utils.Expect(t, "5", outputs.Docs[1].DocId)
  74. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  75. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  76. utils.Expect(t, "1", outputs.Docs[2].DocId)
  77. utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
  78. utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
  79. }
  80. func TestReverseOrder(t *testing.T) {
  81. var engine Engine
  82. engine.Init(types.EngineInitOptions{
  83. SegmenterDictionaries: "../testdata/test_dict.txt",
  84. DefaultRankOptions: &types.RankOptions{
  85. ReverseOrder: true,
  86. OutputOffset: 0,
  87. MaxOutputs: 10,
  88. ScoringCriteria: &RankByTokenProximity{},
  89. },
  90. IndexerInitOptions: &types.IndexerInitOptions{
  91. IndexType: types.LocationsIndex,
  92. },
  93. })
  94. AddDocs(&engine)
  95. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  96. utils.Expect(t, "3", len(outputs.Docs))
  97. utils.Expect(t, "1", outputs.Docs[0].DocId)
  98. utils.Expect(t, "5", outputs.Docs[1].DocId)
  99. utils.Expect(t, "2", outputs.Docs[2].DocId)
  100. }
  101. func TestOffsetAndMaxOutputs(t *testing.T) {
  102. var engine Engine
  103. engine.Init(types.EngineInitOptions{
  104. SegmenterDictionaries: "../testdata/test_dict.txt",
  105. DefaultRankOptions: &types.RankOptions{
  106. ReverseOrder: true,
  107. OutputOffset: 1,
  108. MaxOutputs: 3,
  109. ScoringCriteria: &RankByTokenProximity{},
  110. },
  111. IndexerInitOptions: &types.IndexerInitOptions{
  112. IndexType: types.LocationsIndex,
  113. },
  114. })
  115. AddDocs(&engine)
  116. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  117. utils.Expect(t, "2", len(outputs.Docs))
  118. utils.Expect(t, "5", outputs.Docs[0].DocId)
  119. utils.Expect(t, "2", outputs.Docs[1].DocId)
  120. }
  121. type TestScoringCriteria struct {
  122. }
  123. func (criteria TestScoringCriteria) Score(
  124. doc types.IndexedDocument, fields interface{}) []float32 {
  125. if reflect.TypeOf(fields) != reflect.TypeOf(ScoringFields{}) {
  126. return []float32{}
  127. }
  128. fs := fields.(ScoringFields)
  129. return []float32{float32(doc.TokenProximity)*fs.A + fs.B*fs.C}
  130. }
  131. func TestSearchWithCriteria(t *testing.T) {
  132. var engine Engine
  133. engine.Init(types.EngineInitOptions{
  134. SegmenterDictionaries: "../testdata/test_dict.txt",
  135. DefaultRankOptions: &types.RankOptions{
  136. ScoringCriteria: TestScoringCriteria{},
  137. },
  138. IndexerInitOptions: &types.IndexerInitOptions{
  139. IndexType: types.LocationsIndex,
  140. },
  141. })
  142. AddDocs(&engine)
  143. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  144. utils.Expect(t, "2", len(outputs.Docs))
  145. utils.Expect(t, "1", outputs.Docs[0].DocId)
  146. utils.Expect(t, "18000", int(outputs.Docs[0].Scores[0]*1000))
  147. utils.Expect(t, "5", outputs.Docs[1].DocId)
  148. utils.Expect(t, "9000", int(outputs.Docs[1].Scores[0]*1000))
  149. }
  150. func TestCompactIndex(t *testing.T) {
  151. var engine Engine
  152. engine.Init(types.EngineInitOptions{
  153. SegmenterDictionaries: "../testdata/test_dict.txt",
  154. DefaultRankOptions: &types.RankOptions{
  155. ScoringCriteria: TestScoringCriteria{},
  156. },
  157. })
  158. AddDocs(&engine)
  159. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  160. utils.Expect(t, "2", len(outputs.Docs))
  161. utils.Expect(t, "5", outputs.Docs[0].DocId)
  162. utils.Expect(t, "9000", int(outputs.Docs[0].Scores[0]*1000))
  163. utils.Expect(t, "1", outputs.Docs[1].DocId)
  164. utils.Expect(t, "6000", int(outputs.Docs[1].Scores[0]*1000))
  165. }
  166. type BM25ScoringCriteria struct {
  167. }
  168. func (criteria BM25ScoringCriteria) Score(
  169. doc types.IndexedDocument, fields interface{}) []float32 {
  170. if reflect.TypeOf(fields) != reflect.TypeOf(ScoringFields{}) {
  171. return []float32{}
  172. }
  173. return []float32{doc.BM25}
  174. }
  175. func TestFrequenciesIndex(t *testing.T) {
  176. var engine Engine
  177. engine.Init(types.EngineInitOptions{
  178. SegmenterDictionaries: "../testdata/test_dict.txt",
  179. DefaultRankOptions: &types.RankOptions{
  180. ScoringCriteria: BM25ScoringCriteria{},
  181. },
  182. IndexerInitOptions: &types.IndexerInitOptions{
  183. IndexType: types.FrequenciesIndex,
  184. },
  185. })
  186. AddDocs(&engine)
  187. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  188. utils.Expect(t, "2", len(outputs.Docs))
  189. utils.Expect(t, "5", outputs.Docs[0].DocId)
  190. utils.Expect(t, "2349", int(outputs.Docs[0].Scores[0]*1000))
  191. utils.Expect(t, "1", outputs.Docs[1].DocId)
  192. utils.Expect(t, "2320", int(outputs.Docs[1].Scores[0]*1000))
  193. }
  194. func TestRemoveDocument(t *testing.T) {
  195. var engine Engine
  196. engine.Init(types.EngineInitOptions{
  197. SegmenterDictionaries: "../testdata/test_dict.txt",
  198. DefaultRankOptions: &types.RankOptions{
  199. ScoringCriteria: TestScoringCriteria{},
  200. },
  201. })
  202. AddDocs(&engine)
  203. engine.RemoveDocument(5, true)
  204. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  205. utils.Expect(t, "1", len(outputs.Docs))
  206. utils.Expect(t, "1", outputs.Docs[0].DocId)
  207. utils.Expect(t, "6000", int(outputs.Docs[0].Scores[0]*1000))
  208. }
  209. func TestEngineIndexDocumentWithTokens(t *testing.T) {
  210. var engine Engine
  211. engine.Init(types.EngineInitOptions{
  212. SegmenterDictionaries: "../testdata/test_dict.txt",
  213. DefaultRankOptions: &types.RankOptions{
  214. OutputOffset: 0,
  215. MaxOutputs: 10,
  216. ScoringCriteria: &RankByTokenProximity{},
  217. },
  218. IndexerInitOptions: &types.IndexerInitOptions{
  219. IndexType: types.LocationsIndex,
  220. },
  221. })
  222. docId := uint64(1)
  223. engine.IndexDocument(docId, types.DocumentIndexData{
  224. Content: "",
  225. Tokens: []types.TokenData{
  226. {"中国", []int{0}},
  227. {"人口", []int{18, 24}},
  228. },
  229. Fields: ScoringFields{1, 2, 3},
  230. }, true)
  231. docId++
  232. engine.IndexDocument(docId, types.DocumentIndexData{
  233. Content: "",
  234. Tokens: []types.TokenData{
  235. {"中国", []int{0}},
  236. {"人口", []int{6}},
  237. },
  238. Fields: ScoringFields{1, 2, 3},
  239. }, true)
  240. docId++
  241. engine.IndexDocument(docId, types.DocumentIndexData{
  242. Content: "中国十三亿人口",
  243. Fields: ScoringFields{0, 9, 1},
  244. }, true)
  245. engine.FlushIndex()
  246. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  247. utils.Expect(t, "2", len(outputs.Tokens))
  248. utils.Expect(t, "中国", outputs.Tokens[0])
  249. utils.Expect(t, "人口", outputs.Tokens[1])
  250. utils.Expect(t, "3", len(outputs.Docs))
  251. utils.Expect(t, "2", outputs.Docs[0].DocId)
  252. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  253. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  254. utils.Expect(t, "3", outputs.Docs[1].DocId)
  255. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  256. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  257. utils.Expect(t, "1", outputs.Docs[2].DocId)
  258. utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
  259. utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
  260. }
  261. func TestEngineIndexDocumentWithPersistentStorage(t *testing.T) {
  262. gob.Register(ScoringFields{})
  263. var engine Engine
  264. engine.Init(types.EngineInitOptions{
  265. SegmenterDictionaries: "../testdata/test_dict.txt",
  266. DefaultRankOptions: &types.RankOptions{
  267. OutputOffset: 0,
  268. MaxOutputs: 10,
  269. ScoringCriteria: &RankByTokenProximity{},
  270. },
  271. IndexerInitOptions: &types.IndexerInitOptions{
  272. IndexType: types.LocationsIndex,
  273. },
  274. UsePersistentStorage: true,
  275. PersistentStorageFolder: "wukong.persistent",
  276. PersistentStorageShards: 2,
  277. })
  278. AddDocs(&engine)
  279. engine.RemoveDocument(5, true)
  280. engine.Close()
  281. var engine1 Engine
  282. engine1.Init(types.EngineInitOptions{
  283. SegmenterDictionaries: "../testdata/test_dict.txt",
  284. DefaultRankOptions: &types.RankOptions{
  285. OutputOffset: 0,
  286. MaxOutputs: 10,
  287. ScoringCriteria: &RankByTokenProximity{},
  288. },
  289. IndexerInitOptions: &types.IndexerInitOptions{
  290. IndexType: types.LocationsIndex,
  291. },
  292. UsePersistentStorage: true,
  293. PersistentStorageFolder: "wukong.persistent",
  294. PersistentStorageShards: 2,
  295. })
  296. engine1.FlushIndex()
  297. outputs := engine1.Search(types.SearchRequest{Text: "中国人口"})
  298. utils.Expect(t, "2", len(outputs.Tokens))
  299. utils.Expect(t, "中国", outputs.Tokens[0])
  300. utils.Expect(t, "人口", outputs.Tokens[1])
  301. utils.Expect(t, "2", len(outputs.Docs))
  302. utils.Expect(t, "2", outputs.Docs[0].DocId)
  303. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  304. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  305. utils.Expect(t, "1", outputs.Docs[1].DocId)
  306. utils.Expect(t, "76", int(outputs.Docs[1].Scores[0]*1000))
  307. utils.Expect(t, "[0 18]", outputs.Docs[1].TokenSnippetLocations)
  308. engine1.Close()
  309. os.RemoveAll("wukong.persistent")
  310. }
  311. func TestCountDocsOnly(t *testing.T) {
  312. var engine Engine
  313. engine.Init(types.EngineInitOptions{
  314. SegmenterDictionaries: "../testdata/test_dict.txt",
  315. DefaultRankOptions: &types.RankOptions{
  316. ReverseOrder: true,
  317. OutputOffset: 0,
  318. MaxOutputs: 1,
  319. ScoringCriteria: &RankByTokenProximity{},
  320. },
  321. IndexerInitOptions: &types.IndexerInitOptions{
  322. IndexType: types.LocationsIndex,
  323. },
  324. })
  325. AddDocs(&engine)
  326. engine.RemoveDocument(5, true)
  327. outputs := engine.Search(types.SearchRequest{Text: "中国人口", CountDocsOnly: true})
  328. utils.Expect(t, "0", len(outputs.Docs))
  329. utils.Expect(t, "2", len(outputs.Tokens))
  330. utils.Expect(t, "2", outputs.NumDocs)
  331. }
  332. func TestSearchWithin(t *testing.T) {
  333. var engine Engine
  334. engine.Init(types.EngineInitOptions{
  335. SegmenterDictionaries: "../testdata/test_dict.txt",
  336. DefaultRankOptions: &types.RankOptions{
  337. ReverseOrder: true,
  338. OutputOffset: 0,
  339. MaxOutputs: 10,
  340. ScoringCriteria: &RankByTokenProximity{},
  341. },
  342. IndexerInitOptions: &types.IndexerInitOptions{
  343. IndexType: types.LocationsIndex,
  344. },
  345. })
  346. AddDocs(&engine)
  347. docIds := make(map[uint64]bool)
  348. docIds[5] = true
  349. docIds[1] = true
  350. outputs := engine.Search(types.SearchRequest{
  351. Text: "中国人口",
  352. DocIds: docIds,
  353. })
  354. utils.Expect(t, "2", len(outputs.Tokens))
  355. utils.Expect(t, "中国", outputs.Tokens[0])
  356. utils.Expect(t, "人口", outputs.Tokens[1])
  357. utils.Expect(t, "2", len(outputs.Docs))
  358. utils.Expect(t, "1", outputs.Docs[0].DocId)
  359. utils.Expect(t, "76", int(outputs.Docs[0].Scores[0]*1000))
  360. utils.Expect(t, "[0 18]", outputs.Docs[0].TokenSnippetLocations)
  361. utils.Expect(t, "5", outputs.Docs[1].DocId)
  362. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  363. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  364. }