axon icon indicating copy to clipboard operation
axon copied to clipboard

Allow overriding layer metadata such as `:op` and `:op_name` for built-ins

Open seanmor5 opened this issue 1 year ago • 1 comments

Needed for LoRA

seanmor5 avatar Nov 20 '23 22:11 seanmor5

For a basic LoRA implementation, we have to determine which nodes represent the QKV matrices. In order to figure out which nodes needs to be targeted, we currently have to invoke the node's name function to see if it contains "key", "query", "value"

  Axon.reduce_nodes(axon, [], fn
    %Axon.Node{id: id, name: name_fn, op: :dense}, acc ->
      shortname =
        name_fn.(:dense, nil)
        # something like  "down_blocks.2.transformers.0.blocks.0.cross_attention.key",
        |> String.split(".")
        |> List.last()

      if shortname == "key" or shortname == "query" or shortname == "value" do
        [id | acc]
      else
        acc
      end

    %Axon.Node{}, acc ->
      acc
  end)

For more complex LoRA adaptation, we need to targets nodes outside of QKV. For example, with LCM-LoRA, the following nodes in StableDiffusion would need adapters.

  • query, key, value
  • cross_attention.output
  • input_projection
  • output_projection
  • ffn.intermediate
  • ffn.output
  • conv_1 ("down_blocks.1.residual_blocks.1.conv_1")
  • conv_2
  • shortcut.projection ( "down_blocks.2.residual_blocks.0.shortcut.projection")
  • downsamples.{m}.conv
  • upsamples.{m}.conv
  • timestep_projection

Instead of using the name function, it'd be helpful to have some metadata to determine whether a node needs to be modified.

wtedw avatar Dec 01 '23 22:12 wtedw