Home > @SVMClassifier > learn.m

learn

PURPOSE ^

LEARN learning function of svm classifier

SYNOPSIS ^

function [svmCl, clErr] = learn(svmCl, trnExamples, targetOuts, wts)

DESCRIPTION ^

 LEARN learning function of svm classifier
 function [svmCl, err] = learn(svmCl, trnExamples, targetOuts)
   learning function for the SVM classifier

   Inputs:
       trainingExamples: examples for training the classifier.
       targetOuts: target classification outputs. its size must be the
           same as the number of examples.
       wts: weights of training examples (used to compute the weighted
           classification error)
   Outputs:
       svmCl: trained SVMClassifier
       clErr: classification error of the trained classifier

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [svmCl, clErr] = learn(svmCl, trnExamples, targetOuts, wts)
0002 % LEARN learning function of svm classifier
0003 % function [svmCl, err] = learn(svmCl, trnExamples, targetOuts)
0004 %   learning function for the SVM classifier
0005 %
0006 %   Inputs:
0007 %       trainingExamples: examples for training the classifier.
0008 %       targetOuts: target classification outputs. its size must be the
0009 %           same as the number of examples.
0010 %       wts: weights of training examples (used to compute the weighted
0011 %           classification error)
0012 %   Outputs:
0013 %       svmCl: trained SVMClassifier
0014 %       clErr: classification error of the trained classifier
0015 
0016 %% Deal with Cell Array Input
0017 if iscell(trnExamples),
0018     trnExamples = cell2mat(trnExamples);
0019 end
0020 
0021 %% Create and Train the SVM
0022 % construct the weight vector for the svm
0023 nExamples = numel(targetOuts);
0024 if nargin < 4 || isempty(wts),
0025     wts = zeros(length(targetOuts), 1);
0026     nc = getNumClasses(svmCl);
0027     for c = 1:nc
0028         inds = targetOuts == c;
0029         numc = sum(inds);
0030         wts(inds) = 1 / (nc * numc);
0031     end
0032 %     wts = ones(nExamples, 1) / nExamples;
0033 end
0034 mxWts = max(wts);
0035 mnWts = min(wts);
0036 if mnWts == mxWts
0037     % make them all 10
0038     svmWts = repmat(10, nExamples, 1);
0039 else
0040     % use linear mapping from 1 to 20
0041     svmWts = 19 * (wts - min(wts)) / (max(wts) - min(wts)) + 1;
0042 end
0043 % svmWts = max(20 * wts / max(wts), 1e-2);
0044 % NOTE: libsvm-weights does not work properly if weights are small
0045 % values. In the example they have, all the weights are between 1 and
0046 % 20. It seems that when the weights add up to one, the computed
0047 % misclassification error becomes too small and the trainer terminates
0048 % without learning anything. So, here, I adjust the weights to be 20 at
0049 % maximum. Setting the minimum weight to 1e-2 is set to match a
0050 % restriction on a different implementation for SVM.
0051 % NOTE 2: I changed the weighting so that the minimum is 1, not 1e-2,
0052 % because the behavior of SVM was very poor when the difference between the
0053 % maximum and minimum weight is very large.
0054 
0055 % create and train the SVM
0056 
0057 svmCl.trainedSVM = svmtrain(svmWts, targetOuts, trnExamples, svmCl.libSvmTrnOpts);
0058 
0059 %% Compute Classification Error
0060 if nargout > 1
0061     [outs, ~, ~] = svmpredict(targetOuts, trnExamples, svmCl.trainedSVM, svmCl.libSvmPrdOpts);
0062     if nargin < 4,
0063         clErr = sum(outs ~= targetOuts) / nExamples;
0064     else
0065         clErr = (outs ~= targetOuts)' * wts;
0066     end
0067     if isnan(clErr),
0068         pause;
0069     end
0070 end

Generated on Sun 29-Sep-2013 01:25:24 by m2html © 2005