axon icon indicating copy to clipboard operation
axon copied to clipboard

Unable to replace existing model layers

Open wtedw opened this issue 9 months ago • 3 comments

I'm currently trying to implement the LoRA algorithm in Axon. It involves freezing the original model's weights, adding new weights inside an existing layer, and feeding the input into both original weights and new weights.

I noticed that Axon.map_nodes used to be able to replace layers, but now it only replaces Axon.Nodes. However, if I were to make a custom layer, it would return an %Axon{} struct. I figured it's possible to unravel the Axon struct to retrieve the node, but it doesn't seem right.

My skeleton draft atm:

# Define custom LoRA layer
defmodule Lora do
  import Nx.Defn

  def custom_layer(%Axon{} = input, %Axon{} = target_to_be_replaced) do
    ...
    Axon.layer(&custom_layer_impl)
  end


  defn custom_layer_impl(), do ...
end


# Import model
{:ok, unet} =
  Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
    params_filename: "diffusion_pytorch_model.bin"
  )

# Get Axon model
%{spec: spec, model: unet_model, params: params} = unet

# Replace attention nodes
new_model = Axon.map_nodes(unet_model, fn
  %Axon.Node{} = node ->
    # Can't use Lora.custom_layer() because it returns %Axon{} struct

  node ->
    node
end)

Previous way of replacing layers. Documentation is outdated I believe. https://hexdocs.pm/axon/Axon.html#map_nodes/2

Another use case is to replace entire classes of layers with another. For example, you may want to replace all relu layers with tanh layers:

new_model = Axon.map_nodes(model, fn
  %Axon{op: :relu} = graph ->
    # Get nodes immediate parent
    parent = Axon.get_parent(graph)
    # Replace node with a tanh
    Axon.tanh(parent)

  graph ->
    graph
end)

wtedw avatar Sep 15 '23 14:09 wtedw

Thanks for bringing this up, the %Axon{} data structure used to be the only data structure that represented a model/layer/etc. but now we have %Axon.Node{} and as you pointed out that means Axon.map_nodes can no longer be used to replace layers in the way you described.

As it stands, you could probably just replace the nodes internal properties to point to the custom layer implementation rather than using Lora.custom_layer, but honestly that feels hacky and I don't like it. I will need to think a bit about what a good API for this looks like.

I haven't looked at a LoRA implementation yet, so what you need to be able to do is replace specific Axon nodes with a LoRA version of the layer?

seanmor5 avatar Sep 15 '23 15:09 seanmor5

Gotcha, that makes sense.

I'm not well versed in the LoRA implementation, but it looks like you'd need a wrapper layer that keeps the original implementation, but also creates two new parameters (lora_A, lora_B) to learn.

image

So for the implementation function, the calculation looks something like this

defn lora_embedding_impl(x, original_layer_impl(?), lora_A(?), lora_B(?), opts \\ []) do
    original_output = original_layer_impl(x)

    # Lora has different layers for embedding / conv / linear, 
    # but they all perform the matrix operation: BAx
    after_a = Axon.Layers.embedding(x, lora_A)
    after_b = Nx.dot(after_a, lora_B)

    # Combine original output with our lora calculations
    Nx.add(original_output, after_b)
  end
  
#  reference: https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L79-L85

I think even if I replaced the node's internal properties to point to the custom implementation, it 's not possible to inject new lora parameters w/ Axon.param().

I'll keep hacking away at it to see what's possible. Appreciate your input!

wtedw avatar Sep 15 '23 16:09 wtedw

Hey @seanmor5, I was able to implement LoRA with a couple of tricks.

I ended up not using map_nodes. Instead I created new nodes by extracting them out of Axon.layer. Afterwards I added these new nodes into the original Axon struct, and then wired existing nodes to connect to the new nodes.

Thought I'd leave this comment for anybody who's trying to do something similar. See for more details: https://github.com/wtedw/lorax/blob/main/lib/lorax.ex#L47

wtedw avatar Nov 03 '23 19:11 wtedw