axon
axon copied to clipboard
Unable to replace existing model layers
I'm currently trying to implement the LoRA algorithm in Axon. It involves freezing the original model's weights, adding new weights inside an existing layer, and feeding the input into both original weights and new weights.
I noticed that Axon.map_nodes used to be able to replace layers, but now it only replaces Axon.Nodes. However, if I were to make a custom layer, it would return an %Axon{} struct. I figured it's possible to unravel the Axon struct to retrieve the node, but it doesn't seem right.
My skeleton draft atm:
# Define custom LoRA layer
defmodule Lora do
import Nx.Defn
def custom_layer(%Axon{} = input, %Axon{} = target_to_be_replaced) do
...
Axon.layer(&custom_layer_impl)
end
defn custom_layer_impl(), do ...
end
# Import model
{:ok, unet} =
Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
params_filename: "diffusion_pytorch_model.bin"
)
# Get Axon model
%{spec: spec, model: unet_model, params: params} = unet
# Replace attention nodes
new_model = Axon.map_nodes(unet_model, fn
%Axon.Node{} = node ->
# Can't use Lora.custom_layer() because it returns %Axon{} struct
node ->
node
end)
Previous way of replacing layers. Documentation is outdated I believe. https://hexdocs.pm/axon/Axon.html#map_nodes/2
Another use case is to replace entire classes of layers with another. For example, you may want to replace all relu layers with tanh layers:
new_model = Axon.map_nodes(model, fn
%Axon{op: :relu} = graph ->
# Get nodes immediate parent
parent = Axon.get_parent(graph)
# Replace node with a tanh
Axon.tanh(parent)
graph ->
graph
end)
Thanks for bringing this up, the %Axon{}
data structure used to be the only data structure that represented a model/layer/etc. but now we have %Axon.Node{}
and as you pointed out that means Axon.map_nodes
can no longer be used to replace layers in the way you described.
As it stands, you could probably just replace the nodes internal properties to point to the custom layer implementation rather than using Lora.custom_layer
, but honestly that feels hacky and I don't like it. I will need to think a bit about what a good API for this looks like.
I haven't looked at a LoRA implementation yet, so what you need to be able to do is replace specific Axon nodes with a LoRA version of the layer?
Gotcha, that makes sense.
I'm not well versed in the LoRA implementation, but it looks like you'd need a wrapper layer that keeps the original implementation, but also creates two new parameters (lora_A, lora_B) to learn.
So for the implementation function, the calculation looks something like this
defn lora_embedding_impl(x, original_layer_impl(?), lora_A(?), lora_B(?), opts \\ []) do
original_output = original_layer_impl(x)
# Lora has different layers for embedding / conv / linear,
# but they all perform the matrix operation: BAx
after_a = Axon.Layers.embedding(x, lora_A)
after_b = Nx.dot(after_a, lora_B)
# Combine original output with our lora calculations
Nx.add(original_output, after_b)
end
# reference: https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L79-L85
I think even if I replaced the node's internal properties to point to the custom implementation, it 's not possible to inject new lora parameters w/ Axon.param().
I'll keep hacking away at it to see what's possible. Appreciate your input!
Hey @seanmor5, I was able to implement LoRA with a couple of tricks.
I ended up not using map_nodes
. Instead I created new nodes by extracting them out of Axon.layer.
Afterwards I added these new nodes into the original Axon struct, and then wired existing nodes to connect to the new nodes.
Thought I'd leave this comment for anybody who's trying to do something similar. See for more details: https://github.com/wtedw/lorax/blob/main/lib/lorax.ex#L47