Files
mindoc/vendor/github.com/huichen/wukong/engine/engine.go
2017-04-29 22:06:11 +08:00

447 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package engine
import (
"fmt"
"github.com/huichen/murmur"
"github.com/huichen/sego"
"github.com/huichen/wukong/core"
"github.com/huichen/wukong/storage"
"github.com/huichen/wukong/types"
"github.com/huichen/wukong/utils"
"log"
"os"
"runtime"
"sort"
"strconv"
"sync/atomic"
"time"
)
const (
NumNanosecondsInAMillisecond = 1000000
PersistentStorageFilePrefix = "wukong"
)
type Engine struct {
// 计数器,用来统计有多少文档被索引等信息
numDocumentsIndexed uint64
numDocumentsRemoved uint64
numDocumentsForceUpdated uint64
numIndexingRequests uint64
numRemovingRequests uint64
numForceUpdatingRequests uint64
numTokenIndexAdded uint64
numDocumentsStored uint64
// 记录初始化参数
initOptions types.EngineInitOptions
initialized bool
indexers []core.Indexer
rankers []core.Ranker
segmenter sego.Segmenter
stopTokens StopTokens
dbs []storage.Storage
// 建立索引器使用的通信通道
segmenterChannel chan segmenterRequest
indexerAddDocChannels []chan indexerAddDocumentRequest
indexerRemoveDocChannels []chan indexerRemoveDocRequest
rankerAddDocChannels []chan rankerAddDocRequest
// 建立排序器使用的通信通道
indexerLookupChannels []chan indexerLookupRequest
rankerRankChannels []chan rankerRankRequest
rankerRemoveDocChannels []chan rankerRemoveDocRequest
// 建立持久存储使用的通信通道
persistentStorageIndexDocumentChannels []chan persistentStorageIndexDocumentRequest
persistentStorageInitChannel chan bool
}
func (engine *Engine) Init(options types.EngineInitOptions) {
// 将线程数设置为CPU数
runtime.GOMAXPROCS(runtime.NumCPU())
// 初始化初始参数
if engine.initialized {
log.Fatal("请勿重复初始化引擎")
}
options.Init()
engine.initOptions = options
engine.initialized = true
if !options.NotUsingSegmenter {
// 载入分词器词典
engine.segmenter.LoadDictionary(options.SegmenterDictionaries)
// 初始化停用词
engine.stopTokens.Init(options.StopTokenFile)
}
// 初始化索引器和排序器
for shard := 0; shard < options.NumShards; shard++ {
engine.indexers = append(engine.indexers, core.Indexer{})
engine.indexers[shard].Init(*options.IndexerInitOptions)
engine.rankers = append(engine.rankers, core.Ranker{})
engine.rankers[shard].Init()
}
// 初始化分词器通道
engine.segmenterChannel = make(
chan segmenterRequest, options.NumSegmenterThreads)
// 初始化索引器通道
engine.indexerAddDocChannels = make(
[]chan indexerAddDocumentRequest, options.NumShards)
engine.indexerRemoveDocChannels = make(
[]chan indexerRemoveDocRequest, options.NumShards)
engine.indexerLookupChannels = make(
[]chan indexerLookupRequest, options.NumShards)
for shard := 0; shard < options.NumShards; shard++ {
engine.indexerAddDocChannels[shard] = make(
chan indexerAddDocumentRequest,
options.IndexerBufferLength)
engine.indexerRemoveDocChannels[shard] = make(
chan indexerRemoveDocRequest,
options.IndexerBufferLength)
engine.indexerLookupChannels[shard] = make(
chan indexerLookupRequest,
options.IndexerBufferLength)
}
// 初始化排序器通道
engine.rankerAddDocChannels = make(
[]chan rankerAddDocRequest, options.NumShards)
engine.rankerRankChannels = make(
[]chan rankerRankRequest, options.NumShards)
engine.rankerRemoveDocChannels = make(
[]chan rankerRemoveDocRequest, options.NumShards)
for shard := 0; shard < options.NumShards; shard++ {
engine.rankerAddDocChannels[shard] = make(
chan rankerAddDocRequest,
options.RankerBufferLength)
engine.rankerRankChannels[shard] = make(
chan rankerRankRequest,
options.RankerBufferLength)
engine.rankerRemoveDocChannels[shard] = make(
chan rankerRemoveDocRequest,
options.RankerBufferLength)
}
// 初始化持久化存储通道
if engine.initOptions.UsePersistentStorage {
engine.persistentStorageIndexDocumentChannels =
make([]chan persistentStorageIndexDocumentRequest,
engine.initOptions.PersistentStorageShards)
for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
engine.persistentStorageIndexDocumentChannels[shard] = make(
chan persistentStorageIndexDocumentRequest)
}
engine.persistentStorageInitChannel = make(
chan bool, engine.initOptions.PersistentStorageShards)
}
// 启动分词器
for iThread := 0; iThread < options.NumSegmenterThreads; iThread++ {
go engine.segmenterWorker()
}
// 启动索引器和排序器
for shard := 0; shard < options.NumShards; shard++ {
go engine.indexerAddDocumentWorker(shard)
go engine.indexerRemoveDocWorker(shard)
go engine.rankerAddDocWorker(shard)
go engine.rankerRemoveDocWorker(shard)
for i := 0; i < options.NumIndexerThreadsPerShard; i++ {
go engine.indexerLookupWorker(shard)
}
for i := 0; i < options.NumRankerThreadsPerShard; i++ {
go engine.rankerRankWorker(shard)
}
}
// 启动持久化存储工作协程
if engine.initOptions.UsePersistentStorage {
err := os.MkdirAll(engine.initOptions.PersistentStorageFolder, 0700)
if err != nil {
log.Fatal("无法创建目录", engine.initOptions.PersistentStorageFolder)
}
// 打开或者创建数据库
engine.dbs = make([]storage.Storage, engine.initOptions.PersistentStorageShards)
for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
db, err := storage.OpenStorage(dbPath)
if db == nil || err != nil {
log.Fatal("无法打开数据库", dbPath, ": ", err)
}
engine.dbs[shard] = db
}
// 从数据库中恢复
for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
go engine.persistentStorageInitWorker(shard)
}
// 等待恢复完成
for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
<-engine.persistentStorageInitChannel
}
for {
runtime.Gosched()
if engine.numIndexingRequests == engine.numDocumentsIndexed {
break
}
}
// 关闭并重新打开数据库
for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
engine.dbs[shard].Close()
dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
db, err := storage.OpenStorage(dbPath)
if db == nil || err != nil {
log.Fatal("无法打开数据库", dbPath, ": ", err)
}
engine.dbs[shard] = db
}
for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
go engine.persistentStorageIndexDocumentWorker(shard)
}
}
atomic.AddUint64(&engine.numDocumentsStored, engine.numIndexingRequests)
}
// 将文档加入索引
//
// 输入参数:
// docId 标识文档编号必须唯一docId == 0 表示非法文档(用于强制刷新索引),[1, +oo) 表示合法文档
// data 见DocumentIndexData注释
// forceUpdate 是否强制刷新 cache如果设为 true则尽快添加到索引否则等待 cache 满之后一次全量添加
//
// 注意:
// 1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
// 2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
// 如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
func (engine *Engine) IndexDocument(docId uint64, data types.DocumentIndexData, forceUpdate bool) {
engine.internalIndexDocument(docId, data, forceUpdate)
hash := murmur.Murmur3([]byte(fmt.Sprint("%d", docId))) % uint32(engine.initOptions.PersistentStorageShards)
if engine.initOptions.UsePersistentStorage && docId != 0 {
engine.persistentStorageIndexDocumentChannels[hash] <- persistentStorageIndexDocumentRequest{docId: docId, data: data}
}
}
func (engine *Engine) internalIndexDocument(
docId uint64, data types.DocumentIndexData, forceUpdate bool) {
if !engine.initialized {
log.Fatal("必须先初始化引擎")
}
if docId != 0 {
atomic.AddUint64(&engine.numIndexingRequests, 1)
}
if forceUpdate {
atomic.AddUint64(&engine.numForceUpdatingRequests, 1)
}
hash := murmur.Murmur3([]byte(fmt.Sprint("%d%s", docId, data.Content)))
engine.segmenterChannel <- segmenterRequest{
docId: docId, hash: hash, data: data, forceUpdate: forceUpdate}
}
// 将文档从索引中删除
//
// 输入参数:
// docId 标识文档编号必须唯一docId == 0 表示非法文档(用于强制刷新索引),[1, +oo) 表示合法文档
// forceUpdate 是否强制刷新 cache如果设为 true则尽快删除索引否则等待 cache 满之后一次全量删除
//
// 注意:
// 1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
// 2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
// 如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
func (engine *Engine) RemoveDocument(docId uint64, forceUpdate bool) {
if !engine.initialized {
log.Fatal("必须先初始化引擎")
}
if docId != 0 {
atomic.AddUint64(&engine.numRemovingRequests, 1)
}
if forceUpdate {
atomic.AddUint64(&engine.numForceUpdatingRequests, 1)
}
for shard := 0; shard < engine.initOptions.NumShards; shard++ {
engine.indexerRemoveDocChannels[shard] <- indexerRemoveDocRequest{docId: docId, forceUpdate: forceUpdate}
if docId == 0 {
continue
}
engine.rankerRemoveDocChannels[shard] <- rankerRemoveDocRequest{docId: docId}
}
if engine.initOptions.UsePersistentStorage && docId != 0 {
// 从数据库中删除
hash := murmur.Murmur3([]byte(fmt.Sprint("%d", docId))) % uint32(engine.initOptions.PersistentStorageShards)
go engine.persistentStorageRemoveDocumentWorker(docId, hash)
}
}
// 查找满足搜索条件的文档,此函数线程安全
func (engine *Engine) Search(request types.SearchRequest) (output types.SearchResponse) {
if !engine.initialized {
log.Fatal("必须先初始化引擎")
}
var rankOptions types.RankOptions
if request.RankOptions == nil {
rankOptions = *engine.initOptions.DefaultRankOptions
} else {
rankOptions = *request.RankOptions
}
if rankOptions.ScoringCriteria == nil {
rankOptions.ScoringCriteria = engine.initOptions.DefaultRankOptions.ScoringCriteria
}
// 收集关键词
tokens := []string{}
if request.Text != "" {
querySegments := engine.segmenter.Segment([]byte(request.Text))
for _, s := range querySegments {
token := s.Token().Text()
if !engine.stopTokens.IsStopToken(token) {
tokens = append(tokens, s.Token().Text())
}
}
} else {
for _, t := range request.Tokens {
tokens = append(tokens, t)
}
}
// 建立排序器返回的通信通道
rankerReturnChannel := make(
chan rankerReturnRequest, engine.initOptions.NumShards)
// 生成查找请求
lookupRequest := indexerLookupRequest{
countDocsOnly: request.CountDocsOnly,
tokens: tokens,
labels: request.Labels,
docIds: request.DocIds,
options: rankOptions,
rankerReturnChannel: rankerReturnChannel,
orderless: request.Orderless,
}
// 向索引器发送查找请求
for shard := 0; shard < engine.initOptions.NumShards; shard++ {
engine.indexerLookupChannels[shard] <- lookupRequest
}
// 从通信通道读取排序器的输出
numDocs := 0
rankOutput := types.ScoredDocuments{}
timeout := request.Timeout
isTimeout := false
if timeout <= 0 {
// 不设置超时
for shard := 0; shard < engine.initOptions.NumShards; shard++ {
rankerOutput := <-rankerReturnChannel
if !request.CountDocsOnly {
for _, doc := range rankerOutput.docs {
rankOutput = append(rankOutput, doc)
}
}
numDocs += rankerOutput.numDocs
}
} else {
// 设置超时
deadline := time.Now().Add(time.Nanosecond * time.Duration(NumNanosecondsInAMillisecond*request.Timeout))
for shard := 0; shard < engine.initOptions.NumShards; shard++ {
select {
case rankerOutput := <-rankerReturnChannel:
if !request.CountDocsOnly {
for _, doc := range rankerOutput.docs {
rankOutput = append(rankOutput, doc)
}
}
numDocs += rankerOutput.numDocs
case <-time.After(deadline.Sub(time.Now())):
isTimeout = true
break
}
}
}
// 再排序
if !request.CountDocsOnly && !request.Orderless {
if rankOptions.ReverseOrder {
sort.Sort(sort.Reverse(rankOutput))
} else {
sort.Sort(rankOutput)
}
}
// 准备输出
output.Tokens = tokens
// 仅当CountDocsOnly为false时才充填output.Docs
if !request.CountDocsOnly {
if request.Orderless {
// 无序状态无需对Offset截断
output.Docs = rankOutput
} else {
var start, end int
if rankOptions.MaxOutputs == 0 {
start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
end = len(rankOutput)
} else {
start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
}
output.Docs = rankOutput[start:end]
}
}
output.NumDocs = numDocs
output.Timeout = isTimeout
return
}
// 阻塞等待直到所有索引添加完毕
func (engine *Engine) FlushIndex() {
for {
runtime.Gosched()
if engine.numIndexingRequests == engine.numDocumentsIndexed &&
engine.numRemovingRequests*uint64(engine.initOptions.NumShards) == engine.numDocumentsRemoved &&
(!engine.initOptions.UsePersistentStorage || engine.numIndexingRequests == engine.numDocumentsStored) {
// 保证 CHANNEL 中 REQUESTS 全部被执行完
break
}
}
// 强制更新,保证其为最后的请求
engine.IndexDocument(0, types.DocumentIndexData{}, true)
for {
runtime.Gosched()
if engine.numForceUpdatingRequests*uint64(engine.initOptions.NumShards) == engine.numDocumentsForceUpdated {
return
}
}
}
// 关闭引擎
func (engine *Engine) Close() {
engine.FlushIndex()
if engine.initOptions.UsePersistentStorage {
for _, db := range engine.dbs {
db.Close()
}
}
}
// 从文本hash得到要分配到的shard
func (engine *Engine) getShard(hash uint32) int {
return int(hash - hash/uint32(engine.initOptions.NumShards)*uint32(engine.initOptions.NumShards))
}