现在的位置: 首页 > 综合 > 正文

MapReduce框架中矩阵相乘的算法思路及其实现

2013年10月03日 ⁄ 综合 ⁄ 共 12229字 ⁄ 字号 评论关闭

 

关于在mapreduce框架中的两个矩阵相乘(A*B)的算法实现,有如下两种思路。。

 

第一,因为我们在学校课堂内的矩阵相乘的基本算法就是A的行与B的列相乘当然要满足A的列的维数与B的行维数相同,才能满足相乘的条件。所以有如下基本思路:

让每个map任务计算A的一行乘以B的一列,最后由reduce进行求和输出。这是最原始的实现方法:

 

假设A(m*n)  B(n*s)

map的输入的格式如下<<x,y>,<Ax,By>>    0=<x<m,0=<y<s,0=<z<n

其中 <x,y>是key,x代表A的行号,y代表B的列号,<<Ax,By>>是value,Ax代表A的第x行第z列的元素,By代表B的第y列的第z行的一个元素,

A的一行与B的一列输入到一个maptask中,我们只需要对每个键值对中的value的两个值相乘即可,输出一个<<x,y>,Ax*By>

然后到洗牌阶段,将相同的可以输入到一个Reduce task中,然后reduce只需对相同key的value列表进行Ax*By进行求和即可。这个算法说起来比较简单,但是如何控制split中的内容是主要的问题。

 

首先需要重写InputSplit,InputFormat,Partion,来控制数据的流动,在数据结构方面需要定义一个实现的WritableComparable借口的类来保存两个整数(因为前面的key和value都出现两个整数),而且对象可以排序。

IntPair.class实现

Java代码 
package com.zxx.matrix;   
  
import java.io.DataInput;   
import java.io.DataOutput;   
import java.io.IOException;   
  
import org.apache.hadoop.io.WritableComparable;   
  
public class IntPair implements WritableComparable   
{   
       
    private int right=0;   
    private int left=0;   
       
    public IntPair(){}   
       
    public IntPair(int right,int left){   
        this.right=right;   
        this.left=left;   
    }   
       
    public int getRight(){   
        return right;   
    }   
       
    public int getLeft(){   
        return left;   
    }   
    public void setRight(int right){   
        this.right=right;   
    }   
    public void setLeft(int left){   
        this.left=left;   
    }   
    public String toString(){   
        return left+","+right;   
    }   
    @Override  
    public void readFields(DataInput arg0) throws IOException   
    {   
        // TODO Auto-generated method stub   
        right=arg0.readInt();   
        left=arg0.readInt();   
    }   
  
    @Override  
    public void write(DataOutput arg0) throws IOException   
    {   
        // TODO Auto-generated method stub   
        arg0.writeInt(right);   
        arg0.writeInt(left);   
    }   
  
