lightning-thunder
lightning-thunder copied to clipboard
The `_FabricModule` cannot be jitted after #78
🐛 Bug
extensions/thunder/pretrain.py:146: in setup
main(
extensions/thunder/pretrain.py:233: in main
fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
extensions/thunder/pretrain.py:253: in fit
validate(fabric, model, val_dataloader, max_iters=2) # sanity check
../nightly-env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115: in decorate_context
return func(*args, **kwargs)
extensions/thunder/pretrain.py:389: in validate
loss = forward_and_loss(model, input_ids, targets)
../lightning-thunder/thunder/__init__.py:629: in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
../lightning-thunder/thunder/__init__.py:262: in cache_info_wrapper
res = fn(*args, **kwargs)
../lightning-thunder/thunder/__init__.py:504: in get_computation_and_inputs
prologue_trc, computation_trc, *maybe_epilogue = interpreter(
../lightning-thunder/thunder/__init__.py:175: in _general_frontend
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
../lightning-thunder/thunder/core/jit_ext.py:1430: in thunder_general_jit
result = jfn(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6669: in fn_
raise e
../lightning-thunder/thunder/core/interpreter.py:6632: in fn_2
return fn(*args, **kwargs)
extensions/thunder/pretrain.py:371: in forward_and_loss
logits = model(input_ids)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
../nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
../nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
return forward_call(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
../lightning/src/lightning/fabric/wrappers.py:142: in forward
with precision.forward_context():
../lightning/src/lightning/fabric/plugins/precision/half.py:54: in forward_context
return self.tensor_init_context()
../lightning/src/lightning/fabric/plugins/precision/half.py:46: in tensor_init_context
return _DtypeContextManager(self._desired_input_dtype)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
return fn.__func__(fn.__self__, *args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def __init__(self, dtype: torch.dtype) -> None:
> self._previous_dtype: torch.dtype = torch.get_default_dtype()
E NotImplementedError: Trying to call function torch.get_default_dtype, but it is not yet supported. Please file an issue requesting support. To find out which operations are not yet recongnized by `thunder.jit`, please run `examine` as per:
E
E from thunder.examine import examine
E examine(<your thunder.jit callable argument>, ...)
../lightning/src/lightning/fabric/plugins/precision/utils.py:33: NotImplementedError
Jitting the _FabricModule is currently necessary to compile the joint forward and loss
To Reproduce
from lightning import Fabric
import torch
import thunder
fabric = Fabric(devices=1, precision="16-true")
model = torch.nn.Linear(1, 1, bias=False, device=fabric.device)
x = torch.randn(1, 1)
x = fabric.to_device(x)
fmodel = fabric.setup(model)
tmodel = thunder.jit(fmodel)
print(tmodel(x))
cc @nikitaved