axon icon indicating copy to clipboard operation
axon copied to clipboard

Construct train step from an objective function and optimizer

Open seanmor5 opened this issue 1 year ago • 0 comments
trafficstars

Right now the only way to construct a train step is using a loss function and an optimizer:

def train_step(model, loss, optimizer, opts \\ []) do

This is suitable for most cases, but some instances it may be easier to allow a user to pass an objective function to differentiate through rather than just the loss function. In a default train step the constructed objective function is:

  objective_fn = fn trainable_parameters, model_state, loss_scale_state, inp, tar ->
    # hack to use trainable parameters as grad
    model_state =
      update_in(model_state, [Access.key!(:data)], fn data ->
        tree_merge(data, trainable_parameters, fn _, _, v -> v end)
      end)

    model_out = forward_model_fn.(model_state, inp)
    unscaled_loss = loss_fn.(tar, model_out.prediction)
    scaled_loss = scale_loss.(unscaled_loss, loss_scale_state)

    {model_out, scaled_loss, unscaled_loss}
  end

If we can clean this form up a bit, and get rid of the hack, this could be a useful API for constructing more complex training objectives without needing to re-implement the entire train step

seanmor5 avatar Sep 11 '24 12:09 seanmor5