tch-rs
tch-rs copied to clipboard
batch_accuracy_for_logits Errors out
I have an input of [128, 3, 224, 224] and output of [128, 59], and this is working for training/loss calculation, but when I call model.batch_accuracy_for_logits(input, output, dev, 128) with the same tensors used in training, it causes the below error.
Torch("The size of tensor a (128) must match the size of tensor b (59) at non-singleton dimension 1. Exception raised from infer_size_impl
I may be misunderstanding how to supply the resulting labels since batch needs to be specified
Would you have some minimal way to reproduce the issue? This would help tracking things down. Also note that batch_accuracy_for_logits
is used in a bunch of examples, e.g. here so you may want to just print the shape of the tensors passed as argument there, the shape of the tensor that your code passes as arguments and look at how they differ.
Closing this as no update in a while.