AMDMIGraphX
AMDMIGraphX copied to clipboard
Add softmax cross entropy
Adding softmaxcrossentropy op in MIGraphX. This is used in Llama v2
https://github.com/onnx/onnx/blob/main/docs/Operators.md#softmaxcrossentropyloss
This function is a bit complicated but is able to be built up by smaller Onnx ops
I've added the script I've used to generate the gold data for the verify test (all verify tests pass as of now)
Need to update parser tests to match the result as well as adding in K-Dimension verify tests, but looking to get this into review sooner than later incase there's anything egregious or something that needs to be changed here and I'll be gone this Friday for two weeks.
The initial motivation to implement this originated from the onnx ref implementation: https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_softmax_cross_entropy_loss.py
Operators like where and take() were used but it seems I was able to get around this
-
Use the fact that default weights are all 1, thus, for ignore index, apply the corresponding weight for the ignored class index to 0 and scale. This avoids the need for a where operator and reduces things to a simple mul. This also works in the case where we have non default, or 1 weights.
-
Use gather on the weight based on labels to mirror what take() is doing.
The test verification was done as follows
- Due to softmax being used we can sanity check this mathematically as all zero and all ones resolve to the same values
- Mean and sum shouldn't change based on scatter choices this more tests have been added for the no_reduction cases
- Ignore index and 1 weights are also added to scale classes accordingly
- use softmaxcrossentropyloss_gold_data_check.py to generate the appropriate gold data and verify the result (can be removed before PR is merged)
As some of the banter between @pfultz2 and have suggested, we can also reuse a lot of this and mirror these changes to get NegativeLogLikelihoodLoss (seperate PR) once we decide on this, by removing the softmax frontend to calculate loss.
So SoftmaxCrossEntropyLoss
is a onnx function which means it can be implemented using a composite of other onnx operators(like DynamicQuantizeLinear
). So we shouldn't need to add a new operator for this. I couldn't find a reference of the ONNX operators needed. However, onnxrt has transformation of LogSoftmax->[Cast]->NegativeLogLikelihoodLoss
to SoftmaxCrossEntropyLoss
:
https://github.com/microsoft/onnxruntime/blob/3e4db2c6869975a4e34a71a36c67c5a825f669e2/orttraining/orttraining/core/optimizer/loss_rewriter.h#L11
Furthermore NegativeLogLikelihoodLoss
is also an ONNX function as well, so we dont need to write a custom operator for this as well, although I couldn't find a reference of the ONNX operators, but looking at the ref implementation here:
https://github.com/onnx/onnx/blob/4e7289d05a08453b3f20efeb774d4590b5428903/onnx/reference/ops/op_negative_log_likelihood_loss.py#L74
It looks like this can be implemented using some combination of scatter+where+reshape+reduce(np.take
can be implemented using scatter).
So
SoftmaxCrossEntropyLoss
is a onnx function which means it can be implemented using a composite of other onnx operators(likeDynamicQuantizeLinear
). So we shouldn't need to add a new operator for this. I couldn't find a reference of the ONNX operators needed. However, onnxrt has transformation ofLogSoftmax->[Cast]->NegativeLogLikelihoodLoss
toSoftmaxCrossEntropyLoss
:https://github.com/microsoft/onnxruntime/blob/3e4db2c6869975a4e34a71a36c67c5a825f669e2/orttraining/orttraining/core/optimizer/loss_rewriter.h#L11
Ah I was set on finding a CrossEntropyLoss and didn't see anything and that's why I was writing the cross entropy piece.
Furthermore
NegativeLogLikelihoodLoss
is also an ONNX function as well, so we dont need to write a custom operator for this as well, although I couldn't find a reference of the ONNX operators, but looking at the ref implementation here:https://github.com/onnx/onnx/blob/4e7289d05a08453b3f20efeb774d4590b5428903/onnx/reference/ops/op_negative_log_likelihood_loss.py#L74
It looks like this can be implemented using some combination of scatter+where+reshape+reduce(
np.take
can be implemented using scatter).
Okay, so I guess we'll add a NegativeLossLikelihood Loss then for the SoftmaxCase, leverage the existing ops + softmax on the input?
I know you've mentioned SoftmaxLog() and splitting things by using softmax + log instead so in this case should we do a seperate scatter+where+reshape+reduce for each op or is it better to perform Softmax + NegativeLossLikelihood then?
Furthermore
NegativeLogLikelihoodLoss
is also an ONNX function as well, so we dont need to write a custom operator for this as well, although I couldn't find a reference of the ONNX operators, but looking at the ref implementation here: https://github.com/onnx/onnx/blob/4e7289d05a08453b3f20efeb774d4590b5428903/onnx/reference/ops/op_negative_log_likelihood_loss.py#L74 It looks like this can be implemented using some combination of scatter+where+reshape+reduce(np.take
can be implemented using scatter).Okay, so I guess we'll add a NegativeLossLikelihood Loss then for the SoftmaxCase, leverage the existing ops + softmax on the input?
No need to create a new operator such as NegativeLossLikelihood
since this can also be implemented using our current operators. We may want to make note of this because we could add support for the NegativeLossLikelihood
operator easily in the future reusing the same set of operators.
Furthermore
NegativeLogLikelihoodLoss
is also an ONNX function as well, so we dont need to write a custom operator for this as well, although I couldn't find a reference of the ONNX operators, but looking at the ref implementation here: https://github.com/onnx/onnx/blob/4e7289d05a08453b3f20efeb774d4590b5428903/onnx/reference/ops/op_negative_log_likelihood_loss.py#L74 It looks like this can be implemented using some combination of scatter+where+reshape+reduce(np.take
can be implemented using scatter).Okay, so I guess we'll add a NegativeLossLikelihood Loss then for the SoftmaxCase, leverage the existing ops + softmax on the input?
No need to create a new operator such as
NegativeLossLikelihood
since this can also be implemented using our current operators. We may want to make note of this because we could add support for theNegativeLossLikelihood
operator easily in the future reusing the same set of operators.
That's what I figured after going over the links you showed me. I guess this is kind of a 2:1 then. I've broken this out to a separate issue for now.
https://github.com/ROCm/AMDMIGraphX/issues/3013
Codecov Report
Attention: Patch coverage is 91.91176%
with 11 lines
in your changes missing coverage. Please review.
Project coverage is 92.05%. Comparing base (
7c2fdf5
) to head (af54c44
). Report is 3 commits behind head on develop.
Files with missing lines | Patch % | Lines |
---|---|---|
src/onnx/parse_softmaxcrossentropyloss.cpp | 91.91% | 11 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## develop #3008 +/- ##
=========================================
Coverage 92.04% 92.05%
=========================================
Files 505 506 +1
Lines 20699 20835 +136
=========================================
+ Hits 19052 19179 +127
- Misses 1647 1656 +9
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Rebasing this off master to restart work on this tomorrow
Need to fix parser tests but verify tests all pass and have been validated using the gold data. Gold data check leveraged Onnxruntime and numpy to ensure our answers are consistent for both in each case.
Let me know if you see/want any additional coverage on verification or if there's things we want to clean up further.
Is there any verify test which compares ref value with the equivalent gpu based output? Thanks.
Not needed here as this is an onnx function which is made up of the simpler onnx operator pieces which themselves have been validated between ref and gpu. The comparison I've done is use gold data from Onnxruntime CPU and numpy. This is similar to how we validate output in our driver and are used as part of the verify tests
Is there any verify test which compares ref value with the equivalent gpu based output? Thanks.
Not needed here as this is an onnx function which is made up of the simpler onnx operator pieces which themselves have been validated between ref and gpu. The comparison I've done is use gold data from Onnxruntime CPU and numpy. This is similar to how we validate output in our driver and are used as part of the verify tests
There are composite-layered verify tests e.g. test_transpose_reshape_add_sub_mul
, test_var_sl_gru_bidirect
that I just randomly point to. I think verification is important, and worthwhile to add here as well. Thanks.
Is there any verify test which compares ref value with the equivalent gpu based output? Thanks.
Not needed here as this is an onnx function which is made up of the simpler onnx operator pieces which themselves have been validated between ref and gpu. The comparison I've done is use gold data from Onnxruntime CPU and numpy. This is similar to how we validate output in our driver and are used as part of the verify tests
There are composite-layered verify tests e.g.
test_transpose_reshape_add_sub_mul
,test_var_sl_gru_bidirect
that I just randomly point to. I think verification is important, and worthwhile to add here as well. Thanks.
Is there any verify test which compares ref value with the equivalent gpu based output? Thanks.
Not needed here as this is an onnx function which is made up of the simpler onnx operator pieces which themselves have been validated between ref and gpu. The comparison I've done is use gold data from Onnxruntime CPU and numpy. This is similar to how we validate output in our driver and are used as part of the verify tests
There are composite-layered verify tests e.g.
test_transpose_reshape_add_sub_mul
,test_var_sl_gru_bidirect
that I just randomly point to. I think verification is important, and worthwhile to add here as well. Thanks.
@lakhinderwalia Added some of the composite tests in for k dimension and 2d cases (Batch, class size,...) & (Batch, class_size) May need to limit labels to a literal in this case instead of parameter since I can generate random literals in a certain range as labels have to be within [0, Class_size) . Let me know if you want me to paramaterize these further/ template them also based on reduction type. Will see in the morning.
For the k dimensional case I'd need to sort out the intermediate dimensions/gather calls for larger dimensions since we need to loop over and adjust the indices before the concat. Not hard but wanted to mirror the tests we do already as a baseline verify onnx tests
@lakhinderwalia Added some of the composite tests in for k dimension and 2d cases (Batch, class size,...) & (Batch, class_size) May need to limit labels to a literal in this case instead of parameter since I can generate random literals in a certain range as labels have to be within [0, Class_size) . Let me know if you want me to paramaterize these further/ template them also based on reduction type. Will see in the morning.
Thanks, @TedThemistokleous. I was suggesting very simple tests. Take a test graph with an an onnx operator "softmaxcrossentropy" -- no need to expose its entrails. And then just give it a sample data. Pass it through ref
and gpu
, and then just compare those results -- instead of a comparison to the (pre-calculated) golden data. Templated tests are just fine, for how you test is programmatically isn't the focus here.
Got all the end to end stuff working, seeing some odd behavior with the dnn1 based tests for double/float. Looks like their version of log doesn't support double which causes issues, along with seeing a small difference between mean/sum k -dimension cases which incur enough error to trigger a tolerance error.
Test | Batch | Rate new af54c4 |
Rate old 9dcea5 |
Diff | Compare |
---|---|---|---|---|---|
torchvision-resnet50 | 64 | 3,249.89 | 3,252.64 | -0.08% | :white_check_mark: |
torchvision-resnet50_fp16 | 64 | 6,990.42 | 6,985.91 | 0.06% | :white_check_mark: |
torchvision-densenet121 | 32 | 2,435.63 | 2,433.66 | 0.08% | :white_check_mark: |
torchvision-densenet121_fp16 | 32 | 4,115.55 | 4,106.94 | 0.21% | :white_check_mark: |
torchvision-inceptionv3 | 32 | 1,636.74 | 1,634.68 | 0.13% | :white_check_mark: |
torchvision-inceptionv3_fp16 | 32 | 2,739.42 | 2,740.90 | -0.05% | :white_check_mark: |
cadene-inceptionv4 | 16 | 776.43 | 775.86 | 0.07% | :white_check_mark: |
cadene-resnext64x4 | 16 | 809.44 | 808.52 | 0.11% | :white_check_mark: |
slim-mobilenet | 64 | 7,455.44 | 7,453.75 | 0.02% | :white_check_mark: |
slim-nasnetalarge | 64 | 208.22 | 208.36 | -0.07% | :white_check_mark: |
slim-resnet50v2 | 64 | 3,435.30 | 3,434.58 | 0.02% | :white_check_mark: |
bert-mrpc-onnx | 8 | 1,149.98 | 1,153.57 | -0.31% | :white_check_mark: |
bert-mrpc-tf | 1 | 315.63 | 308.33 | 2.37% | :white_check_mark: |
pytorch-examples-wlang-gru | 1 | 421.63 | 417.95 | 0.88% | :white_check_mark: |
pytorch-examples-wlang-lstm | 1 | 381.67 | 391.78 | -2.58% | :white_check_mark: |
torchvision-resnet50_1 | 1 | 815.47 | 817.84 | -0.29% | :white_check_mark: |
cadene-dpn92_1 | 1 | 400.37 | 399.56 | 0.20% | :white_check_mark: |
cadene-resnext101_1 | 1 | 381.02 | 381.54 | -0.14% | :white_check_mark: |
onnx-taau-downsample | 1 | 344.36 | 343.92 | 0.13% | :white_check_mark: |
dlrm-criteoterabyte | 1 | 35.10 | 35.06 | 0.12% | :white_check_mark: |
dlrm-criteoterabyte_fp16 | 1 | 58.10 | 58.10 | -0.00% | :white_check_mark: |
agentmodel | 1 | 8,129.82 | 7,712.60 | 5.41% | :high_brightness: |
unet_fp16 | 2 | 58.05 | 58.01 | 0.08% | :white_check_mark: |
resnet50v1_fp16 | 1 | 933.12 | 921.90 | 1.22% | :white_check_mark: |
resnet50v1_int8 | 1 | 919.35 | 925.93 | -0.71% | :white_check_mark: |
bert_base_cased_fp16 | 64 | 1,152.56 | 1,153.34 | -0.07% | :white_check_mark: |
bert_large_uncased_fp16 | 32 | 355.90 | 355.71 | 0.05% | :white_check_mark: |
bert_large_fp16 | 1 | 210.08 | 211.92 | -0.87% | :white_check_mark: |
distilgpt2_fp16 | 16 | 2,157.81 | 2,156.57 | 0.06% | :white_check_mark: |
yolov5s | 1 | 534.30 | 534.42 | -0.02% | :white_check_mark: |
tinyllama | 1 | 43.44 | 43.34 | 0.23% | :white_check_mark: |
vicuna-fastchat | 1 | 172.29 | 173.26 | -0.56% | :white_check_mark: |
whisper-tiny-encoder | 1 | 417.67 | 417.96 | -0.07% | :white_check_mark: |
whisper-tiny-decoder | 1 | 425.37 | 423.78 | 0.38% | :white_check_mark: |
Check results before merge :high_brightness:
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output