Home > @AdaBooster > adjustThreshold.m

adjustThreshold

PURPOSE ^

function [ab, clErr] = adjustThreshold(ab, trainingExamples, targetOuts, reqDetRate)

SYNOPSIS ^

function [ab, clErr, resDetRate] = adjustThreshold(ab, trainingExamples, targetOuts, reqDetRate)

DESCRIPTION ^

 function [ab, clErr] = adjustThreshold(ab, trainingExamples, targetOuts, reqDetRate)
    computes the threshold that achives a specific detection rate along
    with the associated classification error

   Inputs:
       trainingExamples: examples for training the classifier.
       targetOuts: target classification outputs. its size must be the
           same as the number of examples.
       reqDetRate: required detection rate
   Outputs:
       ab: adaBooster with adjusted threshold
       clErr: classification error
       resDetRate: actual detection rate after adjustment

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [ab, clErr, resDetRate] = adjustThreshold(ab, trainingExamples, targetOuts, reqDetRate)
0002 % function [ab, clErr] = adjustThreshold(ab, trainingExamples, targetOuts, reqDetRate)
0003 %    computes the threshold that achives a specific detection rate along
0004 %    with the associated classification error
0005 %
0006 %   Inputs:
0007 %       trainingExamples: examples for training the classifier.
0008 %       targetOuts: target classification outputs. its size must be the
0009 %           same as the number of examples.
0010 %       reqDetRate: required detection rate
0011 %   Outputs:
0012 %       ab: adaBooster with adjusted threshold
0013 %       clErr: classification error
0014 %       resDetRate: actual detection rate after adjustment
0015 
0016 [clOuts, realOuts] = computeOutputs(ab, trainingExamples);
0017 posVal = getPosVal(ab);
0018 negVal = getNegVal(ab);
0019 posIndxs = find(targetOuts == posVal);
0020 numPos = length(posIndxs);
0021 targetPos = ceil(numPos * reqDetRate);
0022 numDetPos = sum(clOuts == posVal & targetOuts == posVal);
0023 if targetPos > numDetPos,
0024     % the threshold needs to be adjusted to achieve the required detection
0025     % rate,
0026     cmpSign = sign(posVal - negVal);
0027     if cmpSign > 0,
0028         sortMode = 'ascend';
0029     else
0030         sortMode = 'descend';
0031     end
0032     sortedPosOuts = sort(realOuts(posIndxs), sortMode);
0033     ab.thresh = sortedPosOuts(numPos - targetPos + 1);
0034     clOuts = computeOutputs(ab, trainingExamples);
0035 end
0036 clErr = sum(clOuts ~= targetOuts) / length(clOuts);
0037 resDetRate = sum((targetOuts == posVal) & (clOuts == posVal)) / ...
0038              sum((targetOuts == posVal));
0039 ab.detectionRate = resDetRate;

Generated on Sun 29-Sep-2013 01:25:24 by m2html © 2005