    @Override  
    public int compareTo(Object arg0)   
    {   
        // TODO Auto-generated method stub   
        IntPair o=(IntPair)arg0;   
        if(this.right<o.getRight())   
        {   
            return -1;   
        }else if (this.right>o.getRight()) {   
            return 1;   
        }else if (this.left<o.getLeft()) {   
            return -1;   
        }else if (this.left>o.getLeft()) {   
            return 1;   
        }   
        return 0;   
    }   
       
}  
package com.zxx.matrix; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import org.apache.hadoop.io.WritableComparable; public class IntPair implements WritableComparable { private int right=0; private int left=0; public IntPair(){}
public IntPair(int right,int left){ this.right=right; this.left=left; } public int getRight(){ return right; } public int getLeft(){ return left; } public void setRight(int right){ this.right=right; } public void setLeft(int left){ this.left=left; } public
String toString(){ return left+","+right; } @Override public void readFields(DataInput arg0) throws IOException { // TODO Auto-generated method stub right=arg0.readInt(); left=arg0.readInt(); } @Override public void write(DataOutput arg0) throws IOException
{ // TODO Auto-generated method stub arg0.writeInt(right); arg0.writeInt(left); } @Override public int compareTo(Object arg0) { // TODO Auto-generated method stub IntPair o=(IntPair)arg0; if(this.right<o.getRight()) { return -1; }else if (this.right>o.getRight())
{ return 1; }else if (this.left<o.getLeft()) { return -1; }else if (this.left>o.getLeft()) { return 1; } return 0; } }
 InputSplit.class(样例)

在这个类中用一个ArrayWritable 来保存元素的位置信息以及具体的元素信息

Java代码 
public class matrixInputSplit extends InputSplit implements Writable   
{   
    private IntPair[] t;//具体元素信息   
    private IntPair location;//key的值,元素位置信息   
    private ArrayWritable intPairArray;   
       
    public matrixInputSplit()   
    {   
           
    }   
    public matrixInputSplit(int row,matrix left,int col,matrix right)   
    {   
        //填充intPairArray   
        intPairArray=new ArrayWritable(IntPair.class);   
        t=new IntPair[4];   
        location=new IntPair(row,col);   
                for(int j=0;j<3;j++)   
                {   
                    IntPair intPair=new IntPair();   
                    intPair.setLeft(left.m[row][j]);   
                    intPair.setRight(right.m[j][col]);   
                    t[j]=intPair;   
                }   
                t[3]=location;   
                intPairArray.set(t);   
    }   
       
       
    @Override  
    public long getLength() throws IOException, InterruptedException   
    {   
        return 0;   
    }   
  
    @Override  
    public String[] getLocations() throws IOException, InterruptedException   
    {   
        return new String[]{};   //返回空  这样JobClient就不会从文件中读取split   
    }   
  
    @Override  
    public void readFields(DataInput arg0) throws IOException   
    {   
        this.intPairArray=new ArrayWritable(IntPair.class);   
        this.intPairArray.readFields(arg0);   
    }   
  
    @Override  
    public void write(DataOutput arg0) throws IOException   
    {   
          
        intPairArray.write(arg0);   
           
    }   
    public IntPair getLocation()   
    {   
        t=new IntPair[4];   
        try  
        {   
            t=(IntPair[])intPairArray.toArray();       
        } catch (Exception e)   
        {   
            System.out.println("toArray excption");   
        }   
        return t[3];   
    }   
    public IntPair[] getIntPairs()   
    {   
        t=new IntPair[4];   
        try  
        {   
            t=(IntPair[])intPairArray.toArray();       
        } catch (Exception e)   
        {   
            System.out.println("toArray excption");   
        }   
        IntPair[] intL=new IntPair[3];   
        for(int i=0;i<3;i++)   
        {   
           intL[i]=t[i];       
        }   
        return intL;   
    }   
}  
public class matrixInputSplit extends InputSplit implements Writable { private IntPair[] t;//具体元素信息 private IntPair location;//key的值,元素位置信息 private ArrayWritable intPairArray; public matrixInputSplit() { } public matrixInputSplit(int row,matrix left,int col,matrix
right) { //填充intPairArray intPairArray=new ArrayWritable(IntPair.class); t=new IntPair[4]; location=new IntPair(row,col); for(int j=0;j<3;j++) { IntPair intPair=new IntPair(); intPair.setLeft(left.m[row][j]); intPair.setRight(right.m[j][col]); t[j]=intPair;
} t[3]=location; intPairArray.set(t); } @Override public long getLength() throws IOException, InterruptedException { return 0; } @Override public String[] getLocations() throws IOException, InterruptedException { return new String[]{}; //返回空 这样JobClient就不会从文件中读取split
} @Override public void readFields(DataInput arg0) throws IOException { this.intPairArray=new ArrayWritable(IntPair.class); this.intPairArray.readFields(arg0); } @Override public void write(DataOutput arg0) throws IOException { intPairArray.write(arg0); }
public IntPair getLocation() { t=new IntPair[4]; try { t=(IntPair[])intPairArray.toArray(); } catch (Exception e) { System.out.println("toArray excption"); } return t[3]; } public IntPair[] getIntPairs() { t=new IntPair[4]; try { t=(IntPair[])intPairArray.toArray();
} catch (Exception e) { System.out.println("toArray excption"); } IntPair[] intL=new IntPair[3]; for(int i=0;i<3;i++) { intL[i]=t[i]; } return intL; } }
 Inputformat.class

这个类比较简单,只需要实现getSplit方法即可,不过需要用户自定义一个方法就是从getInputfile获得的路径来解析矩阵,输入到split中即可。

matrixMul.class

Java代码 
public class MatrixNew   
{   
  
    public static class MatrixMapper extends Mapper<IntPair, IntPair, IntPair, IntWritable>   

    {   
        public void map(IntPair key, IntPair value, Context context)   
        {      
            int left=0 ;   
            int right=0;   
            System.out.println("map is do");   
            left = value.getLeft();   
            right = value.getRight();   
               
            IntWritable result = new IntWritable(left * right); // key不变,   
                                                                // value中的两个int相乘   

            try  
            {   
                context.write(key, result);   
            } catch (IOException e)   
            {   
                // TODO Auto-generated catch block   
                e.printStackTrace();   
            } catch (InterruptedException e)   
            {   
                // TODO Auto-generated catch block   
                e.printStackTrace();   
            } // 输出kv对   
        }   
    }   
  
    public static class MatrixReducer extends Reducer<IntPair, IntWritable, IntPair, IntWritable>   

