码迷,mamicode.com
首页 > 其他好文 > 详细

[Project] SpellCorrect源码详解

时间:2015-08-18 21:04:42      阅读:117      评论:0      收藏:0      [点我收藏+]

标签:

该Project原来的应用场景是对电商网站中输入一个错误的商品名称进行智能纠错,比如iphoae纠错为iphone。以下介绍的这个版本对其作了简化,项目源代码地址参见我的github:https://github.com/jianxinzhou/MyProject_1/tree/uint32

该Project的主要思想是利用字符串编辑距离来实现拼写纠错。每当客户端来一个查询词,服务器返回与其编辑距离在2以内的单词中词频最高的那个单词。以下是对该项目的简要介绍与分析,具体代码仍以github中的为准。

0. 项目技术以及网络框架

项目技术:UDP通讯,线程池框架,编辑距离算法,倒排索引,cache优化

网络框架:客户端通过UDP数据报向服务器发送查询请求,服务器收到请求后将查询词与客户端地址打包成task扔进任务队列,线程池中的工作线程从任务队列中取出并执行任务,最后计算结果由工作线程返回。

项目遇到的难点:utf-8存储下的单词之间的编辑距离计算;计算查询词与词频词典中的每个单词的编辑距离的时间总和过长,导致查询速度太慢;硬盘cache与工作线程内存cache的同步。

1. 字符编码问题

UTF-8是Unicode一种的实现方式。UTF-8是一种变长的编码方式,使用1~4个字节表示一个符号,根据不同的符号而变化字节长度。

UTF-8的编码规则很简单,只有二条:

1)对于单字节的符号,字节的第一位设为0,后面7位为这个符号的unicode码。因此对于英语字母,UTF-8编码和ASCII码是相同的。

2)对于n字节的符号(n>1),第一个字节的前n位都设为1,第n+1位设为0,后面字节的前两位一律设为10。剩下的没有提及的二进制位,全部为这个符号的unicode码。

技术分享

在项目中,词频文件采用的文件格式为:word /t frequence /t,示例如下:

apple            789

iphone          60000

手机              80000

问题在于词频文件采用UTF-8格式存储,如果用一个string来存一个单词的话,就没有办法正确计算查询词与词库中单词的编辑距离。

使用string来存储单词,只可以正确计算两个英文单词之间的编辑距离,因为UTF-8格式下,英文仍然是采用一个字节(char)来存储一个字母。但是,对于由汉字构成的单词是无法正确计算的,因为一个汉字可以占1~4个字节。

为了正确计算编辑距离,同时考虑到UTF-8使用1~4个字节表示一个符号,因此我想到的办法是统一将单词中的一个字母(或者中文的汉字)使用uint32_t来存储,而一个单词可以使用vector<uint32_t>来存储。

string转换成vector<uint32_t>

那么如何将单词由string转换成vector<uint32_t>来存储呢?为了便于下文叙述,此处统一将一个字母或者汉字记为letter。

只需遍历string即可,由UTF-8编码规则可知,要知道一个letter由几个字节构成,只要看组成这个letter的第一个字节低地址(左边)有几个连续的1即可(没有1,就是一个字节表示;有2个1,2字节表示;3个1,3字节表示;4个1,4字节表示),之后根据字节数量信息将其拼接成一个uint32_t即可。代码如下:

    // 计算UTF8编码所占的字节
    int getLenOfUTF8(unsigned char c)
    {
        int cnt = 0;
        while(c & (1 << (7-cnt)))
            ++cnt;
        return cnt; 
    }

    // 每个vector代表一个word
    // 把字符串解析成uint32_t数组
    void parseUTF8String(const std::string &s, std::vector<uint32_t> &vec)
    {
        vec.clear();
        for(std::string::size_type ix = 0; ix < s.size(); ++ix)
        {
            int len = getLenOfUTF8(s[ix]);
            uint32_t t = (unsigned char)s[ix]; /*e5*/
            if(len > 1)
            {
                --len;
                /*拼接剩余的字节*/
                while(len--)
                {
                    t = (t << 8) + (unsigned char)s[++ix];
                }
            }
            vec.push_back(t);
        }
    }

调整后的计算单词编辑距离的方法

代码如下:

