标签:predictionio e-commerce recommendation 源码分析
Algorithm 类
@Override
public Model train(SparkContext sc, PreparedData preparedData) {
TrainingData data = preparedData.getTrainingData();
//模型训练
//建立用户索引
JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() {
@Override
public String call(Tuple2<String, User> idUser) throws Exception {
return idUser._1();
}
}).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
@Override
public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
return new Tuple2<>(element._1(), element._2().intValue());
}
});
//变成java的map对象
final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap();
//最终变成 u1->1, u2->2
//建立商品索引
JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() {
@Override
public String call(Tuple2<String, Item> idItem) throws Exception {
return idItem._1();
}
}).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
@Override
public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
return new Tuple2<>(element._1(), element._2().intValue());
}
});
//最终变成 i1->1, i2->2
final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap();
JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() {
@Override
public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception {
return element.swap();
}
});
//索引反转,便于日后根据序号ID找商品
final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap();
//建立评分索引
JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
@Override
public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception {
Integer userIndex = userIndexMap.get(viewEvent.getUser());
Integer itemIndex = itemIndexMap.get(viewEvent.getItem());
return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
}
}).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
@Override
public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
return (element != null);
}
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception {
return integer + integer2;
}
}).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() {
@Override
public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception {
return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue());
}
});
//最终变成 (u1,i1)->1 (u1,i2)->2
// 调用MLlib ALS 算法
MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed());
JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
@Override
public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
return new Tuple2<>((Integer) element._1(), element._2());
}
});//返回基于用户维度的矩阵
JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
@Override
public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
return new Tuple2<>((Integer) element._1(), element._2());
}
});//返回基于商品维度的矩阵
// 当遇到冷启动时,推荐最流行的商品,此数据来源于用户购买的记录
JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
@Override
public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception {
Integer userIndex = userIndexMap.get(buyEvent.getUser());
Integer itemIndex = itemIndexMap.get(buyEvent.getItem());
return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
}
}).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
@Override
public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
return (element != null);
}
}).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() {
@Override
public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
return new Tuple2<>(element._1()._2(), element._2());
}
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception {
return integer + integer2;
}
}).map(new Function<Tuple2<Integer, Integer>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Integer, Integer> element) throws Exception {
return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue());
}
});
//最终变成 i1->1 i2->2
//生成最终的商品维度矩阵
JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD);
//训练结束
return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap(),buyItemForUser);
}
//推荐算法
@Override
public PredictedResult predict(Model model, final Query query) {
final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
@Override
public Boolean call(Tuple2<String, Integer> userIndex) throws Exception {
return userIndex._1().equals(query.getUserEntityId());
}
});//找到要推荐给某用户的用户索引数据
double[] userFeature = null;
if (!matchedUser.isEmpty()) {//如果能找到该用户索引
final Integer matchedUserIndex = matchedUser.first()._2();//返回用户的序号
userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() {
@Override
public Boolean call(Tuple2<Integer, double[]> element) throws Exception {
return element._1().equals(matchedUserIndex);
}
}).first()._2();//返回用户维度的矩阵,并且取第一条
}
if (userFeature != null) {//如果有用户维度的数据,走正常的推荐
return new PredictedResult(topItemsForUser(userFeature, model, query));
} else {
List<double[]> recentProductFeatures = getRecentProductFeatures(query, model);//返回该用户最近点击的商品
if (recentProductFeatures.isEmpty()) {//推最流行的商品
return new PredictedResult(mostPopularItems(model, query));
} else {//走相似推荐
return new PredictedResult(similarItems(recentProductFeatures, model, query));
}
}
}
//正常推荐流程
private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) {
//转成用户维度的矩阵
final DoubleMatrix userMatrix = new DoubleMatrix(userFeature);
JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2())));
}
});//用户维度的矩阵乘以商品维度的矩阵,将来根据得分高低,以此推荐
//过滤一些商品,比如黑名单,或者根据商品属性进行过滤
itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
//排序,并取前几位,推荐出来
List<ItemScore> result= sortAndTake(itemScores, query.getNumber());
return result;
}
//推荐最流程的商品,最流行的商品在训练模型时,已经预置
private List<ItemScore> mostPopularItems(Model model, Query query) {
JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
return sortAndTake(itemScores, query.getNumber());
}
//相似推荐,找到该用户最近浏览的商品
private List<double[]> getRecentProductFeatures(Query query, Model model) {
try {
List<double[]> result = new ArrayList<>();
//根据用户id,找该用户发生的事件(查看商品记录)
List<Event> events = LJavaEventStore.findByEntity(
ap.getAppName(),
"user",
query.getUserEntityId(),
OptionHelper.<String>none(),
OptionHelper.some(ap.getSimilarItemEvents()),
OptionHelper.some(OptionHelper.some("item")),
OptionHelper.<Option<String>>none(),
OptionHelper.<DateTime>none(),
OptionHelper.<DateTime>none(),
OptionHelper.some(10),
true,
Duration.apply(10, TimeUnit.SECONDS));
for (final Event event : events) {
if (event.targetEntityId().isDefined()) {
JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
@Override
public Boolean call(Tuple2<String, Integer> element) throws Exception {
return element._1().equals(event.targetEntityId().get());
}
});//根据事件ID返回,商品数据
//返回第一个商品的序号
final Integer itemIndex = filtered.first()._2();
if (!filtered.isEmpty()) {
JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() {
@Override
public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
return itemIndex.equals(element._1());
}//返回该商品对应的商品维度矩阵
});
//转成javalist对象
List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect();
if (oneIndexItemFeatures.size() > 0) {
result.add(oneIndexItemFeatures.get(0)._2()._2());//返回该商品对应ASL打分矩阵,以此来跟其他的商品打分矩阵,做相似度比较
}
}
}
}
return result;
} catch (Exception e) {
logger.error("Error reading recent events for user " + query.getUserEntityId());
throw new RuntimeException(e.getMessage(), e);
}
}
//具体的相似算法,根据上一个方法返回的item打分向量来计算
private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) {
JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
double similarity = 0.0;
for (double[] recentFeature : recentProductFeatures) {
similarity += cosineSimilarity(element._2()._2(), recentFeature);
}//用每一个商品打分矩阵与返回的某一个商品的打分矩阵,做相似度算分
return new ItemScore(element._2()._1(), similarity);
}
});
itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
return sortAndTake(itemScores, query.getNumber());
}
//如何判断相似
private double cosineSimilarity(double[] a, double[] b) {
DoubleMatrix matrixA = new DoubleMatrix(a);
DoubleMatrix matrixB = new DoubleMatrix(b);
return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2());
}由此来看该例子还是比较简单,适合用于二次开发。下面是一些基础知识
predictionIO E-Commerce Recommendation 源码分析
标签:predictionio e-commerce recommendation 源码分析
原文地址:http://12597095.blog.51cto.com/12587095/1981378