No module named 'flax' when using thunder/benchmarks/benchmark_litgpt.py
🐛 Bug
With newest version of Docker image (tested on pjnl-20240512, on pjnl-20240511 it worked) there are import errors for falcon-7b, Nous-Hermes-13, Llama-3-8B and other models irrespective of used compilation method.
/opt/pytorch/lightning-thunder/thunder/torch/init.py:118: UserWarning: Given {self.id=} does not start with the namespace of
torchwarnings.warn("Given {self.id=} does not start with the namespace oftorch")
benchmark_litgpt.py 539
CLI(benchmark_main) _cli.py 53 CLI caller = inspect.stack()[1][0]
inspect.py 1673 stack return getouterframes(sys._getframe(1), context)
inspect.py 1650 getouterframes frameinfo = (frame,) + getframeinfo(frame, context)
inspect.py 1624 getframeinfo lines, lnum = findsource(frame)
inspect.py 952 findsource module = getmodule(object, file)
inspect.py 869 getmodule if ismodule(module) and hasattr(module, 'file'):
util.py 247 getattribute self.spec.loader.exec_module(self)
883 exec_module
241 _call_with_frames_removed init.py 6
from . import flax init.py 5
from .module import DenseGeneral, LayerNorm module.py 14
from flax import linen as nn ModuleNotFoundError: No module named 'flax'
To Reproduce
Start the container:
mkdir -p output
docker run --pull=always --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v $PWD/output:/output -it INTERNAL_IMAGE:pjnl-20240512
Run the benchmarking script: falcon-7b
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py --model_name falcon-7b --distributed_mode none --shard_mode None --compile thunder_cudnn --checkpoint_activations False
Llama-3-8B
torchrun --nproc-per-node=8 /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py --model_name Llama-3-8B --distributed_mode fsdp --compile eager --checkpoint_activations False
Expected behavior
We can run falcon-7b and Llama-3-8B models.
Environment
As in the Docker image. These results come from a machine with 8xH100.
at glance https://github.com/NVIDIA/TransformerEngine/commit/07291027ed353287149e9df6030862e1e815f32f could be related
https://github.com/NVIDIA/TransformerEngine/pull/839
Should I report it in TransformerEngine repo then?
you needn't, the discussion is going in the pr I referenced :)
I think the issue was resolved and it's doesn't happen anymore.