Pārlūkot izejas kodu

允许引擎返回符合搜索条件的文档个数

geili 10 gadi atpakaļ
vecāks
revīzija
1d098e6b31

+ 44 - 28
core/ranker.go

@@ -12,6 +12,7 @@ type Ranker struct {
 	lock struct {
 		sync.RWMutex
 		fields map[uint64]interface{}
+		docs   map[uint64]bool
 	}
 	initialized bool
 }
@@ -23,68 +24,83 @@ func (ranker *Ranker) Init() {
 	ranker.initialized = true
 
 	ranker.lock.fields = make(map[uint64]interface{})
+	ranker.lock.docs = make(map[uint64]bool)
 }
 
 // 给某个文档添加评分字段
-func (ranker *Ranker) AddScoringFields(docId uint64, fields interface{}) {
+func (ranker *Ranker) AddDoc(docId uint64, fields interface{}) {
 	if ranker.initialized == false {
 		log.Fatal("排序器尚未初始化")
 	}
 
 	ranker.lock.Lock()
 	ranker.lock.fields[docId] = fields
+	ranker.lock.docs[docId] = true
 	ranker.lock.Unlock()
 }
 
 // 删除某个文档的评分字段
-func (ranker *Ranker) RemoveScoringFields(docId uint64) {
+func (ranker *Ranker) RemoveDoc(docId uint64) {
 	if ranker.initialized == false {
 		log.Fatal("排序器尚未初始化")
 	}
 
 	ranker.lock.Lock()
 	delete(ranker.lock.fields, docId)
+	delete(ranker.lock.docs, docId)
 	ranker.lock.Unlock()
 }
 
 // 给文档评分并排序
 func (ranker *Ranker) Rank(
-	docs []types.IndexedDocument, options types.RankOptions) (outputDocs types.ScoredDocuments) {
+	docs []types.IndexedDocument, options types.RankOptions, countDocsOnly bool) (types.ScoredDocuments, int) {
 	if ranker.initialized == false {
 		log.Fatal("排序器尚未初始化")
 	}
 
 	// 对每个文档评分
+	var outputDocs types.ScoredDocuments
+	numDocs := 0
 	for _, d := range docs {
 		ranker.lock.RLock()
-		fs := ranker.lock.fields[d.DocId]
-		ranker.lock.RUnlock()
-		// 计算评分并剔除没有分值的文档
-		scores := options.ScoringCriteria.Score(d, fs)
-		if len(scores) > 0 {
-			outputDocs = append(outputDocs, types.ScoredDocument{
-				DocId:                 d.DocId,
-				Scores:                scores,
-				TokenSnippetLocations: d.TokenSnippetLocations,
-				TokenLocations:        d.TokenLocations})
+		// 判断doc是否存在
+		if _, ok := ranker.lock.docs[d.DocId]; ok {
+			fs := ranker.lock.fields[d.DocId]
+			ranker.lock.RUnlock()
+			// 计算评分并剔除没有分值的文档
+			scores := options.ScoringCriteria.Score(d, fs)
+			if len(scores) > 0 {
+				if !countDocsOnly {
+					outputDocs = append(outputDocs, types.ScoredDocument{
+						DocId:                 d.DocId,
+						Scores:                scores,
+						TokenSnippetLocations: d.TokenSnippetLocations,
+						TokenLocations:        d.TokenLocations})
+				}
+				numDocs++
+			}
+		} else {
+			ranker.lock.RUnlock()
 		}
 	}
 
 	// 排序
-	if options.ReverseOrder {
-		sort.Sort(sort.Reverse(outputDocs))
-	} else {
-		sort.Sort(outputDocs)
-	}
-
-	// 当用户要求只返回部分结果时返回部分结果
-	var start, end int
-	if options.MaxOutputs != 0 {
-		start = utils.MinInt(options.OutputOffset, len(outputDocs))
-		end = utils.MinInt(options.OutputOffset+options.MaxOutputs, len(outputDocs))
-	} else {
-		start = utils.MinInt(options.OutputOffset, len(outputDocs))
-		end = len(outputDocs)
+	if !countDocsOnly {
+		if options.ReverseOrder {
+			sort.Sort(sort.Reverse(outputDocs))
+		} else {
+			sort.Sort(outputDocs)
+		}
+		// 当用户要求只返回部分结果时返回部分结果
+		var start, end int
+		if options.MaxOutputs != 0 {
+			start = utils.MinInt(options.OutputOffset, len(outputDocs))
+			end = utils.MinInt(options.OutputOffset+options.MaxOutputs, len(outputDocs))
+		} else {
+			start = utils.MinInt(options.OutputOffset, len(outputDocs))
+			end = len(outputDocs)
+		}
+		return outputDocs[start:end], numDocs
 	}
-	return outputDocs[start:end]
+	return outputDocs, numDocs
 }

+ 24 - 19
core/ranker_test.go

@@ -33,91 +33,96 @@ func (criteria DummyScoringCriteria) Score(
 func TestRankDocument(t *testing.T) {
 	var ranker Ranker
 	ranker.Init()
-	scoredDocs := ranker.Rank([]types.IndexedDocument{
+	ranker.AddDoc(1, DummyScoringFields{})
+	ranker.AddDoc(3, DummyScoringFields{})
+	ranker.AddDoc(4, DummyScoringFields{})
+
+	scoredDocs, _ := ranker.Rank([]types.IndexedDocument{
 		types.IndexedDocument{DocId: 1, BM25: 6},
 		types.IndexedDocument{DocId: 3, BM25: 24},
 		types.IndexedDocument{DocId: 4, BM25: 18},
-	}, types.RankOptions{ScoringCriteria: types.RankByBM25{}})
+	}, types.RankOptions{ScoringCriteria: types.RankByBM25{}}, false)
 	utils.Expect(t, "[3 [24000 ]] [4 [18000 ]] [1 [6000 ]] ", scoredDocsToString(scoredDocs))
 
-	scoredDocs = ranker.Rank([]types.IndexedDocument{
+	scoredDocs, _ = ranker.Rank([]types.IndexedDocument{
 		types.IndexedDocument{DocId: 1, BM25: 6},
 		types.IndexedDocument{DocId: 3, BM25: 24},
 		types.IndexedDocument{DocId: 2, BM25: 0},
 		types.IndexedDocument{DocId: 4, BM25: 18},
-	}, types.RankOptions{ScoringCriteria: types.RankByBM25{}, ReverseOrder: true})
-	utils.Expect(t, "[2 [0 ]] [1 [6000 ]] [4 [18000 ]] [3 [24000 ]] ", scoredDocsToString(scoredDocs))
+	}, types.RankOptions{ScoringCriteria: types.RankByBM25{}, ReverseOrder: true}, false)
+	// doc0因为没有AddDoc所以没有添加进来
+	utils.Expect(t, "[1 [6000 ]] [4 [18000 ]] [3 [24000 ]] ", scoredDocsToString(scoredDocs))
 }
 
 func TestRankWithCriteria(t *testing.T) {
 	var ranker Ranker
 	ranker.Init()
-	ranker.AddScoringFields(1, DummyScoringFields{
+	ranker.AddDoc(1, DummyScoringFields{
 		label:   "label3",
 		counter: 3,
 		amount:  22.3,
 	})
-	ranker.AddScoringFields(2, DummyScoringFields{
+	ranker.AddDoc(2, DummyScoringFields{
 		label:   "label4",
 		counter: 1,
 		amount:  2,
 	})
-	ranker.AddScoringFields(3, DummyScoringFields{
+	ranker.AddDoc(3, DummyScoringFields{
 		label:   "label1",
 		counter: 7,
 		amount:  10.3,
 	})
-	ranker.AddScoringFields(4, DummyScoringFields{
+	ranker.AddDoc(4, DummyScoringFields{
 		label:   "label1",
 		counter: -1,
 		amount:  2.3,
 	})
 
 	criteria := DummyScoringCriteria{}
-	scoredDocs := ranker.Rank([]types.IndexedDocument{
+	scoredDocs, _ := ranker.Rank([]types.IndexedDocument{
 		types.IndexedDocument{DocId: 1, TokenProximity: 6},
 		types.IndexedDocument{DocId: 2, TokenProximity: -1},
 		types.IndexedDocument{DocId: 3, TokenProximity: 24},
 		types.IndexedDocument{DocId: 4, TokenProximity: 18},
-	}, types.RankOptions{ScoringCriteria: criteria})
+	}, types.RankOptions{ScoringCriteria: criteria}, false)
 	utils.Expect(t, "[1 [25300 ]] [3 [17300 ]] [2 [3000 ]] [4 [1300 ]] ", scoredDocsToString(scoredDocs))
 
 	criteria.Threshold = 4
-	scoredDocs = ranker.Rank([]types.IndexedDocument{
+	scoredDocs, _ = ranker.Rank([]types.IndexedDocument{
 		types.IndexedDocument{DocId: 1, TokenProximity: 6},
 		types.IndexedDocument{DocId: 2, TokenProximity: -1},
 		types.IndexedDocument{DocId: 3, TokenProximity: 24},
 		types.IndexedDocument{DocId: 4, TokenProximity: 18},
-	}, types.RankOptions{ScoringCriteria: criteria})
+	}, types.RankOptions{ScoringCriteria: criteria}, false)
 	utils.Expect(t, "[1 [25300 ]] [3 [17300 ]] ", scoredDocsToString(scoredDocs))
 }
 
 func TestRemoveDocument(t *testing.T) {
 	var ranker Ranker
 	ranker.Init()
-	ranker.AddScoringFields(1, DummyScoringFields{
+	ranker.AddDoc(1, DummyScoringFields{
 		label:   "label3",
 		counter: 3,
 		amount:  22.3,
 	})
-	ranker.AddScoringFields(2, DummyScoringFields{
+	ranker.AddDoc(2, DummyScoringFields{
 		label:   "label4",
 		counter: 1,
 		amount:  2,
 	})
-	ranker.AddScoringFields(3, DummyScoringFields{
+	ranker.AddDoc(3, DummyScoringFields{
 		label:   "label1",
 		counter: 7,
 		amount:  10.3,
 	})
-	ranker.RemoveScoringFields(3)
+	ranker.RemoveDoc(3)
 
 	criteria := DummyScoringCriteria{}
-	scoredDocs := ranker.Rank([]types.IndexedDocument{
+	scoredDocs, _ := ranker.Rank([]types.IndexedDocument{
 		types.IndexedDocument{DocId: 1, TokenProximity: 6},
 		types.IndexedDocument{DocId: 2, TokenProximity: -1},
 		types.IndexedDocument{DocId: 3, TokenProximity: 24},
 		types.IndexedDocument{DocId: 4, TokenProximity: 18},
-	}, types.RankOptions{ScoringCriteria: criteria})
+	}, types.RankOptions{ScoringCriteria: criteria}, false)
 	utils.Expect(t, "[1 [25300 ]] [2 [3000 ]] ", scoredDocsToString(scoredDocs))
 }

+ 49 - 36
engine/engine.go

@@ -40,14 +40,14 @@ type Engine struct {
 	dbs        []storage.Storage
 
 	// 建立索引器使用的通信通道
-	segmenterChannel               chan segmenterRequest
-	indexerAddDocumentChannels     []chan indexerAddDocumentRequest
-	rankerAddScoringFieldsChannels []chan rankerAddScoringFieldsRequest
+	segmenterChannel           chan segmenterRequest
+	indexerAddDocumentChannels []chan indexerAddDocumentRequest
+	rankerAddDocChannels       []chan rankerAddDocRequest
 
 	// 建立排序器使用的通信通道
-	indexerLookupChannels             []chan indexerLookupRequest
-	rankerRankChannels                []chan rankerRankRequest
-	rankerRemoveScoringFieldsChannels []chan rankerRemoveScoringFieldsRequest
+	indexerLookupChannels   []chan indexerLookupRequest
+	rankerRankChannels      []chan rankerRankRequest
+	rankerRemoveDocChannels []chan rankerRemoveDocRequest
 
 	// 建立持久存储使用的通信通道
 	persistentStorageIndexDocumentChannels []chan persistentStorageIndexDocumentRequest
@@ -102,21 +102,21 @@ func (engine *Engine) Init(options types.EngineInitOptions) {
 	}
 
 	// 初始化排序器通道
-	engine.rankerAddScoringFieldsChannels = make(
-		[]chan rankerAddScoringFieldsRequest, options.NumShards)
+	engine.rankerAddDocChannels = make(
+		[]chan rankerAddDocRequest, options.NumShards)
 	engine.rankerRankChannels = make(
 		[]chan rankerRankRequest, options.NumShards)
-	engine.rankerRemoveScoringFieldsChannels = make(
-		[]chan rankerRemoveScoringFieldsRequest, options.NumShards)
+	engine.rankerRemoveDocChannels = make(
+		[]chan rankerRemoveDocRequest, options.NumShards)
 	for shard := 0; shard < options.NumShards; shard++ {
-		engine.rankerAddScoringFieldsChannels[shard] = make(
-			chan rankerAddScoringFieldsRequest,
+		engine.rankerAddDocChannels[shard] = make(
+			chan rankerAddDocRequest,
 			options.RankerBufferLength)
 		engine.rankerRankChannels[shard] = make(
 			chan rankerRankRequest,
 			options.RankerBufferLength)
-		engine.rankerRemoveScoringFieldsChannels[shard] = make(
-			chan rankerRemoveScoringFieldsRequest,
+		engine.rankerRemoveDocChannels[shard] = make(
+			chan rankerRemoveDocRequest,
 			options.RankerBufferLength)
 	}
 
@@ -141,8 +141,8 @@ func (engine *Engine) Init(options types.EngineInitOptions) {
 	// 启动索引器和排序器
 	for shard := 0; shard < options.NumShards; shard++ {
 		go engine.indexerAddDocumentWorker(shard)
-		go engine.rankerAddScoringFieldsWorker(shard)
-		go engine.rankerRemoveScoringFieldsWorker(shard)
+		go engine.rankerAddDocWorker(shard)
+		go engine.rankerRemoveDocWorker(shard)
 
 		for i := 0; i < options.NumIndexerThreadsPerShard; i++ {
 			go engine.indexerLookupWorker(shard)
@@ -240,15 +240,14 @@ func (engine *Engine) internalIndexDocument(docId uint64, data types.DocumentInd
 // 输入参数:
 // 	docId	标识文档编号,必须唯一
 //
-// 注意:这个函数仅从排序器中删除文档的自定义评分字段,索引器不会发生变化。所以
-// 你的自定义评分字段必须能够区别评分字段为nil的情况,并将其从排序结果中删除。
+// 注意:这个函数仅从排序器中删除文档,索引器不会发生变化。
 func (engine *Engine) RemoveDocument(docId uint64) {
 	if !engine.initialized {
 		log.Fatal("必须先初始化引擎")
 	}
 
 	for shard := 0; shard < engine.initOptions.NumShards; shard++ {
-		engine.rankerRemoveScoringFieldsChannels[shard] <- rankerRemoveScoringFieldsRequest{docId: docId}
+		engine.rankerRemoveDocChannels[shard] <- rankerRemoveDocRequest{docId: docId}
 	}
 
 	if engine.initOptions.UsePersistentStorage {
@@ -308,11 +307,13 @@ func (engine *Engine) Search(request types.SearchRequest) (output types.SearchRe
 
 	// 生成查找请求
 	lookupRequest := indexerLookupRequest{
+		countDocsOnly:       request.CountDocsOnly,
 		tokens:              tokens,
 		labels:              request.Labels,
 		docIds:              request.DocIds,
 		options:             rankOptions,
-		rankerReturnChannel: rankerReturnChannel}
+		rankerReturnChannel: rankerReturnChannel,
+	}
 
 	// 向索引器发送查找请求
 	for shard := 0; shard < engine.initOptions.NumShards; shard++ {
@@ -320,6 +321,7 @@ func (engine *Engine) Search(request types.SearchRequest) (output types.SearchRe
 	}
 
 	// 从通信通道读取排序器的输出
+	numDocs := 0
 	rankOutput := types.ScoredDocuments{}
 	timeout := request.Timeout
 	isTimeout := false
@@ -327,9 +329,12 @@ func (engine *Engine) Search(request types.SearchRequest) (output types.SearchRe
 		// 不设置超时
 		for shard := 0; shard < engine.initOptions.NumShards; shard++ {
 			rankerOutput := <-rankerReturnChannel
-			for _, doc := range rankerOutput.docs {
-				rankOutput = append(rankOutput, doc)
+			if !request.CountDocsOnly {
+				for _, doc := range rankerOutput.docs {
+					rankOutput = append(rankOutput, doc)
+				}
 			}
+			numDocs += rankerOutput.numDocs
 		}
 	} else {
 		// 设置超时
@@ -337,9 +342,12 @@ func (engine *Engine) Search(request types.SearchRequest) (output types.SearchRe
 		for shard := 0; shard < engine.initOptions.NumShards; shard++ {
 			select {
 			case rankerOutput := <-rankerReturnChannel:
-				for _, doc := range rankerOutput.docs {
-					rankOutput = append(rankOutput, doc)
+				if !request.CountDocsOnly {
+					for _, doc := range rankerOutput.docs {
+						rankOutput = append(rankOutput, doc)
+					}
 				}
+				numDocs += rankerOutput.numDocs
 			case <-time.After(deadline.Sub(time.Now())):
 				isTimeout = true
 				break
@@ -348,23 +356,28 @@ func (engine *Engine) Search(request types.SearchRequest) (output types.SearchRe
 	}
 
 	// 再排序
-	if rankOptions.ReverseOrder {
-		sort.Sort(sort.Reverse(rankOutput))
-	} else {
-		sort.Sort(rankOutput)
+	if !request.CountDocsOnly {
+		if rankOptions.ReverseOrder {
+			sort.Sort(sort.Reverse(rankOutput))
+		} else {
+			sort.Sort(rankOutput)
+		}
 	}
 
 	// 准备输出
 	output.Tokens = tokens
-	var start, end int
-	if rankOptions.MaxOutputs == 0 {
-		start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
-		end = len(rankOutput)
-	} else {
-		start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
-		end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
+	if !request.CountDocsOnly {
+		var start, end int
+		if rankOptions.MaxOutputs == 0 {
+			start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
+			end = len(rankOutput)
+		} else {
+			start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
+			end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
+		}
+		output.Docs = rankOutput[start:end]
 	}
-	output.Docs = rankOutput[start:end]
+	output.NumDocs = numDocs
 	output.Timeout = isTimeout
 	return
 }

+ 24 - 0
engine/engine_test.go

@@ -362,3 +362,27 @@ func TestEngineIndexDocumentWithPersistentStorage(t *testing.T) {
 	engine1.Close()
 	os.RemoveAll("wukong.persistent")
 }
+
+func TestCountDocsOnly(t *testing.T) {
+	var engine Engine
+	engine.Init(types.EngineInitOptions{
+		SegmenterDictionaries: "../testdata/test_dict.txt",
+		DefaultRankOptions: &types.RankOptions{
+			ReverseOrder:    true,
+			OutputOffset:    0,
+			MaxOutputs:      1,
+			ScoringCriteria: &RankByTokenProximity{},
+		},
+		IndexerInitOptions: &types.IndexerInitOptions{
+			IndexType: types.LocationsIndex,
+		},
+	})
+
+	AddDocs(&engine)
+	engine.RemoveDocument(4)
+
+	outputs := engine.Search(types.SearchRequest{Text: "中国人口", CountDocsOnly: true})
+	utils.Expect(t, "0", len(outputs.Docs))
+	utils.Expect(t, "2", len(outputs.Tokens))
+	utils.Expect(t, "2", outputs.NumDocs)
+}

+ 4 - 1
engine/indexer_worker.go

@@ -10,6 +10,7 @@ type indexerAddDocumentRequest struct {
 }
 
 type indexerLookupRequest struct {
+	countDocsOnly       bool
 	tokens              []string
 	labels              []string
 	docIds              []uint64
@@ -48,9 +49,11 @@ func (engine *Engine) indexerLookupWorker(shard int) {
 		}
 
 		rankerRequest := rankerRankRequest{
+			countDocsOnly:       request.countDocsOnly,
 			docs:                docs,
 			options:             request.options,
-			rankerReturnChannel: request.rankerReturnChannel}
+			rankerReturnChannel: request.rankerReturnChannel,
+		}
 		engine.rankerRankChannels[shard] <- rankerRequest
 	}
 }

+ 13 - 11
engine/ranker_worker.go

@@ -4,7 +4,7 @@ import (
 	"github.com/huichen/wukong/types"
 )
 
-type rankerAddScoringFieldsRequest struct {
+type rankerAddDocRequest struct {
 	docId  uint64
 	fields interface{}
 }
@@ -13,20 +13,22 @@ type rankerRankRequest struct {
 	docs                []types.IndexedDocument
 	options             types.RankOptions
 	rankerReturnChannel chan rankerReturnRequest
+	countDocsOnly       bool
 }
 
 type rankerReturnRequest struct {
-	docs types.ScoredDocuments
+	docs    types.ScoredDocuments
+	numDocs int
 }
 
-type rankerRemoveScoringFieldsRequest struct {
+type rankerRemoveDocRequest struct {
 	docId uint64
 }
 
-func (engine *Engine) rankerAddScoringFieldsWorker(shard int) {
+func (engine *Engine) rankerAddDocWorker(shard int) {
 	for {
-		request := <-engine.rankerAddScoringFieldsChannels[shard]
-		engine.rankers[shard].AddScoringFields(request.docId, request.fields)
+		request := <-engine.rankerAddDocChannels[shard]
+		engine.rankers[shard].AddDoc(request.docId, request.fields)
 	}
 }
 
@@ -37,14 +39,14 @@ func (engine *Engine) rankerRankWorker(shard int) {
 			request.options.MaxOutputs += request.options.OutputOffset
 		}
 		request.options.OutputOffset = 0
-		outputDocs := engine.rankers[shard].Rank(request.docs, request.options)
-		request.rankerReturnChannel <- rankerReturnRequest{docs: outputDocs}
+		outputDocs, numDocs := engine.rankers[shard].Rank(request.docs, request.options, request.countDocsOnly)
+		request.rankerReturnChannel <- rankerReturnRequest{docs: outputDocs, numDocs: numDocs}
 	}
 }
 
-func (engine *Engine) rankerRemoveScoringFieldsWorker(shard int) {
+func (engine *Engine) rankerRemoveDocWorker(shard int) {
 	for {
-		request := <-engine.rankerRemoveScoringFieldsChannels[shard]
-		engine.rankers[shard].RemoveScoringFields(request.docId)
+		request := <-engine.rankerRemoveDocChannels[shard]
+		engine.rankers[shard].RemoveDoc(request.docId)
 	}
 }

+ 2 - 2
engine/segmenter_worker.go

@@ -65,8 +65,8 @@ func (engine *Engine) segmenterWorker() {
 			iTokens++
 		}
 		engine.indexerAddDocumentChannels[shard] <- indexerRequest
-		rankerRequest := rankerAddScoringFieldsRequest{
+		rankerRequest := rankerAddDocRequest{
 			docId: request.docId, fields: request.data.Fields}
-		engine.rankerAddScoringFieldsChannels[shard] <- rankerRequest
+		engine.rankerAddDocChannels[shard] <- rankerRequest
 	}
 }

+ 3 - 0
types/search_request.go

@@ -21,6 +21,9 @@ type SearchRequest struct {
 	// 超时,单位毫秒(千分之一秒)。此值小于等于零时不设超时。
 	// 搜索超时的情况下仍有可能返回部分排序结果。
 	Timeout int
+
+	// 设为true时仅统计搜索到的文档个数,不返回具体的文档
+	CountDocsOnly bool
 }
 
 type RankOptions struct {

+ 3 - 0
types/search_response.go

@@ -13,6 +13,9 @@ type SearchResponse struct {
 
 	// 搜索是否超时。超时的情况下也可能会返回部分结果
 	Timeout bool
+
+	// 搜索到的文档个数。注意这是全部文档中满足条件的个数,可能比返回的文档数要大
+	NumDocs int
 }
 
 type ScoredDocument struct {