KTO training produces NaN rewards
Within the training with KTO Trainer I occasionally experience nan values as rewards.
I am running the training as a job on Ms Azure with one GPU (NVIDIA A100 80GB PCIe).
Ultimately these issues cause my Azure job to crash and retry...
The log output I get from the KTOTrainer:
{'loss': 0.0202, 'grad_norm': 0.016206106171011925, 'learning_rate': 8.595282268751656e-06, 'rewards/chosen': 11.158143043518066, 'rewards/rejected': -29.0671443939209, 'rewards/margins': 40.22528839111328, 'kl': 0.0, 'logps/chosen': -15.192975044250488, 'logps/rejected': -180.1438446044922, 'epoch': 0.43}
{'loss': 0.0155, 'grad_norm': 4.091757774353027, 'learning_rate': 8.568778160614896e-06, 'rewards/chosen': 10.752923965454102, 'rewards/rejected': -26.606868743896484, 'rewards/margins': 37.35979461669922, 'kl': 0.0, 'logps/chosen': -13.974691390991211, 'logps/rejected': -156.9815673828125, 'epoch': 0.44}
{'loss': 0.0124, 'grad_norm': 0.06709074974060059, 'learning_rate': 8.542274052478135e-06, 'rewards/chosen': 10.838713645935059, 'rewards/rejected': -29.24416732788086, 'rewards/margins': 40.08287811279297, 'kl': 0.0, 'logps/chosen': -10.99155044555664, 'logps/rejected': -165.8121795654297, 'epoch': 0.44}
{'loss': 0.0113, 'grad_norm': 14.28693675994873, 'learning_rate': 8.515769944341374e-06, 'rewards/chosen': 11.07004451751709, 'rewards/rejected': -30.99440574645996, 'rewards/margins': 42.064453125, 'kl': 0.0, 'logps/chosen': -13.967004776000977, 'logps/rejected': -176.50094604492188, 'epoch': 0.45}
{'loss': 0.0193, 'grad_norm': 3.899095296859741, 'learning_rate': 8.489265836204611e-06, 'rewards/chosen': 10.825413703918457, 'rewards/rejected': -34.434303283691406, 'rewards/margins': 45.25971984863281, 'kl': 0.0, 'logps/chosen': -12.9598388671875, 'logps/rejected': -186.38381958007812, 'epoch': 0.46}
{'loss': 0.0109, 'grad_norm': 0.009407841600477695, 'learning_rate': 8.46276172806785e-06, 'rewards/chosen': nan, 'rewards/rejected': -33.95360565185547, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': nan, 'logps/rejected': -176.4713897705078, 'epoch': 0.47}
{'loss': 0.0324, 'grad_norm': 17.832523345947266, 'learning_rate': 8.43625761993109e-06, 'rewards/chosen': 10.286358833312988, 'rewards/rejected': -33.60068893432617, 'rewards/margins': 43.887046813964844, 'kl': 0.0, 'logps/chosen': -20.224634170532227, 'logps/rejected': -184.18112182617188, 'epoch': 0.48}
{'loss': 0.0029, 'grad_norm': 0.03802444413304329, 'learning_rate': 8.409753511794329e-06, 'rewards/chosen': 10.086004257202148, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -12.816671371459961, 'logps/rejected': nan, 'epoch': 0.48}
{'loss': 0.012, 'grad_norm': 2.815098524093628, 'learning_rate': 8.383249403657568e-06, 'rewards/chosen': 10.549690246582031, 'rewards/rejected': -31.304590225219727, 'rewards/margins': 41.85428237915039, 'kl': 0.0, 'logps/chosen': -13.178544998168945, 'logps/rejected': -169.447509765625, 'epoch': 0.49}
{'loss': 0.0074, 'grad_norm': 0.001768477726727724, 'learning_rate': 8.356745295520805e-06, 'rewards/chosen': 11.22235107421875, 'rewards/rejected': -33.09156799316406, 'rewards/margins': 44.31391906738281, 'kl': 0.0, 'logps/chosen': -13.94648265838623, 'logps/rejected': -178.08566284179688, 'epoch': 0.5}
{'loss': 0.0055, 'grad_norm': 8.117822647094727, 'learning_rate': 8.330241187384045e-06, 'rewards/chosen': 11.166982650756836, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -14.707374572753906, 'logps/rejected': nan, 'epoch': 0.51}
{'loss': 0.0206, 'grad_norm': 1.6973105669021606, 'learning_rate': 8.303737079247284e-06, 'rewards/chosen': 10.326757431030273, 'rewards/rejected': -33.753868103027344, 'rewards/margins': 44.08062744140625, 'kl': 0.0, 'logps/chosen': -19.15297508239746, 'logps/rejected': -181.1234130859375, 'epoch': 0.52}
{'loss': 0.0136, 'grad_norm': 9.740607261657715, 'learning_rate': 8.277232971110523e-06, 'rewards/chosen': 10.298160552978516, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -15.718103408813477, 'logps/rejected': nan, 'epoch': 0.52}
my pip freeze:
accelerate==0.28.0
aiohttp==3.9.3
aiosignal==1.3.1
async-timeout==4.0.3
attrs==23.2.0
bitsandbytes==0.43.0
certifi==2024.2.2
charset-normalizer==3.3.2
datasets==2.18.0
dill==0.3.8
docstring_parser==0.16
filelock==3.13.1
frozenlist==1.4.1
fsspec==2024.2.0
huggingface-hub==0.21.4
idna==3.6
Jinja2==3.1.3
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.2.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
packaging==24.0
pandas==2.2.1
peft==0.9.0
protobuf==5.26.0
psutil==5.9.8
pyarrow==15.0.1
pyarrow-hotfix==0.6
Pygments==2.17.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
rich==13.7.1
safetensors==0.4.2
sentencepiece==0.2.0
shtab==1.7.1
six==1.16.0
sympy==1.12
tokenizers==0.15.2
torch==2.2.1
tqdm==4.66.2
transformers==4.38.2
triton==2.2.0
trl @ git+https://github.com/huggingface/trl@a2aa0f0b09671eaf81a945eb5e4913165fee92fa
typing_extensions==4.10.0
tyro==0.7.3
tzdata==2024.1
urllib3==2.2.1
xxhash==3.4.1
yarl==1.9.4
the training script I use:
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, BitsAndBytesConfig
from trl import KTOConfig, KTOTrainer, ModelConfig
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
import torch
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""
dataset_path: Optional[str] = field(default=None, metadata={"help": "the online dataset to use, should include keys: [prompt, completion, label] OR [messages, completion, label]"})
data_files: Optional[str] = field(default=None, metadata={"help": "the file(s) including data to use, this looks for 'data/{data_files}_train/test.jsonl.gz'. Datasets should include keys: [prompt, completion, label] OR [messages, completion, label]"})
file_type: Optional[str] = field(default=None, metadata={"help": "the file type to open, e.g. 'json', 'csv'"})
max_tokens: Optional[str] = field(default=4096, metadata={"help": "the maximum number of tokens returned by the data collator"})
# debugging
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
print(f"train with {script_args}, \n{model_args}")
# Peft & Quantisation
quantization_config = BitsAndBytesConfig(load_in_8bit=model_args.load_in_8bit)
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=model_args.lora_r, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout)
# Load the trainable model
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
quantization_config = quantization_config,
torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype != None else None,
device_map = "auto")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config = peft_config)
model.print_trainable_parameters()
# Reference Model
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
quantization_config = quantization_config,
torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype != None else None,
device_map = "auto")
model_ref = prepare_model_for_kbit_training(model_ref)
model_ref = get_peft_model(model_ref, peft_config=peft_config)
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.truncation_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load the desired dataset
if script_args.dataset_path != None:
dataset = load_dataset(script_args.dataset_path)
elif script_args.data_files != None and script_args.file_type != None:
dataset = load_dataset(script_args.file_type, data_files={"train": f"./data/{script_args.data_files}_train.jsonl.gz", "test": f"./data/{script_args.data_files}_test.jsonl.gz"})
else:
print("either dataset_path or data_files & file_type have to be defined")
exit(1)
if script_args.sanity_check == True:
dataset["train"] = dataset["train"].select(range(1000))
# Create Split if not existing already
if "test" not in dataset:
dataset = dataset["train"].train_test_split(train_size=0.9)
# apply chat template if not preformatted
if "prompt" not in dataset["train"].features:
dataset = dataset.map(lambda x: {"prompt": tokenizer.apply_chat_template(x["messages"], tokenize=False, add_generation_prompt=False)})
# Set max. lengths for DefaultDataCollator
max_prompt_len, max_compl_len, max_len = 0, 0, 0
tokenizer.model_max_length = script_args.max_tokens
tokenizer.max_model_input_sizes = script_args.max_tokens
for sample in dataset["train"]:
compl_len = len(tokenizer(sample["completion"], truncation=True)["input_ids"])
total_len = len(tokenizer(sample["prompt"] + sample["completion"], truncation=True)["input_ids"])
prompt_len = total_len - compl_len
max_prompt_len = max(max_prompt_len, prompt_len)
max_compl_len = max(max_compl_len, compl_len)
max_len = max(max_len, total_len)
kto_args.max_prompt_length = max_prompt_len
kto_args.max_completion_length = max_compl_len
kto_args.max_length = max_len
print(dataset)
print(f"max_prompt_length={kto_args.max_prompt_length}, max_completion_length={kto_args.max_completion_length}, max_len={kto_args.max_length}")
# set desired/undesired weights
desired_weight = len(dataset['train']) / (2 * len(dataset["train"].filter(lambda d: d["label"] == True)))
undesired_weight = len(dataset['train']) / (2 * len(dataset["train"].filter(lambda d: d["label"] == False)))
kto_args.desirable_weight = desired_weight
kto_args.undesirable_weight = undesired_weight
# initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
model_ref,
args=kto_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer
)
# train
kto_trainer.train()
the call arguments
python train_kto.py \
--model_name_or_path DiscoResearch/DiscoLM_German_7b_v1 \
--data_files wp_rag_kto_20k \
--file_type json \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--num_train_epochs 3 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 2 \
--logging_steps 10 \
--evaluation_strategy epoch \
--output_dir kto_finetuned \
--optim adamw_bnb_8bit \
--warmup_steps 10 \
--logging_first_step \
--use_peft \
--lora_r 8 \
--lora_alpha 16 \
--report_to none \
--disable_tqdm False \
--beta 0.5 \
--torch_dtype bfloat16 \
--bf16 \
--load_in_8bit
Maybe @lewtun can help
cc also @kashif
@claralp depending on the batch-size it could be some of the metrics are nan, this should not effect the training etc. and special attention has been paid to make sure the loss etc. is robust to these nans when doing back-prop.
@claralp i do not think nans in a dict should cause this to crash... do you have some crash back-traces?
@kashif there are no errors or warnings in the stdout/stderr, it just stops at some point after the nan rewards appear, so I cannot provide a stack trace here.
However, the Azure execution wrapper log shows a blocking process:
2024-03-19T03:33:30.165457Z ERROR Execution::wait_for_completion{parent_span=Span { name: "Execution::spawn", level: Level(Info), target: "execution_wrapper::execution", id: Id(6755674318962691), module_path: "execution_wrapper::execution", line: 163, file: "executor/execution-wrapper/src/execution/mod.rs" } process_manager=Mutex { data: ProcessManager { dangling_processes: [], user_process_groups: [34] } }}: execution_wrapper::execution::process_manager: Failed blocking user process detected, process name: echo, process pid: 34, code: None success_return_code=Zero { additional_codes: [] } code=None
2024-03-19T03:33:31.167084Z ERROR Execution::wait_for_completion{parent_span=Span { name: "Execution::spawn", level: Level(Info), target: "execution_wrapper::execution", id: Id(6755674318962691), module_path: "execution_wrapper::execution", line: 163, file: "executor/execution-wrapper/src/execution/mod.rs" } process_manager=Mutex { data: ProcessManager { dangling_processes: [], user_process_groups: [34] } }}: execution_wrapper::execution: Execution process terminated by a signal, which may be due to failure in other user processes on the same node or node ran out of memory. local_rank=0 name=echo
lifecycler log shows only a Preemption signal:
2024-03-19T03:33:29.494161Z WARN run_lifecycler:run_service_and_step_through_lifecycle:step_through_lifecycle: lifecycler::lifecycle: Received abort message, exiting lifecycle abort_message=AbortMessage { error: Some(Error { code: "ReceivedPreemptionSignal", message: "{\"Compliant\":\"Job was terminated due to: Runtime received a preemption signal.\"}", target: "", node_info: None, category: UserError, error_details: [], inner_error: None }), broadcast_abort: true, request_timeout: 25 }
I think this is could be the "normal" low-prioity Azure preemption? :-(
Important note here: The crash only appears after the training shows nan values. Otherwise it doesn't.
I even saw cases where all results converge to nan values
{'loss': 0.0, 'grad_norm': 281.6248474121094, 'learning_rate': 9.856115107913668e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.17875319719314575, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.11}
{'loss': 0.0, 'grad_norm': 192.08326721191406, 'learning_rate': 9.848121502797762e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 1.0570355653762817, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.11}
{'loss': 0.0, 'grad_norm': 33.55568313598633, 'learning_rate': 9.840127897681853e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 2.1016669273376465, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.12}
{'loss': 0.0, 'grad_norm': 44.5154914855957, 'learning_rate': 9.832134292565947e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 2.197722911834717, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.12}
{'loss': 0.0, 'grad_norm': 10.592936515808105, 'learning_rate': 9.82414068745004e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 1.0713751316070557, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.13}
{'loss': 0.0, 'grad_norm': 61.1552734375, 'learning_rate': 9.81614708233413e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.3863883912563324, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.13}
Could there be anything wrong with the hyperparameter choice, @kashif ?
@claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans?
also does this happen if you try locally outside of the azure
The output below is from a test with very unbalanced data, namely 2k desired completions and 10k undesired ones.
I know that a ratio between 4:3 and 1:1 is required for proper training.
This is just an experiment to see if missing pos/neg samples in a batch might be the reason behind nan values as rewards.
But here I get nan losses even without nan rewards...
{'loss': 1.0431, 'grad_norm': 42.099464416503906, 'learning_rate': 1.0000000000000002e-06, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/margins': 0.0, 'kl': 0.0, 'logps/chosen': -37.16696548461914, 'logps/rejected': -87.62107849121094, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 41.9438362121582, 'learning_rate': 2.0000000000000003e-06, 'rewards/chosen': 0.0, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -32.92508316040039, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 29.28327178955078, 'learning_rate': 3e-06, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.15479230880737305, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 32.70748519897461, 'learning_rate': 4.000000000000001e-06, 'rewards/chosen': 0.06518054008483887, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.43951892852783203, 'logps/chosen': -31.101844787597656, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 44.989227294921875, 'learning_rate': 5e-06, 'rewards/chosen': 0.3087962865829468, 'rewards/rejected': 0.23543643951416016, 'rewards/margins': 0.07335984706878662, 'kl': 1.230994462966919, 'logps/chosen': -32.83413314819336, 'logps/rejected': -74.81724548339844, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 55.32667541503906, 'learning_rate': 6e-06, 'rewards/chosen': 0.3336696922779083, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.3016533851623535, 'logps/chosen': -38.598453521728516, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 32.44403839111328, 'learning_rate': 7e-06, 'rewards/chosen': 0.8524215817451477, 'rewards/rejected': 0.5893988609313965, 'rewards/margins': 0.2630227208137512, 'kl': 0.7648882865905762, 'logps/chosen': -35.86614227294922, 'logps/rejected': -93.13447570800781, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 26.85154914855957, 'learning_rate': 8.000000000000001e-06, 'rewards/chosen': 0.8056153059005737, 'rewards/rejected': 0.40718716382980347, 'rewards/margins': 0.39842814207077026, 'kl': 1.3891675472259521, 'logps/chosen': -34.07681655883789, 'logps/rejected': -113.53411102294922, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 25.181703567504883, 'learning_rate': 9e-06, 'rewards/chosen': nan, 'rewards/rejected': 0.9289813041687012, 'rewards/margins': nan, 'kl': 1.279036521911621, 'logps/chosen': nan, 'logps/rejected': -132.0060272216797, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 36.62141799926758, 'learning_rate': 1e-05, 'rewards/chosen': 1.4094278812408447, 'rewards/rejected': 0.8396401405334473, 'rewards/margins': 0.5697878003120422, 'kl': 2.0255985260009766, 'logps/chosen': -30.87615394592285, 'logps/rejected': -102.92286682128906, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 43.035221099853516, 'learning_rate': 9.997300215982722e-06, 'rewards/chosen': 1.5928469896316528, 'rewards/rejected': 1.5922844409942627, 'rewards/margins': 0.0005625784397125244, 'kl': 2.884922981262207, 'logps/chosen': -39.46299362182617, 'logps/rejected': -121.78970336914062, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 33.07608413696289, 'learning_rate': 9.994600431965443e-06, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 3.1301448345184326, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 43.48128128051758, 'learning_rate': 9.991900647948165e-06, 'rewards/chosen': 2.113973617553711, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 3.475428819656372, 'logps/chosen': -26.679065704345703, 'logps/rejected': nan, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 31.501819610595703, 'learning_rate': 9.989200863930886e-06, 'rewards/chosen': 2.6266024112701416, 'rewards/rejected': 2.2295963764190674, 'rewards/margins': 0.3970060348510742, 'kl': 4.643209934234619, 'logps/chosen': -42.25154495239258, 'logps/rejected': -95.91471862792969, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 34.09553527832031, 'learning_rate': 9.986501079913607e-06, 'rewards/chosen': 2.7660703659057617, 'rewards/rejected': 2.6509010791778564, 'rewards/margins': 0.11516910791397095, 'kl': 4.8384199142456055, 'logps/chosen': -49.93422317504883, 'logps/rejected': -73.00190734863281, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 30.591957092285156, 'learning_rate': 9.983801295896329e-06, 'rewards/chosen': 3.131122350692749, 'rewards/rejected': 2.9620559215545654, 'rewards/margins': 0.1690664291381836, 'kl': 4.498130798339844, 'logps/chosen': -29.836196899414062, 'logps/rejected': -105.75230407714844, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 13.737163543701172, 'learning_rate': 9.98110151187905e-06, 'rewards/chosen': nan, 'rewards/rejected': 3.1204824447631836, 'rewards/margins': nan, 'kl': 6.049262523651123, 'logps/chosen': nan, 'logps/rejected': -96.40724182128906, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 30.375396728515625, 'learning_rate': 9.978401727861771e-06, 'rewards/chosen': nan, 'rewards/rejected': 3.636046886444092, 'rewards/margins': nan, 'kl': 6.3599958419799805, 'logps/chosen': nan, 'logps/rejected': -97.00442504882812, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 27.26076889038086, 'learning_rate': 9.975701943844493e-06, 'rewards/chosen': 4.384129524230957, 'rewards/rejected': 3.9822707176208496, 'rewards/margins': 0.40185898542404175, 'kl': 8.23063850402832, 'logps/chosen': -24.248661041259766, 'logps/rejected': -105.89572143554688, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 18.513507843017578, 'learning_rate': 9.973002159827214e-06, 'rewards/chosen': 4.265963077545166, 'rewards/rejected': 3.8863425254821777, 'rewards/margins': 0.3796207308769226, 'kl': 6.635190010070801, 'logps/chosen': -24.802963256835938, 'logps/rejected': -68.99553680419922, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 25.997692108154297, 'learning_rate': 9.970302375809935e-06, 'rewards/chosen': 5.037494659423828, 'rewards/rejected': 4.227317810058594, 'rewards/margins': 0.8101770877838135, 'kl': 8.07493782043457, 'logps/chosen': -24.345657348632812, 'logps/rejected': -74.88150024414062, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 26.245861053466797, 'learning_rate': 9.967602591792658e-06, 'rewards/chosen': 4.526309490203857, 'rewards/rejected': 4.603299140930176, 'rewards/margins': -0.07698965072631836, 'kl': 8.698637008666992, 'logps/chosen': -22.94290542602539, 'logps/rejected': -99.22356414794922, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 22.14063835144043, 'learning_rate': 9.964902807775378e-06, 'rewards/chosen': 5.355809211730957, 'rewards/rejected': 4.891297340393066, 'rewards/margins': 0.464511513710022, 'kl': 8.954204559326172, 'logps/chosen': -23.850910186767578, 'logps/rejected': -87.7445068359375, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 25.642059326171875, 'learning_rate': 9.962203023758101e-06, 'rewards/chosen': 5.606294631958008, 'rewards/rejected': 6.807004928588867, 'rewards/margins': -1.2007099390029907, 'kl': 9.733396530151367, 'logps/chosen': -24.039264678955078, 'logps/rejected': -119.2092514038086, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 10.412492752075195, 'learning_rate': 9.959503239740822e-06, 'rewards/chosen': 5.953470230102539, 'rewards/rejected': 5.025949954986572, 'rewards/margins': 0.9275206327438354, 'kl': 10.74533462524414, 'logps/chosen': -16.727996826171875, 'logps/rejected': -80.9796142578125, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 17.695709228515625, 'learning_rate': 9.956803455723542e-06, 'rewards/chosen': nan, 'rewards/rejected': 6.109594345092773, 'rewards/margins': nan, 'kl': 11.900070190429688, 'logps/chosen': nan, 'logps/rejected': -121.30842590332031, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 35.035892486572266, 'learning_rate': 9.954103671706265e-06, 'rewards/chosen': 6.687896251678467, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 12.4317626953125, 'logps/chosen': -16.511333465576172, 'logps/rejected': nan, 'epoch': 0.02}
kashif commented 1 hour ago @claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans?
batch size is 8 and gradient accumulation steps is 2 as in the config above
also does this happen if you try locally outside of the azure
currently checking this
closed with #1499 and #1514