Home > @LogitBooster > learn.m



function [lb, err] = learn(lb, trnExamples, targetOuts, nStages, reqDetRate)


function [lb, clErr] = learn(lb, trnExamples, targetOuts, nStages, reqDetRate, reqErrBnd)


 function [lb, err] = learn(lb, trnExamples, targetOuts, nStages, reqDetRate)
   learning function for the Logit boost 

       trnExamples: examples for training the classifier.
       targetOuts: target classification outputs. its size must be the
           same as the number of examples.
       nStages: number of boosting stages required [default 10]
       reqDetRate: target detection rate for this classifier
           which is the percentage of positive examples that are classified
       reqErrBnd: bound on miss classification error, training stops when 
           it is reached or the max number of stages is reached.
       ab: trained logitBoost
       clErr: classification error of the trained classifier


This function calls: This function is called by:


0001 function [lb, clErr] = learn(lb, trnExamples, targetOuts, nStages, reqDetRate, reqErrBnd)
0002 % function [lb, err] = learn(lb, trnExamples, targetOuts, nStages, reqDetRate)
0003 %   learning function for the Logit boost
0004 %
0005 %   Inputs:
0006 %       trnExamples: examples for training the classifier.
0007 %       targetOuts: target classification outputs. its size must be the
0008 %           same as the number of examples.
0009 %       nStages: number of boosting stages required [default 10]
0010 %       reqDetRate: target detection rate for this classifier
0011 %           which is the percentage of positive examples that are classified
0012 %       reqErrBnd: bound on miss classification error, training stops when
0013 %           it is reached or the max number of stages is reached.
0014 %   Outputs:
0015 %       ab: trained logitBoost
0016 %       clErr: classification error of the trained classifier
0018 %% ----
0019 disp('============');
0020 disp('Logit Boost')
0021 disp('============');
0022 %% Check Validity of Input
0023 if nargin < 3
0024     error('incorrect number of arguments');
0025 end
0027 if nargin < 4 || isempty(nStages),
0028     nStages = 10;
0029 end
0031 if nargin < 5 || isempty(reqDetRate)
0032     reqDetRate = NaN;
0033 end
0035 if nargin >= 6 && ~isempty(reqErrBnd)
0036     lb.errBound = reqErrBnd;
0037 end
0039 %% Initialize Example Weights
0040 if isempty(lb.lastExWeights)
0041     wts = ones(length(targetOuts), 1) * 1/length(targetOuts);
0042 else
0043     % classifier was trained before for a different number of stages
0044     % initialize weights to the weights of the last stage trained before
0045     disp('classifier was trained before');
0046     wts = lb.lastExWeights;
0047 end;
0049 %% Boosting Loop - Begin
0050 %%
0051 clErr = Inf;
0052 lb.thresh = 0;
0054 %handel cell input
0055 trainData = trnExamples;
0056 if iscell(trainData),
0057     trainData = cell2mat(trainData);
0058 end
0060 % init F and P
0061 J = getNumClasses(lb);
0062 if lb.nStages == 0
0063     lb.F = zeros(J, size(trainData,2) + 1);      % (J,n)
0064     lb.P = ones(size(trainData,1), J) * 1/J;        % (m,n)
0065 end
0067 iStage = lb.nStages;   %stage iterator
0069 while (isinf(nStages) && clErr > lb.errBound) || ...
0070       (~isinf(nStages) && iStage < nStages),
0072     fprintf('Iteration %d\n', iStage);
0073     z = zeros(size(trainData,1), J);
0074     w = zeros(size(trainData,1), J);
0075     Fj = zeros(J, size(trainData,2) + 1);
0076     sumFj = zeros(J, size(trainData,2) + 1);
0077     breakLoop = false;
0078     for j = 1:J
0079         JVec = ones(size(trainData,1),1) * j;
0080         YStar = JVec == targetOuts;
0082         z(:,j) = (YStar - lb.P(:,j)) ./ ( (lb.P(:,j)) .* ( 1 - lb.P(:,j)) );
0083         w(:,j) = ( (lb.P(:,j)) .* ( 1 - lb.P(:,j)) );
0085         % break if nan values was detected
0086         if (sum(isnan(w(:,j))) > 0 || sum(isnan(z(:,j))) > 0)
0087             breakLoop = true;
0088             break;
0089         end
0091         regressor = learn(lb.regressor, trainData, z(:,j), w(:,j));
0092         Fj(j,:) = getWeights(regressor);
0094         % break if nan values was detected
0095         if (sum(isnan(Fj(j,:))) > 0)
0096             breakLoop = true;
0097             break;
0098         end
0100         sumFj(j,:) = sumFj(j,:) + Fj(j,:);
0101     end
0103     if (~breakLoop)
0104         Fj = ((J-1)/J)*(Fj - (1/J) * sumFj );
0105         lb.F = lb.F + Fj;
0106     else
0107         fprintf('next iterations (after %d) will have no effects due to nan\n' ...
0108          ,iStage+1);
0110     end
0113     Fvalues = [trainData ones(size(trainData,1),1)] * (lb.F)';
0115     eFvalues = exp(Fvalues);
0116     eps = 1e-6;
0117     lb.P = bsxfun(@times, eFvalues, 1 ./ max(sum(eFvalues,2), eps ));
0119     %-------------------------------------------------
0120     if isinf(nStages)
0121         % compute boosted classification error only if the number of stages
0122         % is not specified, otherwise, no need to compute it now
0124         abOuts = computeOutputs(lb, trainData);
0126         clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0127         fprintf('\t\t\tCurrently, boosted classifier''s error = %f\n', clErr);
0128     end
0130     iStage = iStage + 1;
0131     lb.nStages = lb.nStages + 1;
0132 end
0133 if ~isinf(nStages)
0134     % compute boosted classification error only if the number of stages
0135     % is specified, otherwise, it is already compute in the boosting loop
0136     abOuts = computeOutputs(lb, trainData);
0137     clErr = sum(abOuts ~= targetOuts) / length(abOuts);
0138 end
0139 %% Boosting Loop - End
0140 %% Set Number of Stages and Last Set of Weighs Computed
0141 lb.lastExWeights = wts;
0143 %% Adjust Threshold if Necessary (valid only in binary classification)
0144 if getNumClasses(lb) == 2
0145     if ~isnan(reqDetRate),
0146         if useGlobalExamples,
0147             [lb, clErr] = adjustThreshold(lb, [], targetOuts, reqDetRate);
0148         else
0149             [lb, clErr] = adjustThreshold(lb, trainingExamples, targetOuts, ...
0150                                           reqDetRate);
0151         end
0152     else
0153         lb.detectionRate = ...
0154             sum((targetOuts == getPosVal(lb)) & (abOuts == getPosVal(lb))) / ...
0155             sum((targetOuts == getPosVal(lb)));
0156     end
0157 end
0158 fprintf('\t\t\tLogit Boost Training is Done with err = %f\n', clErr);

Generated on Sun 29-Sep-2013 01:25:24 by m2html © 2005