浏览代码

给悟空引擎添加持久存储(persistent storage)

Hui Chen 12 年之前
父节点
当前提交
64d9394605
共有 5 个文件被更改,包括 281 次插入9 次删除
  1. 86 1
      engine/engine.go
  2. 59 2
      engine/engine_test.go
  3. 84 0
      engine/persistent_storage_worker.go
  4. 42 6
      examples/benchmark.go
  5. 10 0
      types/engine_init_options.go

+ 86 - 1
engine/engine.go

@@ -2,14 +2,17 @@ package engine
 
 import (
 	"fmt"
+	"github.com/cznic/kv"
 	"github.com/huichen/murmur"
 	"github.com/huichen/sego"
 	"github.com/huichen/wukong/core"
 	"github.com/huichen/wukong/types"
 	"github.com/huichen/wukong/utils"
 	"log"
+	"os"
 	"runtime"
 	"sort"
+	"strconv"
 	"sync/atomic"
 	"time"
 )
@@ -23,6 +26,7 @@ type Engine struct {
 	numDocumentsIndexed uint64
 	numIndexingRequests uint64
 	numTokenIndexAdded  uint64
+	numDocumentsStored  uint64
 
 	// 记录初始化参数
 	initOptions types.EngineInitOptions
@@ -32,6 +36,7 @@ type Engine struct {
 	rankers    []core.Ranker
 	segmenter  sego.Segmenter
 	stopTokens StopTokens
+	dbs        []*kv.DB
 
 	// 建立索引器使用的通信通道
 	segmenterChannel               chan segmenterRequest
@@ -42,6 +47,10 @@ type Engine struct {
 	indexerLookupChannels             []chan indexerLookupRequest
 	rankerRankChannels                []chan rankerRankRequest
 	rankerRemoveScoringFieldsChannels []chan rankerRemoveScoringFieldsRequest
+
+	// 建立持久存储使用的通信通道
+	persistentStorageIndexDocumentChannel chan persistentStorageIndexDocumentRequest
+	persistentStorageInitChannel          chan bool
 }
 
 func (engine *Engine) Init(options types.EngineInitOptions) {
@@ -108,6 +117,15 @@ func (engine *Engine) Init(options types.EngineInitOptions) {
 			options.RankerBufferLength)
 	}
 
+	// 初始化持久化存储通道
+	if engine.initOptions.UsePersistentStorage {
+		engine.persistentStorageIndexDocumentChannel = make(
+			chan persistentStorageIndexDocumentRequest,
+			engine.initOptions.PersistentStorageShards)
+		engine.persistentStorageInitChannel = make(
+			chan bool, engine.initOptions.PersistentStorageShards)
+	}
+
 	// 启动分词器
 	for iThread := 0; iThread < options.NumSegmenterThreads; iThread++ {
 		go engine.segmenterWorker()
@@ -126,6 +144,47 @@ func (engine *Engine) Init(options types.EngineInitOptions) {
 			go engine.rankerRankWorker(shard)
 		}
 	}
+
+	// 启动持久化存储工作协程
+	if engine.initOptions.UsePersistentStorage {
+		err := os.MkdirAll(engine.initOptions.PersistentStorageFolder, 0700)
+		if err != nil {
+			log.Fatal("无法创建目录", engine.initOptions.PersistentStorageFolder)
+		}
+
+		// 打开或者创建数据库
+		engine.dbs = make([]*kv.DB, engine.initOptions.PersistentStorageShards)
+		for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
+			dbPath := engine.initOptions.PersistentStorageFolder + "/persist." + strconv.Itoa(shard) + "-of-" + strconv.Itoa(engine.initOptions.PersistentStorageShards)
+			db, err := utils.OpenOrCreateKv(dbPath, &kv.Options{})
+			if db == nil || err != nil {
+				log.Fatal("无法打开数据库", dbPath, ": ", err)
+			}
+			engine.dbs[shard] = db
+		}
+
+		// 从数据库中恢复
+		for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
+			go engine.persistentStorageInitWorker(shard)
+		}
+
+		// 等待恢复完成
+		for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
+			<-engine.persistentStorageInitChannel
+		}
+		for {
+			runtime.Gosched()
+			if engine.numIndexingRequests == engine.numDocumentsIndexed {
+				break
+			}
+		}
+
+		for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
+			go engine.persistentStorageIndexDocumentWorker(shard)
+		}
+	}
+
+	atomic.AddUint64(&engine.numDocumentsStored, engine.numIndexingRequests)
 }
 
 // 将文档加入索引