int edit_distance_uint_32(const std::vector<uint32_t> &w1, const std::vector<uint32_t> &w2) 
{
    int len_a = w1.size();
    int len_b = w2.size();
    int memo[100][100];
    memset(memo, 0x00, 100 * 100 * sizeof(int));
    for (int i = 1; i <= len_a; ++i) 
    {
        memo[i][0] = i;
    }
    for (int j = 1; j <= len_b; ++j) 
    {
        memo[0][j] = j;
    }
    for (int i = 1; i <= len_a; ++i) 
    {
        for (int j = 1; j <= len_b; ++j) 
        {
            if (w1[i - 1] == w2[j - 1]) 
            {
                memo[i][j] = memo[i - 1][j - 1];
            } 
            else 
            {
                memo[i][j] = MIN(memo[i - 1][j - 1], memo[i][j - 1],memo[i - 1][j]) + 1;
            }
        }
    }
    return memo[len_a][len_b];
}

2 配置文件

配置文件内容如下:

/* SpellCorrect.conf */

my_ip   192.168.153.131
my_port 5080
my_dict /home/purple/SpellCheck/data/dict.dat
my_cache /home/purple/SpellCheck/data/cache.dat

包含了SpellCorrect服务器的IP地址、端口、数据词典以及缓存。

3 main函数

主要内容下面用了比较详细的注释,总体思想是服务器在某一个端口上接收来自用户的udp数据报请求,之后将用户的查询词和地址(IP和端口)打包封装成一个task,并将该task扔进线程池中的任务队列中(扮演生产者的角色),由工作线程(扮演消费者角色)负责从任务队列中取出任务并执行,并将执行结果send回给客户端。代码如下:

/* main.cc */

#include "ThreadPool.h"
#include "MySocket.h"
#include "MyConf.h"
#include "MyCache.h"

int main(int argc, char* argv[])
{
    /* 初始化配置文件, MyConf类会根据读入的配置文件,生成词频词典以及倒排索引等,如下所示 */
    //    std::map<std::string, std::string> mapConf_ ;                     // 配置文件
    //    std::vector<std::pair<std::string, int> > strDict_;               // 原始词频词典
    //    std::vector< std::pair<std::vector<uint32_t>, int> > vecDict_;    // 经转换后的词频字典
    //    std::map<uint32_t, std::set<int> > mapIndex_;                     // 倒排索引(将包含这个letter的单词所在vector的下标放入set中)
    MyConf conf(argv[1]); 

    
    /* 初始化线程池对象,线程池对象将持有以下内容 */
    //    MyConf &conf_;                        // 配置对象的引用 
                                                // 线程池需要持有配置文件对象的引用,因为:
                                                // 配置文件中拥有硬盘cache的地址,工作线程启动时,需要将硬盘cache读入每个工作线程自身的内存cache中    
    
    //    std::vector<MyThread> vecThreads_ ;   // 存放工作线程的容器
    //    std::queue<MyTask>    queueTasks_ ;   // 存放任务的队列
        
    //    MyLock queueTaskslock_ ;              // 用于工作线程之间同步的互斥锁         
    //    MyCondition queueTasksCond_ ;         // 用于工作线程之间同步的条件变量
        
    //    bool isStarted_ ;                     // 用于标识线程池是否开启的变量
        
    //    MyCacheThread cacheThread_ ;          // 定时扫描内存cache的线程
    ThreadPool apool(conf) ;    
    
    /* 初始化用于UDP通信的socket对象 */
    //    int peerfd_ ;               // 用于标识socket的描述符
    //    struct sockaddr_in addr_ ;  // 用于保存服务器端或客户端ip和端口号信息
    //    socklen_t addrLen_ ;        // 用于保存struct sockaddr_in 的长度
    //    MyConf& conf_ ;             // 配置对象的引用(需要该引用,是因为服务器IP地址和端口存放在配置对象中)
    MySocket socket(conf);
    
    /* 开启线程中的工作线程以及cacche扫描线程 */
    apool.on();
    



    const int len = 1024 ;   
    char buf[len];
    int iret ;
    
    // 主循环,不断接收客户端的udp数据报
    while(true) {
        memset(buf, 0, len);
        iret = socket.recv_message(buf, len) ;
        std::cout << "main" << buf <<"len: "<< iret << std::endl ;
        
        // 将客户端的查询词以及地址封装成task放入线程池中的任务队列(生产者)
        // 工作线程将会从任务队列中取出任务执行,执行完直接由工作线程将结果返回给客户端
        MyTask task(buf,socket.get_addr(), conf);
        apool.allocate_task(task);
    }
    
    apool.off();
    return 0 ;
}

