t5_finetune icon indicating copy to clipboard operation
t5_finetune copied to clipboard

Choice of AdamW vs AdaFactor?

Open FL33TW00D opened this issue 4 years ago • 9 comments

Hi there, Wondering if you can share your reasoning behind using AdamW over AdaFactor?

Thanks for sharing the script. Regards, Chris

FL33TW00D avatar Nov 28 '20 22:11 FL33TW00D

It was the huggingface default (same with the choice of epsilon at 1e-8, and learning rates in the 1e-4 to 1e-5 range). It could be worth trying other optimizers. Have you seen better results with AdaFactor?

jsrozner avatar Nov 30 '20 06:11 jsrozner

It was the huggingface default (same with the choice of epsilon at 1e-8, and learning rates in the 1e-4 to 1e-5 range). It could be worth trying other optimizers. Have you seen better results with AdaFactor?

I've personally been using AdaFactor, based on the recommendations by Google and in the following thread you may have already seen: https://discuss.huggingface.co/t/t5-finetuning-tips/684/12

I plan to do a quantitative comparison of the optimizers soon.

Regards, Chris

FL33TW00D avatar Nov 30 '20 12:11 FL33TW00D

I'd been meaning to read through that post and tune over optimizer as well! I think transformers finetune.py script defaults to Adam (and a lot of the notebooks also seem to use adam).

What params have you been using with adafactor? And how did you settle on them?

With adam I've settled on 1e-4 or 3e-4 LR and linear schedule with epsilon 1e-8 on a dataset of 90k train and 30k eval, but it begins to overfit quite quickly, after only about 13 epochs

jsrozner avatar Dec 07 '20 18:12 jsrozner

Hi @jsrozner, I am not sure if you're still using this repo, but I used it as the baseline of my own and have been making incremental improvements over time, so thank you!

One of the main ways I've found to speed up training is implementing a collate_fn with dynamic padding and uniform batch lengths. If you're still experimenting with T5 thought I'd attach the snippet:

from torch.nn.utils.rnn import pad_sequence
def collate_batch(batch):
    """
    Take a list of samples from a Dataset and collate them into a batch.
    Returns:
        A dictionary of tensors
    """
    pad_token_id = 0
    src_ids = pad_sequence([sample['source_ids'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    src_text = [sample['source_text'] for sample in batch]
    src_mask = pad_sequence([sample['source_mask'] for sample in batch], batch_first=True, padding_value=pad_token_id)

    tgt_ids = pad_sequence([sample['target_ids'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    tgt_ids[tgt_ids[:, :] == 0] = -100
    tgt_mask = pad_sequence([sample['target_mask'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    tgt_text = [sample['target_text'] for sample in batch]

    return {
        'source_ids': src_ids, 
        'target_ids': tgt_ids,
        'source_mask': src_mask, 
        "target_mask": tgt_mask,
        "source_text": src_text, 
        "target_text": tgt_text
    }`

FL33TW00D avatar Jan 29 '21 20:01 FL33TW00D

cool, collate was on my feature list actually! and i'm glad you've found it useful!

i've also been making a lot of changes - i've made it considerably more modular so that everything inherits from a common abstract trainer baseclass.

i'll probably push the changes to this repo once my research calms down a little. i can incorporate the collate function then.

jsrozner avatar Jan 29 '21 20:01 jsrozner

also, huggingface's transformer offers a batch_encode method that should take care of uniform padding and length

jsrozner avatar Feb 04 '21 21:02 jsrozner

Hi @jsrozner, The reason I did it this way is following along with this following blog post: https://wandb.ai/pommedeterresautee/speed_training/reports/Train-HuggingFace-Models-Twice-As-Fast--VmlldzoxMDgzOTI

It means we no longer need to pad to max length when we are batch encoding, and can strategically take batches of similar length samples in order to reduce the amount of padding needed. This really accelerated my training as my mean length of sample is 48 tokens but max is 128.

Can we do this with batch_encode? Would be easier if so.

FL33TW00D avatar Feb 04 '21 22:02 FL33TW00D

I wrote the following, using huggingface tokenizer to handle the batch encoding. It will pad to the max length in a batch.

This also substantially reduces the memory footprint from what I had before.

It means we no longer need to pad to max length when we are batch encoding, and can strategically take batches of similar length samples in order to reduce the amount of padding needed.

Initially I read this to mean that you intentionally collate batches that have similar length sequences, but that probably isn't what you'd want to do if there's any correlation between length and your objective, since then your batches would not be grouped in a truly random way?

This implementation does not attempt to group similarly sized batches together, so if there is a batch where the longest is 100 tokens and all others are 10, it will still pad to 100 for all of them. Huggingface offers a pad_to_second_longest, I think that can help avoid this problem.

For an even larger dataset, where the dataset itself won't fit easily into memory, we'd want to write an IterableDataset.

from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Dict

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer

log = logging.getLogger(__name__)

@dataclass
class DataSetEntry:
	src: str
	tgt: str

@dataclass
class DataLoaderConfig:
	shuffle: bool = True
	batch_size: int = 64
	num_workers: int = 4

@dataclass
class DatasetConfig:
	tokenizer: PreTrainedTokenizer
	max_examples: int = 1  # if not -1, will truncate
	src_len: int = 100
	tgt_len: int = 20
	
class ClueDatasetBatched(Dataset):
	def __init__(self,
	             dataset_config: DatasetConfig,
	             data_dir: str,
	             type_path):
		valid_type_paths = ["test", "train", "val"]
		assert type_path in valid_type_paths, f"Type path must be one of {valid_type_paths}"

		self.example_path = Path(data_dir) / type_path
		self.max_examples = dataset_config.max_examples

		# populated in build
		self._len = None        # the total number of examples
		self.data_list: Optional[List[DataSetEntry]] = None
		
		self._build()  # fill inputs, targets, max_lens

	def __len__(self):
		return self._len

	def __getitem__(self, index):
		return self.data_list[index]

	def _build(self):
		source_path = self.example_path.with_suffix(".source")
		target_path = self.example_path.with_suffix(".target")

		with open(source_path, 'r') as f_source, \
			open(target_path, 'r') as f_target:

			source_lines, target_lines = f_source.readlines(), f_target.readlines()

			# do length calcs
			source_ct, target_ct = len(source_lines), len(target_lines)
			assert source_ct == target_ct, f"Lengths don't match"
			if self.max_examples > 0:
				source_ct = min(self.max_examples, source_ct)
			self._len = source_ct

			self.data_list = []
			for idx in range(source_ct):
				src = source_lines[idx].strip()
				tgt = target_lines[idx].strip()
				self.data_list.append(DataSetEntry(src, tgt))


def collate_fn(tokenizer: PreTrainedTokenizer, batch_list: List[DataSetEntry]) -> Dict:
	src_text = [e.src for e in batch_list]
	tgt_text = [e.tgt for e in batch_list]

	tokenized_inputs = tokenizer(src_text, padding='longest', return_tensors='pt')
	tokenized_outputs = tokenizer(tgt_text, padding='longest', return_tensors='pt')

	source_ids = tokenized_inputs["input_ids"]
	target_ids = tokenized_outputs["input_ids"]
	src_mask = tokenized_inputs["attention_mask"]      # might need to squeeze
	target_mask = tokenized_outputs["attention_mask"]  # might need to squeeze

	# We cast these to torch.long in preprocess batch
	ret = {"source_ids": source_ids,
	       "source_mask": src_mask,
		   "target_ids": target_ids,
		   "target_mask": target_mask,
		   "source_text": src_text,
		   "target_text": tgt_text}

	return ret


def get_dataloader_batched(tokenizer,
                           dataset_config: DatasetConfig,
                           dl_config: DataLoaderConfig,
                           data_dir,
                           type_path: str = None) -> DataLoader:

	def curried_collate_fn(input_list) -> Dict:
		return collate_fn(tokenizer, input_list)

	data_set = ClueDatasetBatched(dataset_config,
	                              data_dir=data_dir,
	                              type_path=type_path)
	dataloader = DataLoader(data_set,
	                        batch_size=dl_config.batch_size,
	                        shuffle=dl_config.shuffle,
	                        num_workers=dl_config.num_workers,
	                        collate_fn=curried_collate_fn)
	log.info(f'Dataset {type_path} loaded with size: {len(data_set)}')
	return dataloader

jsrozner avatar Feb 05 '21 01:02 jsrozner

@FL33TW00D what'd you think about the new implementation?

jsrozner avatar Feb 13 '21 20:02 jsrozner