dilated-attention-pytorch
dilated-attention-pytorch copied to clipboard
Training on yet-another-retnet script
Hello Frank!
I love what you have created, and am having a great time going through and parsing through your implementation of the paper. It appears you have nailed the dilated attention calculation method.
Here are my observations so far:
- I'm relatively new to this field, and am learning a lot of concepts on the go, so bear with me if I miss out finer details!
- I'm trying to train the LM variant that you designed on a text dataset for language modelling. My goal is to ultimately test out how many tokens can my GPUs handle during training and inference (2x RTX 3090).
- During training, it was taking around 8GB RAM to process 1024 tokens, scaling it up I think we can manage around 10k tokens within 50GBs of RAM consumption.
- I was trying to use the training script from this https://github.com/fkodom/yet-another-retnet repo that you created, and while I got to clearing the shape mismatch and other issues, the loss, starts very low (around 0.0004) and then goes to NaN and the iterations stop.
- I'm sharing how I did the training script with you, please let me know if you have any suggestions for me! I want to get a checkpoint that I can later use for inference. Thanks! :)
import os
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import tiktoken
import torch
from lightning import Fabric, seed_everything
from lightning.fabric.loggers import TensorBoardLogger
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm
#from yet_another_retnet.retnet import RetNet
from yet_another_retnet.utils.gutenberg import project_gutenberg_top_100_datapipe
torch.set_float32_matmul_precision("medium")
TOKENIZER = tiktoken.get_encoding("gpt2")
EVAL_PROMPT = "A Lannister always pays his debts."
def collate_fn(
batch: List[str],
max_length: int = 4096,
device: Optional[Union[torch.device, str]] = None,
) -> Tuple[Tensor, Tensor]:
x = torch.zeros(len(batch), max_length, device=device, dtype=torch.long)
y = torch.zeros(len(batch), max_length, device=device, dtype=torch.long)
for i, text in enumerate(batch):
encoding = torch.as_tensor(
TOKENIZER.encode(text), device=device, dtype=torch.long
)
seq_length = min(len(encoding) - 1, max_length)
x[i, :seq_length] = encoding[:seq_length]
y[i, :seq_length] = encoding[1 : seq_length + 1]
return x, y
@dataclass
class TrainingState:
fabric: Fabric
model: LongNetLM
optimizer: torch.optim.Optimizer
callbacks: Sequence[Callable[["TrainingState", float], None]] = ()
current_step: int = 0
current_epoch: int = 0
accumulate_grad_batches: int = 1
monitor: str = "val_loss"
monitor_mode: Literal["min", "max"] = "min"
@dataclass
class ModelCheckpoint:
state_dict: Dict[str, Tensor]
optimizer_state: Dict[str, Tensor]
current_step: int
current_epoch: int
@classmethod
def from_training_state(cls, state: TrainingState) -> "ModelCheckpoint":
return cls(
state_dict=state.model.state_dict(),
optimizer_state=state.optimizer.state_dict(),
current_step=state.current_step,
current_epoch=state.current_epoch,
)
def to_dict(self) -> Dict[str, Any]:
return {
"state_dict": self.state_dict,
"optimizer_state": self.optimizer_state,
"current_step": self.current_step,
"current_epoch": self.current_epoch,
}
def save(self, path: str) -> None:
torch.save(self.to_dict(), path)
@classmethod
def load(cls, path: str) -> "ModelCheckpoint":
checkpoint_dict = torch.load(path)
return cls(**checkpoint_dict)
class CheckpointCallback:
def __init__(
self, save_dir: str, name: str = "checkpoint_epoch-{epoch:03d}.pt"
) -> None:
self.save_dir = save_dir
self.name = name
self.best_path: Optional[str] = None
self.best_loss: Optional[float] = None
def __call__(self, state: TrainingState, loss: float) -> None:
if self.best_loss is None:
self.best_loss = loss
fabric = state.fabric
# 'local_rank == 0' means this only happens for the main process
if fabric.local_rank == 0 and loss <= self.best_loss:
checkpoint = ModelCheckpoint.from_training_state(state)
self.best_loss = loss
if self.best_path is not None:
os.remove(self.best_path)
self.best_path = os.path.join(
self.save_dir, self.name.format(epoch=state.current_epoch)
)
torch.save(checkpoint, self.best_path)
# All processes wait for main to finish saving the checkpoint.
fabric.barrier()
def train_one_epoch(
state: TrainingState,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
log_frequency: int = 25,
) -> None:
state.current_epoch += 1
fabric, model, optimizer = state.fabric, state.model, state.optimizer
is_training = model.training
model.train()
with tqdm(desc=f"Ep: {state.current_epoch}") as progbar:
train_loss, val_loss = 0.0, 0.0
for x, y in train_dataloader:
state.current_step += 1
accumulating = state.current_step % state.accumulate_grad_batches != 0
with fabric.no_backward_sync(model, enabled=accumulating):
loss = model.forward(x)
loss_value = loss.mean()
fabric.backward(loss_value)
if not accumulating:
optimizer.step()
optimizer.zero_grad()
if state.current_step % log_frequency == 0:
loss_scalar = loss.mean().item() # Calculate the mean and extract the scalar value
fabric.log("loss", loss_scalar, step=state.current_step) # Log the scalar loss
train_loss = loss_scalar # Update the train_loss variable
progbar.set_postfix_str(f"loss={train_loss:.4f}", refresh=False)
progbar.update(1)
model.eval()
val_progbar = tqdm(desc="val", position=1, leave=False)
for i, (x) in enumerate(val_dataloader):
with torch.inference_mode():
loss = model.forward(x)
val_loss = (val_loss * i + loss.mean().item()) / (i + 1)
if i % log_frequency == 0:
val_progbar.set_postfix_str(f"val_loss={val_loss:.4f}", refresh=False)
val_progbar.update(1)
progbar.update(1)
fabric.log("val_loss", val_loss, step=state.current_step)
val_progbar.close()
progbar.set_postfix_str(
f"loss={train_loss:.4f}, val_loss={val_loss:.4f}", refresh=False
)
for callback in state.callbacks:
callback(state, val_loss)
model.train(mode=is_training)
def train(
longnet: LongNetLM,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
accelerator: str = "auto",
strategy: str = "auto",
precision: Optional[str] = None,
epochs: int = 10,
lr: float = 3e-4,
log_frequency: int = 25,
):
if precision is None:
if torch.cuda.is_available():
# use bfloat16 if supported
version, _ = torch.cuda.get_device_capability()
precision = "bf16-mixed" if version >= 8 else "16-mixed"
else:
precision = "float32"
logger = TensorBoardLogger(root_dir="./")
fabric = Fabric(
accelerator=accelerator,
strategy=strategy,
precision=precision, # type: ignore
loggers=[logger],
)
fabric.launch()
print(f"Experiment version: {logger.version}")
print("-" * 40)
# Setup with fabric.
optimizer = torch.optim.AdamW(longnet.parameters(), lr=lr)
longnet, optimizer = fabric.setup(longnet, optimizer)
train_dataloader, val_dataloader = fabric.setup_dataloaders(
train_dataloader, val_dataloader
)
# Construct a training state and run the training loop.
state = TrainingState(
fabric=fabric,
model=longnet,
optimizer=optimizer,
callbacks=[CheckpointCallback(save_dir=logger.log_dir)],
)
for _ in range(epochs):
train_one_epoch(
state=state,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
log_frequency=log_frequency,
)
def generate(
longnet: LongNet,
prompt: str,
prompt_chunk_size: Optional[int] = None,
max_new_tokens: int = 4096,
stop_tokens: Sequence[str] = (),
top_k: int = 10,
temperature: float = 1.0,
seed: int = 42,
) -> Iterator[str]:
seed_everything(seed)
device = next(iter(longnet.parameters())).device
is_training = longnet.training
longnet.eval()
# Tokenize the prompt and convert to a tensor.
tokenized = TOKENIZER.encode(prompt)
x = torch.as_tensor(tokenized, dtype=torch.long, device=device).unsqueeze_(0)
if not prompt_chunk_size:
prompt_chunk_size = x.size(1)
prev_states: List[Optional[Tensor]] = [None] * longnet.num_layers
start_idx: int = 0
for start_idx in range(0, x.size(1), prompt_chunk_size):
y, prev_states = longnet.forward(
x, start_idx=start_idx, prev_states=prev_states
)
y = y[:, -1]
# Generate tokens until we reach the maximum number of tokens or a stop token.
for i in range(max_new_tokens):
probs: Tensor = torch.softmax(y.squeeze() / max(temperature, 1e-8), dim=-1)
# Get top-k tokens, renormalize their probabilities, and weighted sample.
tokens: Tensor # for mypy
probs, tokens = probs.topk(k=top_k, dim=-1)
probs /= probs.sum()
# Take weighted random sample from the top-k tokens.
sampled_idx: int = torch.multinomial(probs, num_samples=1).item() # type: ignore
token: int = tokens[sampled_idx].item() # type: ignore
tokenized.append(token)
yield TOKENIZER.decode(tokenized)
token_str: str = TOKENIZER.decode([token])
if token_str in stop_tokens:
break
elif i < (max_new_tokens - 1):
start_idx += 1
x = torch.as_tensor([token], dtype=torch.long, device=device)
y, prev_states = longnet.forward(
x, start_idx, prev_states=prev_states
)
# Restore the model's original training state.
longnet.train(mode=is_training)
def main(
model_checkpoint: Optional[str] = None,
accelerator: str = "auto",
strategy: str = "auto",
precision: Optional[str] = None,
epochs: int = 10,
batch_size: int = 16,
lr: float = 3e-4,
log_frequency: int = 25,
seed: int = 42,
eval_only: bool = False,
eval_prompt: str = EVAL_PROMPT,
eval_max_tokens: int = 1024,
):
seed_everything(seed)
# Create a (relatively small) model and dataloaders
longnet = LongNetLM(
num_tokens=TOKENIZER.n_vocab,
d_model=768,
nhead=12,
num_encoder_layers=12,
num_decoder_layers=12,
dim_feedforward=3072,
segment_lengths = [512,1024, 2048,4096],
dilation_rates = [1, 2, 4, 6],
dropout = 0.1,
activation = F.relu,
layer_norm_eps = 1e-5,
)
if model_checkpoint is not None:
longnet.load_state_dict(ModelCheckpoint.load(model_checkpoint).state_dict)
if not eval_only:
train_dataloader = DataLoader(
project_gutenberg_top_100_datapipe(
split="train",
chunk_size=4096,
step_size=1024,
shuffle=True,
drop_last=True,
),
batch_size=batch_size,
collate_fn=collate_fn,
drop_last=True,
)
val_dataloader = DataLoader(
project_gutenberg_top_100_datapipe(
split="val", chunk_size=4096, step_size=1024
),
batch_size=batch_size,
collate_fn=collate_fn,
)
train(
longnet=longnet,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
accelerator=accelerator,
strategy=strategy,
precision=precision,
epochs=epochs,
lr=lr,
log_frequency=log_frequency,
)
# Generate some text
prev_output: str = ""
for output in generate(longnet, eval_prompt, max_new_tokens=eval_max_tokens):
# Return to the start of the line and print the output (no newline)
print(output[len(prev_output) :], end="", flush=True)
prev_output = output
print()
# Define the default values or provide your desired values
default_model_checkpoint = None
default_accelerator = "auto"
default_strategy = "dp"
default_precision = None
default_epochs = 1
default_batch_size = 1
default_lr = 3e-4
default_log_frequency = 25
default_seed = 42
default_eval_only = False
default_eval_prompt = EVAL_PROMPT
default_eval_max_tokens = 1024
# Replace the argparse-related code
model_checkpoint = default_model_checkpoint
accelerator = default_accelerator
strategy = default_strategy
precision = default_precision
epochs = default_epochs
batch_size = default_batch_size
lr = default_lr
log_frequency = default_log_frequency
seed = default_seed
eval_only = default_eval_only
eval_prompt = default_eval_prompt
eval_max_tokens = default_eval_max_tokens
# Call the main function
main(
model_checkpoint=model_checkpoint,
accelerator=accelerator,
strategy=strategy,
precision=precision,
epochs=epochs,
batch_size=batch_size,
lr=lr,
log_frequency=log_frequency,
seed=seed,
eval_only=eval_only,
eval_prompt=eval_prompt,
eval_max_tokens=eval_max_tokens
)