djl
djl copied to clipboard
open for inheritance
To implement an onnx export (see https://github.com/deepjavalibrary/djl/discussions/1420#discussioncomment-1931364) I would like to extend the export capability of existing block classes by inheritance. To do this, it would be very helpful if the DJL classes were opened up more to inheritance (the builders and the classes themselves).
This pull request is a suggestion on how this could happen. Since it simply expands access with respect to inheritance, it should have no impact on existing implementations that use the library.
Codecov Report
Merging #1451 (0ef5369) into master (bb5073f) will decrease coverage by
1.41%
. The diff coverage is49.53%
.
@@ Coverage Diff @@
## master #1451 +/- ##
============================================
- Coverage 72.08% 70.67% -1.42%
- Complexity 5126 5297 +171
============================================
Files 473 493 +20
Lines 21970 23222 +1252
Branches 2351 2533 +182
============================================
+ Hits 15838 16413 +575
- Misses 4925 5529 +604
- Partials 1207 1280 +73
Impacted Files | Coverage Δ | |
---|---|---|
api/src/main/java/ai/djl/modality/cv/Image.java | 69.23% <ø> (-4.11%) |
:arrow_down: |
...i/djl/modality/cv/translator/BigGANTranslator.java | 21.42% <ø> (-5.24%) |
:arrow_down: |
...odality/cv/translator/BigGANTranslatorFactory.java | 33.33% <0.00%> (+8.33%) |
:arrow_up: |
...nslator/InstanceSegmentationTranslatorFactory.java | 14.28% <0.00%> (-3.90%) |
:arrow_down: |
.../modality/cv/translator/YoloTranslatorFactory.java | 8.33% <0.00%> (-1.67%) |
:arrow_down: |
...i/djl/modality/cv/translator/YoloV5Translator.java | 5.69% <0.00%> (ø) |
|
...odality/cv/translator/YoloV5TranslatorFactory.java | 8.33% <0.00%> (-1.67%) |
:arrow_down: |
...pi/src/main/java/ai/djl/ndarray/BytesSupplier.java | 54.54% <0.00%> (-12.13%) |
:arrow_down: |
...i/src/main/java/ai/djl/ndarray/NDArrayAdapter.java | 15.23% <0.00%> (-0.83%) |
:arrow_down: |
...l/training/loss/SigmoidBinaryCrossEntropyLoss.java | 64.00% <0.00%> (ø) |
|
... and 173 more |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update ca691da...0ef5369. Read the comment docs.
In the example I set up for onnx export I am facing the following specific inheritance problems:
From the screenshot you see a particular problem with Conv2d with package privacy on the constructors and the final Builder class.
You can see Conv2d, Linear, LayerNorm, LamdaBlock which I could not inherit easily. In the pull request I also included other classes with likely similar situations ... as I would not like to bother you daily with some more classes.
To achieve an ONNX export for my MuZero implementation
https://enpasos.ai/muzero/How#onnx
I only "opened" the four classes Conv2d, Linear, LayerNorm, LambdaBlock a little as mentioned above.
What do you think about:
- I close this Pull Request
- I open a new Pull Request with only changes in the four classes
That would help me in the way that I could remove the patches of the four classes from my working MuZero ONNX export - which likely is not the best way to do it in general, but at least works for this case.
@enpasos
Sorry for the delay. I finally get some time to look into this. I look a bit in your repo, and I think the inherit solution doesn't provide a good user experience. It requires user to rewrite their model. I made some changes, it seems we can achieve the Onnx conversion by implementing a set of OnnxConverter
, see: https://github.com/frankfliu/DJL2OnnxExample/tree/converter/onnxModelGen/src/main/java/ai/enpasos/mnist/blocks
Of course, we need make some changes to DJL Block API to expose some internal information. The change can be minimal: https://github.com/frankfliu/djl/commit/a1b789f2f7379b91ada7bc7e001591e4fe7db20a
Let me know what do you think.
@frankfliu
Thank you for your expert look into the subject. Your converter approach gives a nice separation of concerns - much better than the inheritance approach I used. Looks promissing to achieve a good UX.
My approach was to see if I could manage to run MuZero Inference in the browser. Result: Works fine on the CPU via WebAssembly, but not on the GPU via WebGL, because there were too few onnx network functions supported in WebGL.
I would like to add some thoughts from the experience I made on that way:
- Translation of the network blocks was not always straight forward. For example the translation of the layer norm was not nice because there is not yet a corresponding onnx component. I ended up with a far from nice solution that gave the right mathematical behaviour misusing a batchnorm component. From this experience I deduce that it is important to allow for future replacement of particular converters.
- It could also be that the target system where the onnx should be used can not work with a particular onnx component or version. That also suggests to give the end user the freedom to exchange a converter implementation if needed.
I think this could be achieved by allowing the user to provide an implementation of the factory method OnnxConverter getConverter(Block block)
.
Two more thoughts:
-
the ParallelBlock takes an input for the join function. The solutions needs the flexibility that the writer of the join function can also provide the corresponding translation to onnx.
-
I got a comment by Craigacp who pointed out that they worked on an onnx solution in Tribuo that might help here. I didn't consider switching from my onnx approach to the Tribuo ideas because I almost achieved my MuZero export goal at that time.
@frankfliu
For me it would be perfect to switch to your approach.
To be able to map an onnx join function to a djl join function in ParallelBlocks it would be nice to have the option to provide a name string for the join function that can be used for the mapping.
I do not know what would be the best way to proceed. If you like you can close and open a new issue, or rename this issue.
What do you think?
Two reasons why I closed the old issue and opened a new one (https://github.com/deepjavalibrary/djl/pull/2231):
- I was reorganizing the branches of my fork
- I reduced the content of the change to what I really need