关于在mapreduce框架中的两个矩阵相乘(A*B)的算法实现,有如下两种思路。。
第一,因为我们在学校课堂内的矩阵相乘的基本算法就是A的行与B的列相乘当然要满足A的列的维数与B的行维数相同,才能满足相乘的条件。所以有如下基本思路:
让每个map任务计算A的一行乘以B的一列,最后由reduce进行求和输出。这是最原始的实现方法:
假设A(m*n) B(n*s)
map的输入的格式如下<<x,y>,<Ax,By>> 0=<x<m,0=<y<s,0=<z<n
其中 <x,y>是key,x代表A的行号,y代表B的列号,<<Ax,By>>是value,Ax代表A的第x行第z列的元素,By代表B的第y列的第z行的一个元素,
A的一行与B的一列输入到一个maptask中,我们只需要对每个键值对中的value的两个值相乘即可,输出一个<<x,y>,Ax*By>
然后到洗牌阶段,将相同的可以输入到一个Reduce task中,然后reduce只需对相同key的value列表进行Ax*By进行求和即可。这个算法说起来比较简单,但是如何控制split中的内容是主要的问题。
首先需要重写InputSplit,InputFormat,Partion,来控制数据的流动,在数据结构方面需要定义一个实现的WritableComparable借口的类来保存两个整数(因为前面的key和value都出现两个整数),而且对象可以排序。
IntPair.class实现
Java代码
package com.zxx.matrix;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.WritableComparable;
public class IntPair implements WritableComparable
{
private int right=0;
private int left=0;
public IntPair(){}
public IntPair(int right,int left){
this.right=right;
this.left=left;
}
public int getRight(){
return right;
}
public int getLeft(){
return left;
}
public void setRight(int right){
this.right=right;
}
public void setLeft(int left){
this.left=left;
}
public String toString(){
return left+","+right;
}
@Override
public void readFields(DataInput arg0) throws IOException
{
// TODO Auto-generated method stub
right=arg0.readInt();
left=arg0.readInt();
}
@Override
public void write(DataOutput arg0) throws IOException
{
// TODO Auto-generated method stub
arg0.writeInt(right);
arg0.writeInt(left);
}
@Override
public int compareTo(Object arg0)
{
// TODO Auto-generated method stub
IntPair o=(IntPair)arg0;
if(this.right<o.getRight())
{
return -1;
}else if (this.right>o.getRight()) {
return 1;
}else if (this.left<o.getLeft()) {
return -1;
}else if (this.left>o.getLeft()) {
return 1;
}
return 0;
}
}
package com.zxx.matrix; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import org.apache.hadoop.io.WritableComparable; public class IntPair implements WritableComparable { private int right=0; private int left=0; public IntPair(){}
public IntPair(int right,int left){ this.right=right; this.left=left; } public int getRight(){ return right; } public int getLeft(){ return left; } public void setRight(int right){ this.right=right; } public void setLeft(int left){ this.left=left; } public
String toString(){ return left+","+right; } @Override public void readFields(DataInput arg0) throws IOException { // TODO Auto-generated method stub right=arg0.readInt(); left=arg0.readInt(); } @Override public void write(DataOutput arg0) throws IOException
{ // TODO Auto-generated method stub arg0.writeInt(right); arg0.writeInt(left); } @Override public int compareTo(Object arg0) { // TODO Auto-generated method stub IntPair o=(IntPair)arg0; if(this.right<o.getRight()) { return -1; }else if (this.right>o.getRight())
{ return 1; }else if (this.left<o.getLeft()) { return -1; }else if (this.left>o.getLeft()) { return 1; } return 0; } }
InputSplit.class(样例)
在这个类中用一个ArrayWritable 来保存元素的位置信息以及具体的元素信息
Java代码
public class matrixInputSplit extends InputSplit implements Writable
{
private IntPair[] t;//具体元素信息
private IntPair location;//key的值,元素位置信息
private ArrayWritable intPairArray;
public matrixInputSplit()
{
}
public matrixInputSplit(int row,matrix left,int col,matrix right)
{
//填充intPairArray
intPairArray=new ArrayWritable(IntPair.class);
t=new IntPair[4];
location=new IntPair(row,col);
for(int j=0;j<3;j++)
{
IntPair intPair=new IntPair();
intPair.setLeft(left.m[row][j]);
intPair.setRight(right.m[j][col]);
t[j]=intPair;
}
t[3]=location;
intPairArray.set(t);
}
@Override
public long getLength() throws IOException, InterruptedException
{
return 0;
}
@Override
public String[] getLocations() throws IOException, InterruptedException
{
return new String[]{}; //返回空 这样JobClient就不会从文件中读取split
}
@Override
public void readFields(DataInput arg0) throws IOException
{
this.intPairArray=new ArrayWritable(IntPair.class);
this.intPairArray.readFields(arg0);
}
@Override
public void write(DataOutput arg0) throws IOException
{
intPairArray.write(arg0);
}
public IntPair getLocation()
{
t=new IntPair[4];
try
{
t=(IntPair[])intPairArray.toArray();
} catch (Exception e)
{
System.out.println("toArray excption");
}
return t[3];
}
public IntPair[] getIntPairs()
{
t=new IntPair[4];
try
{
t=(IntPair[])intPairArray.toArray();
} catch (Exception e)
{
System.out.println("toArray excption");
}
IntPair[] intL=new IntPair[3];
for(int i=0;i<3;i++)
{
intL[i]=t[i];
}
return intL;
}
}
public class matrixInputSplit extends InputSplit implements Writable { private IntPair[] t;//具体元素信息 private IntPair location;//key的值,元素位置信息 private ArrayWritable intPairArray; public matrixInputSplit() { } public matrixInputSplit(int row,matrix left,int col,matrix
right) { //填充intPairArray intPairArray=new ArrayWritable(IntPair.class); t=new IntPair[4]; location=new IntPair(row,col); for(int j=0;j<3;j++) { IntPair intPair=new IntPair(); intPair.setLeft(left.m[row][j]); intPair.setRight(right.m[j][col]); t[j]=intPair;
} t[3]=location; intPairArray.set(t); } @Override public long getLength() throws IOException, InterruptedException { return 0; } @Override public String[] getLocations() throws IOException, InterruptedException { return new String[]{}; //返回空 这样JobClient就不会从文件中读取split
} @Override public void readFields(DataInput arg0) throws IOException { this.intPairArray=new ArrayWritable(IntPair.class); this.intPairArray.readFields(arg0); } @Override public void write(DataOutput arg0) throws IOException { intPairArray.write(arg0); }
public IntPair getLocation() { t=new IntPair[4]; try { t=(IntPair[])intPairArray.toArray(); } catch (Exception e) { System.out.println("toArray excption"); } return t[3]; } public IntPair[] getIntPairs() { t=new IntPair[4]; try { t=(IntPair[])intPairArray.toArray();
} catch (Exception e) { System.out.println("toArray excption"); } IntPair[] intL=new IntPair[3]; for(int i=0;i<3;i++) { intL[i]=t[i]; } return intL; } }
Inputformat.class
这个类比较简单,只需要实现getSplit方法即可,不过需要用户自定义一个方法就是从getInputfile获得的路径来解析矩阵,输入到split中即可。
matrixMul.class
Java代码
public class MatrixNew
{
public static class MatrixMapper extends Mapper<IntPair, IntPair, IntPair, IntWritable>
{
public void map(IntPair key, IntPair value, Context context)
{
int left=0 ;
int right=0;
System.out.println("map is do");
left = value.getLeft();
right = value.getRight();
IntWritable result = new IntWritable(left * right); // key不变,
// value中的两个int相乘
try
{
context.write(key, result);
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
} catch (InterruptedException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
} // 输出kv对
}
}
public static class MatrixReducer extends Reducer<IntPair, IntWritable, IntPair, IntWritable>
{
private IntWritable result = new IntWritable();
public void reduce(IntPair key, Iterable<IntWritable> values, Context context)
{
int sum = 0;
for (IntWritable val : values)
{
int v = val.get();
sum += v;
}
result.set(sum);
try
{
context.write(key, result);
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
} catch (InterruptedException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
public static class FirstPartitioner extends Partitioner<IntPair, IntWritable>
{
public int getPartition(IntPair key, IntWritable value, int numPartitions)
{
int abs = Math.abs(key.getLeft()) % numPartitions;
// numPartitions是reduce线程的数量
return abs;
}
}
public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException
{
Configuration conf=new Configuration();
new GenericOptionsParser(conf, args);
FileSystem fs=FileSystem.get(conf);
Job job = new Job(conf, "New Matrix Multiply Job ");
job.setJarByClass(MatrixNew.class);
job.setNumReduceTasks(1);
job.setInputFormatClass(matrixInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setMapperClass(MatrixMapper.class);
job.setReducerClass(MatrixReducer.class);
job.setPartitionerClass(FirstPartitioner.class);
job.setMapOutputKeyClass(IntPair.class);
job.setMapOutputValueClass(IntWritable.class);
job.setOutputKeyClass(IntPair.class);
job.setOutputValueClass(IntWritable.class);
matrixInputFormat.setInputPath(args[0]);
FileOutputFormat.setOutputPath(job,new Path(fs.makeQualified(new Path("/newMartixoutput")).toString()));
boolean ok = job.waitForCompletion(true);
if(ok){ //删除临时文件
}
}
}
public class MatrixNew { public static class MatrixMapper extends Mapper<IntPair, IntPair, IntPair, IntWritable> { public void map(IntPair key, IntPair value, Context context) { int left=0 ; int right=0; System.out.println("map is do"); left = value.getLeft();
right = value.getRight(); IntWritable result = new IntWritable(left * right); // key不变, // value中的两个int相乘 try { context.write(key, result); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (InterruptedException e) {
// TODO Auto-generated catch block e.printStackTrace(); } // 输出kv对 } } public static class MatrixReducer extends Reducer<IntPair, IntWritable, IntPair, IntWritable> { private IntWritable result = new IntWritable(); public void reduce(IntPair key, Iterable<IntWritable>
values, Context context) { int sum = 0; for (IntWritable val : values) { int v = val.get(); sum += v; } result.set(sum); try { context.write(key, result); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (InterruptedException
e) { // TODO Auto-generated catch block e.printStackTrace(); } } } public static class FirstPartitioner extends Partitioner<IntPair, IntWritable> { public int getPartition(IntPair key, IntWritable value, int numPartitions) { int abs = Math.abs(key.getLeft())
% numPartitions; // numPartitions是reduce线程的数量 return abs; } } public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException { Configuration conf=new Configuration(); new GenericOptionsParser(conf, args); FileSystem
fs=FileSystem.get(conf); Job job = new Job(conf, "New Matrix Multiply Job "); job.setJarByClass(MatrixNew.class); job.setNumReduceTasks(1); job.setInputFormatClass(matrixInputFormat.class); job.setOutputFormatClass(TextOutputFormat.class); job.setMapperClass(MatrixMapper.class);
job.setReducerClass(MatrixReducer.class); job.setPartitionerClass(FirstPartitioner.class); job.setMapOutputKeyClass(IntPair.class); job.setMapOutputValueClass(IntWritable.class); job.setOutputKeyClass(IntPair.class); job.setOutputValueClass(IntWritable.class);
matrixInputFormat.setInputPath(args[0]); FileOutputFormat.setOutputPath(job,new Path(fs.makeQualified(new Path("/newMartixoutput")).toString())); boolean ok = job.waitForCompletion(true); if(ok){ //删除临时文件 } } }
以上代码只是简单测试下。。如有问题欢迎大家指正!这里先谢过!
第二个方法就是矩阵分块相乘,这个算法网上有大牛已经给出了源代码。。。