svm 10折交叉验证 matlab,matlab – 在一对一SVM中使用10倍交叉验证(使用LibSVM)

主要有两个原因我们做

cross-validation:

>作为一种测试方法,使我们几乎无偏估计我们的模型的泛化力(通过避免过度拟合)

>作为model selection的一种方式(例如:在训练数据中找到最佳的C和gamma参数,参见this post的一个例子)

对于我们感兴趣的第一种情况,该过程涉及对每个折叠训练k个模型,然后在整个训练集中训练一个最终模型.

我们报告平均精度在k折.

现在,由于我们使用一对一的方法来处理多类问题,每个模型由N个支持向量机(每个类一个)组成.

以下是实现一对一的方法的封装函数:

function mdl = libsvmtrain_ova(y, X, opts)

if nargin < 3, opts = ''; end

%# classes

labels = unique(y);

numLabels = numel(labels);

%# train one-against-all models

models = cell(numLabels,1);

for k=1:numLabels

models{k} = libsvmtrain(double(y==labels(k)), X, strcat(opts,' -b 1 -q'));

end

mdl = struct('models',{models}, 'labels',labels);

end

function [pred,acc,prob] = libsvmpredict_ova(y, X, mdl)

%# classes

labels = mdl.labels;

numLabels = numel(labels);

%# get probability estimates of test instances using each 1-vs-all model

prob = zeros(size(X,1), numLabels);

for k=1:numLabels

[~,~,p] = libsvmpredict(double(y==labels(k)), X, mdl.models{k}, '-b 1 -q');

prob(:,k) = p(:, mdl.models{k}.Label==1);

end

%# predict the class with the highest probability

[~,pred] = max(prob, [], 2);

%# compute classification accuracy

acc = mean(pred == y);

end

并且这里是支持交叉验证的功能:

function acc = libsvmcrossval_ova(y, X, opts, nfold, indices)

if nargin < 3, opts = ''; end

if nargin < 4, nfold = 10; end

if nargin < 5, indices = crossvalidation(y, nfold); end

%# N-fold cross-validation testing

acc = zeros(nfold,1);

for i=1:nfold

testIdx = (indices == i); trainIdx = ~testIdx;

mdl = libsvmtrain_ova(y(trainIdx), X(trainIdx,:), opts);

[~,acc(i)] = libsvmpredict_ova(y(testIdx), X(testIdx,:), mdl);

end

acc = mean(acc); %# average accuracy

end

function indices = crossvalidation(y, nfold)

%# stratified n-fold cros-validation

%#indices = crossvalind('Kfold', y, nfold); %# Bioinformatics toolbox

cv = cvpartition(y, 'kfold',nfold); %# Statistics toolbox

indices = zeros(size(y));

for i=1:nfold

indices(cv.test(i)) = i;

end

end

最后,这里是简单的演示来说明用法:

%# laod dataset

S = load('fisheriris');

data = zscore(S.meas);

labels = grp2idx(S.species);

%# cross-validate using one-vs-all approach

opts = '-s 0 -t 2 -c 1 -g 0.25'; %# libsvm training options

nfold = 10;

acc = libsvmcrossval_ova(labels, data, opts, nfold);

fprintf('Cross Validation Accuracy = %.4f%%\n', 100*mean(acc));

%# compute final model over the entire dataset

mdl = libsvmtrain_ova(labels, data, opts);

将其与由libsvm默认使用的一对一方法进行比较:

acc = libsvmtrain(labels, data, sprintf('%s -v %d -q',opts,nfold));

model = libsvmtrain(labels, data, strcat(opts,' -q'));