Home > @RealAdaBooster > learn.m

learn

PURPOSE ^

function [ab, err, itrErrors] = learn(ab, trnExamples, targetOuts, ...

SYNOPSIS ^

function [rab, clErr, itrErrors] = learn(rab, trnExamples, targetOuts,nStages, verbose, crossvalset, crossvalout, reqDetRate, reqErrBnd)

DESCRIPTION ^

 function [ab, err, itrErrors] = learn(ab, trnExamples, targetOuts, ...
    nStages, crossvalset, crossvalout, reqDetRate, reqErrBnd)

   learning function for the realAdaBooster

   Inputs:
       1 rab : real ada boost classifier instance
       2 trnExamples: examples for training the classifier.
       3 targetOuts: target classification outputs. its size must be the
           same as the number of examples.
       4 nStages: number of boosting stages required [default 10]
       5 verbose: prints into while classifier is working
       6 crossvalset : cross validation set
       7 crossvalout : output of cross validation set
       8 reqDetRate: target detection rate for this classifier
           which is the percentage of positive examples that are classified
       9 reqErrBnd: bound on miss classification error, training stops when 
           it is reached or the max number of stages is reached.
   Outputs:
       ab: trained realAdaBooster
       clErr: classification error of the trained classifier
       itrErrors : total error of real adaboost classifier at each iteration

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [rab, clErr, itrErrors] = learn(rab, trnExamples, targetOuts, ...
0002     nStages, verbose, crossvalset, crossvalout, reqDetRate, reqErrBnd)
0003 % function [ab, err, itrErrors] = learn(ab, trnExamples, targetOuts, ...
0004 %    nStages, crossvalset, crossvalout, reqDetRate, reqErrBnd)
0005 %
0006 %   learning function for the realAdaBooster
0007 %
0008 %   Inputs:
0009 %       1 rab : real ada boost classifier instance
0010 %       2 trnExamples: examples for training the classifier.
0011 %       3 targetOuts: target classification outputs. its size must be the
0012 %           same as the number of examples.
0013 %       4 nStages: number of boosting stages required [default 10]
0014 %       5 verbose: prints into while classifier is working
0015 %       6 crossvalset : cross validation set
0016 %       7 crossvalout : output of cross validation set
0017 %       8 reqDetRate: target detection rate for this classifier
0018 %           which is the percentage of positive examples that are classified
0019 %       9 reqErrBnd: bound on miss classification error, training stops when
0020 %           it is reached or the max number of stages is reached.
0021 %   Outputs:
0022 %       ab: trained realAdaBooster
0023 %       clErr: classification error of the trained classifier
0024 %       itrErrors : total error of real adaboost classifier at each iteration
0025 
0026 %% ----
0027 
0028 % fprintf like function that works only if condition(first parameter) is true
0029 function printfTrue(condition, varargin)
0030     if condition
0031         fprintf(varargin{:});
0032     end
0033 end
0034 
0035 %% Check Validity of Input
0036 if isempty(rab.weakCl)
0037     error('A weak classifier has not been specified');
0038 end
0039 
0040 if nargin < 3
0041     error('incorrect number of arguments');
0042 end
0043 
0044 if nargin < 4 || isempty(nStages),
0045     nStages = 10;
0046     rab.nStages = 10;
0047 else
0048     rab.nStages = nStages;    
0049 end
0050 
0051 if nargin < 5
0052     verbose = false;
0053 end
0054 
0055 if nargin > 5 && nargin < 7
0056     error('cross validation output is not given');
0057 end
0058 
0059 useCrossValidationSet = false;
0060 
0061 if nargin >= 6 && ~isempty(crossvalset) && ~isempty(crossvalout)
0062     useCrossValidationSet = true;
0063 end
0064 
0065 if nargin < 8 || isempty(reqDetRate)
0066     reqDetRate = NaN;
0067 end
0068 
0069 if nargin >= 9 && ~isempty(reqErrBnd)
0070     rab.errBound = reqErrBnd;
0071 end
0072 
0073 if nargout > 2 
0074     if isinf(nStages)
0075         errors('Cant record errors with infinite number of stages');
0076     else
0077         itrErrors = zeros(nStages, 1);
0078     end
0079 end
0080 
0081 printfTrue(verbose, 'RealAdaBoost\n');
0082 printfTrue(verbose, '============\n');
0083 
0084 %% Initialize Example Weights
0085 if isempty(rab.lastExWeights)
0086     printfTrue(verbose, 'RealAda boost Classifier wasnt trained before\n' ,1);
0087     wts = ones(length(targetOuts), 1) * 1/length(targetOuts);
0088 else
0089     % classifier was trained before for a different number of stages
0090     % initialize weights to the weights of the last stage trained before
0091     printfTrue(verbose, 'Real AdaBoost has been trained before\n');
0092     wts = rab.lastExWeights;
0093 end;
0094 
0095 %% Boosting Loop - Begin
0096 
0097 clErr = Inf;
0098 rab.thresh = 0;
0099 
0100 iStage = length(rab.trndCls) + 1;
0101 
0102 while (isinf(nStages) && clErr > rab.errBound) || ...
0103       (~isinf(nStages) && iStage <= nStages),
0104   
0105     printfTrue(verbose, '\n\t\t\t==========================\n');
0106     printfTrue(verbose, '\t\t\tRealAda Boosting Stage # %d\n', iStage);
0107     printfTrue(verbose, '\t\t\t==========================\n');
0108     
0109     % First : Fit classifier with prob estimates
0110     if useCrossValidationSet
0111         % the weak classifier must support cross validation set
0112         [trndCl, err] = learn(rab.weakCl, trnExamples, targetOuts, wts, ...
0113             crossvalset, crossvalout);
0114     else
0115         [trndCl, err] = learn(rab.weakCl, trnExamples, targetOuts, wts);
0116     end
0117     
0118     % avoid division by zero
0119     if err < 1e-9, 
0120         err = 1e-9;
0121     end
0122     
0123     printfTrue(verbose, ...
0124             '\t\t\tWeak Classifier has been trained, err = %f\n', err);
0125     
0126     % add the trained classifier to the set of trained classifiers
0127     rab.trndCls = {rab.trndCls{:}, trndCl};
0128     % calculate classifier weight
0129     beta = err / (1 - err);
0130     alpha = log(1 / beta) + log(getNumClasses(rab) - 1);
0131     rab.clsWeights = [rab.clsWeights, alpha];
0132     
0133     % Second : Get probabilty estimates of new weak classifier
0134     [~, prob] = computeOutputs(trndCl, trnExamples);
0135     
0136     % Third: update the weights of training examples to be ready for the next
0137     % stage
0138     
0139     % construct the Y Coding matrix (SAMME equation number 7)
0140     K = getNumClasses(rab);
0141     Y_CODE = ones(length(targetOuts), K) * (-1/(K-1));
0142     Y_CODE(sub2ind(size(Y_CODE), (1:length(targetOuts))' , targetOuts)) = 1;
0143         
0144     % avoid zero prob for the log and also negative prob (due to negative wts)
0145     prob(prob <= 0) = 1e-5;
0146     
0147     % dot product of Y_CODE and log Prob matrix multiplied with -1 * (k-1)/k
0148     estimator_wts = -1 * ((K-1)/K) * sum(Y_CODE .* log(prob), 2);
0149 
0150     %wts = wts .* exp(estimator_wts);
0151     wts = wts .* exp( (estimator_wts .* (+(estimator_wts < 0) |+(wts > 0) ) ) );
0152     
0153     % finally normalize wts
0154     wts = wts ./ sum(wts);
0155     
0156     if isinf(nStages) || nargout > 2
0157         % compute boosted classification error only if the number of stages
0158         % is not specified, otherwise, no need to compute it now
0159         abOuts = computeOutputs(rab, trnExamples);
0160         clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0161         printfTrue(verbose, '\t\t\tCurrently, boosted classifier''s error = %f\n', clErr);
0162         if ~isinf(nStages)
0163             itrErrors(iStage ,1) = clErr;
0164         end
0165     end
0166     
0167     iStage = iStage + 1;
0168 end
0169 
0170 if ~isinf(nStages) && nargout < 3
0171     % compute boosted classification error only if the number of stages
0172     % is specified, otherwise, it is already compute in the boosting loop
0173 
0174     abOuts = computeOutputs(rab, trnExamples);
0175     clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0176 end
0177 %% Boosting Loop - End
0178 %% Set Number of Stages and Last Set of Weighs Computed
0179 rab.lastExWeights = wts;
0180 
0181 printfTrue(verbose, '\t\t\tRealAdaBoost Training is Done with err = %f\n', clErr)
0182 end

Generated on Sun 29-Sep-2013 01:25:24 by m2html © 2005