Running locally on aarch64 system
Hello,
I am trying to install and run magma locally on a jetson device, particularly on Jetson IGX Orin. It has 64GB unified memory, ubuntu 20.04 installed and cuda 11.4, but the installation libraries seem incompatible. Can you please help with the minimum system requirements?
Thanks!
Hi @jwyang
Installed all the libraries successfully on igx orin, but at the inference step in the example code, it throws this error:
Traceback (most recent call last): File "local_inf.py", line 35, in <module> generate_ids = model.generate(**inputs, **generation_args) File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 2024, in generate result = self._sample( File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 2982, in _sample outputs = self(**model_inputs, return_dict=True) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Magma-8B/691a1d808b747692d9f9531bd5116492ba38116e/modeling_magma.py", line 751, in forward outputs = self.language_model.model( File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 971, in forward causal_mask = self._update_causal_mask( File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 1086, in _update_causal_mask causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 96, in _prepare_4d_causal_attention_mask_with_cache_position causal_mask = torch.triu(causal_mask, diagonal=1) RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
Can you help with rectifying it?
I am using nvcc version:
Cuda compilation tools, release 11.4, V11.4.315
Build cuda_11.4.r11.4/compiler.31964100_0
and torch versions are:
open_clip_torch==2.31.0
pytorch-lightning==2.4.0
torch @ file:///opt/torch-2.1.0a0%2B41361538.nv23.06-cp38-cp38-linux_aarch64.whl
torchvision==0.16.0
@rr3087 , It seems that your environment or transformers lib does not support BFloat16 with triu_tril_cuda_template. I would suggest that you can try changing the data type from BFloat16 to Float16.
@jwyang converting to float16 throws this other error:
File "/usr/local/lib/python3.8/dist-packages/timm/models/convnext.py", line 573, in _create_convnext model = build_model_with_cfg( File "/usr/local/lib/python3.8/dist-packages/timm/models/_builder.py", line 424, in build_model_with_cfg model = model_cls(**kwargs) File "/usr/local/lib/python3.8/dist-packages/timm/models/convnext.py", line 406, in __init__ named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) File "/usr/local/lib/python3.8/dist-packages/timm/models/_manipulate.py", line 38, in named_apply named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) File "/usr/local/lib/python3.8/dist-packages/timm/models/_manipulate.py", line 38, in named_apply named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) File "/usr/local/lib/python3.8/dist-packages/timm/models/_manipulate.py", line 40, in named_apply fn(module=module, name=name) File "/usr/local/lib/python3.8/dist-packages/timm/models/convnext.py", line 514, in _init_weights trunc_normal_(module.weight, std=.02) File "/usr/local/lib/python3.8/dist-packages/timm/layers/weight_init.py", line 67, in trunc_normal_ return _trunc_normal_(tensor, mean, std, a, b) File "/usr/local/lib/python3.8/dist-packages/timm/layers/weight_init.py", line 32, in _trunc_normal_ tensor.erfinv_() RuntimeError: "erfinv_vml_cpu" not implemented for 'Half'
hmm, interesting, seems that some functions in timm does not support Half, but it should support. can you share your environment's setup, e.g., system version, package versions?