torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

DPO supports multi-device training

Open yechenzhi opened this issue 1 year ago • 11 comments

As previously discussed, we will support the parallel training of DPO. However, it seems that the content of the parallel training config file is almost the same as that of the single device, with only the difference of enabling activation checkpointing. Is it necessary to have two separate files?

yechenzhi avatar Apr 16 '24 08:04 yechenzhi

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/771

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 03dac6dab4fa9e565f34c74663725060903275d8 with merge base cd779783f9acecccbebc3c50265f6caf97fa99aa (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Apr 16 '24 08:04 pytorch-bot[bot]

Thanks @yechenzhi for opening the PR! Sorry for the delay here, this one slipped through the cracks over the excitement of the past couple days. I will take a proper look today.

Regarding the question on separate configs vs a single config, this is something we've seen in other recipes as well. It's OK to have two very similar config files (and in fact we prefer it in this case). The reason is the following: there is a natural mapping between recipe files and config files, we need this to determine whether a given (recipe, config) pair is valid (e.g. I cannot run our Mistral full-finetune config with your DPO single-device recipe). While a single recipe may work with many different configs, we want to restrict that a config may only work with a single recipe. This prevents the possibility of a many-to-many mapping, which is harder to reason about and maintain. It also allows for easier experimentation with advanced features (e.g. some low-precision optimizers like 8-bit AdamW from bitsandbytes do not interact well with FSDP, by having separate configs for single-device and distributed we can support such features in a more official capacity).

ebsmothers avatar Apr 19 '24 16:04 ebsmothers

Main question is around testing: are you able to kick off a proper distributed run and see decreasing losses here? Specifically one concern I have is around the interaction with FSDP.. since we are doing multiple forward passes of the model (one with no_grad) and modifying the internals in nontrivial ways in between, there are potentially some gotchas here. So let's definitely make sure it trains properly. Happy to help out on this front if needed, just let me know.

Affirmative, I conducted distributed training using two Nvidia RTX 4090 GPUs. The loss appeared normal; however, I encountered an issue where utilizing two GPUs did not expedite the training process. I also attempted the 'tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora' command, leading me to suspect that the bottleneck may be due to the RTX 4090 GPUs. loss: image similar to here

yechenzhi avatar Apr 20 '24 04:04 yechenzhi

The loss appeared normal; however, I encountered an issue where utilizing two GPUs did not expedite the training process. I also attempted the 'tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora' command, leading me to suspect that the bottleneck may be due to the RTX 4090 GPUs.

@yechenzhi thanks for the info. The loss curve looks good so no major concerns from my side. I am curious about this comment though. In general we expect slow per-iteration speed due to extra communication overhead of using FSDP. However, this should be balanced by the ability to fit larger batch sizes than you otherwise would. Out of curiosity, are you able to increase the per-device batch size beyond 4 on your RTX 4090s? And if so, does this improve the training time?

I will also patch your changes and run some experiments on my end.

ebsmothers avatar Apr 20 '24 22:04 ebsmothers

Out of curiosity, are you able to increase the per-device batch size beyond 4 on your RTX 4090s?

Yes, I set per-device batch size to 6 now.

And if so, does this improve the training time?

When compared to single-device training, there was no improvement in training speed. I suspect this is because the 4090 model lacks NVLink technology (though I'm not certain, as I'm not familiar with FSDP or NVLink).

I will also patch your changes and run some experiments on my end.

Okay, you can take a look at this.

yechenzhi avatar Apr 21 '24 03:04 yechenzhi

@yechenzhi my loss (admittedly over only a few hundred iterations) using the commands from the new config file looks more like this: Screenshot 2024-04-21 at 5 14 53 PM

I may be missing something obvious here though, if so please let me know. I know you mentioned filtering out longer examples from the training set which I did not do. But does this really make such a substantial difference in the loss curve?

ebsmothers avatar Apr 22 '24 00:04 ebsmothers

I may be missing something obvious here though, if so please let me know. I know you mentioned filtering out longer examples from the training set which I did not do. But does this really make such a substantial difference in the loss curve?

Filtering out longer examples is implemented here and here. I've checked the 'stack-exchange-paired' dataset, and many of the longer examples involve writing codes, which could make them more challenging to learn from. Additionally, I've been averaging the loss every 10 steps. Could you implement these changes and observe the new loss curve?

Have you tried single-device recipe? If you don't mind, could you share your training recipe with me? I'd like to identify any potential issues.

yechenzhi avatar Apr 22 '24 01:04 yechenzhi

I'll rerun my recipes to see if there are any issues with them.

yechenzhi avatar Apr 22 '24 02:04 yechenzhi

I'll rerun my recipes to see if there are any issues with them.

Here is the code I have written for visualizing loss curves:

import numpy as np
import matplotlib.pyplot as plt
import re

filename = '/tmp/lora_dpo_output/log_1713751433.txt' # multi-device

steps = []
loss = []
lr = []
rewards_chosen = []
rewards_rejected = []
rewards_accuracies = []
rewards_margins = []
log_probs_rejected = []
log_probs_chosen = []
logits_rejected = []
logits_chosen = []
gpu_resources = []

with open(filename, 'r') as file:
    for line in file:
        match = re.match(r'Step (\d+) \| loss:(\S+) lr:(\S+) rewards/chosen:(\S+) rewards/rejected:(\S+) rewards/accuracies:(\S+) rewards/margins:(\S+) log_probs/rejected:(\S+) log_probs/chosen:(\S+) logits/rejected:(\S+) logits/chosen:(\S+) gpu_resources:(\S+)', line)
        if match:
            steps.append(int(match.group(1)))
            loss.append(float(match.group(2)))
            lr.append(float(match.group(3)))
            rewards_chosen.append(float(match.group(4)))
            rewards_rejected.append(float(match.group(5)))
            rewards_accuracies.append(float(match.group(6)))
            rewards_margins.append(float(match.group(7)))
            log_probs_rejected.append(float(match.group(8)))
            log_probs_chosen.append(float(match.group(9)))
            logits_rejected.append(float(match.group(10)))
            logits_chosen.append(float(match.group(11)))
            gpu_resources.append(float(match.group(12)))

steps = np.array(steps)
loss = np.array(loss)
lr = np.array(lr)
rewards_chosen = np.array(rewards_chosen)
rewards_rejected = np.array(rewards_rejected)
rewards_accuracies = np.array(rewards_accuracies)
rewards_margins = np.array(rewards_margins)
log_probs_rejected = np.array(log_probs_rejected)
log_probs_chosen = np.array(log_probs_chosen)
logits_rejected = np.array(logits_rejected)
logits_chosen = np.array(logits_chosen)
gpu_resources = np.array(gpu_resources)

window_size = 80

def get_avg(arr, window_size):
    avg_arr = []
    for i in range(0, len(arr), window_size):
        avg = np.mean(arr[i:i+window_size])
        avg_arr.append(avg)

    avg_arr = np.array(avg_arr)
    return avg_arr

avg_steps = get_avg(steps,window_size)
avg_lr = get_avg(lr, window_size)
avg_loss = get_avg(loss, window_size)
avg_rewards_accuracies = get_avg(rewards_accuracies, window_size)
avg_rewards_rejected = get_avg(rewards_rejected, window_size)
avg_rewards_chosen = get_avg(rewards_chosen, window_size)
avg_rewards_margins = get_avg(rewards_margins, window_size)
avg_log_probs_rejected = get_avg(log_probs_rejected, window_size)
avg_log_probs_chosen = get_avg(log_probs_chosen, window_size)
avg_logits_rejected = get_avg(logits_rejected, window_size)
avg_logits_chosen = get_avg(logits_chosen, window_size)
avg_gpu_resources = get_avg(gpu_resources, window_size)

plt.figure(figsize=(10, 6))
plt.plot(avg_steps, avg_loss, label='Average Loss')
plt.xlabel('Step')
plt.ylabel('Average')
plt.title('Average Metrics over Steps')
plt.legend()
plt.grid(True)
plt.show()

The window size is set to 80, as the gradient accumulator steps are 8, and I have averaged over 10 steps.

yechenzhi avatar Apr 22 '24 09:04 yechenzhi

@yechenzhi apologies for the delayed response, I am just now getting to re-running the DPO recipe now. The loss curve I shared previously is run directly off of this PR via the command

tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama2/7B_lora_dpo

I know you mentioned some data filtering up front in this comment, are you doing that here as well? I'll let this one run for a bit longer, plot the smoothed loss curve, and get back to you once I have the results.

Update: Here is the loss curve I'm getting now. I don't smooth exactly the same as in the code you provided above, but still it looks generally similar. This is with the data filtering from your other comment, just curious if you find this impacts the results substantially? Modulo that question I think this PR is looking good to merge.

Screenshot 2024-04-22 at 2 20 42 PM

ebsmothers avatar Apr 22 '24 19:04 ebsmothers

I know you mentioned some data filtering up front in this comment, are you doing that here as well?

Yes, I applied data filtering in both TRL and Torchtune.

I don't smooth exactly the same as in the code you provided above, but still it looks generally similar. This is with the data filtering from your other comment, just curious if you find this impacts the results substantially?

Indeed, I've observed a similar impact. I attribute it to the substantial size of the full dataset. It seems there's a discernible orderliness within the dataset, as previously mentioned. However, onsidering our training duration of only 1000 steps and the shuffling of the dataset, focusing on the initial 64000 * 5 (approximately 70000 samples after filtering) makes the training process eaiser.

yechenzhi avatar Apr 23 '24 01:04 yechenzhi

Thanks @yechenzhi for patiently bearing with all my questions here. I believe this PR should be good to go.

ebsmothers avatar Apr 23 '24 03:04 ebsmothers