4 Task类

任务对象

1. 工作线程从线程池中的任务队列中取出的任务是一个“任务对象”,之后执行任务task.excute(cache_) ;
2. 因此真正的计算逻辑以及返回给客户端的结果,都是task对象进行的
3. 工作线程中持有自身的内存cache,再执行任务时,需要将工作线程内存cache的引用传给任务的excute方法
4. 任务对象在进行计算时,需要词频词典,倒排索引,因此初始化一个任务对象时,需要传入配置文件对象的引用

头文件的关键部分如下:

#ifndef __MYTASK_H__
#define __MYTASK_H__

class MyCache ;

class MyTask
{
    public:
        MyTask( MyConf& conf);
        
        MyTask(const std::string &queryWord, 
               const struct sockaddr_in &addr ,  
               MyConf& conf);
        
        void excute(MyCache& cache) ;      // 执行函数。需要传递一个MyCache对象 。
        
        int length(const std::string& str) // 计算查询词的长度
        {
            int index ;
            int len = 0 ;
            for(index = 0 ; index != str.size(); index ++)
            {
                if(str[index] & (1 << 7))
                {
                    index ++ ;
                }
                len ++ ;
            }
            return len ;
        }
        
        ~MyTask()
        {
            close(peerfd_);
        }
        
        void satistic(std::set<int> & iset ); //计算vecDictPtr_指向的vector中下标在iset中的词与用户输入词的编辑距离 。
    
    
    private:
        std::string queryWord_;                  // 用户的查询词
        std::vector<uint32_t> vecQueryWord_;     // 经过转换后的用户的查询词
        struct sockaddr_in addr_;     // 用于保存用户端地址和端口号
        int peerfd_;                  // 与用户端通信的socket描述符
        
        std::vector<std::pair<std::vector<uint32_t>, int> > *vecDictPtr_;   // 指向保存数据词典的指针
        std::vector<std::pair<std::string, int>> *strDictPtr_;
        std::map<uint32_t, std::set<int> >* mapIndexPtr_;                   // 指向保存倒排索引的指针
        
        std::priority_queue<MyResult, std::vector<MyResult>, MyCompare> result_; // 用于保存查询结果的优先级队列
        
        void get_result(); // 根据用户的查询词获取最终结果。最终结果将放在优先级队列里
        int editdistance(const std::vector<uint32_t> &right); // 计算right与用户输入查询词的编辑距离
        
        int triple_min(const int &a, const int &b, const int& c ) // 返回3个数中的最小值
        {
            return a < b ? (a < c ? a : c) : (b < c ? b : c) ;
        }

};

#endif /* MyTask.h */

source文件如下:

#include "MyTask.h"

//匿名命名空间,存放一些辅助函数,用于将string格式的查询词转换为vector<uint32_t>来存储,以便正确计算编辑距离
namespace
{

int getLenOfUTF8(unsigned char c)
{
    int cnt = 0;
    while(c & (1 << (7-cnt)))
        ++cnt;
    return cnt; 
}


void parseStringToUTF8(const std::string &s, std::vector<uint32_t> &vec)
{
    vec.clear();
    for(std::string::size_type ix = 0; ix < s.size(); ++ix)
    {
        int len = getLenOfUTF8(s[ix]);
        uint32_t t = (unsigned char)s[ix]; /*e5*/
        if(len > 1)
        {
            --len; /*2*/
            /*拼接剩余的字节*/
            while(len--)
            {
                t = (t << 8) + (unsigned char)s[++ix];
            }
        }
        vec.push_back(t);
    }
}

inline int MIN(int a, int b, int c) 
{
    int ret = (a < b) ? a : b;
    ret = (ret < c) ? ret : c;
    return ret;
}

int edit_distance_uint_32(const std::vector<uint32_t> &w1, const std::vector<uint32_t> &w2) 
{
    int len_a = w1.size();
    int len_b = w2.size();
    int memo[100][100];
    memset(memo, 0x00, 100 * 100 * sizeof(int));
    for (int i = 1; i <= len_a; ++i) 
    {
        memo[i][0] = i;
    }
    for (int j = 1; j <= len_b; ++j) 
    {
        memo[0][j] = j;
    }
    for (int i = 1; i <= len_a; ++i) 
    {
        for (int j = 1; j <= len_b; ++j) 
        {
            if (w1[i - 1] == w2[j - 1]) 
            {
                memo[i][j] = memo[i - 1][j - 1];
            } 
            else 
            {
                memo[i][j] = MIN(memo[i - 1][j - 1], memo[i][j - 1],memo[i - 1][j]) + 1;
            }
        }
    }
    return memo[len_a][len_b];
}

}
// end namespace

