package doc
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
object SQLMLlib {
def main(args: Array[String]) {
//屏蔽不必要的日志显示在终端上
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
//设置运行环境
val sparkConf = new SparkConf().setAppName("SQLMLlib")
val sc = new SparkContext(sparkConf)
val hiveContext = new HiveContext(sc)
//使用sparksql查出每个店的销售数量和金额
hiveContext.sql("use saledata")
val sqldata = hiveContext.sql("select a.locationid, sum(b.qty) totalqty,sum(b.amount) totalamount from tblStock a join tblstockdetail b on a.ordernumber=b.ordernumber group by a.locationid")
//将查询数据转换成向量
val parsedData = sqldata.map {
case Row(_, totalqty, totalamount) =>
val features = Array[Double](totalqty.toString.toDouble, totalamount.toString.toDouble)
Vectors.dense(features)
}
//对数据集聚类,3个类,20次迭代,形成数据模型
//注意这里没设置partition的数量,会使用MLLib的缺省partition数200
val numClusters = 3
val numIterations = 20
val model = KMeans.train(parsedData, numClusters, numIterations)
//用模型对读入的数据进行分类,并输出
//由于partition没设置,输出为200个小文件,可以使用bin/hdfs dfs -getmerge 合并下载到本地
val result2 = sqldata.map {
case Row(locationid, totalqty, totalamount) =>
val features = Array[Double](totalqty.toString.toDouble, totalamount.toString.toDouble)
val linevectore = Vectors.dense(features)
val prediction = model.predict(linevectore)
locationid + " " + totalqty + " " + totalamount + " " + prediction
}.saveAsTextFile(args(0))
sc.stop()
}
}编译打包后运行:package doc
//由于暂时手上缺少数据,本例只给出框架,以后有机会补上
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
object SQLGraphX {
def main(args: Array[String]) {
//屏蔽不必要的日志显示在终端上
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
//设置运行环境
val sparkConf = new SparkConf().setAppName("SQLGraphX")
val sc = new SparkContext(sparkConf)
val hiveContext = new HiveContext(sc)
//切换到销售数据库
hiveContext.sql("use saledata")
//使用sparksql查出店铺的销量和库存,作为图的顶点
//其中locationid为VertexID,(销量,库存)为VD,一般为(Int,Int)类型
val vertexdata = hiveContext.sql("select a.locationid, b.saleQty, b.InvQty From a join b on a.col1=b.col2 where conditions")
//使用sparksql查出店铺之间的距离,也可以是花费时间等和调拨相关的属性,作为图的边
//distance为ED,可以使用Int、Long、Double等数据类型
val edgedata = hiveContext.sql("select srcid, distid, distance From distanceInfo")
//构造vertexRDD和edgeRDD
val vertexRDD: RDD[(Long, (Int, Int))] = vertexdata.map(...)
val edgeRDD: RDD[Edge[Int]] = edgedata.map(...)
//构造图Graph[VD,ED]
val graph: Graph[(Int, Int), Int] = Graph(vertexRDD, edgeRDD)
//根据调拨的规则进行图处理
val initialGraph = graph.mapVertices(...)
initialGraph.pregel(...)
//输出
sc.stop()
}
}原文地址:http://blog.csdn.net/book_mmicky/article/details/39202093