improved Jamba deepspeed z3 compat
Description
implements z3 leaf for jamba
Motivation and Context
as requested on https://discord.com/channels/1104757954588196865/1104758010959634503/1224077312648024235
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Getting ValueError: module class JambaSparseMoeBlock not found in JambaMambaMixer. Probably because it's going down the JambaMambaMixer recursive path first, which has children but none of them are the required module -> it raises an exception.
Getting
ValueError: module class JambaSparseMoeBlock not found in JambaMambaMixer. Probably because it's going down the JambaMambaMixer recursive path first, which has children but none of them are the required module -> it raises an exception.
@bjoernpl I updated so this should work now.
maybe we don't need this for the Jamba MoE?
raise RuntimeError(
RuntimeError: tracing error at step 916:
module id: 1565, training: True
expected the next 1 parameters in the parameter fetch queue to be ({'id': 'name=base_model.model.model.layers.31.mamba.dt_proj.lora_A.default.weight id=3012', 'status': 'AVAILABLE', 'numel': 2048, 'ds
_numel': 2048, 'shape': (8, 256), 'ds_shape': (8, 256), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': {1560, 1565}, 'ds_tensor.shape': torch.Size([256])},)
but got
({'id': 'name=base_model.model.model.layers.31.mamba.dt_proj.base_layer.bias id=3011', 'status': 'AVAILABLE', 'numel': 8192, 'ds_numel': 8192, 'shape': (8192,), 'ds_shape': (8192,), 'requires_grad':
False, 'grad_shape': None, 'persist': True, 'active_sub_modules': {1537}, 'ds_tensor.shape': torch.Size([1024])},).
Is this still needed?