LWM icon indicating copy to clipboard operation
LWM copied to clipboard

Running into issues for mac M1

Open Pkoiralap opened this issue 1 year ago • 4 comments

I am trying to run the run_sample_video.sh file from the scripts folder. I am running into a lot of dependency issues when running this on a mac M1. Has anyone been successful in running it on M1 ?

Pkoiralap avatar Feb 15 '24 18:02 Pkoiralap

Was able to make it work as soon as I posted this SMH Here is what I did.

  1. Remove the decord line from requirements.txt file. Instead install eva-decord from https://github.com/georgia-tech-db/eva-decord
  2. Update jax, chex, tux, and flax to latest version.
pip install jax -U
pip install chex -U
pip install flax -U
pip install tux -U
  1. Remove the --mesh_dim='!-1,1,8,1' \ line from the run_sample_video.sh file.

This was able to make it run on CPU. However, to make it run on the M1 processor was a whole another level of dependencies issues. So I just stopped at that. For someone trying to open and clean the can of worms, start by installing jax-metal. This will make the script run on apple m1 processor by default and the show begins.

Since the original problem with this issue was resolved, I will let the repo owners decide if they should close this issue or keep it open as it technically doesn't run on the metal architecture.

Pkoiralap avatar Feb 15 '24 18:02 Pkoiralap

@Pkoiralap thanks for your info, im using jax-metal, i installed eva-decord and update all libs, remove the mes-dim line, but still i got data type unsuppored error:

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.sharding = "{replicated}"}], function_type = (tensor<4096x32000xbf16>) -> tensor<4096x32000xf32>, res_attrs = [{jax.result_info = "", mhlo.sharding = "{replicated}"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<4096x32000xbf16>):
  %0 = "mhlo.convert"(%arg0) : (tensor<4096x32000xbf16>) -> tensor<4096x32000xf32>
  "func.return"(%0) : (tensor<4096x32000xf32>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.sharding = "{replicated}"}], function_type = (tensor<4096x32000xbf16>) -> tensor<4096x32000xf32>, res_attrs = [{jax.result_info = "", mhlo.sharding = "{replicated}"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<4096x32000xbf16>):
  %0 = "mhlo.convert"(%arg0) : (tensor<4096x32000xbf16>) -> tensor<4096x32000xf32>
  "func.return"(%0) : (tensor<4096x32000xf32>) -> ()
}) : () -> ()

any ideas? thanks

eisneim avatar Feb 17 '24 07:02 eisneim

The code is tested on ubuntu, we are not sure about how well Jax would work on mac.

haoliuhl avatar Feb 18 '24 19:02 haoliuhl

What Python version are you using?

I get the following error:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Documents/GitHub/LWM/lwm/vision_generation.py", line 11, in <module>
    from tux import (
  File "/Documents/GitHub/LWM/venv/lib/python3.11/site-packages/tux/__init__.py", line 6, in <module>
    from .distributed import (FlaxTemperatureLogitsWarper, JaxDistributedConfig,
  File "/Documents/GitHub/LWM/venv/lib/python3.11/site-packages/tux/distributed.py", line 15, in <module>
    from jax.experimental.pjit import \
ImportError: cannot import name 'with_sharding_constraint' from 'jax.experimental.pjit' (/Documents/GitHub/LWM/venv/lib/python3.11/site-packages/jax/experimental/pjit.py)

grumpyp avatar Feb 20 '24 10:02 grumpyp