MyTask::MyTask( MyConf& conf)
    : queryWord_(""),
    strDictPtr_(&(conf.strDict_)),
    vecDictPtr_(&(conf.vecDict_)),
    mapIndexPtr_(&conf.mapIndex_)
{
    memset(&addr_, 0, sizeof(addr_));
}

MyTask::MyTask(const std::string &queryWord, 
        const struct sockaddr_in &addr ,  
        MyConf& conf)
    : queryWord_(queryWord),
      addr_(addr),
      strDictPtr_(&(conf.strDict_)),
      vecDictPtr_(&conf.vecDict_), 
      mapIndexPtr_(&conf.mapIndex_)
{
    parseStringToUTF8(queryWord_, vecQueryWord_);
}

// 执行任务,并将结果发回客户端
void MyTask::excute(MyCache& cache) // cache_通过工作线程传入
{
    peerfd_ = socket(AF_INET, SOCK_DGRAM, 0);
    std::cout << "Task excute" << std::endl ;
    
    std::unordered_map<std::string, std::string>::iterator iter;
    iter =  cache.isMapped(queryWord_);
    // 如果在工作线程中的cache_中可以找到,那么直接返回
    if(iter != cache.hashmap_.end())
    {
        std::cout << " cached "  << std::endl;
        int iret = sendto(peerfd_, (iter -> second).c_str(), 
                (iter -> second).size(), 0, 
                (struct sockaddr*)&addr_, sizeof(addr_));
        std::cout <<"send: " << iret << std::endl ;
    }
    else // 否则在词频词典中进行计算后,返回最佳匹配的单词
    {
        std::cout << " no cached " << std::endl ;
        get_result();
        //std::cout << inet_ntoa(m_addr.sin_addr) << std::endl ;
        if(result_.empty())
        {
            std::string res = "no anwser !" ;
            int iret = sendto(peerfd_, res.c_str(), 
                    res.size(), 0, 
                    (struct sockaddr*)&addr_, sizeof(addr_));
            std::cout <<"send: " << iret << std::endl;
        }
        else 
        {
            MyResult res = result_.top();
            int iret = sendto(peerfd_, res.word_.c_str(), 
                       res.word_.size(), 0, 
                       (struct sockaddr*)&addr_, sizeof(addr_));
            std::cout <<"send:" << iret << std::endl ;
            cache.map_to_cache(queryWord_, res.word_);            // 注意:需要更新当前工作线程的cache
        }
    }
}


// 遍历查询词的每一个letter,经由倒排索引,统计出编辑距离小于3的单词放入优先级队列result_中
void MyTask::get_result()
{
    uint32_t ch ;
    int index ;
    for(index = 0 ; index != vecQueryWord_.size(); index ++ )
    {
        ch = vecQueryWord_[index];
        if( ( *mapIndexPtr_ ).count(ch) )
        {
            std::cout << "map_ cout return true " << std::endl ;
            statistic( (*mapIndexPtr_)[ch] ) ;
        }
    } 
}

// 传入参数为相应letter对应的单词在所在vector中的下标结合
// 将这些单词中,编辑距离与查询词在3以内的单词放入优先级队列result_中
void MyTask::statistic(std::set<int> & iset)
{
    std::set<int>::iterator iter ;
    for( iter = iset.begin() ;  iter != iset.end() ;  iter ++)
    {
        int dist = editdistance(  ((*vecDictPtr_)[ *iter ]).first  );
        if(dist < 3)
        {
            MyResult res ;
            res.word_ = ((*strDictPtr_)[ *iter ]).first ;
            res.distance_ = dist ;
            res.frequence_ = ((*vecDictPtr_)[ *iter ]).second ; 
            result_.push( res );
        }
    }
    
}

// 计算编辑距离
int MyTask::editdistance(const std::vector<uint32_t> &right) 
{
    return edit_distance_uint_32(vecQueryWord_, right);
}

5 线程池、工作线程、扫描线程

普通线程

为了实现复用,在封装工作线程之前,我们先封装一个普通的线程,之后的工作线程以及扫描线程只需要继承该普通线程即可。代码如下:

