candle-lora icon indicating copy to clipboard operation
candle-lora copied to clipboard

In Llama model, only the embedding layer is converted to lora layer.

Open Adamska1008 opened this issue 10 months ago • 5 comments

I tried to fine tune TinyLlama with this crate. After training, the safetensors saved only contains two tensors:

lora_llama.b0
lora_llama.a0

I expand the macro in mod llama and find that these two layers will be used in embedding layers.

        pub fn get_lora_model<'a>(
            &'a mut self,
            lora_config: candle_lora::LoraConfig,
            vb: &candle_nn::VarBuilder,
            linear_config: Option<candle_lora::LoraLinearConfig>,
            conv1d_config: Option<candle_lora::LoraConv1dConfig>,
            conv2d_config: Option<candle_lora::LoraConv2dConfig>,
            embed_config: Option<candle_lora::LoraEmbeddingConfig>,
        ) {
            let mut linear: ::std::collections::HashMap<
                String,
                &dyn candle_lora::LinearLayerLike,
            > = ::std::collections::HashMap::new();
            let mut conv1d: ::std::collections::HashMap<
                String,
                &dyn candle_lora::Conv1dLayerLike,
            > = ::std::collections::HashMap::new();
            let mut conv2d: ::std::collections::HashMap<
                String,
                &dyn candle_lora::Conv2dLayerLike,
            > = ::std::collections::HashMap::new();
            let mut embed: ::std::collections::HashMap<
                String,
                &dyn candle_lora::EmbeddingLayerLike,
            > = ::std::collections::HashMap::new();
            [(embed.insert("wte".to_string(), &*self.wte))];
            if !linear.is_empty() && linear_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not speified for linear layers."),
                    );
                };
            }
            if !conv1d.is_empty() && conv1d_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not speified for conv1d layers."),
                    );
                };
            }
            if !conv2d.is_empty() && conv2d_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not speified for conv2d layers."),
                    );
                };
            }
            if !embed.is_empty() && embed_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not speified for embedding layers."),
                    );
                };
            }
            let mut builder = candle_lora::SelectedLayersBuilder::new();
            if linear_config.is_some() {
                builder = builder.add_linear_layers(linear, linear_config.unwrap());
            }
            if conv1d_config.is_some() {
                builder = builder.add_conv1d_layers(conv1d, conv1d_config.unwrap());
            }
            if conv2d_config.is_some() {
                builder = builder.add_conv2d_layers(conv2d, conv2d_config.unwrap());
            }
            if embed_config.is_some() {
                builder = builder.add_embed_layers(embed, embed_config.unwrap());
            }
            let selection = builder.build();
            let new_layers = candle_lora::Lora::convert_model(selection, lora_config, &vb);
            [
                (self
                    .wte = ::std::sync::Arc::new(
                    new_layers.embed.get("wte").unwrap().clone(),
                )),
            ];
        }

So none of linear layer in the self-attention block is converted to lora layer. When I use my fine-tuned model, it behave exactly the same as before.

Adamska1008 avatar Apr 09 '24 03:04 Adamska1008

Without code to look at, I can only speculate that it's because the linear layers aren't being converted by:

#[replace_layer_fields]
#[derive(Debug, Clone, AutoLoraConvert)]

or you're not calling get_lora_model on those layers.

Since you're getting the embedding weights (which likely sit in the top module) I would suspect you might only be calling the top module's get_lora_model method.

I'm trying to finetune a Phi-3 model myself. I'm getting the attention layers here:

src/main.rs:38:5] &vars = [
    (
        "model.layers.4.mlp.down_proj.lora_linear.a0.weight",
        Var(
            Tensor[dims 1, 8192; bf16, cuda:0],
        ),
    ),
    (
        "model.layers.9.self_attn.o_proj.lora_linear.a0.weight",
        Var(
            Tensor[dims 1, 3072; bf16, cuda:0],
        ),
    ),
    (
        "model.layers.9.mlp.down_proj.lora_linear.a0.weight",
        Var(
            Tensor[dims 1, 8192; bf16, cuda:0],
        ),
    ),
... // more weights
]

Here's the Phi model's top module:

#[replace_layer_fields]
#[derive(Debug, Clone, AutoLoraConvert)]
pub struct PhiModel {
    embed_tokens: Embedding,
    layers: Vec<DecoderLayer>,
    norm: RmsNorm,
    lm_head: Linear,
    device: Device,
    dtype: DType,
}

This would convert the Embedding and Linear layers to lora layers. However, it would not convert self.layers automatically. You'd have to iterate through them as is done here: https://github.com/EricLBuehler/candle-lora/blob/9dc75e1af6142d00cf9e5257faa0c4b9a77d9759/candle-lora-transformers/src/llama.rs#L553 where in each self.load, get_lora_model is called for each module's child modules.

I'm not sure if inner: Box<dyn LinearLayerLike> is converted. Seems like a combination of the old way: https://github.com/EricLBuehler/candle-lora/blob/9dc75e1af6142d00cf9e5257faa0c4b9a77d9759/candle-lora-examples/examples/linear_old.rs#L16 and new way: https://github.com/EricLBuehler/candle-lora/blob/9dc75e1af6142d00cf9e5257faa0c4b9a77d9759/candle-lora-examples/examples/linear_macro.rs#L10

Another thing to keep in mind is that when doing:

let mut optimizer = candle_nn::SGD::new(varmap.all_vars(), 0.003).unwrap();

you'll have both lora and normal variables in the varmap. Unclear if this causes issues/slowdown/OOM when training. Haven't tried.

I think you can filter that out by doing:

let vars = varmap
        .data()
        .lock()
        .unwrap()
        .iter()
        .filter(|s| s.0.contains("lora))
        .collect::<Vec<_>>();
let mut optimizer = candle_nn::SGD::new(vars, 0.003).unwrap();

since you can name the variables with VarBuilder before passing it to get_lora_model.

However, I'm getting OOM on only 1608000 params, calculated using this:

let num_params = vars
    .iter()
    .map(|s| s.1.shape().dims().iter().product::<usize>()) // assuming vectors and matrices of weights
    .sum::<usize>();

while running Phi-3 with no input on an RTX 3090 24GB. So not entirely sure I'm doing this correctly either. (specifically it panics with CUDA_ERROR_OUT_OF_MEMORY on https://github.com/huggingface/candle/blob/c68ed8963fb6fc842f20d84baa07ff97b56aedb4/candle-nn/src/optim.rs#L21 but I feel like this might be a separate issue, or me just not understanding memory usage of backprop)

Perhaps @EricLBuehler can give some inputs?

AntBlo avatar Apr 28 '24 10:04 AntBlo

Found this: https://github.com/huggingface/candle/issues/2079

But the following modification (using HashMap) doesn't help with OOM:

pub fn from_mmaped_safetensors<'a, P: AsRef<Path>>(
    paths: &[P],
    dtype: DType,
    device: &Device,
    silent: bool,
) -> Result<VarBuilderArgs<'a, Box<dyn SimpleBackend>>, Error> {
    let mut map = HashMap::new();
    {
        let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(paths)? };

        if silent {
            for (name, _) in tensors.tensors() {
                let tensor = tensors
                    .load(&name, device)?
                    .to_device(device)?
                    .to_dtype(dtype)?;
                map.insert(name.clone(), tensor);
            }
        } else {
            for (name, _) in tensors.tensors().iter() {
                let tensor = tensors
                    .load(name, device)?
                    .to_device(device)?
                    .to_dtype(dtype)?;
                map.insert(name.clone(), tensor);
            }
        };
    }

    Ok(VarBuilder::from_tensors(map, dtype, device))
}

AntBlo avatar Apr 29 '24 16:04 AntBlo

@AntBlo memory usage of backprop is very high, what is your GPU memory capacity?

EricLBuehler avatar Aug 20 '24 14:08 EricLBuehler

@EricLBuehler

From nvidia-smi: NVIDIA GeForce RTX 3090 with 24576MiB (24GB VRAM)

Put this on the back burner for a bit, but if there's anything I can test then let me know

AntBlo avatar Aug 23 '24 12:08 AntBlo

@AntBlo 24GB should be enough for backprop. This may be connected to #21.

EricLBuehler avatar Aug 31 '24 01:08 EricLBuehler