cca_zoo icon indicating copy to clipboard operation
cca_zoo copied to clipboard

Is the loss function calculated on the raw data?

Open Neel-132 opened this issue 1 year ago • 22 comments

Sorry to disturb you again; but I am confused while coding the training loop by myself, is the loss calculated on the input data and not on the latent representation learnt by the encoder?

Neel-132 avatar Dec 18 '23 06:12 Neel-132

On the representations

jameschapman19 avatar Dec 18 '23 06:12 jameschapman19

(And trust me the EY loss is going to work better! I'm almost thinking about changing the default in this package to that solver)

jameschapman19 avatar Dec 18 '23 07:12 jameschapman19

Well, I will be transparent with my use case. I have a list of three tensors say X, Y, Z for which I need to use CCA to project them in a common space. Now, since its more than two views I am using TCCA. Also, I need to write the training loop by myself, so I had this doubt about the loss function definition in class DCCA.

As shown in the screenshot below def forward in line 34 returns an encoded representation of each view and stores it in a list comprehension.

However, in the def loss in line 41 as shown in the screenshot, it has this line representations = self(batch["views"]) return {"objective": self.objective(representations)}

so, this batch is the batch from trainloader. So, my question is this loss function does not use the encoded representations right? Or if it does, please help me a bit in this.

Screenshot 2023-12-18 135334

Neel-132 avatar Dec 18 '23 08:12 Neel-132

Loss is applied to representations because we are calling self.objective with representations as argument

jameschapman19 avatar Dec 18 '23 08:12 jameschapman19

Yes, I understand that but this representation variable is batch["views"] which is just a batch from trainloader right?

Neel-132 avatar Dec 18 '23 08:12 Neel-132

I actually think in future I will change the callable loss classes like TCCALoss just class methods of DTCCA. It's just a throwback to an older version where it was helpful to do it the current way.

jameschapman19 avatar Dec 18 '23 08:12 jameschapman19

Okay okay. But if you could answer my question it would really help me.

Neel-132 avatar Dec 18 '23 08:12 Neel-132

Sorry half my response got lost so that looked really weird like I was just offering up something totally random 😅

The answer to your question is NO! I'm assuming familiarity with PyTorch but when we call self(arg) with a nn.Module we call its forward method.

So representations is the output of the forward method applied to the batch.

jameschapman19 avatar Dec 18 '23 08:12 jameschapman19

More generally in python you can use the__call__ method with classes so loss(args) is the same as loss__call__(args) (which is how the current loss class works).

jameschapman19 avatar Dec 18 '23 09:12 jameschapman19

yes exactly but this throws an error which got me confused in the first place Screenshot 2023-12-18 143539

Neel-132 avatar Dec 18 '23 09:12 Neel-132

so, either I have to store the output of forward for a specific batch, store it in a dictionary with "views" as the key and then pass it. But that is kind of the last thing I wanted to do and that's why I thought of raising the issue first.

Thanks

Neel-132 avatar Dec 18 '23 09:12 Neel-132

No? Output of self() is a list

Your data loader needs to have the structure that views is a key with a list of arrays.

jameschapman19 avatar Dec 18 '23 09:12 jameschapman19

Yeah looking at your code I think you should read what eg the DCCA class of mine is doing. It's not a loss function it's a pytorch lightning module that implements dcca with a specific loss function.

If you take a look in the files you can see the loss function that DCCA is using behind the scenes

jameschapman19 avatar Dec 18 '23 09:12 jameschapman19

Yes, in the DCCA class, the def loss function takes batch as its argument and passes batch['views'] to the respective objective. So, while I pass z = DCCA()(batch), loss(z) the error is obvious because the loss function expects a dictionary and not a list of tensor. So, that's where the error is coming from

Neel-132 avatar Dec 18 '23 09:12 Neel-132

Ahhh! I've understood our confusion now!

So my DCCA class 'loss' method is applied to data not representation!

Apologies I thought you meant in general is the DCCA loss applied to data or representations.

So just change your snippet to DCCA.loss(batch) and will be fine

jameschapman19 avatar Dec 18 '23 09:12 jameschapman19

(Because my DCCA loss method has a forward call inside)

jameschapman19 avatar Dec 18 '23 09:12 jameschapman19

Have just got off a 12 hour flight so forgive me for not realising the motivation behind your Q!

jameschapman19 avatar Dec 18 '23 09:12 jameschapman19

Exactly so this was my confusion. So, the DCCA loss takes batch of dataloader as its argument right.

Forgive me I dont know a lot about the theoretical aspect of CCA, so the for my case when I use loss = loss(batch) loss.backward()

How will the neural network learn the parameters if we don't pass the encoded representations in the loss?

Neel-132 avatar Dec 18 '23 09:12 Neel-132

If the loss function was instead called 'encode_and_calculate_loss' you would understand right? It's a bit misleading maybe but also a done thing in NN code.

the function first passes raw data through encoder and then calculates the loss - look at what the function is actually doing.

jameschapman19 avatar Dec 18 '23 14:12 jameschapman19

Yes, that is alright but my question is so, after a forward pass the DCCA encodes the given batch, then calls the loss on the batch and performs loss.backward(). So, is this sequence correct? That is my question

Neel-132 avatar Dec 18 '23 15:12 Neel-132

This:

loss=DCCA().loss(batch) #batch contains dictionary with "views":list of tensors
loss.backward()
optimiser.step()

jameschapman19 avatar Dec 18 '23 15:12 jameschapman19

The function that takes representations and returns a scalar loss is .objective()

jameschapman19 avatar Dec 18 '23 15:12 jameschapman19