码迷,mamicode.com
首页 > 其他好文 > 详细

hadoop 大矩阵相乘

时间:2016-05-13 01:37:36      阅读:286      评论:0      收藏:0      [点我收藏+]

标签:

package org.bigdata.util;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Scanner;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

/**
 * 大矩阵相乘
 * 
 * @author wwhhf
 * 
 */
public class MatrixMapReduce {

    public static class Node {
        private Integer i = null;
        private Integer j = null;
        private Long val = null;

        public Node(Integer i, Integer j, Long val) {
            super();
            this.i = i;
            this.j = j;
            this.val = val;
        }

        public Integer getI() {
            return i;
        }

        public Integer getJ() {
            return j;
        }

        public Long getVal() {
            return val;
        }

        @Override
        public String toString() {
            return "Node [i=" + i + ", j=" + j + ", val=" + val + "]";
        }
    }

    public static class MatrixComparator implements Comparator<Node> {

        @Override
        public int compare(Node o1, Node o2) {
            if (o1.getI() == o2.getI()) {
                return (int) (o1.getJ() - o2.getJ());
            } else {
                return (int) (o1.getI() - o2.getI());
            }
        }

    }

    public static class MatrixMapper extends
            Mapper<LongWritable, Text, Text, Text> {

        private int M = 0;
        private int N = 0;

        @Override
        protected void map(LongWritable key, Text value, Context context)
                throws IOException, InterruptedException {

            FileSplit fileSplit = (FileSplit) context.getInputSplit();
            String fileName = fileSplit.getPath().getName();

            String terms[] = value.toString().split(" ");
            String xy[] = terms[0].split(",");
            int x = Integer.valueOf(xy[0]);
            int y = Integer.valueOf(xy[1]);

            // 矩阵M*N
            if (fileName.startsWith("M")) {
                // 矩阵M
                for (int i = 1; i <= N; i++) {
                    context.write(new Text(x + "," + i),
                            new Text("M" + value.toString()));
                }
            } else {
                // 矩阵N
                for (int i = 1; i <= M; i++) {
                    context.write(new Text(i + "," + y),
                            new Text("N" + value.toString()));
                }
            }
        }

        @Override
        protected void setup(Context context) throws IOException,
                InterruptedException {
            Configuration config = context.getConfiguration();
            M = config.getInt("M", 0);
            N = config.getInt("N", 0);
        }

    }

    public static class MatrixReducer extends
            Reducer<Text, Text, Text, LongWritable> {

        private int M = 0;
        private int N = 0;

        @Override
        protected void reduce(Text key, Iterable<Text> values, Context context)
                throws IOException, InterruptedException {
            List<Node> MMatrix = new ArrayList<>();
            List<Node> NMatrix = new ArrayList<>();
            for (Text value : values) {
                String record = value.toString();
                String terms[] = record.substring(1).split(" ");
                String xy[] = terms[0].split(",");
                int x = Integer.valueOf(xy[0]);
                int y = Integer.valueOf(xy[1]);
                long val = Integer.valueOf(terms[1]);
                if (record.startsWith("M")) {
                    // 矩阵M
                    MMatrix.add(new Node(x, y, val));
                } else {
                    NMatrix.add(new Node(x, y, val));
                }
            }
            Comparator<Node> cmp = new MatrixComparator();
            Collections.sort(MMatrix, cmp);
            Collections.sort(NMatrix, cmp);
            System.out.println(MMatrix);
            System.out.println(NMatrix);
            if (NMatrix.size() == MMatrix.size()) {
                long sum = 0L;
                for (Node a : MMatrix) {
                    for (Node b : NMatrix) {
                        sum = sum + (a.getVal() * b.getVal());
                    }
                }
                context.write(key, new LongWritable(sum));
            }
        }

        @Override
        protected void setup(Context context) throws IOException,
                InterruptedException {
            Configuration config = context.getConfiguration();
            M = config.getInt("M", 0);
            N = config.getInt("N", 0);
        }
    }

    public static void main(String[] args) {
        Scanner cin = new Scanner(System.in);
        try {
            Configuration cfg = HadoopCfg.getConfiguration();
            cfg.setInt("M", cin.nextInt());
            cfg.setInt("K", cin.nextInt());
            cfg.setInt("N", cin.nextInt());

            Job job = Job.getInstance(cfg);
            job.setJobName("Matrix");
            job.setJarByClass(MatrixMapReduce.class);

            // mapper
            job.setMapperClass(MatrixMapper.class);
            job.setMapOutputKeyClass(Text.class);
            job.setMapOutputValueClass(Text.class);

            // reducer
            job.setReducerClass(MatrixReducer.class);
            job.setOutputKeyClass(Text.class);
            job.setOutputValueClass(LongWritable.class);

            FileInputFormat.addInputPath(job, new Path("/matrix"));
            FileOutputFormat.setOutputPath(job, new Path("/matrix_out/"));

            System.exit(job.waitForCompletion(true) ? 0 : 1);

        } catch (IllegalStateException | IllegalArgumentException
                | ClassNotFoundException | IOException | InterruptedException e) {
            e.printStackTrace();
        }
    }
}

hadoop 大矩阵相乘

标签:

原文地址:http://blog.csdn.net/qq_17612199/article/details/51345240

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!