LWM
LWM copied to clipboard
Running into issues for mac M1
I am trying to run the run_sample_video.sh
file from the scripts folder.
I am running into a lot of dependency issues when running this on a mac M1.
Has anyone been successful in running it on M1 ?
Was able to make it work as soon as I posted this SMH Here is what I did.
- Remove the
decord
line fromrequirements.txt
file. Instead installeva-decord
from https://github.com/georgia-tech-db/eva-decord - Update
jax
,chex
,tux
, andflax
to latest version.
pip install jax -U
pip install chex -U
pip install flax -U
pip install tux -U
- Remove the
--mesh_dim='!-1,1,8,1' \
line from therun_sample_video.sh
file.
This was able to make it run on CPU. However, to make it run on the M1 processor was a whole another level of dependencies issues. So I just stopped at that. For someone trying to open and clean the can of worms, start by installing jax-metal
. This will make the script run on apple m1 processor by default and the show begins.
Since the original problem with this issue was resolved, I will let the repo owners decide if they should close this issue or keep it open as it technically doesn't run on the metal architecture.
@Pkoiralap thanks for your info, im using jax-metal, i installed eva-decord and update all libs, remove the mes-dim line, but still i got data type unsuppored error:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.sharding = "{replicated}"}], function_type = (tensor<4096x32000xbf16>) -> tensor<4096x32000xf32>, res_attrs = [{jax.result_info = "", mhlo.sharding = "{replicated}"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<4096x32000xbf16>):
%0 = "mhlo.convert"(%arg0) : (tensor<4096x32000xbf16>) -> tensor<4096x32000xf32>
"func.return"(%0) : (tensor<4096x32000xf32>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.sharding = "{replicated}"}], function_type = (tensor<4096x32000xbf16>) -> tensor<4096x32000xf32>, res_attrs = [{jax.result_info = "", mhlo.sharding = "{replicated}"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<4096x32000xbf16>):
%0 = "mhlo.convert"(%arg0) : (tensor<4096x32000xbf16>) -> tensor<4096x32000xf32>
"func.return"(%0) : (tensor<4096x32000xf32>) -> ()
}) : () -> ()
any ideas? thanks
The code is tested on ubuntu, we are not sure about how well Jax would work on mac.
What Python version are you using?
I get the following error:
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/Documents/GitHub/LWM/lwm/vision_generation.py", line 11, in <module>
from tux import (
File "/Documents/GitHub/LWM/venv/lib/python3.11/site-packages/tux/__init__.py", line 6, in <module>
from .distributed import (FlaxTemperatureLogitsWarper, JaxDistributedConfig,
File "/Documents/GitHub/LWM/venv/lib/python3.11/site-packages/tux/distributed.py", line 15, in <module>
from jax.experimental.pjit import \
ImportError: cannot import name 'with_sharding_constraint' from 'jax.experimental.pjit' (/Documents/GitHub/LWM/venv/lib/python3.11/site-packages/jax/experimental/pjit.py)