Home > @AdaBooster > learn.m

learn

PURPOSE ^

function [ab, err] = learn(ab, trnExamples, targetOuts, nStages, reqDetRate)

SYNOPSIS ^

function [ab, clErr, itrErrors] = learn(ab, trnExamples, targetOuts,nStages, verbose, crossvalset, crossvalout, reqDetRate, reqErrBnd)

DESCRIPTION ^

 function [ab, err] = learn(ab, trnExamples, targetOuts, nStages, reqDetRate)
   learning function for the adaBooster

   Inputs:
       trnExamples: examples for training the classifier.
       targetOuts: target classification outputs. its size must be the
           same as the number of examples.
       nStages: number of boosting stages required [default 10]
       reqDetRate: target detection rate for this classifier
           which is the percentage of positive examples that are classified
       reqErrBnd: bound on miss classification error, training stops when 
           it is reached or the max number of stages is reached.
   Outputs:
       ab: trained adaBooster
       clErr: classification error of the trained classifier

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [ab, clErr, itrErrors] = learn(ab, trnExamples, targetOuts, ...
0002     nStages, verbose, crossvalset, crossvalout, reqDetRate, reqErrBnd)
0003 % function [ab, err] = learn(ab, trnExamples, targetOuts, nStages, reqDetRate)
0004 %   learning function for the adaBooster
0005 %
0006 %   Inputs:
0007 %       trnExamples: examples for training the classifier.
0008 %       targetOuts: target classification outputs. its size must be the
0009 %           same as the number of examples.
0010 %       nStages: number of boosting stages required [default 10]
0011 %       reqDetRate: target detection rate for this classifier
0012 %           which is the percentage of positive examples that are classified
0013 %       reqErrBnd: bound on miss classification error, training stops when
0014 %           it is reached or the max number of stages is reached.
0015 %   Outputs:
0016 %       ab: trained adaBooster
0017 %       clErr: classification error of the trained classifier
0018 
0019 %% ---
0020 % fprintf like function that works only if condition(first parameter) is true
0021 function printfTrue(condition, varargin)
0022     if condition
0023         fprintf(varargin{:});
0024     end
0025 end
0026 
0027 %% Check Validity of Input
0028 if isempty(ab.weakCl)
0029     error('A weak classifier has not been specified');
0030 end
0031 
0032 if nargin < 3
0033     error('incorrect number of arguments');
0034 end
0035 
0036 if nargin < 4 || isempty(nStages),
0037     nStages = 10;
0038 end
0039 
0040 if nargin < 5
0041     verbose = false;
0042 end
0043 
0044 if nargin > 5 && nargin < 7
0045     error('cross validation output is not given');
0046 end
0047 
0048 useCrossValidationSet = false;
0049 
0050 if nargin >= 6 && ~isempty(crossvalset) && ~isempty(crossvalout)
0051     useCrossValidationSet = true;
0052 end
0053 
0054 if nargin < 8 || isempty(reqDetRate)
0055     reqDetRate = NaN;
0056 end
0057 
0058 if nargin >= 9 && ~isempty(reqErrBnd)
0059     ab.errBound = reqErrBnd;
0060 end
0061 
0062 if nargout > 2 
0063     if isinf(nStages)
0064         errors('Cant record errors with infinite number of stages');
0065     else
0066         itrErrors = zeros(nStages, 1);
0067     end
0068 end
0069 
0070 printfTrue(verbose, 'AdaBoost\n');
0071 printfTrue(verbose, '========\n');
0072 
0073 %% Initialize Example Weights
0074 if isempty(ab.lastExWeights)
0075     % that classifier was not trained before
0076     
0077     % initialize weights to be uniform over each class
0078     wts = zeros(length(targetOuts), 1);
0079     nc = getNumClasses(ab);
0080     for c = 1:nc
0081         inds = targetOuts == c;
0082         numc = sum(inds);
0083         wts(inds) = 1 / (nc * numc);
0084     end
0085 %     % initialize weights to be uniform over all training examples
0086 %     wts = repmat(1 / length(targetOuts), length(targetOuts), 1);
0087 else
0088     % classifier was trained before for a different number of stages
0089     % initialize weights to the weights of the last stage trained before
0090     wts = ab.lastExWeights;
0091 end;
0092 
0093 %% Boosting Loop - Begin
0094 %%
0095 clErr = Inf;
0096 ab.thresh = 0;
0097 while (isinf(nStages) && clErr > ab.errBound) || ...
0098       (~isinf(nStages) && length(ab.clsWeights) < nStages),
0099 
0100     printfTrue(verbose, '\n\t\t\t==================\n');
0101     printfTrue(verbose, '\t\t\tBoosting Stage # %d\n', length(ab.clsWeights) + 1);
0102     printfTrue(verbose, '\t\t\t==================\n\n');
0103 
0104     if useCrossValidationSet
0105         % the weak classifier must support cross validation set
0106         [trndCl, err] = learn(ab.weakCl, trnExamples, targetOuts, wts, ...
0107             crossvalset, crossvalout);
0108     else
0109         [trndCl, err] = learn(ab.weakCl, trnExamples, targetOuts, wts);
0110     end
0111     
0112     if err < 1e-9, % that is to avoid division by zero later on
0113         % Division by zero may occur also if err is almost 1.
0114         % However, I assume here that the weak classifier does not produce
0115         %   error rate higher than 1-(1/number of classes).
0116         err = 1e-9;
0117     end
0118 
0119     printfTrue(verbose, ... 
0120             '\t\t\tWeak Classifier has been trained, err = %f\n', err);
0121 
0122     beta = err / (1 - err);
0123     % add the trained classifier to the set of trained classifiers
0124     ab.trndCls = {ab.trndCls{:}, trndCl}; %#ok<CCAT>
0125     % compute the weight of this classifier
0126     alpha = log(1 / beta) + log(getNumClasses(ab) - 1);
0127     ab.clsWeights = [ab.clsWeights, alpha];
0128     %ab.thresh = sum(ab.clsWeights) * (posVal + negVal) / 2;
0129     % update the weights of training examples to be ready for the next
0130     % stage
0131 
0132     outs = computeOutputs(trndCl, trnExamples);
0133 
0134     e = outs ~= targetOuts;
0135     wts = wts .* exp(alpha .* e);
0136     % normalize the weights
0137     wts = wts / sum(wts);
0138     if isinf(nStages) || nargout > 2
0139         % compute boosted classification error only if the number of stages
0140         % is not specified, otherwise, no need to compute it now
0141         abOuts = computeOutputs(ab, trnExamples);
0142         clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0143         printfTrue(verbose, ...
0144             '\t\t\tCurrently, boosted classifier''s error = %f\n', clErr);
0145         if ~isinf(nStages)
0146             itrErrors(length(ab.clsWeights) ,1) = clErr;
0147         end
0148     end
0149 end
0150 if ~isinf(nStages) && nargout < 3
0151     % compute boosted classification error only if the number of stages
0152     % is specified, otherwise, it is already compute in the boosting loop
0153     abOuts = computeOutputs(ab, trnExamples);
0154     clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0155 end
0156 %% Boosting Loop - End
0157 %% Set Number of Stages and Last Set of Weighs Computed
0158 ab.lastExWeights = wts;
0159 %% Adjust Threshold if Necessary (valid only in binary classification)
0160 if getNumClasses(ab) == 2
0161     if ~isnan(reqDetRate),
0162         [ab, clErr] = adjustThreshold(ab, trnExamples, targetOuts, ...
0163                                           reqDetRate);
0164     else
0165         ab.detectionRate = ...
0166             sum((targetOuts == getPosVal(ab)) & (abOuts == getPosVal(ab))) / ...
0167             sum((targetOuts == getPosVal(ab)));
0168     end
0169 end
0170 
0171 printfTrue(verbose, ...
0172         '\t\t\tAdaBoost Training is Done with err = %f\n', clErr);
0173 end

Generated on Sun 29-Sep-2013 01:25:24 by m2html © 2005