#ifndef __THREAD_H__
#define __THREAD_H__
#include <iostream>
#include <stdlib.h>

class Thread
{
    public:
        Thread()
            :threadId_(0),isRunning_(false)
        {
            if(pthread_attr_init(&threadAttr))
            {
                std::cout << __DATE__ << " " << __TIME__ << " " 
                          << __FILE__ << " " << __LINE__ << ":" 
                          << "pthread_attr_init" << std::endl;
                exit(-1) ;
            }
        }
        
        ~Thread()
        {
            pthread_attr_destroy(&threadAttr);
        }
        
        void start(void* arg = NULL)
        {
            if(isRunning_)
                return;
            
            isRunning_ = true ;
            
            // 将线程设置为detach
            if(pthread_attr_setdetachstate(&threadAttr, PTHREAD_CREATE_DETACHED))
            {
                std::cout << __DATE__ << " " << __TIME__ << " " 
                          << __FILE__ << " " << __LINE__ << ":" 
                          << "pthread_attr_setdetachstate" << std::endl ;
                exit(-1) ;
            }
            
            // 创建线程
            if(pthread_create(&threadId_, &threadAttr, Thread::runInThread, this))
            {
                std::cout << __DATE__ << " " << __TIME__ << " " 
                          << __FILE__ << " " << __LINE__ << ":" 
                          << "pthread_create" << std::endl;
                exit(-1) ;
            }
        }
    

    private:
        static void* runInThread(void* arg)
        {
            Thread* p = (Thread*)arg;
            p -> run();
            
            return NULL;
        }
        
        // 在工作线程中只需要重写该函数即可实现自己的线程例程
        virtual void run() = 0;
        
        bool isRunning_;
        pthread_t threadId_;
        pthread_attr_t threadAttr;
};


#endif

显然,我们只需要重写run函数,即可实现工作线程和扫描线程的工作。

工作线程

工作线程持有线程池对象的指针,以及自身的内存cache,这是因为:工作线程的任务,就是不断的从线程池的任务队列中取出任务,当工作线程持有线程池对象的指针时,就可以调用线程池对象的get_task方法,然后执行取出的任务。持有自身的内存cache,这就再自然不过了,当工作线程执行查询词匹配时,首先会先从自身的内存cache中进行查找。

头文件如下:

#ifndef __MYTHREAD_H__
#define __MYTHREAD_H__
#include "Thread.h"
#include "MyCache.h"

class ThreadPool;

// 继承抽象类Thread
class MyThread : public Thread 
{
    public:
        void get_related(ThreadPool* p)
        {
            threadPoolPtr_ = p ;
        }
    
    private:
        void run(); // 需要自己实现虚函数
        
        // 由于线程池对象中持有任务队列,工作线程持有线程池对象指针,就可以方便的从线程池对象的任务队列中取出任务执行
        ThreadPool * threadPoolPtr_ ; 
        MyCache cache_ ;
        
        friend class  MyCacheThread ;

};
#endif

源文件如下:

void MyThread::run()
{
    std::cout << "run" << std::endl ;
    // 工作线程刚启动时,会将硬盘中的缓存文件更新到工作线程中的内存cache
    cache_.read_from_file( (threadPoolPtr_ -> conf_).getMapConf()["my_cache"].c_str());
    // 从任务队列中取任务,执行任务。
    while(true)
    {
        MyTask task(threadPoolPtr_-> conf_) ;    // 任务的执行(编辑距离的计算)需要词频词典以及倒排索引
        if(!(threadPoolPtr_ -> get_task(task)) )
        {
            break ;
        }
        task.excute(cache_) ;
    }
}

扫描线程

在介绍扫描线程前,我们先来看看对cache类的封装:

// 数据成员为
// unorderer_map
// 控制互斥访问unordered_map的锁:hashmapLock_
class MyCache
{
    public:
        std::unordered_map<std::string, std::string> hashmap_;
        
        /**
         * 工作线程应该与扫描线程互斥的访问工作线程的内存cache_
         */
        void map_to_cache(std::string& key, std::string& value )
        {
            hashmapLock_.lock();
            hashmap_[key] = value;
            hashmapLock_.unlock();
        }
        
        std::unordered_map<std::string, std::string>::iterator 
        isMapped(const std::string& word)
        {
            hashmapLock_.lock();
            return hashmap_.find(word);
            hashmapLock_.unlock();
        } 
        
