package main import ( "./db" "./es" //"encoding/json" "fmt" "reflect" "time" ) type PrdES struct { DB *prd.Mysql ES *es.Elastic } // func (this *PrdES) Handle(result []*prd.Series) { // // for _, value := range result { // // this.DB.FormatData(value) // // //json, _ := json.Marshal(value) // // //fmt.Println(string(json)) // // } // //写入ES,以多线程的方式执行,最多保持5个线程 // this.ES.DoBulk(result) // } func (this *PrdES) Run() { count := 50 offset := 0 maxCount := 20 //create channel chs := make([]chan []*prd.Series, maxCount) selectCase := make([]reflect.SelectCase, maxCount) for i := 0; i < maxCount; i++ { offset = count * i fmt.Println("offset:", offset) //init channel chs[i] = make(chan []*prd.Series) //set select case selectCase[i].Dir = reflect.SelectRecv selectCase[i].Chan = reflect.ValueOf(chs[i]) //运行 go this.DB.GetData(offset, count, chs[i]) } var result []*prd.Series for { //wait data return chosen, recv, ok := reflect.Select(selectCase) if ok { fmt.Println("channel id:", chosen) result = recv.Interface().([]*prd.Series) //读取数据从mysql go this.DB.GetData(offset, count, chs[chosen]) //写入ES,以多线程的方式执行,最多保持15个线程 this.ES.DoBulk(result) //update offset offset = offset + len(result) //判断是否到达数据尾部,最后一次查询 if len(result) < count { fmt.Println("read end of DB") //等所有的任务执行完毕 this.ES.Over() fmt.Println("MySQL Total:", this.DB.GetTotal(), ",Elastic Total:", this.ES.GetTotal()) return } } } } func main() { s := time.Now() fmt.Println("start") pe := new(PrdES) pe.DB = prd.NewDB() pe.ES = es.NewES() //fmt.Println("mysql info:") //fmt.Println("ES info:") pe.Run() fmt.Println("time out:", time.Since(s).Seconds(), "(s)") fmt.Println("Over!") }
在run函数里可以看到使用了reflect.SelectCase,使用reflect.SelectCase的原因是读MySQL数据是多个协程,不可预计哪个会首先返回,selectCase是任何一个处理完毕reflect.Select函数就会返回,MySQL读取的数据放在channel中宕Select函数返回时chosen, recv, ok := reflect.Select(selectCase)判断ok是否未true chosen代表的是协程id通过result = recv.Interface().([]*prd.Series)获得返回的数据,因为MySQL读取的数据是对象的结果集,因次使用recv.Interface函数,如果是简单类型可以使用recv.recvInt(),recv.recvString()等函数直接获取channel返回数据。
package es import ( "../db" //"encoding/json" "fmt" elastigo "github.com/mattbaird/elastigo/lib" //elastigo "github.com/Uncodin/elastigo/lib" //"github.com/Uncodin/elastigo/core" "time" //"bytes" "flag" "sync" //"github.com/fatih/structs" ) var ( //开发测试库 //host = flag.String("host", "", "Elasticsearch Host") //C平台线上 host = flag.String("host", "", "Elasticsearch Host") port = flag.String("port", "9200", "Elasticsearch port") ) //indexor := core.NewBulkIndexorErrors(10, 60) // func init() { // //connect to elasticsearch // fmt.Println("connecting es") // //api.Domain = *host //"" // //api.Port = "9300" // } //save thread count var counter int type Elastic struct { //Seq int64 c *elastigo.Conn lock *sync.Mutex lockTotal *sync.Mutex wg *sync.WaitGroup total int64 } func (this *Elastic) Conn() { this.c = elastigo.NewConn() this.c.Domain = *host this.c.Port = *port //NewClient(fmt.Sprintf("%s:%d", *host, *port)) } func (this *Elastic) CreateLock() { this.lock = &sync.Mutex{} this.lockTotal = &sync.Mutex{} this.wg = &sync.WaitGroup{} counter = 0 this.total = 0 } func NewES() (es *Elastic) { //connect elastic es = new(Elastic) es.Conn() //create lock es.CreateLock() return es } func (this *Elastic) DoBulk(series []*prd.Series) { for true { this.lock.Lock() if counter < 25 { //跳出,执行任务 break } else { this.lock.Unlock() //等待100毫秒 //fmt.Println("wait counter less than 25, counter:", counter) time.Sleep(1e8) } } this.lock.Unlock() //执行任务 go this.bulk(series, this.lock) } func (this *Elastic) Over() { this.wg.Wait() /*for { this.lock.Lock() if counter <= 0 { this.lock.Unlock() break } this.lock.Unlock() } */ } func (this *Elastic) GetTotal() (t int64) { this.lockTotal.Lock() t = this.total this.lockTotal.Unlock() return t } func (this *Elastic) bulk(series []*prd.Series, lock *sync.Mutex) (succCount int64) { //增加计数器 this.wg.Add(1) //减少计数器 defer this.wg.Done() //加计数器 lock.Lock() counter++ fmt.Println("add task, coutner:", counter) lock.Unlock() //设置初始成功写入的数量 succCount = 0 for _, value := range series { //json, _ := json.Marshal(value) //fmt.Println(string(json)) if value.ServiceGroup != nil { fmt.Println("series code:", value.Code, ",ServiceGroup:", value.ServiceGroup) resp, err := this.c.Index("guttv", "series", value.Code, nil, *value) if err != nil { panic(err) } else { //fmt.Println(value.Code + " write to ES succsessful!") fmt.Println(resp) succCount++ } } else { fmt.Println("series code:", value.Code, "service group is null") } } //计数器减一 lock.Lock() counter-- fmt.Println("reduce task, coutner:", counter, ",success count:", succCount) lock.Unlock() this.lockTotal.Lock() this.total = this.total + succCount this.lockTotal.Unlock() return succCount }
defer this.wg.Done()
package prd import ( "fmt" "github.com/astaxie/beego/orm" _ "github.com/go-sql-driver/mysql" // import your used driver "strings" "sync" "time" ) func init() { orm.RegisterDataBase("default", "mysql", "@tcp(", 30) orm.RegisterModelWithPrefix("t_", new(Series), new(Product), new(ServiceGroup)) orm.RunSyncdb("default", false, false) } type Mysql struct { sql string total int64 lock *sync.Mutex } func (this *Mysql) New() { //this.sql = "SELECT s.*, p.code ProductCode, p.name pName FROM guttv_vod.t_series s inner join guttv_vod.t_product p on p.itemcode=s.code and p.isdelete=0 limit ?,?" this.sql = "SELECT s.*, p.code ProductCode, p.name pName FROM guttv_vod.t_series s , guttv_vod.t_product p where p.itemcode=s.code and p.isdelete=0 limit ?,?" this.total = 0 this.lock = &sync.Mutex{} } func NewDB() (db *Mysql) { db = new(Mysql) db.New() return db } func (this *Mysql) GetTotal() (t int64) { t = 0 this.lock.Lock() t = this.total this.lock.Unlock() return t } func (this *Mysql) toTime(toBeCharge string) int64 { timeLayout := "2006-01-02 15:04:05" loc, _ := time.LoadLocation("Local") theTime, _ := time.ParseInLocation(timeLayout, toBeCharge, loc) sr := theTime.Unix() if sr < 0 { sr = 0 } return sr } func (this *Mysql) getSGCode(seriesCode string) (result []string, num int64) { sql := "select distinct ref.servicegroupcode code from t_servicegroup_reference_category ref " sql = sql + "left join t_category_product cp on cp.categorycode=ref.categorycode " sql = sql + "left join t_package pkg on pkg.code = cp.assetcode " sql = sql + "left join t_package_product pp on pp.parentcode=pkg.code " sql = sql + "left join t_product prd on prd.code = pp.assetcode " sql = sql + "where prd.itemcode=?" o := orm.NewOrm() var sg []*ServiceGroup num, err := o.Raw(sql, seriesCode).QueryRows(&sg) if err == nil { //fmt.Println(num) for _, value := range sg { //fmt.Println(value.Code) result = append(result, value.Code) } } else { fmt.Println(err) } //fmt.Println(result) return result, num } func (this *Mysql) formatData(value *Series) { //设置业务分组数据 sg, _ := this.getSGCode(value.Code) //fmt.Println(sg) value.ServiceGroup = []string{} value.ServiceGroup = sg[0:] //更改OnlineTime为整数 value.OnlineTimeInt = this.toTime(value.OnlineTime) //分解地区 value.OriginalCountryArr = strings.Split(value.OriginalCountry, "|") //分解二级分类 value.ProgramType2Arr = strings.Split(value.ProgramType2, "|") //写入记录内容 value.Description = strings.Replace(value.Description, "\n", "", -1) } func (this *Mysql) GetData(offset int, size int, ch chan []*Series) { var result []*Series o := orm.NewOrm() num, err := o.Raw(this.sql, offset, size). QueryRows(&result) if err != nil { fmt.Println("read DB err") panic(err) //return //err, nil } for _, value := range result { this.formatData(value) //json, _ := json.Marshal(value) //fmt.Println(string(json)) //fmt.Println(value.ServiceGroup) } this.lock.Lock() this.total += num this.lock.Unlock() fmt.Println("read count :", num) //, "Total:", Total) //return nil, result ch <- result }