[BUG]: ZeRO causes runtime error when use GRU and pack sequence
🐛 Describe the bug
I run the following script and get RuntimeError: Function CudnnRnnBackward0 returned an invalid gradient at index 1 - got [0] but expected shape compatible with [768, 512]. If I comment torch.nn.utils.rnn.pack_padded_sequence and related code, the script will run ok. Can someone help me with this.
import colossalai
import colossalai.nn
import colossalai.utils
import colossalai.zero.init_ctx
import colossalai.zero.shard_utils
import torch
import torch.utils.data
from colossalai.core import global_context as colossal_gpc
IN_DIM, OUT_DIM = 512, 512
SEQ_LEN = 100
BATCH_SIZE = 1
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = torch.nn.GRU(
IN_DIM,
IN_DIM // 2,
num_layers=1,
batch_first=True,
bidirectional=True,
)
self.linear = torch.nn.Sequential(
torch.nn.Linear(IN_DIM, 2048),
torch.nn.ReLU(),
torch.nn.Linear(2048, OUT_DIM),
)
def forward(self, x, input_lengths):
rnn_input = torch.nn.utils.rnn.pack_padded_sequence(
x, input_lengths.cpu(), batch_first=True, enforce_sorted=False
)
# rnn_input = x
ret, _ = self.gru(rnn_input)
ret, _ = torch.nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)
return self.linear(ret)
def main():
config = {
"zero": {
"model_config": {
"shard_strategy": colossalai.zero.shard_utils.BucketTensorShardStrategy(),
"reduce_scatter_bucket_size_mb": 25,
"fp32_reduce_scatter": False,
"tensor_placement_policy": "cuda",
"gradient_predivide_factor": 1.0,
"reuse_fp16_shard": True,
},
"optimizer_config": {
"gpu_margin_mem_ratio": 0.5,
"initial_scale": 2**5,
"min_scale": 1,
"growth_factor": 2,
"backoff_factor": 0.5,
"growth_interval": 1000,
"hysteresis": 2,
"max_scale": 2**32,
},
},
}
colossalai.launch_from_torch(config)
ctx = colossalai.zero.init_ctx.ZeroInitContext(
target_device=colossalai.utils.get_current_device(),
shard_strategy=colossal_gpc.config.zero.model_config.shard_strategy,
shard_param=True,
)
with ctx:
model = MyModel()
optim = colossalai.nn.HybridAdam(model.parameters())
criterion = torch.nn.MSELoss()
engine, _, _, _ = colossalai.initialize(model, optim, criterion)
engine.train()
x = torch.randn(BATCH_SIZE, SEQ_LEN, IN_DIM).cuda()
y = torch.randn(BATCH_SIZE, SEQ_LEN, OUT_DIM).cuda()
input_lengths = torch.tensor(BATCH_SIZE * [SEQ_LEN]).cuda()
y_hat = engine(x, input_lengths)
y_hat = y_hat.float()
loss = engine.criterion(y_hat, y)
engine.backward(loss)
engine.step()
engine.zero_grad()
if __name__ == "__main__":
main()
Environment
colossalai version: 0.1.5+torch1.11cu10.2
I think ZeRO does not support pack_padded_sequence right now. Since RNN usually does not have too many parameters. Since DP is often enough for RNNs, we do not test RNN in ColossalAI.
But it's possible to use RNN as part of a large model. Is pack sequence support on the roadmap?
Yes, it is on the roadmap of the v0.2.0 version. In that version, we will have a new interface for ZeRO.
But it's possible to use RNN as part of a large model. Is pack sequence support on the roadmap?
You can implement this feature by yourself now. Tokenizer usually returns padded input_ids and mask. Here is an exmaple:
import torch
import torch.nn.functional as F
def pad_batch_sequence(batch):
assert len(batch) > 0
max_len = max(tensor.size(0) for tensor in batch)
tensors = []
masks = []
for tensor in batch:
padding_size = max_len - tensor.size(0)
padded = F.pad(tensor, (0, padding_size))
mask = torch.ones(max_len, device=tensor.device)
mask[tensor.size(0):].fill_(0)
tensors.append(padded)
masks.append(mask)
return torch.stack(tensors), torch.stack(masks)
VOCAB_SIZE = 1000
SEQ_LENS = [3, 5, 7, 6]
input_list = [torch.randint(0, VOCAB_SIZE, (length, )) for length in SEQ_LENS]
inputs, mask = pad_batch_sequence(input_list)
BATCH_SIZE = len(SEQ_LENS)
MAX_LEN = max(SEQ_LENS)
logits = torch.rand(BATCH_SIZE, VOCAB_SIZE, MAX_LEN)
labels = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, MAX_LEN))
criterion = torch.nn.CrossEntropyLoss(reduction='none')
loss = criterion(logits, labels)
loss = torch.mean(loss * mask)
print(loss)
We have updated a lot. This issue was closed due to inactivity. Thanks.