        // 将内存cache写入硬盘
        void write_to_file(std::ofstream& outfile)
        {
            hashmapLock_.lock();
            for(std::unordered_map<std::string, std::string>::iterator iter = hashmap_.begin(); 
                iter != hashmap_.end();
                ++iter)
            {
                outfile << iter -> first << "\t" << iter -> second << std::endl ;
            }
            hashmapLock_.unlock();
        }
        
        // 从硬盘cache读入内存
        void read_from_file(const std::string &fileName)
        {
            hashmapLock_.lock();
            std::ifstream infile(fileName.c_str());
            if(!infile)
            {
                std::cout << "cache file: " << fileName << std::endl ; 
                throw std::runtime_error("open cache file fail !");
            }
            std::string query , result;
            while(infile >> query >> result)
            {
                hashmap_.insert(std::make_pair(query, result));
            }
            infile.close();
            hashmapLock_.lock();
        }
    
    private:
        MyLock hashmapLock_;

};


#endif

由于扫描线程每过60秒,就会依次同步工作线程的内存cache与硬盘cache,而工作线程在执行任务时同样会访问内存cache,因此必须使得工作线程与扫描线程互斥的访问内存cache。扫描线程代码如下:

头文件

#ifndef __MYCACHETHREAD_H__
#define __MYCACHETHREAD_H__
#include "Thread.h"
#include <vector>
class ThreadPool;
class MyCache;
class MyThread;


class MyCacheThread : public Thread
{
    public:
        MyCacheThread(const int& num = 12)
            : Thread(), vecWorkThreadPtr_(num)
        { }
        
        void get_related(ThreadPool* threadPoolPtr);
    
    private:
        void run() ;
        void scan_cache() ;
        
        ThreadPool* threadPoolPtr_;                // 指向线程池的指针
        std::vector<MyThread*> vecWorkThreadPtr_ ; // 含有指向工作线程指针的vector 。

};
#endif

源文件

#include "MyCacheThread.h"
#include "ThreadPool.h"
#include "MyThread.h"
#include "MyCache.h"
#include <unistd.h>
#include <fstream>

void MyCacheThread::run()
{
    while(true)
    {
        sleep(60);
        scan_cache();   
        std::cout << "scan cache" << std::endl ;
    }
}

// 线程池对象中拥有一个扫描线程对象
// 线程池初始化时,会调用该函数,使该扫描线程对象持有指向线程池对象的指针
// 并使该扫描线程对象持有线程池中所有工作对象的指针
void MyCacheThread::get_related(ThreadPool* threadPoolPtr)
{
    threadPoolPtr_ = threadPoolPtr;
    
    std::vector<MyThread>::iterator  iter1  = (threadPoolPtr_ -> vecThreads_).begin();
    std::vector<MyThread*>::iterator iter2  = vecWorkThreadPtr_.begin() ;
    
    while(iter2 != vecWorkThreadPtr_.end() && 
          iter1 != (threadPoolPtr_ -> vecThreads_).end() )
    {
        *iter2 = &(*iter1);
        iter1++;
        iter2++;
    }
}

// 同步每一个工作线程的内存cache与硬盘cache
void MyCacheThread::scan_cache()
{
    std::vector<MyThread*>::iterator iter = vecWorkThreadPtr_.begin();
    for(; 
        iter != vecWorkThreadPtr_.end(); 
        ++iter)
    {
        ( (*iter) -> cache_ ).read_from_file( (threadPoolPtr_ -> conf_).getMapConf()["my_cache"].c_str()) ;
        
        std::ofstream outfile( (threadPoolPtr_ -> conf_).getMapConf()["my_cache"].c_str() ) ;
        if(!outfile)
        {
            throw std::runtime_error("scan cache : open cache failed");
        }       
        
        ( (*iter ) -> cache_ ).write_to_file(outfile) ;
        outfile.close();
    }
}

线程池

头文件

#ifndef __THREADPOOL_H__
#define __THREADPOOL_H__

class ThreadPool
{
    public:
        
        friend class MyCacheThread ;
        
