luz icon indicating copy to clipboard operation
luz copied to clipboard

`predict`: How to get embedding layer?

Open talegari opened this issue 2 years ago • 3 comments

Question: How do I get the embeddings after fitting using triplet loss in this example: https://mlverse.github.io/luz/articles/examples/mnist-triplet.html ?

talegari avatar Jun 07 '23 17:06 talegari

You could so something like this:

dataset <- mnist_dataset(dir, transform = transform_to_tensor)
preds <- predict(fitted, dataset)

preds

Calling predict is just calling the forward method with the model in eval mode.

dfalbel avatar Jun 08 '23 07:06 dfalbel

Sorry the above is wrong, you could modify the triplet model to be something like:

triplet_model <- torch::nn_module(
  initialize = function(embedding_dim = 2, margin = 1) {
    self$embedding <- net(embedding_dim = embedding_dim)
    self$criterion <- nn_triplet_margin_loss(margin = margin)
  },
  loss = function(input, ...) {
    embeds <- lapply(input, self$embedding)
    self$criterion(
      embeds$anchor,
      embeds$positive,
      embeds$negative
    )
  },
  predict = function(x) {
    self$embedding(x)
  }
)

Adding a predict method that just calls the embedding, and then:

dataset <- mnist_dataset(dir, transform = transform_to_tensor)
preds <- predict(fitted, dataset)

preds

You can also access the embedding module from the fitted object, but, in this case you have to manually put the model in eval mode and disable gradients and move tensors to the correct device.

fitted$model$eval()
with_no_grad({
   fitted$model$embedding(dataset[1]$x$unsqueeze(1)$to(device="mps"))
})

dfalbel avatar Jun 08 '23 09:06 dfalbel

Thanks Daniel. IMHO, most folks would require embeddings after the training process. The predict method above should be added to the vignette to make it complete.

talegari avatar Jun 10 '23 11:06 talegari