cca_zoo
cca_zoo copied to clipboard
Is the loss function calculated on the raw data?
Sorry to disturb you again; but I am confused while coding the training loop by myself, is the loss calculated on the input data and not on the latent representation learnt by the encoder?
On the representations
(And trust me the EY loss is going to work better! I'm almost thinking about changing the default in this package to that solver)
Well, I will be transparent with my use case. I have a list of three tensors say X, Y, Z for which I need to use CCA to project them in a common space. Now, since its more than two views I am using TCCA. Also, I need to write the training loop by myself, so I had this doubt about the loss function definition in class DCCA.
As shown in the screenshot below def forward in line 34 returns an encoded representation of each view and stores it in a list comprehension.
However, in the def loss in line 41 as shown in the screenshot, it has this line representations = self(batch["views"]) return {"objective": self.objective(representations)}
so, this batch is the batch from trainloader. So, my question is this loss function does not use the encoded representations right? Or if it does, please help me a bit in this.
Loss is applied to representations because we are calling self.objective with representations as argument
Yes, I understand that but this representation variable is batch["views"] which is just a batch from trainloader right?
I actually think in future I will change the callable loss classes like TCCALoss just class methods of DTCCA. It's just a throwback to an older version where it was helpful to do it the current way.
Okay okay. But if you could answer my question it would really help me.
Sorry half my response got lost so that looked really weird like I was just offering up something totally random 😅
The answer to your question is NO! I'm assuming familiarity with PyTorch but when we call self(arg) with a nn.Module we call its forward method.
So representations is the output of the forward method applied to the batch.
More generally in python you can use the__call__ method with classes so loss(args) is the same as loss__call__(args) (which is how the current loss class works).
yes exactly but this throws an error which got me confused in the first place
so, either I have to store the output of forward for a specific batch, store it in a dictionary with "views" as the key and then pass it. But that is kind of the last thing I wanted to do and that's why I thought of raising the issue first.
Thanks
No? Output of self() is a list
Your data loader needs to have the structure that views is a key with a list of arrays.
Yeah looking at your code I think you should read what eg the DCCA class of mine is doing. It's not a loss function it's a pytorch lightning module that implements dcca with a specific loss function.
If you take a look in the files you can see the loss function that DCCA is using behind the scenes
Yes, in the DCCA class, the def loss function takes batch as its argument and passes batch['views'] to the respective objective. So, while I pass z = DCCA()(batch), loss(z) the error is obvious because the loss function expects a dictionary and not a list of tensor. So, that's where the error is coming from
Ahhh! I've understood our confusion now!
So my DCCA class 'loss' method is applied to data not representation!
Apologies I thought you meant in general is the DCCA loss applied to data or representations.
So just change your snippet to DCCA.loss(batch) and will be fine
(Because my DCCA loss method has a forward call inside)
Have just got off a 12 hour flight so forgive me for not realising the motivation behind your Q!
Exactly so this was my confusion. So, the DCCA loss takes batch of dataloader as its argument right.
Forgive me I dont know a lot about the theoretical aspect of CCA, so the for my case when I use loss = loss(batch) loss.backward()
How will the neural network learn the parameters if we don't pass the encoded representations in the loss?
If the loss function was instead called 'encode_and_calculate_loss' you would understand right? It's a bit misleading maybe but also a done thing in NN code.
the function first passes raw data through encoder and then calculates the loss - look at what the function is actually doing.
Yes, that is alright but my question is so, after a forward pass the DCCA encodes the given batch, then calls the loss on the batch and performs loss.backward(). So, is this sequence correct? That is my question
This:
loss=DCCA().loss(batch) #batch contains dictionary with "views":list of tensors
loss.backward()
optimiser.step()
The function that takes representations and returns a scalar loss is .objective()