Request for Changes-16: Enhance confidence map for Random Forest classification

From OTBWiki
Jump to: navigation, search

[Request for Changes - 16] Enhance confidence map for Random Forest classification

Status

  • Author: Jordi Inglada
  • Additional Contributors :
  • Submitted on 21.10.2015
  • Proposed target release : 5.2
  • Git branch : rfc-16-rfconfmap

Summary

Provide a confidence map for Random Forest classification for multi-class problem

Rationale

There is one inconsistency for the Random Forest classifier. For all other classifiers which support the computation of a confidence map, the confidence value is higher when the chosen label is more reliable : sum of votes for Boosting, distance to margin (SVM), difference between the 2 higher responses (ANN).

However, in the RF case, the value used is the proportion of trees in the forest which chose the 2nd class (CvRTrees::predict_prob member function).

This 2nd class is not the class with lower votes, but just the 2nd class in terms of label order (class_idx). Since this only works for 2-class problems, a value of for instance 0.1 is actually a high confidence value, since the other (first) class will be selected by the classifier. But a value of 0.9 is also a high confidence value, since this 2nd class will be chosen. Another way to understand the problem is to imagine the case of a perfect classifier. The confidence map will be a binary image with ones for the pixels of the second class and zeroes for pixels of the first class.

In order to get a proper confidence value, the return value of CvRTrees::predict_prob has to be modified for instance as follows:

2*fabs(predict_prob - 0.5)

An this will yield a normalised confidence between 0 and 1.

Another issue is the fact that OpenCV only allows the use of predict_prob for 2-class problems. I don't know about other users, but we seldom use only 2 classes.

Creating a derived class allows to compute a confidence value for any number of classes. It is enough to retrieve the proportion of trees having selected the majority class. Another possibility would be to compute the difference of votes of the 2 majority classes. This second option provides also information about the level of conflict between the 2 majority classes. This would also cover the case of a 2-class problem.

Implementation details

The class otb::CvRTreesWrapper derives from OpenCV's CvRTrees and implements the predict_confidence method. The return value is the normalized difference of the votes of the 2 majority classes.

The otb::RandomForestsMachineLearningModel has been updated to use this class and call the predic_confidence method instead of predict_prob.

Classes and files

N       Modules/Learning/Supervised/include/otbCvRTreesWrapper.h
M       Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h
M       Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx      

Applications

Tests

Documentation

Additional notes

Full patch

diff --git a/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h
new file mode 100644
index 0000000..47f4560
--- /dev/null
+++ b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h
@@ -0,0 +1,69 @@
+/*=========================================================================
+
+  Program:   ORFEO Toolbox
+  Language:  C++
+  Date:      $Date$
+  Version:   $Revision$
+
+
+  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+  See OTBCopyright.txt for details.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef __otbCvRTreesWrapper_h
+#define __otbCvRTreesWrapper_h
+
+#include "otbOpenCVUtils.h"
+#include <vector>
+#include <algorithm>
+
+
+namespace otb
+{
+
+/** \class CvRTreesWrapper
+ * \brief Wrapper for OpenCV Random Trees
+ *
+ * \ingroup OTBSupervised
+ */
+class CV_EXPORTS_W CvRTreesWrapper : public CvRTrees
+{
+public:
+  CvRTreesWrapper(){};
+  virtual ~CvRTreesWrapper(){};
+  
+  /** Predict the confidence of the classifcation by computing the 
+      difference in votes between the first and second most voted classes.
+      This measure is preferred to the proportion of votes of the majority
+      class, since it provides information about the conflict between the
+      most likely classes.
+  */
+  float predict_confidence(const cv::Mat& sample, 
+                           const cv::Mat& missing = 
+                           cv::Mat()) const
+  {
+    std::vector<unsigned int> classVotes(nclasses);
+    for( int k = 0; k < ntrees; k++ )
+      {
+      CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
+      int class_idx = predicted_node->class_idx;
+      CV_Assert( 0 <= class_idx && class_idx < nclasses );
+      ++classVotes[class_idx];
+      }
+    // We only sort the 2 greatest elements
+    std::nth_element(classVotes.begin(), classVotes.begin()+1, 
+                     classVotes.end(), std::greater<>());
+    float confidence = static_cast<float>(classVotes[0]-classVotes[1])/ntrees;
+    return confidence;
+  };
+
+};
+
+}
+
+#endif
	Modified   Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h
diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h
index ec62b76..5872688 100644
--- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h
+++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h
@@ -24,8 +24,9 @@
 #include "itkFixedArray.h"
 #include "otbMachineLearningModel.h"
 #include "itkVariableSizeMatrix.h"
+#include "otbCvRTreesWrapper.h"
 
-class CvRTrees;
+class CvRTreesWrapper;
 
 namespace otb
 {
@@ -53,7 +54,7 @@ public:
 
 
   //opencv typedef
-  typedef CvRTrees RFType;
+  typedef CvRTreesWrapper RFType;
 
   /** Run-time type information (and related methods). */
   itkNewMacro(Self);
@@ -145,7 +146,7 @@ private:
   RandomForestsMachineLearningModel(const Self &); //purposely not implemented
   void operator =(const Self&); //purposely not implemented
 
-  CvRTrees * m_RFModel;
+  CvRTreesWrapper * m_RFModel;
   /** The depth of the tree. A low value will likely underfit and conversely a
    * high value will likely overfit. The optimal value can be obtained using cross
    * validation or other suitable methods. */
@@ -189,7 +190,7 @@ private:
    * first category. */
   std::vector<float> m_Priors;
   /** If true then variable importance will be calculated and then it can be
-   * retrieved by CvRTrees::get_var_importance(). */
+   * retrieved by CvRTreesWrapper::get_var_importance(). */
   bool m_CalculateVariableImportance;
   /** The size of the randomly selected subset of features at each tree node and
    * that are used to find the best split(s). If you set it to 0 then the size will
	Modified   Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx
diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx
index 78642f1..67d21f8 100644
--- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx
+++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx
@@ -29,17 +29,17 @@ namespace otb
 template <class TInputValue, class TOutputValue>
 RandomForestsMachineLearningModel<TInputValue,TOutputValue>
 ::RandomForestsMachineLearningModel() :
- m_RFModel (new CvRTrees),
- m_MaxDepth(5),
- m_MinSampleCount(10),
- m_RegressionAccuracy(0.01),
- m_ComputeSurrogateSplit(false),
- m_MaxNumberOfCategories(10),
- m_CalculateVariableImportance(false),
- m_MaxNumberOfVariables(0),
- m_MaxNumberOfTrees(100),
- m_ForestAccuracy(0.01),
- m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)
+  m_RFModel (new CvRTreesWrapper),
+  m_MaxDepth(5),
+  m_MinSampleCount(10),
+  m_RegressionAccuracy(0.01),
+  m_ComputeSurrogateSplit(false),
+  m_MaxNumberOfCategories(10),
+  m_CalculateVariableImportance(false),
+  m_MaxNumberOfVariables(0),
+  m_MaxNumberOfTrees(100),
+  m_ForestAccuracy(0.01),
+  m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)
 {
   this->m_ConfidenceIndex = true;
   this->m_IsRegressionSupported = true;
@@ -125,7 +125,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
 
   if (quality != NULL)
     {
-    (*quality) = m_RFModel->predict_prob(sample);
+    (*quality) = m_RFModel->predict_confidence(sample);
     }
 
   return target[0];
@@ -158,23 +158,23 @@ bool
 RandomForestsMachineLearningModel<TInputValue,TOutputValue>
 ::CanReadFile(const std::string & file)
 {
-   std::ifstream ifs;
-   ifs.open(file.c_str());
+  std::ifstream ifs;
+  ifs.open(file.c_str());
 
-   if(!ifs)
-   {
-      std::cerr<<"Could not read file "<<file<<std::endl;
-      return false;
-   }
+  if(!ifs)
+    {
+    std::cerr<<"Could not read file "<<file<<std::endl;
+    return false;
+    }
 
 
-   while (!ifs.eof())
-   {
-      std::string line;
-      std::getline(ifs, line);
+  while (!ifs.eof())
+    {
+    std::string line;
+    std::getline(ifs, line);
 
-      //if (line.find(m_RFModel->getName()) != std::string::npos)
-      if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos)
+    //if (line.find(m_RFModel->getName()) != std::string::npos)
+    if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos)
       {
          //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl;
          return true;