axon
axon copied to clipboard
Allow overriding layer metadata such as `:op` and `:op_name` for built-ins
Needed for LoRA
For a basic LoRA implementation, we have to determine which nodes represent the QKV matrices. In order to figure out which nodes needs to be targeted, we currently have to invoke the node's name function to see if it contains "key", "query", "value"
Axon.reduce_nodes(axon, [], fn
%Axon.Node{id: id, name: name_fn, op: :dense}, acc ->
shortname =
name_fn.(:dense, nil)
# something like "down_blocks.2.transformers.0.blocks.0.cross_attention.key",
|> String.split(".")
|> List.last()
if shortname == "key" or shortname == "query" or shortname == "value" do
[id | acc]
else
acc
end
%Axon.Node{}, acc ->
acc
end)
For more complex LoRA adaptation, we need to targets nodes outside of QKV. For example, with LCM-LoRA, the following nodes in StableDiffusion would need adapters.
- query, key, value
- cross_attention.output
- input_projection
- output_projection
- ffn.intermediate
- ffn.output
- conv_1 ("down_blocks.1.residual_blocks.1.conv_1")
- conv_2
- shortcut.projection ( "down_blocks.2.residual_blocks.0.shortcut.projection")
- downsamples.{m}.conv
- upsamples.{m}.conv
- timestep_projection
Instead of using the name function, it'd be helpful to have some metadata to determine whether a node needs to be modified.