unsloth icon indicating copy to clipboard operation
unsloth copied to clipboard

Implementing weighted loss function

Open skerit opened this issue 1 year ago • 1 comments

Mistral has a new finetuner repository where you can assign weights to specific messages, and those will be taken into account when the loss is calculated. I wanted to implement something similar for SFTTrainer, because my dataset contains information that doesn't really make a lot of sense to punish the model for not knowing. But switching completely to DataCollatorForCompletionOnlyLM is also not possible.

My problem is that it's not working at all :sweat_smile:

I might be misunderstanding what Unsloth is doing to the existing trainer. Is it as simple to just create a new trainer class, let it inherit from SFTTrainer with a custom compute_loss function and expect it to run with Unsloth, or is that a no go?

Here's an example dataset to illustrate what I'm trying to achieve

{"pieces":[{"text":"### system:\n","weight":0.5},{"text":"QuirkyQuarters v1.1\n","weight":1},{"text":"\n","weight":0.1},{"text":"### parameters:\n","weight":0.5}]}
{"pieces":[{"text":"### system:\n","weight":0.5},{"text":"QuirkyQuarters v1.2\n","weight":1},{"text":"\n","weight":0.1},{"text":"### parameters:\n","weight":0.5}]}

skerit avatar May 26 '24 09:05 skerit

Oh weighting is possible, but you'll need to add a custom cross entropy loss function ie via removing the LM Head, and putting a custom one

danielhanchen avatar May 26 '24 13:05 danielhanchen

Oh, is that different than implementing a new compute_loss method in a Trainer class?

skerit avatar May 28 '24 16:05 skerit

@danielhanchen Sorry to make you look at some newbie trainer code, but this custom trainer of mine works locally, but always OOMs on Google Collab, when the non-custom trainer does work.


dataset = load_dataset("json", data_files="drive/MyDrive/Unsloth/dataset.jsonl", split = "train")

def generate_and_tokenize_pieces(sample):

	all_input_ids = []
	all_attention_masks = []
	all_weight_ranges = []
	current_length = 0

	for item in sample['pieces']:
		tokenized = tokenizer(item['text'], return_tensors='pt')

		# Get tensor, remove batch dimension
		input_ids = tokenized.input_ids.squeeze()

		# Get tensor, remove batch dimension
		attention_mask = tokenized.attention_mask.squeeze()

		start_idx = current_length
		end_idx = start_idx + len(input_ids) - 1

		all_input_ids.append(input_ids)
		all_attention_masks.append(attention_mask)
		all_weight_ranges.append((start_idx, end_idx, item['weight']))

		# Update current length
		current_length = end_idx + 1

	concatenated_input_ids = torch.cat(all_input_ids, dim=0) if all_input_ids else torch.tensor([], dtype=torch.long)
	concatenated_attention_masks = torch.cat(all_attention_masks, dim=0) if all_attention_masks else torch.tensor([], dtype=torch.long)

	expanded_weight_ranges = torch.tensor([], dtype=torch.long)

	# Convert the weight ranges
	for start_idx, end_idx, weight in all_weight_ranges:
		# Turn the weight into an integer
		weight = int(weight * 100)
		expanded_weight_ranges = torch.cat([expanded_weight_ranges, torch.tensor([weight] * (end_idx - start_idx + 1))])

	# If there are no weight ranges, we return a tensor of ones
	if len(expanded_weight_ranges) == 0:
		expanded_weight_ranges = torch.ones_like(concatenated_input_ids)

	# Pad all the tensors to the same length (max_seq_length)
	concatenated_input_ids = torch.cat([concatenated_input_ids, torch.zeros(max_seq_length - concatenated_input_ids.size(0), dtype=torch.long)])
	concatenated_attention_masks = torch.cat([concatenated_attention_masks, torch.zeros(max_seq_length - concatenated_attention_masks.size(0), dtype=torch.long)])
	expanded_weight_ranges = torch.cat([expanded_weight_ranges, torch.zeros(max_seq_length - expanded_weight_ranges.size(0), dtype=torch.long)])

	return {
		"input_ids": concatenated_input_ids,
		"attention_mask": concatenated_attention_masks,
		"labels": concatenated_input_ids.clone(),
		"weights": expanded_weight_ranges,
	}

tokenized_train_dataset = dataset.map(generate_and_tokenize_pieces, remove_columns=["pieces"])

# My naive custom Trainer class with a custom weighted loss computation
class WeightedLossTrainer(transformers.Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):

        # Pop off my custom weights property
        weights = inputs.pop("weights")

        # Get the labels
        labels = inputs.get("labels")

        # This always takes a long, long time and OOMs the GPU
        outputs = model(**inputs)

        logits = outputs.get("logits")

        batch_size, seq_len, num_classes = logits.shape

        total_weighted_loss = 0.0
        total_weights = 0.0

        for batch_idx in range(batch_size):
            for seq_idx in range(seq_len):
                weight = weights[batch_idx, seq_idx]
                if weight > 0:  # Only consider tokens that have a weight > 0
                    token_logits = logits[batch_idx, seq_idx]
                    token_label = labels[batch_idx, seq_idx]
                    token_loss = F.cross_entropy(token_logits.unsqueeze(0), token_label.unsqueeze(0), reduction='none')

                    weighted_token_loss = token_loss * (weight / 100)

                    total_weighted_loss += weighted_token_loss.item()
                    total_weights += (weight / 100)

        # Compute the mean loss.
        mean_loss = total_weighted_loss / total_weights if total_weights > 0 else 0.0
        mean_loss = torch.tensor(mean_loss, dtype=torch.float32, device=logits.device, requires_grad=True)

        return (mean_loss, outputs) if return_outputs else mean_loss

training_args = transformers.TrainingArguments(
      per_device_train_batch_size = 2,
      gradient_accumulation_steps = 4,
      warmup_steps = 5,
      max_steps = 60,
      learning_rate = 2e-4,
      fp16 = not is_bfloat16_supported(),
      bf16 = is_bfloat16_supported(),
      logging_steps = 5,
      optim = "adamw_8bit",
      weight_decay = 0.01,
      lr_scheduler_type = "linear",
      seed = 3407,
      output_dir = "outputs",
      remove_unused_columns=False,
)

trainer = WeightedLossTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = tokenized_train_dataset,
    args = training_args,
)

trainer_stats = trainer.train()

Is this because I'm bypassing some kind of Unsloth optimization, or is what I'm doing just ... wrong?

skerit avatar May 30 '24 12:05 skerit

You need to use autocasting ie

with torch.cuda.amp.autocast(dtype = torch.bfloat16):
    model(...)

danielhanchen avatar May 30 '24 18:05 danielhanchen

@skerit did you manage to get it working. I am also working on a similar problem and would love to start with your code

KamranMK avatar Dec 03 '24 22:12 KamranMK

@KamranMK I found this it really helped me a lot!

You should not need a new Custom trainer for a custom loss function

https://huggingface.co/docs/evaluate/transformers_integrations

darkness8i8 avatar Dec 15 '24 10:12 darkness8i8