AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Add softmax cross entropy

Open TedThemistokleous opened this issue 9 months ago • 7 comments

Adding softmaxcrossentropy op in MIGraphX. This is used in Llama v2

https://github.com/onnx/onnx/blob/main/docs/Operators.md#softmaxcrossentropyloss

This function is a bit complicated but is able to be built up by smaller Onnx ops

I've added the script I've used to generate the gold data for the verify test (all verify tests pass as of now)

Need to update parser tests to match the result as well as adding in K-Dimension verify tests, but looking to get this into review sooner than later incase there's anything egregious or something that needs to be changed here and I'll be gone this Friday for two weeks.

The initial motivation to implement this originated from the onnx ref implementation: https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_softmax_cross_entropy_loss.py

Operators like where and take() were used but it seems I was able to get around this

  1. Use the fact that default weights are all 1, thus, for ignore index, apply the corresponding weight for the ignored class index to 0 and scale. This avoids the need for a where operator and reduces things to a simple mul. This also works in the case where we have non default, or 1 weights.

  2. Use gather on the weight based on labels to mirror what take() is doing.

The test verification was done as follows

  • Due to softmax being used we can sanity check this mathematically as all zero and all ones resolve to the same values
  • Mean and sum shouldn't change based on scatter choices this more tests have been added for the no_reduction cases
  • Ignore index and 1 weights are also added to scale classes accordingly
  • use softmaxcrossentropyloss_gold_data_check.py to generate the appropriate gold data and verify the result (can be removed before PR is merged)

As some of the banter between @pfultz2 and have suggested, we can also reuse a lot of this and mirror these changes to get NegativeLogLikelihoodLoss (seperate PR) once we decide on this, by removing the softmax frontend to calculate loss.

TedThemistokleous avatar Apr 26 '24 19:04 TedThemistokleous

So SoftmaxCrossEntropyLoss is a onnx function which means it can be implemented using a composite of other onnx operators(like DynamicQuantizeLinear). So we shouldn't need to add a new operator for this. I couldn't find a reference of the ONNX operators needed. However, onnxrt has transformation of LogSoftmax->[Cast]->NegativeLogLikelihoodLoss to SoftmaxCrossEntropyLoss:

https://github.com/microsoft/onnxruntime/blob/3e4db2c6869975a4e34a71a36c67c5a825f669e2/orttraining/orttraining/core/optimizer/loss_rewriter.h#L11

Furthermore NegativeLogLikelihoodLoss is also an ONNX function as well, so we dont need to write a custom operator for this as well, although I couldn't find a reference of the ONNX operators, but looking at the ref implementation here:

https://github.com/onnx/onnx/blob/4e7289d05a08453b3f20efeb774d4590b5428903/onnx/reference/ops/op_negative_log_likelihood_loss.py#L74

It looks like this can be implemented using some combination of scatter+where+reshape+reduce(np.take can be implemented using scatter).

pfultz2 avatar Apr 29 '24 16:04 pfultz2

So SoftmaxCrossEntropyLoss is a onnx function which means it can be implemented using a composite of other onnx operators(like DynamicQuantizeLinear). So we shouldn't need to add a new operator for this. I couldn't find a reference of the ONNX operators needed. However, onnxrt has transformation of LogSoftmax->[Cast]->NegativeLogLikelihoodLoss to SoftmaxCrossEntropyLoss:

https://github.com/microsoft/onnxruntime/blob/3e4db2c6869975a4e34a71a36c67c5a825f669e2/orttraining/orttraining/core/optimizer/loss_rewriter.h#L11

Ah I was set on finding a CrossEntropyLoss and didn't see anything and that's why I was writing the cross entropy piece.

Furthermore NegativeLogLikelihoodLoss is also an ONNX function as well, so we dont need to write a custom operator for this as well, although I couldn't find a reference of the ONNX operators, but looking at the ref implementation here:

https://github.com/onnx/onnx/blob/4e7289d05a08453b3f20efeb774d4590b5428903/onnx/reference/ops/op_negative_log_likelihood_loss.py#L74

It looks like this can be implemented using some combination of scatter+where+reshape+reduce(np.take can be implemented using scatter).

Okay, so I guess we'll add a NegativeLossLikelihood Loss then for the SoftmaxCase, leverage the existing ops + softmax on the input?

I know you've mentioned SoftmaxLog() and splitting things by using softmax + log instead so in this case should we do a seperate scatter+where+reshape+reduce for each op or is it better to perform Softmax + NegativeLossLikelihood then?

TedThemistokleous avatar Apr 29 '24 16:04 TedThemistokleous

Furthermore NegativeLogLikelihoodLoss is also an ONNX function as well, so we dont need to write a custom operator for this as well, although I couldn't find a reference of the ONNX operators, but looking at the ref implementation here: https://github.com/onnx/onnx/blob/4e7289d05a08453b3f20efeb774d4590b5428903/onnx/reference/ops/op_negative_log_likelihood_loss.py#L74 It looks like this can be implemented using some combination of scatter+where+reshape+reduce(np.take can be implemented using scatter).

