Request for Changes-65: Add Shark Random forests implementation
- 1 Status
- 2 Summary
- 3 Rationale
- 4 Implementation details
- 4.1 CMake
- 4.2 Modules
- 4.3 Classes and files
- 4.4 Applications
- 4.5 Tests
- 4.6 Documentation
- 5 Additional notes
- Author: Julien Michel, Jordi Inglada
- Additional Contributors: Victor Poughon, Emmanuelle Sarrazin
- Submitted on 23.09.2016
- Proposed target release: 5.8
- Adopted (+4 from Julien, Victor, Rémi, Guillaume)
- Merged : 77f80161e5bfb31a5ee5245a9250d8bd49bb73d2
This RFC adds Shark Machine Learning library as a new third party module. It provides a new specialization of MachineLearningModel to use Shark implementation of Random Forests algorithm. This work also required modifications of the internal API of the MachineLearningModel class, so as to benifit from the parallel prediction capabilities of the Shark implementation. Learning applications have been updated to allow using the new machine learning model, and Superbuild now builds Shark.
See the corresponding Request for Comments for rationale on using Shark for machine learning tasks.
A fix has been made to
CMake/OTBModuleEnablement.cmake to include optional module dependencies in cycle detection.
Import of Shark third party.
Classes and files
The rationale for changing the internal API of
otb::MachineLearningModel is as follows:
PredictBatch()method is not const, and therefore not thread-safe
- Public interface was virtual, which is not recommended
- Depending on the algorithm (and often on the third party), inner implementations have different capabilities:
- Predict a single sample only (OpenCV, LibSVM)
- Predict a batch of samples at once (not encoutered yet)
- Predict a batch of samples, in parallel (Shark)
We had to find a design that will work for all cases. Here is what we did:
- Two new public, non-virtual const methods have been added:
Predict(), which allows to predict a single sample
PredictBatch(), which allos to predict a batch of samples
- Two new private, virtual const methods have been added:
DoPredict(), which is the actual implementation of single sample prediction. This is a pure virtual method.
DoPredictBatch()which is the actual implementation of batch sample prediction. Default implementation calls
DoPredictBatch(). Its behaviour differs depending on the protected flag
- If true,
- If false, and if OTB is built with
PredictBatch()will split the input batch in as many pieces as there are threads available, and perform parallel prediction.
- If true,
PredictAll()method is marked as deprecated and its implementation now calls
Here is a table summing up what method should be implemented depending on the algorithm / third party capabilities:
|Can only predict single samples||Implement |
|Can predict samples batch, no parallelism||Implement both |
|can predict samples batch, in parallel||Implement both |
In any case,
PredictBatch() is parallel provided that OTB has been built with
This new design has been applied to all existing models, plus the new Shark one.
Special thanks to Victor for helping me sorting things out and coming up with this design.
This filter has been modified to use the
const PredictBatch() method instead of calling
Predict() for each pixel. If OTB is built with
-fopenmp flag, then multi-threading is delegated to the
PredictBatch() method. Else, classical ITK multi-threading is used. This way, the
ImageClassificationFilter is seamlessly always multi-threaded.
Practical implementation of the Shark Random Forests model, calls Shark.
Not much to say about it.
Note that since Shark uses boost::archive to serialize and deserialize the model file, it is not possible to implement a light
Also, we add to call the
.name() method of the Shark model in
CanRead(), becaus Boost was happily deserializing invalid files to an invalid class instance, throwing exception as soon as it was used.
Factory to build SharkRandomForestsMachineLearningModel from file.
We had to move up this factory before the KNN one, because the KNN factory accepts any kind of text file and thus prevented the correct factory to be found.
Utility static methods to convert between
ListSample and Shark structures.
Impacted applications are the TrainVectorsClassifier and TrainImagesClassifier. Shark Random Forests have been added as a new algorithm option, along with its parameters. There is no code duplication since the implementation relies on
Tests have been added for both the new Shark classes and the new option of the application.
One caveat we found is that Shark Random Forests implementation uses a random generator for which it is impossible to set the seed. As a result, produced model may slightly differ from one run to another and it is not possible to do regression testing on the model file. Regression testing has therefore been deactivating for Shark Random Forests test of the TrainImagesClassifier (other tests still do regression testing).
Branch is tested on the dashboard.
Note that only pc-christophe enables Shark in its configuration.
List documentation modification that were made (doxygen, example, software guide, application documentation, cookbook).
This is only a first step, as there are many interesting algorithms we could use in Shark.