Add 'encoder-decoder' support
I have been LM-Cocktail for merging Language models, specifically the 'mix_models_with_data' function. However, I noticed there are only implementations for encoder or decoder models, not encoder-decoder.
Maybe it'd be nice to consider adding this functionality to the repo. My own implementation is below, let me know what you think.
The merging was done using two finetuned versions of mT0-small.
- One finetuned on German data from XNLI.
- One finetuned on Arabic data from XNLI. Merging weights were calculated using 50 examples from the Arabic XNLI validation set.
model1 = "output/xnli/experiment_mt0-small_xnli_ar"
model2 = "output/xnli/experiment_mt0-small_xnli_de"
model = mix_models_with_data(
model_names_or_paths=[model1, model2],
model_type='encoder-decoder',
example_ata=examples,
temperature=5.0,
max_input_length=512,
neg_number=2,
output_path="output/xnli-ar_de-datamix")
Output: weight for each model: output/xnli/experiment_mt0-small_xnli_ar 0.5131871104240417 output/xnli/experiment_mt0-small_xnli_de 0.48681285977363586 Saving the new model to output/xnli-ar_de-datamix
Implementation
Updated the assert statement to accept 'encoder-decoder' as a model_type:
assert model_type in ['decoder', 'encoder', 'encoder-decoder']
Updated 'load_model':
def load_model(model_name:str, model_type:str, trust_remote_code:bool=True):
if model_type == 'decoder':
model = load_llm(model_name, trust_remote_code=trust_remote_code)
elif model_type == 'encoder':
model = load_embedder(model_name, trust_remote_code=trust_remote_code)
elif model_type == 'reranker':
model = load_reranker(model_name, trust_remote_code=trust_remote_code)
elif model_type == 'encoder-decoder':
model = load_seq2seq_model(model_name, trust_remote_code=trust_remote_code)
else:
raise NotImplementedError(f"not support this model_type: {model_type}")
return model
'load_seq2seq_model':
def load_seq2seq_model(model_name:str, trust_remote_code:bool):
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=trust_remote_code)
return model
Updated 'compute_weights':
def compute_weights(base_model, tokenizer, param_list: List[Dict], model_type: str, example_data: List[Any], temperature: float=5.0, batch_size:int=2, max_input_length:int=2048, neg_number:int=7):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
base_model = base_model.to(device)
if model_type == 'decoder':
input_data = preprocess_data_for_llm(example_data=example_data, tokenizer=tokenizer, device=device, batch_size=batch_size, max_input_length=max_input_length)
loss_func = llm_loss
elif model_type == 'encoder':
input_data = preprocess_data_for_embedder(example_data=example_data, tokenizer=tokenizer, device=device, batch_size=batch_size, max_input_length=max_input_length, neg_number=neg_number)
loss_func = embedder_loss
elif model_type == 'encoder-decoder':
input_data = preprocess_data_for_seq2seq(example_data=example_data, tokenizer=tokenizer, device=device, batch_size=batch_size, max_input_length=max_input_length)
loss_func = seq2seq_loss
example_loss = []
with torch.no_grad():
for params in param_list:
base_model.load_state_dict(params)
loss = loss_func(base_model=base_model, input_data=input_data)
example_loss.append(loss)
weights = torch.softmax(-torch.FloatTensor(example_loss)/temperature, -1).numpy().tolist()
return weights
'seq2seq_loss':
def seq2seq_loss(base_model, input_data):
total_loss = 0
with torch.no_grad():
for batch in input_data:
outputs = base_model(input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
decoder_input_ids=batch["decoder_input_ids"],
labels=batch["labels"])
total_loss += outputs.loss.cpu()
average_loss = total_loss / len(input_data)
return float(average_loss)
'preprocess_data_for_seq2seq':
def preprocess_data_for_seq2seq(example_data, tokenizer, device, batch_size:int=2, max_input_length:int=512):
batch_data = []
for i in range(0, len(example_data), batch_size):
batch_examples = example_data[i:i+batch_size]
input_texts = [ex['input'] for ex in batch_examples]
target_texts = [ex['output'] for ex in batch_examples]
input_encodings = tokenizer(input_texts, max_length=max_input_length, padding=True, truncation=True, return_tensors="pt")
target_encodings = tokenizer(target_texts, max_length=max_input_length, padding=True, truncation=True, return_tensors="pt")
input_ids = input_encodings.input_ids.to(device)
attention_mask = input_encodings.attention_mask.to(device)
decoder_input_ids = target_encodings.input_ids.to(device)
labels = target_encodings.input_ids.to(device)
labels[labels == tokenizer.pad_token_id] = -100
batch_data.append({
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels
})
return batch_data
Alternatives
I noticed the decoder preprocessing and loss functions ('preprocess_data_for_llm', 'llm_loss') also work for mT0. However, due to the load_model function it does not allow you to specify decoder and use the functions directly. Using the decoder function also gave different results for my experiment: weight for each model: output/xnli/experiment_mt0-small_xnli_ar 0.48479509353637695 output/xnli/experiment_mt0-small_xnli_de 0.5152048468589783 Saving the new model to output/xnli-ar_de-datamix
My implementation for the preprocess function follows a different style than the decoder function. Therefore I also rewrote it as follows:
def preprocess_data_for_seq2seq(example_data, tokenizer, device, batch_size:int=2, max_input_length:int=2048):
batch_input_ids = []
batch_labels = []
batch_decoder_input_ids = []
batch_max_length = max_input_length
for data in example_data:
input, output = data['input'], data['output']
input_ids = tokenizer.encode(input)
output_ids = tokenizer.encode(output)
labels = [-100]*len(input_ids) + output_ids + [tokenizer.eos_token_id]
input_ids.append(tokenizer.eos_token_id)
input_ids = input_ids[:batch_max_length]
input_ids += [tokenizer.pad_token_id] * (batch_max_length - len(input_ids))
batch_input_ids.append(input_ids)
labels = labels[:batch_max_length]
labels += [-100] * (batch_max_length - len(labels))
batch_labels.append(labels)
decoder_input_ids = output_ids + [tokenizer.eos_token_id]
decoder_input_ids += [tokenizer.pad_token_id] * (batch_max_length - len(decoder_input_ids))
batch_decoder_input_ids.append(decoder_input_ids)
batch_input_ids = torch.LongTensor(batch_input_ids).to(device)
batch_labels = torch.LongTensor(batch_labels).to(device)
batch_decoder_input_ids = torch.LongTensor(batch_decoder_input_ids).to(device)
attention_mask = batch_input_ids.ne(tokenizer.pad_token_id).to(device)
batch_data = []
for i in range(0, len(batch_input_ids), batch_size):
batch_data.append(dict(
input_ids=batch_input_ids[i:i+batch_size],
labels=batch_labels[i:i+batch_size],
decoder_input_ids=batch_decoder_input_ids[i:i+batch_size],
attention_mask=attention_mask[i:i+batch_size],
))
return batch_data
NOTE: This implementation gives the same results as for using the llm_preprocessor. So maybe I made a mistake in one of the two functions. At the moment I cannot say which one is correct without more experiments.
Let me know what you think or if I missed already existing functionality that makes this obsolete.
It would be nice to know whether 'encoder-decoder' architectures are supported as I also want to use LMCocktail for my research 😃
@RKoopal your implementation looks nice, curious to hear back from the maintainers!
@RKoopal , thanks for your suggestion!
MT5 uses the pad_token_id as the starting token for decoder_input_ids generation, but the preprocess function you used doesn't add a special token at the begining of the decoder_input_ids. I recommend using the official function: https://huggingface.co/docs/transformers/model_doc/mt5#transformers.MT5ForConditionalGeneration.example
inputs = tokenizer(input_texts, text_target=output_texts, return_tensors="pt")
outputs = model(**inputs)
loss = outputs.loss
You can use this function:tokenizer(input_texts, text_target=output_texts, return_tensors="pt"), which is very simple.
Welcome to submit PR~
@staoxiao Thank you for your reply! I have implemented your suggestions and will make a PR shortly.
@Nacho888 I'll link the PR here in case you're interested.
@Nacho888 @staoxiao Created PR: https://github.com/FlagOpen/FlagEmbedding/pull/761