Okay, so I guess we'll add a NegativeLossLikelihood Loss then for the SoftmaxCase, leverage the existing ops + softmax on the input?

No need to create a new operator such as NegativeLossLikelihood since this can also be implemented using our current operators. We may want to make note of this because we could add support for the NegativeLossLikelihood operator easily in the future reusing the same set of operators.

pfultz2 avatar Apr 29 '24 18:04 pfultz2

Furthermore NegativeLogLikelihoodLoss is also an ONNX function as well, so we dont need to write a custom operator for this as well, although I couldn't find a reference of the ONNX operators, but looking at the ref implementation here: https://github.com/onnx/onnx/blob/4e7289d05a08453b3f20efeb774d4590b5428903/onnx/reference/ops/op_negative_log_likelihood_loss.py#L74 It looks like this can be implemented using some combination of scatter+where+reshape+reduce(np.take can be implemented using scatter).

Okay, so I guess we'll add a NegativeLossLikelihood Loss then for the SoftmaxCase, leverage the existing ops + softmax on the input?

No need to create a new operator such as NegativeLossLikelihood since this can also be implemented using our current operators. We may want to make note of this because we could add support for the NegativeLossLikelihood operator easily in the future reusing the same set of operators.

That's what I figured after going over the links you showed me. I guess this is kind of a 2:1 then. I've broken this out to a separate issue for now.

https://github.com/ROCm/AMDMIGraphX/issues/3013

TedThemistokleous avatar Apr 29 '24 18:04 TedThemistokleous

Codecov Report

Attention: Patch coverage is 91.91176% with 11 lines in your changes missing coverage. Please review.

Project coverage is 92.05%. Comparing base (7c2fdf5) to head (af54c44). Report is 3 commits behind head on develop.

Files with missing lines Patch % Lines
src/onnx/parse_softmaxcrossentropyloss.cpp 91.91% 11 Missing :warning:
Additional details and impacted files
@@            Coverage Diff            @@
##           develop    #3008    +/-   ##
=========================================
  Coverage    92.04%   92.05%            
=========================================
  Files          505      506     +1     
  Lines        20699    20835   +136     
=========================================
+ Hits         19052    19179   +127     
- Misses        1647     1656     +9     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Apr 30 '24 17:04 codecov[bot]

Rebasing this off master to restart work on this tomorrow

TedThemistokleous avatar Aug 21 '24 01:08 TedThemistokleous

Need to fix parser tests but verify tests all pass and have been validated using the gold data. Gold data check leveraged Onnxruntime and numpy to ensure our answers are consistent for both in each case.

Let me know if you see/want any additional coverage on verification or if there's things we want to clean up further.

TedThemistokleous avatar Aug 27 '24 16:08 TedThemistokleous

Is there any verify test which compares ref value with the equivalent gpu based output? Thanks.

Not needed here as this is an onnx function which is made up of the simpler onnx operator pieces which themselves have been validated between ref and gpu. The comparison I've done is use gold data from Onnxruntime CPU and numpy. This is similar to how we validate output in our driver and are used as part of the verify tests

TedThemistokleous avatar Sep 09 '24 13:09 TedThemistokleous

Is there any verify test which compares ref value with the equivalent gpu based output? Thanks.

Not needed here as this is an onnx function which is made up of the simpler onnx operator pieces which themselves have been validated between ref and gpu. The comparison I've done is use gold data from Onnxruntime CPU and numpy. This is similar to how we validate output in our driver and are used as part of the verify tests

There are composite-layered verify tests e.g. test_transpose_reshape_add_sub_mul, test_var_sl_gru_bidirect that I just randomly point to. I think verification is important, and worthwhile to add here as well. Thanks.

lakhinderwalia avatar Sep 09 '24 14:09 lakhinderwalia

Is there any verify test which compares ref value with the equivalent gpu based output? Thanks.

Not needed here as this is an onnx function which is made up of the simpler onnx operator pieces which themselves have been validated between ref and gpu. The comparison I've done is use gold data from Onnxruntime CPU and numpy. This is similar to how we validate output in our driver and are used as part of the verify tests

There are composite-layered verify tests e.g. test_transpose_reshape_add_sub_mul, test_var_sl_gru_bidirect that I just randomly point to. I think verification is important, and worthwhile to add here as well. Thanks.

Is there any verify test which compares ref value with the equivalent gpu based output? Thanks.

Not needed here as this is an onnx function which is made up of the simpler onnx operator pieces which themselves have been validated between ref and gpu. The comparison I've done is use gold data from Onnxruntime CPU and numpy. This is similar to how we validate output in our driver and are used as part of the verify tests

There are composite-layered verify tests e.g. test_transpose_reshape_add_sub_mul, test_var_sl_gru_bidirect that I just randomly point to. I think verification is important, and worthwhile to add here as well. Thanks.

@lakhinderwalia Added some of the composite tests in for k dimension and 2d cases (Batch, class size,...) & (Batch, class_size) May need to limit labels to a literal in this case instead of parameter since I can generate random literals in a certain range as labels have to be within [0, Class_size) . Let me know if you want me to paramaterize these further/ template them also based on reduction type. Will see in the morning.