        ThreadPool(MyConf &conf, int size = 12)
            :vecThreads_(size),
             queueTaskslock_(), 
             queueTasksCond_(queueTaskslock_),
             isStarted_(false),
             conf_(conf),
             cacheThread_(size)
        {
            std::vector<MyThread>::iterator iter ;
            for(iter = vecThreads_.begin(); 
                iter != vecThreads_.end(); 
                ++iter)
            {
                iter -> get_related(this);           // 使线程池中的每一个工作线程持有线程池对象的指针
            }
                cacheThread_.get_related(this);      // 使线程池中的扫描线程持有线程池对象的指针
        }
        
        void on()
        {
            if(isStarted_)
            {
                return ;
            }
            isStarted_ = true ;
            std::vector<MyThread>::iterator iter ;
            for(iter = vecThreads_.begin(); iter != vecThreads_.end(); iter ++)
            {
                iter -> start();    // 开启工作线程
            }
            cacheThread_.start();   // 开启扫描线程
        }
        
        void off()
        {
            if(isStarted_)
            {
                isStarted_ = false ;
                queueTasksCond_.broadcast();
                while(!queueTasks_.empty())
                {
                    queueTasks_.pop();
                } 
            }
        }
        
        void allocate_task( MyTask& task)
        {
            queueTaskslock_.lock();
            std::cout << "Add Task" << std::endl ;
            queueTasks_.push(task);
            queueTaskslock_.unlock();
            queueTasksCond_.broadcast();
        }
        
        bool get_task(MyTask &task)
        {
            queueTaskslock_.lock();
            while(isStarted_ && queueTasks_.empty())
            {
                queueTasksCond_.wait();
            }
            if(!isStarted_)
            { 
                queueTaskslock_.unlock();
                queueTasksCond_.broadcast();
                return false ;
            }
            task = queueTasks_.front();
            queueTasks_.pop();
            queueTaskslock_.unlock();
            queueTasksCond_.broadcast();
            std::cout << "get task" << std::endl ;
            return true ;
        }
        
        MyConf &conf_; // 配置对象的引用                       
    
    private:
        // 禁止赋值和复制
        ThreadPool(const ThreadPool& obj) ;
        ThreadPool& operator = (const ThreadPool& obj) ;
        
        std::vector<MyThread> vecThreads_ ;   // 存放工作线程的容器
        std::queue<MyTask>    queueTasks_ ;   // 存放任务的队列
        
        MyLock queueTaskslock_ ;
        MyCondition queueTasksCond_ ;
        
        bool isStarted_ ;                     // 用于标识线程池是否开启的变量
        
        MyCacheThread cacheThread_ ;          // 定时扫描内存cache的线程
};
#endif

6 注意点

配置文件对象持有词频词典,倒排索引以及硬盘cache。

线程池对象持有任务队列,工作线程,扫描线程,配置文件对象引用(方便使用配置文件对象的资源)。当线程池启动时,线程池对象会逐个开启工作线程以及扫描线程。工作线程负责从任务队列中取出任务并执行任务。扫描线程负责每隔60秒同步工作线程的cache与硬盘cache。

工作线程中持有线程池对象的引用,这样工作线程可以直接调用线程池对象的get_task方法来取任务。

扫描线程持有线程池对象的指针(线程池对象拥有配置文件对象引用,配置文件对象中拥有硬盘cache),以及每个工作线程对象的指针,便于完成工作线程的内存cache与硬盘cache的同步。

7 需要同步的两处地方

1. 工作线程从任务队列中取任务的时候,需要对任务队列上锁。主线程收到客户端请求时,往任务队列中push任务时,也需要加锁。

2. 工作线程与扫描线程对内存cache的访问需要同步。

    1)工作线程查询cache时

    2)当查询词在cache中找不到时,工作线程将会计算查询词与词频词典中单词的编辑距离,得到的结果需要写回内存cache

    3)扫描线程用硬盘cache来更新工作线程的内存cache。

    4)扫描线程用工作线程的内存cache来更新硬盘cache。

8 优化

如果客户端每来一个查询词,工作线程都需要到词频词典中去与所有的单词进行编辑距离的计算的话,效率实现太低。因此,我们需要加速拼写纠错的过程。

1. 每个工作线程加入了cache,每当来一个查询词,工作线程先到cache中去查找。

2. 使用倒排索引,减少了需要与查询词进行编辑距离计算的单词。实际上,可以进一步的优化,如果规定编辑距离的阈值为2,那么我们只需要对其中的任意三个letter对应的单词的并集做计算即可。

 

[Project] SpellCorrect源码详解

标签:

原文地址:http://www.cnblogs.com/jianxinzhou/p/4740392.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!