Request for Changes-65: Add Shark Random forests implementation
Contents
Status
- Author: Julien Michel, Jordi Inglada
- Additional Contributors: Victor Poughon, Emmanuelle Sarrazin
- Submitted on 23.09.2016
- Proposed target release: 5.8
- Adopted (+4 from Julien, Victor, Rémi, Guillaume)
- Merged : 77f80161e5bfb31a5ee5245a9250d8bd49bb73d2
Summary
This RFC adds Shark Machine Learning library as a new third party module. It provides a new specialization of MachineLearningModel to use Shark implementation of Random Forests algorithm. This work also required modifications of the internal API of the MachineLearningModel class, so as to benifit from the parallel prediction capabilities of the Shark implementation. Learning applications have been updated to allow using the new machine learning model, and Superbuild now builds Shark.
Rationale
See the corresponding Request for Comments for rationale on using Shark for machine learning tasks.
Implementation details
CMake
A fix has been made to CMake/OTBModuleEnablement.cmake
to include optional module dependencies in cycle detection.
Modules
ThirdParty/Shark
Import of Shark third party.
Classes and files
otb::MachineLearningModel
The rationale for changing the internal API of otb::MachineLearningModel
is as follows:
- The
PredictBatch()
method is not const, and therefore not thread-safe - Public interface was virtual, which is not recommended
- Depending on the algorithm (and often on the third party), inner implementations have different capabilities:
- Predict a single sample only (OpenCV, LibSVM)
- Predict a batch of samples at once (not encoutered yet)
- Predict a batch of samples, in parallel (Shark)
We had to find a design that will work for all cases. Here is what we did:
- Two new public, non-virtual const methods have been added:
-
Predict()
, which allows to predict a single sample -
PredictBatch()
, which allos to predict a batch of samples
-
- Two new private, virtual const methods have been added:
-
DoPredict()
, which is the actual implementation of single sample prediction. This is a pure virtual method. -
DoPredictBatch()
which is the actual implementation of batch sample prediction. Default implementation callsDoPredict()
iteratively.
-
- Public
Predict()
method callsDoPredict()
- Public
PredictBatch()
method callsDoPredictBatch()
. Its behaviour differs depending on the protected flagm_IsDoPredictBatchMultiThreaded
:- If true,
PredictBatch()
will callDoPredictBatch()
directly - If false, and if OTB is built with
-fopenmp
flag,PredictBatch()
will split the input batch in as many pieces as there are threads available, and perform parallel prediction.
- If true,
-
PredictAll()
method is marked as deprecated and its implementation now callsPredictBatch()
Here is a table summing up what method should be implemented depending on the algorithm / third party capabilities:
Capabilities | Implementation |
---|---|
Can only predict single samples | Implement DoPredict()
|
Can predict samples batch, no parallelism | Implement both DoPredict() and DoPredictBatch()
|
can predict samples batch, in parallel | Implement both DoPredict() and DoPredictBatch() and set m_IsDoPredictBatchMultiThreaded to true
|
In any case, PredictBatch()
is parallel provided that OTB has been built with -fopenmp
flag.
This new design has been applied to all existing models, plus the new Shark one.
Special thanks to Victor for helping me sorting things out and coming up with this design.
otb::ImageClassificationFilter
This filter has been modified to use the const PredictBatch()
method instead of calling Predict()
for each pixel. If OTB is built with -fopenmp
flag, then multi-threading is delegated to the PredictBatch()
method. Else, classical ITK multi-threading is used. This way, the ImageClassificationFilter
is seamlessly always multi-threaded.
otb::SharkRandomForestsMachineLearningModel
Practical implementation of the Shark Random Forests model, calls Shark.
Not much to say about it.
Note that since Shark uses boost::archive to serialize and deserialize the model file, it is not possible to implement a light CanRead()
method.
Also, we add to call the .name()
method of the Shark model in CanRead()
, becaus Boost was happily deserializing invalid files to an invalid class instance, throwing exception as soon as it was used.
otb::SharkRandomForestsMachineLearningModelFactory
Factory to build SharkRandomForestsMachineLearningModel from file.
We had to move up this factory before the KNN one, because the KNN factory accepts any kind of text file and thus prevented the correct factory to be found.
otbSharkUtils.h
Utility static methods to convert between ListSample
and Shark structures.
Applications
Impacted applications are the TrainVectorsClassifier and TrainImagesClassifier. Shark Random Forests have been added as a new algorithm option, along with its parameters. There is no code duplication since the implementation relies on otb::LearningApplicationBase
class.
Tests
Tests have been added for both the new Shark classes and the new option of the application.
One caveat we found is that Shark Random Forests implementation uses a random generator for which it is impossible to set the seed. As a result, produced model may slightly differ from one run to another and it is not possible to do regression testing on the model file. Regression testing has therefore been deactivating for Shark Random Forests test of the TrainImagesClassifier (other tests still do regression testing).
Branch is tested on the dashboard.
Note that only pc-christophe enables Shark in its configuration.
Documentation
List documentation modification that were made (doxygen, example, software guide, application documentation, cookbook).
Additional notes
This is only a first step, as there are many interesting algorithms we could use in Shark.