现在的位置: 首页 > 综合 > 正文

MNIST手写数字体分类–KNN matlab实现

2017年12月16日 ⁄ 综合 ⁄ 共 2182字 ⁄ 字号 评论关闭

关于数据集神马的,请直接参考:http://blog.csdn.net/wangyuquanliuli/article/details/11606435

这里直接给出KNN matlab的实现

trainImages = loadMNISTImages('train-images.idx3-ubyte');      
trainLabels = loadMNISTLabels('train-labels.idx1-ubyte');
N = 784;
K = 100;% can be any other value
testImages = loadMNISTImages('t10k-images.idx3-ubyte');
testLabels = loadMNISTLabels('t10k-labels.idx1-ubyte');
trainLength = length(trainImages);
testLength = length(testImages);
testResults = linspace(0,0,length(testImages));
compLabel = linspace(0,0,K);
tic;
for i=1:testLength
    curImage = repmat(testImages(:,i),1,trainLength);
    curImage = abs(trainImages-curImage);
    comp=sum(curImage);
    [sortedComp,ind] = sort(comp);
    for j = 1:K
        compLabel(j) = trainLabels(ind(j));
    end
    table = tabulate(compLabel);
    [maxCount,idx] = max(table(:,2));
    testResults(i) = table(idx);  

    disp(testResults(i));
    disp(testLabels(i));
end
% Compute the error on the test set
error=0;
for i=1:testLength
  if (testResults(i) ~= testLabels(i))
    error=error+1;
  end
end

%Print out the classification error on the test set
error/testLength
toc;
disp(toc-tic);
        

其中训练数据60000条,测试数据10000条

运行时间慢的原因分析:

没有进行主成分分析,用所有的维度在进行比较,这一点是可以改进的地方:)

附上其他次要代码:

function images = loadMNISTImages(filename)
%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing
%the raw MNIST images

fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);

magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2051, ['Bad magic number in ', filename, '']);

numImages = fread(fp, 1, 'int32', 0, 'ieee-be');
numRows = fread(fp, 1, 'int32', 0, 'ieee-be');
numCols = fread(fp, 1, 'int32', 0, 'ieee-be');

images = fread(fp, inf, 'unsigned char');
images = reshape(images, numCols, numRows, numImages);
images = permute(images,[2 1 3]);

fclose(fp);

% Reshape to #pixels x #examples
images = reshape(images, size(images, 1) * size(images, 2), size(images, 3));
% Convert to double and rescale to [0,1]
images = double(images) / 255;

end

function labels = loadMNISTLabels(filename)
%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing
%the labels for the MNIST images

fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);

magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2049, ['Bad magic number in ', filename, '']);

numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');

labels = fread(fp, inf, 'unsigned char');

assert(size(labels,1) == numLabels, 'Mismatch in label count');

fclose(fp);

end

抱歉!评论已关闭.