ColossalAI
ColossalAI copied to clipboard
[PROPOSAL]: refactor core API of Engine
Proposal
Motivation
- Current initialization process is difficult and hard to maintain. It contains hundreds of hard code
if-else
, which is hard to read and modify. - Current
Engine
is hard to use. The usage is very different from native torch, and users may take some effort to learn before starting their first applications. - Current
Engine
is not flexible. It relies on a configuration file or dict and a global context. If we want to run two models with different parallelism method, it's hard to implement this now. It also only supports single model training, which cannot support some famous RL like PPO. - Too many legacy code.
Gemini
and auto-parallelism both have another entry points instead ofEngine
.
Design
We keep engine as the main entry point of colossalai training.
Engine has 6 main components:
- CheckpointIO: manage how to save and load checkpoints.
- Scheduler: manage forward and backward process of pipeline parallelism.
- PrecisionBolt: manage training (mixed) precision, including grad scaler and clipping.
- EnvironmentTable: store distributed environment info, like cluster info, global process group info and device mesh info.
- Accelerator: manage computation device.
- ParallelismPlugin: manage parallelism training.
Engine's features include:
- parallel training (forward, backward or pipeline)
- save / load checkpoint
- mixed precision training
- grad clipping
- (maybe) various devices support
- torch-like grad accumulation (provide
no_sync()
)
Engine is not a singleton, though in the most cases single engine is enough.
Possible sample code (pseudo-code)
# create engine
plugin = GeminiPlugin(stage=3, ...)
engine = Engine(precision='fp16', parallelism_plugin=plugin)
# initialize models, optimizers, lr schedulers
model, optimizer, lr_scheduler = engine.initialize(model, optimizer, lr_scheduler)
# or multi-models
actor, critic, actor_optimizer, critic_optimizer = engine.initialize(actor, critic, actor_optimizer, critic_optimizer)
# forward backward
outputs = model(inputs)
engine.backward(loss, optimizer)
optimizer.step()
# run pipeline (another paradigm)
engine.execute_pipeline(data_iter, model, criterion, optimizer, ...)
optimizer.step()
# HF models generation
sequences = model.generate(input_ids)
# IO Support 2 styles:
# 1. torch style (target path is a file)
# 2. Huggingface style (target path is a directory)
# torch style (don't consider checkpoint size, maybe OOM as for large models)
engine.load(model, 'model.pt', plan='torch')
engine.save(optimizer, 'optimizer.pt', plan='torch')
# huggingface style (save checkpoint in chunks)
engine.save(model, 'checkpoint/gpt2', max_file_size_gb=10, plan='huggingface')
engine.load(optimizer, 'checkpoint/gpt2', plan='huggingface')
Single-model supervised learning train loop without pipeline
colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
engine = Engine(precision='fp16',parallelism_plugin=plugin)
model = GPT2()
optimizer = Adam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
model, optimizer, lr_scheduler, dataloader = engine.initialize(model, optimizer, lr_scheduler, dataloader)
for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids, attention_mask)
loss = criterion(outputs.logits, input_ids)
engine.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Single-model supervised learning train loop with pipeline
colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
engine = Engine(precision='fp16',parallelism_plugin=plugin)
model = GPT2()
optimizer = Adam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
model, optimizer, lr_scheduler, dataloader = engine.initialize(model, optimizer, lr_scheduler, dataloader)
for epoch in range(max_epochs):
num_steps = len(dataloader)
for step in range(num_steps):
loss = engine.execute_pipeline(dataloader, model, criterion, optimizer, return_loss=True, return_outputs=False)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Multi-model RL train loop without pipeline
colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
engine = Engine(precision='fp16',parallelism_plugin=plugin)
actor = GPT2Actor()
critic = GPT2Critic()
actor_optim = Adam(actor.parameters())
critic_optim = Adam(critic.parameters())
actor_loss_fn = ActorLoss()
critic_loss_fn = CriticLoss()
actor, critic, actor_optim, critic_optim = engine.initialize(actor, critic, actor_optim, critic_optim)
for epoch in range(max_epochs):
for experience in replay_buffer:
action_log_probs = actor(experience.sequences)
actor_loss = actor_loss_fn(action_log_probs, experience.old_log_probs, experience.adv)
engine.backward(actor_loss, actor_optim)
actor_optim.step()
actor_optim.zero_grad()
values = critic(experience.sequences)
critic_loss = critic_loss_fn(values, experience.old_values, experience.reward)
engine.backward(loss, critic_optim)
critic_optim.step()
critic_optim.zero_grad()
Possible class definition (pseudo-code)
class Engine:
def __init__(self,
device: Union[str, torch.device] = 'cuda',
precision: str = 'fp32',
grad_clipping_type: str = 'norm',
grad_clipping_value: float = 0.0,
parallelism_plugin: Optional[ParallelismPlugin] = None) -> None:
# sanity check
assert device in parallelism_plugin.supported_devices
assert precision in parallelism_plugin.supported_precisions
self.parallelism_plugin = parallelism_plugin
self.accelerator = None
self.precision_bolt = None
if not parallelism_plugin.control_device:
self.accelerator = Accelerator(device)
if not parallelism_plugin.control_precision:
self.precision_bolt = PrecisionBolt(precision, grad_clipping_type, grad_clipping_value)
self.environment_table = EnvironmentTable(parallelism_plugin.device_mesh_shapes)
self.checkpoint_io = CheckpointIO(self.parallelism_plugin, self.precision_bolt, self.accelerator)
def initialize(self, *args: Union[Module, Optimizer, LRScheduler, DataLoader]) -> List[Union[Module, Optimizer, LRScheduler, DataLoader]]:
rets = []
for arg in args:
if isinstance(arg, Module):
arg = self.parallelism_plugin.setup_model(arg, self.environment_table.device_mesh_pool)
if not self.parallelism_plugin.control_precision:
arg= self.precision_bolt.setup_model(arg)
if not self.parallelism_plugin.control_device:
arg = self.accelerator.setup_model(arg)
elif isinstance(arg, Optimizer):
arg = self.parallelism_plugin.setup_optimizer(arg)
if not self.parallelism_plugin.control_precision:
arg = self.precision_bolt.setup_optimizer(arg)
else:
# TODO
pass
rets.append(arg)
return rets
def backward(self, loss: Tensor, optimizer: Optimizer) -> None:
# do backward when not using pipeline
if not self.parallelism_plugin.control_precision:
loss = self.precision_bolt.scale_loss(loss)
optimizer.backward(loss)
def execute_pipeline(self, data_iter: Iterator, model: Module, criterion: Callable[[Inputs, Outputs], Tensor], optimizer: Optimizer, return_loss: bool = True, return_outputs: bool = False) -> Tuple[Optional[Tensor], ...]:
# run pipeline forward backward pass
# return loss or outputs if needed
pass
def no_sync(self, model: Module) -> Context:
if not self.parallelism_plugin.support_no_sync:
raise RuntimeError()
return model.no_sync()
def save(self, obj: Union[Module, Optimizer, LRScheduler], path_like: str, plan: str = 'torch', **kwargs) -> None:
pass
def load(self, obj: Union[Module, Optimizer, LRScheduler], path_like: str, plan: str = 'torch', **kwargs) -> None:
pass
class EnvironmentTable:
def __init__(self, intra_op_world_sizes: List[int]):
self.world_size
self.rank
self.global_group
self.device_mesh_pool # generate from intra_op_world_sizes
@property
def is_master(self) -> bool:
pass
class Accelerator:
def __init__(self, device):
self.device = device
def setup_model(self, model) -> nn.Module:
pass
class PrecisionBolt:
def __init__(self, precision_type: dtype, grad_clipping_type: str, grad_clipping_value: float):
self.precision_type = precision_type
self.grad_clipping_type = grad_clipping_type
self.grad_clipping_value = grad_clipping_value
def setup_model(self, model) -> nn.Module:
pass
def setup_optimizer(self, optimizer) -> Optimizer:
# inject grad clipping and unscale loss
pass
def scale_loss(self, loss) -> torch.Tensor:
pass
class ParallelismPlugin:
@property
def supported_devices(self) -> List[device]:
pass
@property
def supported_precisions(self) -> List[str]:
pass
@property
def control_precision(self) -> bool:
pass
@property
def control_device(self) -> bool:
pass
@property
def support_no_sync(self) -> bool:
pass
def setup_model(self, model, device_mesh_pool) -> Module:
pass
def setup_optimizer(self, optimizer) -> Optimizer:
pass
def setup_dataloader(self, dataloader) -> Dataloader:
pass
@property
def device_mesh_shape(self) -> List[Tuple[int, ...]]:
pass
Futher work
Huggingface/accelerate and Lightning/fabric may have similar design.
We may provide colossalai plugin / strategy to these libs.
Self-service
- [X] I'd be willing to do some initial work on this proposal myself.
@ver217 There are some suggestions regarding the API design:
- Rename
PrecisionBolt
asbolt
is rather unclear - Don't name the plugin as
ParallelismPlugin
as we can extend to other features such as quantization. I think it is enough to simply name it asPlugin
. - Don't name it as engine as it can be a bit misleading as discussed earlier on.
This issue is migrated to #3046 , thus, I will close it for now and all discussions will take place in #3046 .