For the k dimensional case I'd need to sort out the intermediate dimensions/gather calls for larger dimensions since we need to loop over and adjust the indices before the concat. Not hard but wanted to mirror the tests we do already as a baseline verify onnx tests

TedThemistokleous avatar Sep 13 '24 04:09 TedThemistokleous

@lakhinderwalia Added some of the composite tests in for k dimension and 2d cases (Batch, class size,...) & (Batch, class_size) May need to limit labels to a literal in this case instead of parameter since I can generate random literals in a certain range as labels have to be within [0, Class_size) . Let me know if you want me to paramaterize these further/ template them also based on reduction type. Will see in the morning.

Thanks, @TedThemistokleous. I was suggesting very simple tests. Take a test graph with an an onnx operator "softmaxcrossentropy" -- no need to expose its entrails. And then just give it a sample data. Pass it through ref and gpu, and then just compare those results -- instead of a comparison to the (pre-calculated) golden data. Templated tests are just fine, for how you test is programmatically isn't the focus here.

lakhinderwalia avatar Sep 13 '24 13:09 lakhinderwalia

Got all the end to end stuff working, seeing some odd behavior with the dnn1 based tests for double/float. Looks like their version of log doesn't support double which causes issues, along with seeing a small difference between mean/sum k -dimension cases which incur enough error to trigger a tolerance error.

TedThemistokleous avatar Sep 15 '24 15:09 TedThemistokleous

Test Batch Rate new
af54c4
Rate old
9dcea5
Diff Compare
torchvision-resnet50 64 3,249.89 3,252.64 -0.08% :white_check_mark:
torchvision-resnet50_fp16 64 6,990.42 6,985.91 0.06% :white_check_mark:
torchvision-densenet121 32 2,435.63 2,433.66 0.08% :white_check_mark:
torchvision-densenet121_fp16 32 4,115.55 4,106.94 0.21% :white_check_mark:
torchvision-inceptionv3 32 1,636.74 1,634.68 0.13% :white_check_mark:
torchvision-inceptionv3_fp16 32 2,739.42 2,740.90 -0.05% :white_check_mark:
cadene-inceptionv4 16 776.43 775.86 0.07% :white_check_mark:
cadene-resnext64x4 16 809.44 808.52 0.11% :white_check_mark:
slim-mobilenet 64 7,455.44 7,453.75 0.02% :white_check_mark:
slim-nasnetalarge 64 208.22 208.36 -0.07% :white_check_mark:
slim-resnet50v2 64 3,435.30 3,434.58 0.02% :white_check_mark:
bert-mrpc-onnx 8 1,149.98 1,153.57 -0.31% :white_check_mark:
bert-mrpc-tf 1 315.63 308.33 2.37% :white_check_mark:
pytorch-examples-wlang-gru 1 421.63 417.95 0.88% :white_check_mark:
pytorch-examples-wlang-lstm 1 381.67 391.78 -2.58% :white_check_mark:
torchvision-resnet50_1 1 815.47 817.84 -0.29% :white_check_mark:
cadene-dpn92_1 1 400.37 399.56 0.20% :white_check_mark:
cadene-resnext101_1 1 381.02 381.54 -0.14% :white_check_mark:
onnx-taau-downsample 1 344.36 343.92 0.13% :white_check_mark:
dlrm-criteoterabyte 1 35.10 35.06 0.12% :white_check_mark:
dlrm-criteoterabyte_fp16 1 58.10 58.10 -0.00% :white_check_mark:
agentmodel 1 8,129.82 7,712.60 5.41% :high_brightness:
unet_fp16 2 58.05 58.01 0.08% :white_check_mark:
resnet50v1_fp16 1 933.12 921.90 1.22% :white_check_mark:
resnet50v1_int8 1 919.35 925.93 -0.71% :white_check_mark:
bert_base_cased_fp16 64 1,152.56 1,153.34 -0.07% :white_check_mark:
bert_large_uncased_fp16 32 355.90 355.71 0.05% :white_check_mark:
bert_large_fp16 1 210.08 211.92 -0.87% :white_check_mark:
distilgpt2_fp16 16 2,157.81 2,156.57 0.06% :white_check_mark:
yolov5s 1 534.30 534.42 -0.02% :white_check_mark:
tinyllama 1 43.44 43.34 0.23% :white_check_mark:
vicuna-fastchat 1 172.29 173.26 -0.56% :white_check_mark:
whisper-tiny-encoder 1 417.67 417.96 -0.07% :white_check_mark:
whisper-tiny-decoder 1 425.37 423.78 0.38% :white_check_mark:

Check results before merge :high_brightness:

migraphx-bot avatar Sep 17 '24 11:09 migraphx-bot


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_base_cased_fp16: PASSED: MIGraphX meets tolerance
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
     :white_check_mark: yolov5s: PASSED: MIGraphX meets tolerance
     :white_check_mark: tinyllama: PASSED: MIGraphX meets tolerance
     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-decoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: distilgpt2_fp16: PASSED: MIGraphX meets tolerance

migraphx-bot avatar Sep 17 '24 11:09 migraphx-bot