ColossalAI
ColossalAI copied to clipboard
[Gemini] independent runtime tracer
import torch
import torch.nn as nn
from colossalai.gemini.memory_tracer import MemtracerWrapper
import colossalai
colossalai.launch_from_torch(config={})
rank = torch.distributed.get_rank()
model = nn.Sequential(*[nn.Linear(2000, 2000) for _ in range(20)])
model = MemtracerWrapper(model)
input = torch.randn(2000).cuda()
loss = model(input).sum()
model.backward(loss)