@@ -139,6 +198,13 @@ func (engine *Engine) Init(options types.EngineInitOptions) {
 // 	2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
 //         如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
 func (engine *Engine) IndexDocument(docId uint64, data types.DocumentIndexData) {
+	engine.internalIndexDocument(docId, data)
+	if engine.initOptions.UsePersistentStorage {
+		engine.persistentStorageIndexDocumentChannel <- persistentStorageIndexDocumentRequest{docId: docId, data: data}
+	}
+}
+
+func (engine *Engine) internalIndexDocument(docId uint64, data types.DocumentIndexData) {
 	if !engine.initialized {
 		log.Fatal("必须先初始化引擎")
 	}
@@ -164,13 +230,22 @@ func (engine *Engine) RemoveDocument(docId uint64) {
 	for shard := 0; shard < engine.initOptions.NumShards; shard++ {
 		engine.rankerRemoveScoringFieldsChannels[shard] <- rankerRemoveScoringFieldsRequest{docId: docId}
 	}
+
+	if engine.initOptions.UsePersistentStorage {
+		// 从数据库中删除
+		for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
+			go engine.persistentStorageRemoveDocumentWorker(docId, shard)
+		}
+	}
 }
 
 // 阻塞等待直到所有索引添加完毕
 func (engine *Engine) FlushIndex() {
 	for {
 		runtime.Gosched()
-		if engine.numIndexingRequests == engine.numDocumentsIndexed {
+		if engine.numIndexingRequests == engine.numDocumentsIndexed &&
+			(!engine.initOptions.UsePersistentStorage ||
+				engine.numIndexingRequests == engine.numDocumentsStored) {
 			return
 		}
 	}
@@ -275,6 +350,16 @@ func (engine *Engine) Search(request types.SearchRequest) (output types.SearchRe
 	return
 }
 
+// 关闭引擎
+func (engine *Engine) Close() {
+	engine.FlushIndex()
+	if engine.initOptions.UsePersistentStorage {
+		for _, db := range engine.dbs {
+			db.Close()
+		}
+	}
+}
+
 // 从文本hash得到要分配到的shard
 func (engine *Engine) getShard(hash uint32) int {
 	return int(hash - hash/uint32(engine.initOptions.NumShards)*uint32(engine.initOptions.NumShards))

+ 59 - 2
engine/engine_test.go

@@ -1,14 +1,16 @@
 package engine
 
 import (
+	"encoding/gob"
 	"github.com/huichen/wukong/types"
 	"github.com/huichen/wukong/utils"
+	"os"
 	"reflect"
 	"testing"
 )
 
 type ScoringFields struct {
-	a, b, c float32
+	A, B, C float32
 }
 
 func AddDocs(engine *Engine) {
@@ -145,7 +147,7 @@ func (criteria TestScoringCriteria) Score(
 		return []float32{}
 	}
 	fs := fields.(ScoringFields)
-	return []float32{float32(doc.TokenProximity)*fs.a + fs.b*fs.c}
+	return []float32{float32(doc.TokenProximity)*fs.A + fs.B*fs.C}
 }
 
 func TestSearchWithCriteria(t *testing.T) {
@@ -305,3 +307,58 @@ func TestEngineIndexDocumentWithTokens(t *testing.T) {
 	utils.Expect(t, "76", int(outputs.Docs[2].Scores[0]*1000))
 	utils.Expect(t, "[0 18]", outputs.Docs[2].TokenSnippetLocations)
 }
+
+func TestEngineIndexDocumentWithPersistentStorage(t *testing.T) {
+	gob.Register(ScoringFields{})
+	var engine Engine
+	engine.Init(types.EngineInitOptions{
+		SegmenterDictionaries: "../testdata/test_dict.txt",
+		DefaultRankOptions: &types.RankOptions{
+			OutputOffset:    0,
+			MaxOutputs:      10,
+			ScoringCriteria: &RankByTokenProximity{},
+		},
+		IndexerInitOptions: &types.IndexerInitOptions{
+			IndexType: types.LocationsIndex,
+		},
+		UsePersistentStorage:    true,
+		PersistentStorageFolder: "wukong.persistent",
+		PersistentStorageShards: 2,
+	})
+	AddDocs(&engine)
+	engine.RemoveDocument(4)
+	engine.Close()
+
+	var engine1 Engine
+	engine1.Init(types.EngineInitOptions{
+		SegmenterDictionaries: "../testdata/test_dict.txt",
+		DefaultRankOptions: &types.RankOptions{
+			OutputOffset:    0,
+			MaxOutputs:      10,
+			ScoringCriteria: &RankByTokenProximity{},
+		},
+		IndexerInitOptions: &types.IndexerInitOptions{
+			IndexType: types.LocationsIndex,
+		},
+		UsePersistentStorage:    true,
+		PersistentStorageFolder: "wukong.persistent",
+		PersistentStorageShards: 2,
+	})
+
+	outputs := engine1.Search(types.SearchRequest{Text: "中国人口"})
+	utils.Expect(t, "2", len(outputs.Tokens))
+	utils.Expect(t, "中国", outputs.Tokens[0])
+	utils.Expect(t, "人口", outputs.Tokens[1])
+	utils.Expect(t, "2", len(outputs.Docs))
+
+	utils.Expect(t, "1", outputs.Docs[0].DocId)
+	utils.Expect(t, "1000", int(outputs.Docs[0].Scores[0]*1000))
+	utils.Expect(t, "[0 6]", outputs.Docs[0].TokenSnippetLocations)
+
+	utils.Expect(t, "0", outputs.Docs[1].DocId)
+	utils.Expect(t, "76", int(outputs.Docs[1].Scores[0]*1000))
+	utils.Expect(t, "[0 18]", outputs.Docs[1].TokenSnippetLocations)
+
+	engine1.Close()
+	os.RemoveAll("wukong.persistent")
+}

+ 84 - 0
engine/persistent_storage_worker.go

@@ -0,0 +1,84 @@
+package engine
+
+import (
+	"bytes"
+	"encoding/binary"
+	"encoding/gob"
+	"github.com/huichen/wukong/types"
+	"io"
+	"log"
+	"sync/atomic"
+)
+
+type persistentStorageIndexDocumentRequest struct {
+	docId uint64
+	data  types.DocumentIndexData
+}
+
+func (engine *Engine) persistentStorageIndexDocumentWorker(shard int) {
+	for {
+		request := <-engine.persistentStorageIndexDocumentChannel
+
+		// 得到key
+		b := make([]byte, 8)
+		length := binary.PutUvarint(b, request.docId)
+
+		// 得到value
+		var buf bytes.Buffer
+		enc := gob.NewEncoder(&buf)
+		err := enc.Encode(request.data)
+		if err != nil {
+			atomic.AddUint64(&engine.numDocumentsStored, 1)
+			continue
+		}
+
+		// 将key-value写入数据库
+		engine.dbs[shard].Set(b[0:length], buf.Bytes())
+		atomic.AddUint64(&engine.numDocumentsStored, 1)
+	}
+}
+
+func (engine *Engine) persistentStorageRemoveDocumentWorker(docId uint64, shard int) {
+	// 得到key
+	b := make([]byte, 8)
+	length := binary.PutUvarint(b, docId)
+
+	// 从数据库删除该key
+	engine.dbs[shard].Delete(b[0:length])
+}
+
+func (engine *Engine) persistentStorageInitWorker(shard int) {
+	iter, err := engine.dbs[shard].SeekFirst()
+	if err == io.EOF {
+		engine.persistentStorageInitChannel <- true
+		return
+	} else if err != nil {
+		engine.persistentStorageInitChannel <- true
+		log.Fatal("无法遍历数据库")
+	}
+
+	for {
+		key, value, err := iter.Next()
+		if err == io.EOF {
+			break
+		} else if err != nil {
+			continue
+		}
+
+		// 得到docID
+		docId, _ := binary.Uvarint(key)
+
+		// 得到data
+		buf := bytes.NewReader(value)
+		dec := gob.NewDecoder(buf)
+		var data types.DocumentIndexData
+		err = dec.Decode(&data)
+		if err != nil {
+			continue
+		}
+
+		// 添加索引
+		engine.internalIndexDocument(docId, data)
+	}
+	engine.persistentStorageInitChannel <- true
+}

+ 42 - 6
examples/benchmark.go

@@ -36,10 +36,13 @@ var (
 		"stop_token_file",
 		"../data/stop_tokens.txt",
 		"停用词文件")
-	cpuprofile      = flag.String("cpuprofile", "", "处理器profile文件")
-	memprofile      = flag.String("memprofile", "", "内存profile文件")
-	num_repeat_text = flag.Int("num_repeat_text", 10, "文本重复加入多少次")
-	index_type      = flag.Int("index_type", types.DocIdsIndex, "索引类型")
+	cpuprofile                = flag.String("cpuprofile", "", "处理器profile文件")
+	memprofile                = flag.String("memprofile", "", "内存profile文件")
+	num_repeat_text           = flag.Int("num_repeat_text", 10, "文本重复加入多少次")
+	index_type                = flag.Int("index_type", types.DocIdsIndex, "索引类型")
+	use_persistent            = flag.Bool("use_persistent", false, "是否使用持久存储")
+	persistent_storage_folder = flag.String("persistent_storage_folder", "benchmark.persistent", "持久存储数据库保存的目录")
+	persistent_storage_shards = flag.Int("persistent_storage_shards", 0, "持久数据库存储裂分数目")
 
 	searcher = engine.Engine{}
 	options  = types.RankOptions{
@@ -59,15 +62,21 @@ func main() {
 	log.Printf("待搜索的关键词为\"%s\"", searchQueries)
 
 	// 初始化
+	tBeginInit := time.Now()
 	searcher.Init(types.EngineInitOptions{
 		SegmenterDictionaries: *dictionaries,
 		StopTokenFile:         *stop_token_file,
 		IndexerInitOptions: &types.IndexerInitOptions{
 			IndexType: *index_type,
 		},
-		NumShards:          NumShards,
-		DefaultRankOptions: &options,
+		NumShards:               NumShards,
+		DefaultRankOptions:      &options,
+		UsePersistentStorage:    *use_persistent,
+		PersistentStorageFolder: *persistent_storage_folder,
+		PersistentStorageShards: *persistent_storage_shards,
 	})
+	tEndInit := time.Now()
+	defer searcher.Close()
 
 	// 打开将要搜索的文件
 	file, err := os.Open(*weibo_data)
@@ -164,6 +173,33 @@ func main() {
 	log.Printf("搜索吞吐量每秒 %v 次查询",
 		float64(numRepeatQuery*numQueryThreads*len(searchQueries))/
 			t3.Sub(t2).Seconds())
+
+	if *use_persistent {
+		searcher.Close()
+		t4 := time.Now()
+		searcher1 := engine.Engine{}
+		searcher1.Init(types.EngineInitOptions{
+			SegmenterDictionaries: *dictionaries,
+			StopTokenFile:         *stop_token_file,
+			IndexerInitOptions: &types.IndexerInitOptions{
+				IndexType: *index_type,
+			},
+			NumShards:               NumShards,
+			DefaultRankOptions:      &options,
+			UsePersistentStorage:    *use_persistent,
+			PersistentStorageFolder: *persistent_storage_folder,
+			PersistentStorageShards: *persistent_storage_shards,
+		})
+		defer searcher1.Close()
+		t5 := time.Now()
+		t := t5.Sub(t4).Seconds() - tEndInit.Sub(tBeginInit).Seconds()
+		log.Print("从持久存储加入的索引总数", searcher1.NumTokenIndexAdded())
+		log.Printf("从持久存储建立索引花费时间 %v", t)
+		log.Printf("从持久存储建立索引速度每秒添加 %f 百万个索引",
+			float64(searcher1.NumTokenIndexAdded())/t/(1000000))
+
+	}
+	os.RemoveAll(*persistent_storage_folder)
 }
 
 func search(ch chan bool) {

+ 10 - 0
types/engine_init_options.go

@@ -24,6 +24,7 @@ var (
 		K1: 2.0,
 		B:  0.75,
 	}
+	defaultPersistentStorageShards = runtime.NumCPU()
 )
 
 type EngineInitOptions struct {
@@ -58,6 +59,11 @@ type EngineInitOptions struct {
 
 	// 默认的搜索选项
 	DefaultRankOptions *RankOptions
+
+	// 是否使用持久数据库,以及数据库文件保存的目录和裂分数目
+	UsePersistentStorage bool
+	PersistentStorageFolder string
+	PersistentStorageShards int
 }
 
 // 初始化EngineInitOptions,当用户未设定某个选项的值时用默认值取代
@@ -105,4 +111,8 @@ func (options *EngineInitOptions) Init() {
 	if options.DefaultRankOptions.ScoringCriteria == nil {
 		options.DefaultRankOptions.ScoringCriteria = defaultDefaultRankOptions.ScoringCriteria
 	}
+
+	if options.PersistentStorageShards == 0 {
+		options.PersistentStorageShards = defaultPersistentStorageShards
+	}
 }