engine_test.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package engine
  2. import (
  3. "github.com/huichen/wukong/types"
  4. "github.com/huichen/wukong/utils"
  5. "reflect"
  6. "testing"
  7. )
  8. type ScoringFields struct {
  9. a, b, c float32
  10. }
  11. func AddDocs(engine *Engine) {
  12. docId := uint64(0)
  13. engine.IndexDocument(docId, types.DocumentIndexData{
  14. Content: "中国有十三亿人口人口",
  15. Fields: ScoringFields{1, 2, 3},
  16. })
  17. docId++
  18. engine.IndexDocument(docId, types.DocumentIndexData{
  19. Content: "中国人口",
  20. Fields: nil,
  21. })
  22. docId++
  23. engine.IndexDocument(docId, types.DocumentIndexData{
  24. Content: "有人口",
  25. Fields: ScoringFields{2, 3, 1},
  26. })
  27. docId++
  28. engine.IndexDocument(docId, types.DocumentIndexData{
  29. Content: "有十三亿人口",
  30. Fields: ScoringFields{2, 3, 3},
  31. })
  32. docId++
  33. engine.IndexDocument(docId, types.DocumentIndexData{
  34. Content: "中国十三亿人口",
  35. Fields: ScoringFields{0, 9, 1},
  36. })
  37. engine.FlushIndex()
  38. }
  39. type RankByTokenProximity struct {
  40. }
  41. func (rule RankByTokenProximity) Score(
  42. doc types.IndexedDocument, fields interface{}) []float32 {
  43. if doc.TokenProximity < 0 {
  44. return []float32{}
  45. }
  46. return []float32{1.0 / (float32(doc.TokenProximity) + 1)}
  47. }
  48. func TestEngineIndexDocument(t *testing.T) {
  49. var engine Engine
  50. engine.Init(types.EngineInitOptions{
  51. SegmenterDictionaries: "../testdata/test_dict.txt",
  52. DefaultRankOptions: &types.RankOptions{
  53. OutputOffset: 0,
  54. MaxOutputs: 10,
  55. ScoringCriteria: &RankByTokenProximity{},
  56. },
  57. IndexerInitOptions: &types.IndexerInitOptions{
  58. IndexType: types.LocationsIndex,
  59. },
  60. })
  61. AddDocs(&engine)
  62. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  63. utils.Expect(t, "2", len(outputs.Tokens))
  64. utils.Expect(t, "中国", outputs.Tokens[0])
  65. utils.Expect(t, "人口", outputs.Tokens[1])
  66. utils.Expect(t, "3", len(outputs.Docs))
  67. utils.Expect(t, "1", outputs.Docs[0].DocId)
  68. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  69. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  70. utils.Expect(t, "4", outputs.Docs[1].DocId)
  71. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  72. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  73. utils.Expect(t, "0", outputs.Docs[2].DocId)
  74. utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
  75. utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
  76. }
  77. func TestReverseOrder(t *testing.T) {
  78. var engine Engine
  79. engine.Init(types.EngineInitOptions{
  80. SegmenterDictionaries: "../testdata/test_dict.txt",
  81. DefaultRankOptions: &types.RankOptions{
  82. ReverseOrder: true,
  83. OutputOffset: 0,
  84. MaxOutputs: 10,
  85. ScoringCriteria: &RankByTokenProximity{},
  86. },
  87. IndexerInitOptions: &types.IndexerInitOptions{
  88. IndexType: types.LocationsIndex,
  89. },
  90. })
  91. AddDocs(&engine)
  92. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  93. utils.Expect(t, "3", len(outputs.Docs))
  94. utils.Expect(t, "0", outputs.Docs[0].DocId)
  95. utils.Expect(t, "4", outputs.Docs[1].DocId)
  96. utils.Expect(t, "1", outputs.Docs[2].DocId)
  97. }
  98. func TestOffsetAndMaxOutputs(t *testing.T) {
  99. var engine Engine
  100. engine.Init(types.EngineInitOptions{
  101. SegmenterDictionaries: "../testdata/test_dict.txt",
  102. DefaultRankOptions: &types.RankOptions{
  103. ReverseOrder: true,
  104. OutputOffset: 1,
  105. MaxOutputs: 3,
  106. ScoringCriteria: &RankByTokenProximity{},
  107. },
  108. IndexerInitOptions: &types.IndexerInitOptions{
  109. IndexType: types.LocationsIndex,
  110. },
  111. })
  112. AddDocs(&engine)
  113. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  114. utils.Expect(t, "2", len(outputs.Docs))
  115. utils.Expect(t, "4", outputs.Docs[0].DocId)
  116. utils.Expect(t, "1", outputs.Docs[1].DocId)
  117. }
  118. type TestScoringCriteria struct {
  119. }
  120. func (criteria TestScoringCriteria) Score(
  121. doc types.IndexedDocument, fields interface{}) []float32 {
  122. if reflect.TypeOf(fields) != reflect.TypeOf(ScoringFields{}) {
  123. return []float32{}
  124. }
  125. fs := fields.(ScoringFields)
  126. return []float32{float32(doc.TokenProximity)*fs.a + fs.b*fs.c}
  127. }
  128. func TestSearchWithCriteria(t *testing.T) {
  129. var engine Engine
  130. engine.Init(types.EngineInitOptions{
  131. SegmenterDictionaries: "../testdata/test_dict.txt",
  132. DefaultRankOptions: &types.RankOptions{
  133. ScoringCriteria: TestScoringCriteria{},
  134. },
  135. IndexerInitOptions: &types.IndexerInitOptions{
  136. IndexType: types.LocationsIndex,
  137. },
  138. })
  139. AddDocs(&engine)
  140. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  141. utils.Expect(t, "2", len(outputs.Docs))
  142. utils.Expect(t, "0", outputs.Docs[0].DocId)
  143. utils.Expect(t, "18000", int(outputs.Docs[0].Scores[0]*1000))
  144. utils.Expect(t, "4", outputs.Docs[1].DocId)
  145. utils.Expect(t, "9000", int(outputs.Docs[1].Scores[0]*1000))
  146. }
  147. func TestCompactIndex(t *testing.T) {
  148. var engine Engine
  149. engine.Init(types.EngineInitOptions{
  150. SegmenterDictionaries: "../testdata/test_dict.txt",
  151. DefaultRankOptions: &types.RankOptions{
  152. ScoringCriteria: TestScoringCriteria{},
  153. },
  154. })
  155. AddDocs(&engine)
  156. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  157. utils.Expect(t, "2", len(outputs.Docs))
  158. utils.Expect(t, "4", outputs.Docs[0].DocId)
  159. utils.Expect(t, "9000", int(outputs.Docs[0].Scores[0]*1000))
  160. utils.Expect(t, "0", outputs.Docs[1].DocId)
  161. utils.Expect(t, "6000", int(outputs.Docs[1].Scores[0]*1000))
  162. }
  163. type BM25ScoringCriteria struct {
  164. }
  165. func (criteria BM25ScoringCriteria) Score(
  166. doc types.IndexedDocument, fields interface{}) []float32 {
  167. if reflect.TypeOf(fields) != reflect.TypeOf(ScoringFields{}) {
  168. return []float32{}
  169. }
  170. return []float32{doc.BM25}
  171. }
  172. func TestFrequenciesIndex(t *testing.T) {
  173. var engine Engine
  174. engine.Init(types.EngineInitOptions{
  175. SegmenterDictionaries: "../testdata/test_dict.txt",
  176. DefaultRankOptions: &types.RankOptions{
  177. ScoringCriteria: BM25ScoringCriteria{},
  178. },
  179. IndexerInitOptions: &types.IndexerInitOptions{
  180. IndexType: types.FrequenciesIndex,
  181. },
  182. })
  183. AddDocs(&engine)
  184. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  185. utils.Expect(t, "2", len(outputs.Docs))
  186. utils.Expect(t, "4", outputs.Docs[0].DocId)
  187. utils.Expect(t, "2311", int(outputs.Docs[0].Scores[0]*1000))
  188. utils.Expect(t, "0", outputs.Docs[1].DocId)
  189. utils.Expect(t, "2211", int(outputs.Docs[1].Scores[0]*1000))
  190. }
  191. func TestRemoveDocument(t *testing.T) {
  192. var engine Engine
  193. engine.Init(types.EngineInitOptions{
  194. SegmenterDictionaries: "../testdata/test_dict.txt",
  195. DefaultRankOptions: &types.RankOptions{
  196. ScoringCriteria: TestScoringCriteria{},
  197. },
  198. })
  199. AddDocs(&engine)
  200. engine.RemoveDocument(4)
  201. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  202. utils.Expect(t, "1", len(outputs.Docs))
  203. utils.Expect(t, "0", outputs.Docs[0].DocId)
  204. utils.Expect(t, "6000", int(outputs.Docs[0].Scores[0]*1000))
  205. }
  206. func TestEngineIndexDocumentWithTokens(t *testing.T) {
  207. var engine Engine
  208. engine.Init(types.EngineInitOptions{
  209. SegmenterDictionaries: "../testdata/test_dict.txt",
  210. DefaultRankOptions: &types.RankOptions{
  211. OutputOffset: 0,
  212. MaxOutputs: 10,
  213. ScoringCriteria: &RankByTokenProximity{},
  214. },
  215. IndexerInitOptions: &types.IndexerInitOptions{
  216. IndexType: types.LocationsIndex,
  217. },
  218. })
  219. docId := uint64(0)
  220. engine.IndexDocument(docId, types.DocumentIndexData{
  221. Content: "",
  222. Tokens: []types.TokenData{
  223. {"中国", []int{0}},
  224. {"人口", []int{18, 24}},
  225. },
  226. Fields: ScoringFields{1, 2, 3},
  227. })
  228. docId++
  229. engine.IndexDocument(docId, types.DocumentIndexData{
  230. Content: "",
  231. Tokens: []types.TokenData{
  232. {"中国", []int{0}},
  233. {"人口", []int{6}},
  234. },
  235. Fields: ScoringFields{1, 2, 3},
  236. })
  237. docId++
  238. engine.IndexDocument(docId, types.DocumentIndexData{
  239. Content: "中国十三亿人口",
  240. Fields: ScoringFields{0, 9, 1},
  241. })
  242. engine.FlushIndex()
  243. outputs := engine.Search(types.SearchRequest{Text: "中国人口"})
  244. utils.Expect(t, "2", len(outputs.Tokens))
  245. utils.Expect(t, "中国", outputs.Tokens[0])
  246. utils.Expect(t, "人口", outputs.Tokens[1])
  247. utils.Expect(t, "3", len(outputs.Docs))
  248. utils.Expect(t, "1", outputs.Docs[0].DocId)
  249. utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
  250. utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
  251. utils.Expect(t, "2", outputs.Docs[1].DocId)
  252. utils.Expect(t, "100", int(outputs.Docs[1].Scores[0]*1000))
  253. utils.Expect(t, "[0 15]", outputs.Docs[1].TokenSnippetLocations)
  254. utils.Expect(t, "0", outputs.Docs[2].DocId)
  255. utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
  256. utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
  257. }