Request for Changes-65: Add Shark Random forests implementation

From OTBWiki
Jump to: navigation, search

Status

  • Author: Julien Michel, Jordi Inglada
  • Additional Contributors: Victor Poughon, Emmanuelle Sarrazin
  • Submitted on 23.09.2016
  • Proposed target release: 5.8
  • Adopted (+4 from Julien, Victor, Rémi, Guillaume)
  • Merged : 77f80161e5bfb31a5ee5245a9250d8bd49bb73d2

Summary

This RFC adds Shark Machine Learning library as a new third party module. It provides a new specialization of MachineLearningModel to use Shark implementation of Random Forests algorithm. This work also required modifications of the internal API of the MachineLearningModel class, so as to benifit from the parallel prediction capabilities of the Shark implementation. Learning applications have been updated to allow using the new machine learning model, and Superbuild now builds Shark.

Rationale

See the corresponding Request for Comments for rationale on using Shark for machine learning tasks.

Implementation details

CMake

A fix has been made to CMake/OTBModuleEnablement.cmake to include optional module dependencies in cycle detection.

Modules

ThirdParty/Shark

Import of Shark third party.

Classes and files

otb::MachineLearningModel

The rationale for changing the internal API of otb::MachineLearningModel is as follows:

  • The PredictBatch() method is not const, and therefore not thread-safe
  • Public interface was virtual, which is not recommended
  • Depending on the algorithm (and often on the third party), inner implementations have different capabilities:
    •  Predict a single sample only (OpenCV, LibSVM)
    •  Predict a batch of samples at once (not encoutered yet)
    •  Predict a batch of samples, in parallel (Shark)

We had to find a design that will work for all cases. Here is what we did:

  1. Two new public, non-virtual const methods have been added:
    1.  Predict(), which allows to predict a single sample
    2.  PredictBatch(), which allos to predict a batch of samples
  2. Two new private, virtual const methods have been added:
    1. DoPredict(), which is the actual implementation of single sample prediction. This is a pure virtual method.
    2. DoPredictBatch() which is the actual implementation of batch sample prediction. Default implementation calls DoPredict() iteratively.
  3. Public Predict() method calls DoPredict()
  4. Public PredictBatch() method calls DoPredictBatch(). Its behaviour differs depending on the protected flag m_IsDoPredictBatchMultiThreaded:
    1.  If true, PredictBatch() will call DoPredictBatch() directly
    2. If false, and if OTB is built with -fopenmp flag, PredictBatch() will split the input batch in as many pieces as there are threads available, and perform parallel prediction.
  5. PredictAll() method is marked as deprecated and its implementation now calls PredictBatch()

Here is a table summing up what method should be implemented depending on the algorithm / third party capabilities:

Capabilities Implementation
Can only predict single samples Implement DoPredict()
Can predict samples batch, no parallelism Implement both DoPredict() and DoPredictBatch()
can predict samples batch, in parallel Implement both DoPredict() and DoPredictBatch() and set m_IsDoPredictBatchMultiThreaded to true

In any case, PredictBatch() is parallel provided that OTB has been built with -fopenmp flag. This new design has been applied to all existing models, plus the new Shark one.

Special thanks to Victor for helping me sorting things out and coming up with this design.

otb::ImageClassificationFilter

This filter has been modified to use the const PredictBatch() method instead of calling Predict() for each pixel. If OTB is built with -fopenmp flag, then multi-threading is delegated to the PredictBatch() method. Else, classical ITK multi-threading is used. This way, the ImageClassificationFilter is seamlessly always multi-threaded.

otb::SharkRandomForestsMachineLearningModel

Practical implementation of the Shark Random Forests model, calls Shark.

Not much to say about it.

Note that since Shark uses boost::archive to serialize and deserialize the model file, it is not possible to implement a light CanRead() method.

Also, we add to call the .name() method of the Shark model in CanRead(), becaus Boost was happily deserializing invalid files to an invalid class instance, throwing exception as soon as it was used.

otb::SharkRandomForestsMachineLearningModelFactory

Factory to build SharkRandomForestsMachineLearningModel from file.

We had to move up this factory before the KNN one, because the KNN factory accepts any kind of text file and thus prevented the correct factory to be found.

otbSharkUtils.h

Utility static methods to convert between ListSample and Shark structures.

Applications

Impacted applications are the TrainVectorsClassifier and TrainImagesClassifier. Shark Random Forests have been added as a new algorithm option, along with its parameters. There is no code duplication since the implementation relies on otb::LearningApplicationBase class.

Tests

Tests have been added for both the new Shark classes and the new option of the application.

One caveat we found is that Shark Random Forests implementation uses a random generator for which it is impossible to set the seed. As a result, produced model may slightly differ from one run to another and it is not possible to do regression testing on the model file. Regression testing has therefore been deactivating for Shark Random Forests test of the TrainImagesClassifier (other tests still do regression testing).

Branch is tested on the dashboard.

Note that only pc-christophe enables Shark in its configuration.

Documentation

List documentation modification that were made (doxygen, example, software guide, application documentation, cookbook).

Additional notes

This is only a first step, as there are many interesting algorithms we could use in Shark.