    {   
        private IntWritable result = new IntWritable();   
  
        public void reduce(IntPair key, Iterable<IntWritable> values, Context context)   

        {   
            int sum = 0;   
            for (IntWritable val : values)   
            {   
                int v = val.get();   
                sum += v;   
            }   
            result.set(sum);   
            try  
            {   
                context.write(key, result);   
            } catch (IOException e)   
            {   
                // TODO Auto-generated catch block   
                e.printStackTrace();   
            } catch (InterruptedException e)   
            {   
                // TODO Auto-generated catch block   
                e.printStackTrace();   
            }   
        }   
    }   
  
    public static class FirstPartitioner extends Partitioner<IntPair, IntWritable>   

    {   
        public int getPartition(IntPair key, IntWritable value, int numPartitions)   

        {   
            int abs = Math.abs(key.getLeft()) % numPartitions;   
            // numPartitions是reduce线程的数量   
            return abs;   
        }   
    }   
  
    public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException   

    {   
        Configuration conf=new Configuration();   
        new GenericOptionsParser(conf, args);   
        FileSystem fs=FileSystem.get(conf);   
        Job job = new Job(conf, "New Matrix Multiply Job ");   
        job.setJarByClass(MatrixNew.class);   
        job.setNumReduceTasks(1);   
        job.setInputFormatClass(matrixInputFormat.class);   
        job.setOutputFormatClass(TextOutputFormat.class);   
        job.setMapperClass(MatrixMapper.class);   
        job.setReducerClass(MatrixReducer.class);   
        job.setPartitionerClass(FirstPartitioner.class);       
        job.setMapOutputKeyClass(IntPair.class);   
        job.setMapOutputValueClass(IntWritable.class);   
        job.setOutputKeyClass(IntPair.class);   
        job.setOutputValueClass(IntWritable.class);   
           
        matrixInputFormat.setInputPath(args[0]);   
            
        FileOutputFormat.setOutputPath(job,new Path(fs.makeQualified(new Path("/newMartixoutput")).toString()));   

           
        boolean ok = job.waitForCompletion(true);   
        if(ok){  //删除临时文件   
               
        }   
    }   
  
}  
public class MatrixNew { public static class MatrixMapper extends Mapper<IntPair, IntPair, IntPair, IntWritable> { public void map(IntPair key, IntPair value, Context context) { int left=0 ; int right=0; System.out.println("map is do"); left = value.getLeft();
right = value.getRight(); IntWritable result = new IntWritable(left * right); // key不变, // value中的两个int相乘 try { context.write(key, result); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (InterruptedException e) {
// TODO Auto-generated catch block e.printStackTrace(); } // 输出kv对 } } public static class MatrixReducer extends Reducer<IntPair, IntWritable, IntPair, IntWritable> { private IntWritable result = new IntWritable(); public void reduce(IntPair key, Iterable<IntWritable>
values, Context context) { int sum = 0; for (IntWritable val : values) { int v = val.get(); sum += v; } result.set(sum); try { context.write(key, result); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (InterruptedException
e) { // TODO Auto-generated catch block e.printStackTrace(); } } } public static class FirstPartitioner extends Partitioner<IntPair, IntWritable> { public int getPartition(IntPair key, IntWritable value, int numPartitions) { int abs = Math.abs(key.getLeft())
% numPartitions; // numPartitions是reduce线程的数量 return abs; } } public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException { Configuration conf=new Configuration(); new GenericOptionsParser(conf, args); FileSystem
fs=FileSystem.get(conf); Job job = new Job(conf, "New Matrix Multiply Job "); job.setJarByClass(MatrixNew.class); job.setNumReduceTasks(1); job.setInputFormatClass(matrixInputFormat.class); job.setOutputFormatClass(TextOutputFormat.class); job.setMapperClass(MatrixMapper.class);
job.setReducerClass(MatrixReducer.class); job.setPartitionerClass(FirstPartitioner.class); job.setMapOutputKeyClass(IntPair.class); job.setMapOutputValueClass(IntWritable.class); job.setOutputKeyClass(IntPair.class); job.setOutputValueClass(IntWritable.class);
matrixInputFormat.setInputPath(args[0]); FileOutputFormat.setOutputPath(job,new Path(fs.makeQualified(new Path("/newMartixoutput")).toString())); boolean ok = job.waitForCompletion(true); if(ok){ //删除临时文件 } } }
 以上代码只是简单测试下。。如有问题欢迎大家指正!这里先谢过!

第二个方法就是矩阵分块相乘,这个算法网上有大牛已经给出了源代码。。。

 

转自 http://zxxapple.iteye.com/blog/1405209

抱歉!评论已关闭.