tch-rs
tch-rs copied to clipboard
Broadcasting seems to mess with gradients
Hello,
I'm not entirely sure this is a bug or if this is me just misusing the API.
I am training a pretty simple full connected layer with some data irrelevant to this issue (issue persists with small amount of random data). My inputs and training targets are of shape [x, 773] and [x, 1] where x is the amount of samples.
I have noticed that the network really struggles to converge when having the shape of the target be [x] instead of [x, 1]. Here is the code of my training.
Notice that on line 56 I set the training target to random data, for the sake of this example.
Here is the loss if the y_full Tensor is changed to shape [x, 1] instead of being [x]:

And here is the loss if the y_full Tensor is of shape [x] (remove the .view([-1, 1]) from line 57).

I assume this is due to broadcasting as the output tensor of the network (forward on line 63) is of shape [batch_size, 1].
Could this be?
Maybe you want to try and reproduce the same behavior with the python api to see if it's something wrong with this crate vs with the way PyTorch is used in your example.