djl icon indicating copy to clipboard operation
djl copied to clipboard

Bulk Batch creation

Open patins1 opened this issue 3 years ago • 2 comments
trafficstars

This pull requests optimizes Batch creation for ArrayDataset when using StackBatchifier . My DJL application learns the data now in 58550ms rather than 79843ms so 36% faster! For applications not using range-based indexing but indices-based indexing, my DJL application would run in 66341ms , so still 20% faster. This is true for PyTorch, i get a test failure for MxNet, so this solution maybe restricted to PyTorch

patins1 avatar Aug 04 '22 13:08 patins1

Codecov Report

Merging #1869 (76364ff) into master (bb5073f) will decrease coverage by 1.98%. The diff coverage is 68.26%.

@@             Coverage Diff              @@
##             master    #1869      +/-   ##
============================================
- Coverage     72.08%   70.10%   -1.99%     
- Complexity     5126     5867     +741     
============================================
  Files           473      576     +103     
  Lines         21970    26024    +4054     
  Branches       2351     2810     +459     
============================================
+ Hits          15838    18245    +2407     
- Misses         4925     6406    +1481     
- Partials       1207     1373     +166     
Impacted Files Coverage Δ
api/src/main/java/ai/djl/modality/cv/Image.java 69.23% <ø> (-4.11%) :arrow_down:
...rc/main/java/ai/djl/modality/cv/MultiBoxPrior.java 76.00% <ø> (ø)
...rc/main/java/ai/djl/modality/cv/output/Joints.java 71.42% <ø> (ø)
.../main/java/ai/djl/modality/cv/output/Landmark.java 100.00% <ø> (ø)
...main/java/ai/djl/modality/cv/output/Rectangle.java 72.41% <0.00%> (ø)
...i/djl/modality/cv/translator/BigGANTranslator.java 21.42% <0.00%> (-5.24%) :arrow_down:
...odality/cv/translator/BigGANTranslatorFactory.java 33.33% <0.00%> (+8.33%) :arrow_up:
...nslator/InstanceSegmentationTranslatorFactory.java 14.28% <0.00%> (-3.90%) :arrow_down:
.../cv/translator/SemanticSegmentationTranslator.java 0.00% <0.00%> (ø)
.../cv/translator/StyleTransferTranslatorFactory.java 40.00% <ø> (ø)
... and 479 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov-commenter avatar Aug 06 '22 13:08 codecov-commenter

The problems with MxNet is fixed by using "take" operation instead of "pick". TrainMnist executes now 4 times faster for MxNet .

patins1 avatar Aug 07 '22 04:08 patins1

Thanks for your contribution. Since all engines are fixed, we will take a look. So far LGTM

lanking520 avatar Aug 12 '22 21:08 lanking520

@KexinFeng can you help to check on the NDIndex part changes?

lanking520 avatar Aug 12 '22 21:08 lanking520

@patins1 It looks like BulkDataIterable is not covered by any unit test. You mentioned that using BulkDataIterable is more efficient. Could you add a unit test that covers this class you added?

KexinFeng avatar Aug 25 '22 21:08 KexinFeng

The problems with MxNet is fixed by using "take" operation instead of "pick". TrainMnist executes now 4 times faster for MxNet .

Could you integrate it with the existing take function in PR?

I tried to use take for MXNet engine too see here, and the PR tests all passed. But it caused the issue. Here in your edition, setting fullPick.setIndexTake(true); will trigger this problem again.

KexinFeng avatar Aug 25 '22 21:08 KexinFeng

@patins1 Thanks for dealing with this issue in such a timely manner! But it looks like the unittest that covers the new file api/src/main/java/ai/djl/training/dataset/BulkDataIterable.java is still missing. You mentioned you have tested it, which display efficiency increase. Could you add it into the proper unit test file?

Also you mentioned

The problems with MxNet is fixed by using "take" operation instead of "pick". TrainMnist executes now 4 times faster for MxNet .

It'd be better to have unit test for this too, which will prevent future edition from breaking this fix.

Thanks!

KexinFeng avatar Aug 27 '22 20:08 KexinFeng

Tests added

patins1 avatar Aug 28 '22 01:08 patins1

@siddvenk Hi Siddarth, in this pr, we have changed the definition of get(NDArray index) from pick to take, and have given the warning. This will affect the result you mentioned in https://github.com/deepjavalibrary/djl/issues/1800. See the test in commit "add index test and clean code" 8ffac00e6584c96500e6ac96ba35cc45513f1f4e. Pick can still be used though by addPickDim(). This change is for the purpose of making it consistent with numpy and pytorch engine, also it is more efficient as shown here. Should we update the relevant part of the Dive into Deep Learning for Java book mentioned in https://github.com/deepjavalibrary/djl/issues/1800?

KexinFeng avatar Aug 28 '22 08:08 KexinFeng