candle icon indicating copy to clipboard operation
candle copied to clipboard

Could someone please explain why this is happening? (batcher.rs seq_len:4294967040)

Open dotori1995 opened this issue 1 year ago • 1 comments
trafficstars

Hello! i am a students in korea.

trying to make trainging code for llm , i encounted some problem.

my code referencing "llama2-c > training.rs" code, they use like this.

let device = candle_examples::device(common_args.cpu)?;
    let dataset = Dataset::new(&args.pretokenized_dir)?;
    println!(
        "loaded dataset, train: {} files, valid: {} files",
        dataset.train_tokens(),
        dataset.valid_tokens()
    );
    let varmap = candle_nn::VarMap::new();
    let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
    let config = Config::tiny_15m();
    let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
    let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);

    let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
    let model = Llama::load(vb, config)?;
    let params = candle_nn::ParamsAdamW {
        lr: args.learning_rate,
        ..Default::default()
    };
    let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
    for (batch_index, batch) in batch_iter.enumerate() {
        let (inp, tgt) = batch?;
        let logits = model.forward(&inp, 0, &mut cache)?;
        let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
        opt.backward_step(&loss)?;

        if batch_index > 0 && batch_index % 100 == 0 {
            // TODO: Add a way to deactivate the backprop graph tracking when computing the
            // validation loss.
            let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
            println!("{batch_index} {loss}");
        }
        if batch_index > 0 && batch_index % 1000 == 0 {
            varmap.save("checkpoint.safetensors")?
        }
    }

i made this code.

fn main() {
   tokenizer::train_tokenizer("data", "bert-wiki.json");
   let tokenizer = tokenizer::load_tokenizer("bert-wiki.json");
   encode_and_write_to_file(&tokenizer.unwrap(), "Hello, world!", "token/data.bin").expect("error");
   
   let tokenizer = tokenizer::load_tokenizer("bert-wiki.json");
   encode_and_write_to_file(&tokenizer.unwrap(), "Hel21453lo, worldzxas!", "token/data1.bin").expect("error");
   
   let tokenizer = tokenizer::load_tokenizer("bert-wiki.json");
   encode_and_write_to_file(&tokenizer.unwrap(), "Hel21453lo, zx zasasv!", "token/data2.bin").expect("error");
   
   
   let pretokenized_dir = "token";
   let batch_size = 1;
   let sequence_length   = 8;

   let dataset = Dataset::new(&pretokenized_dir).unwrap();
   println!(
       "loaded dataset, train: {} files, valid: {} files",
       dataset.train_tokens(),
       dataset.valid_tokens()
   );

   let device = Device::Cpu;
   let iter = DatasetRandomIter::new(&dataset, false, sequence_length, device.clone());
   let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(batch_size);
   for (batch_index, batch) in batch_iter.enumerate() {
      let (inp, tgt) = batch.unwrap();
      println!("batch: {}", batch_index);
      println!("input: {:?}", inp.to_vec2::<u32>());
      //inp size
      println!("input: {:?}", inp.shape());
      println!("target: {:?}", tgt.to_vec2::<u32>());
   }

}

fn encode_and_write_to_file(tokenizer: &Tokenizer, text: &str, file_path: &str) -> std::io::Result<()> {
   let encode_output = tokenizer::encode(tokenizer, text);
   let binding = encode_output.expect("error");
   let ids = binding.get_ids();
   println!("{:?}", ids);

   let mut file = File::create(file_path)?;

   for &num in ids {
       println!("write {:?}", num);
       file.write_all(&num.to_le_bytes())?;
   }

   Ok(())
}

But the loop keeps running at this point

 for (batch_index, batch) in batch_iter.enumerate() {
      let (inp, tgt) = batch.unwrap();
      println!("batch: {}", batch_index);
      println!("input: {:?}", inp.to_vec2::<u32>());
      //inp size
      println!("input: {:?}", inp.shape());
      println!("target: {:?}", tgt.to_vec2::<u32>());
   }

Here, upon checking, the seq_len is 4294967040.

batcher.rs

impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> {
    pub fn new_r2(inner: I) -> Self {
        Self::new(IterResult2 { inner })
    }
}

스크린샷 2024-04-08 211008

dotori1995 avatar Apr 08 '24 12:04 dotori1995

The process of making 'encode data' readable with memmap2, I was making a mistake.

However, there's still an issue.

dotori1995 avatar Apr 08 '24 15:04 dotori1995