grok-1 icon indicating copy to clipboard operation
grok-1 copied to clipboard

Nr. of devices needed

Open zcobol opened this issue 1 year ago • 36 comments

Running python run.py on a single Nvidia GPU it fails with ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)

Can the nr of devices be adjusted to 1 only?

zcobol avatar Mar 18 '24 01:03 zcobol

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 2 devices in mesh Traceback (most recent call last): File "/opt/grok-1/run.py", line 72, in main() File "/opt/grok-1/run.py", line 63, in main inference_runner.initialize() File "/opt/grok-1/runners.py", line 282, in initialize runner.initialize( File "/opt/grok-1/runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 2 must equal the product of mesh_shape (1, 8)

this what you get?

nickorlabs avatar Mar 18 '24 01:03 nickorlabs

i did put 1 instead of 8

yarodevuci avatar Mar 18 '24 04:03 yarodevuci

I keep getting same error : PermissionError: [WinError 32] The process cannot access the file because it is being used by another process: 'D:\dev\shm\tmpp53ohpcl'

yarodevuci avatar Mar 18 '24 04:03 yarodevuci

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 2 devices in mesh Traceback (most recent call last): File "/opt/grok-1/run.py", line 72, in main() File "/opt/grok-1/run.py", line 63, in main inference_runner.initialize() File "/opt/grok-1/runners.py", line 282, in initialize runner.initialize( File "/opt/grok-1/runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 2 must equal the product of mesh_shape (1, 8)

this what you get?

I have the same issues, is there a way to resolve this?

KHARAPSY avatar Mar 18 '24 09:03 KHARAPSY

same issue even all requirements install. I am using 8 GPUs

zRzRzRzRzRzRzR avatar Mar 18 '24 09:03 zRzRzRzRzRzRzR

I have 2 GPUs and everything installed ok as well.

nickorlabs avatar Mar 18 '24 13:03 nickorlabs

in run.py, I changed line 60: local_mesh_config=(1, 8), to local_mesh_config=(1, 1),

(I have 1 3090)

bluevisor avatar Mar 18 '24 13:03 bluevisor

Ok got a little further this time!

Traceback (most recent call last): File "/opt/grok-1/run.py", line 72, in main() File "/opt/grok-1/run.py", line 63, in main inference_runner.initialize() File "/opt/grok-1/runners.py", line 294, in initialize params = runner.load_or_init(dummy_data) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/runners.py", line 238, in load_or_init state = xai_checkpoint.restore( ^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 196, in restore loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 107, in load_tensors return [f.result() for f in fs] ^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 107, in return [f.result() for f in fs] ^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/_base.py", line 449, in result return self.__get_result() ^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result raise self._exception File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 72, in fast_unpickle with copy_to_shm(path) as tmp_path: File "/opt/anaconda3/envs/groq-1/lib/python3.11/contextlib.py", line 137, in enter return next(self.gen) ^^^^^^^^^^^^^^ File "/opt/grok-1/checkpoint.py", line 52, in copy_to_shm shutil.copyfile(file, tmp_path) File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 269, in copyfile _fastcopy_sendfile(fsrc, fdst) File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 158, in _fastcopy_sendfile raise err from None File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 144, in _fastcopy_sendfile sent = os.sendfile(outfd, infd, offset, blocksize) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpi8_qagu5'

I have 2 Quadro 5000s, I guess we do not have enough vRAM doh.

nickorlabs avatar Mar 18 '24 13:03 nickorlabs

I'm at the same point, GPT told me /dev/shm is a ramdisk, which means we don't have enough ram, not vram.

I have 64G, not sure how much we need... would 128 be enough?

bluevisor avatar Mar 18 '24 14:03 bluevisor

I have 128 GB on this rig, with the two cards its like 32 GB, this is why I assumed vRAM. Maybe I could be wrong.

nickorlabs avatar Mar 18 '24 15:03 nickorlabs

bummer... guess we'll just have to wait for gguf...

bluevisor avatar Mar 18 '24 15:03 bluevisor

Possibly. I might spin up a runpod, or wait for GGUF, I was reading people needing 8 GPUs.

nickorlabs avatar Mar 18 '24 15:03 nickorlabs

after changing the mesh to (1, 6) i get this error:

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 6) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 6 devices in mesh
2024-03-18 15:58:10.001688: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
INFO:rank:partition rules: <bound method LanguageModelConfig.partition_rules of LanguageModelConfig(model=TransformerConfig(emb_size=6144, key_size=128, num_q_heads=48, num_kv_heads=8, num_layers=64, vocab_size=131072, widening_factor=8, attn_output_multiplier=0.08838834764831845, name=None, num_experts=8, capacity_factor=1.0, num_selected_experts=2, init_scale=1.0, shard_activations=True, data_axis='data', model_axis='model'), vocab_size=131072, pad_token=0, eos_token=2, sequence_len=8192, model_size=6144, embedding_init_scale=1.0, embedding_multiplier_scale=78.38367176906169, output_multiplier_scale=0.5773502691896257, name=None, fprop_dtype=<class 'jax.numpy.bfloat16'>, model_type=None, init_scale_override=None, shard_embeddings=True)>
INFO:rank:(1, 256, 6144)
INFO:rank:(1, 256, 131072)
INFO:rank:State sharding type: <class 'model.TrainingState'>
INFO:rank:(1, 256, 6144)
INFO:rank:(1, 256, 131072)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/loading/PycharmProjects/grok-1/run.py", line 72, in <module>
    main()
  File "/home/loading/PycharmProjects/grok-1/run.py", line 63, in main
    inference_runner.initialize()
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 294, in initialize
    params = runner.load_or_init(dummy_data)
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 235, in load_or_init
    state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
ValueError: One of pjit outputs with pytree key path .params['transformer/decoder_layer_0/moe/linear']['w'] was given the sharding of NamedSharding(mesh=Mesh('data': 1, 'model': 6), spec=PartitionSpec(None, 'data', 'model')), which implies that the global size of its dimension 2 should be divisible by 6, but it is equal to 32768 (full shape: (8, 6144, 32768))

looks like it doesnt like 6 either

thisIsLoading avatar Mar 18 '24 16:03 thisIsLoading

looks like i have to set

        widening_factor=6,
        num_kv_heads=6,

in the TransformerConfig to the number of devices as well

thisIsLoading avatar Mar 18 '24 16:03 thisIsLoading

get it up and running?

nickorlabs avatar Mar 18 '24 16:03 nickorlabs

looks like i have to set

        widening_factor=6,
        num_kv_heads=6,

in the TransformerConfig to the number of devices as well

did it work after?

yarodevuci avatar Mar 18 '24 17:03 yarodevuci

@yarodevuci still downloading weights.

i was under the impression that the test wiould download stuff (looks like i'm spoiled by the huggingface api which does it) will report tomorrow. right now it tells me 17 more hours (dont know why so long, am on 750mbit but magnet download is painfully slow)

thisIsLoading avatar Mar 18 '24 17:03 thisIsLoading

Im seeding (again), took me most the evening last night to download, and I have 2000mbit download

nickorlabs avatar Mar 18 '24 18:03 nickorlabs

I'm at the same point, GPT told me /dev/shm is a ramdisk, which means we don't have enough ram, not vram.

I have 64G, not sure how much we need... would 128 be enough?

My system has 192GB of RAM, I also encountered same. OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpbeofn6hn

ad1tyac0des avatar Mar 18 '24 19:03 ad1tyac0des

@ad1tyac0des it creates temp folder with over 300GB in it, do you have that space on the hard drive?

yarodevuci avatar Mar 18 '24 19:03 yarodevuci

Is anybody here who saw live presentation where X developers run it using exact commands or we all trying to test it for them?

toughcoding avatar Mar 18 '24 19:03 toughcoding

坑爹,为了下载它,花费了我一天的心血

pwxpwxtop avatar Mar 19 '24 03:03 pwxpwxtop

I succeeded increasing space and get rid of this error "OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpi8_qagu5'"

but in exchange to do that I end up with system crashed instead, so I will give up for now. I don't have enough RAM to run Grok-1 neither enough money to upgrade my hardware"

KHARAPSY avatar Mar 19 '24 04:03 KHARAPSY

same issue even all requirements install. I am using 8 GPUs

I change it to 8 x A100 GPU and it cost 65G memory in per GPU to run this model, The resources required to run this model are a bit large. and the requirement is instealled successfull.

Finally run with this code

AX_TRACEBACK_FILTERING=off python run.py

and its work

image

zRzRzRzRzRzRzR avatar Mar 19 '24 04:03 zRzRzRzRzRzRzR

@ad1tyac0des it creates temp folder with over 300GB in it, do you have that space on the hard drive?

I had about 100GB of storage left, but at the moment when the error occurred, my system's RAM was completely utilized. This seems to be the reason why the program stopped. It looks like the problem was due to the high RAM usage rather than storage space.

ad1tyac0des avatar Mar 19 '24 06:03 ad1tyac0des

am at 272/300 gb right now. excitement starts to kick in, lets hope this thing runs.

only having 6x 4090 (144GB VRAM) and 512GB RAM, if this isnt enough to at least run it, regardless of the speed, then something is off

thisIsLoading avatar Mar 19 '24 06:03 thisIsLoading

ok, got a little further but still no cigar:

(.venv) loading@ai:~/PycharmProjects/grok-1$ python run.py                                                                                                                                           │└───────────────────────────────────────────┴───────────────────────────────────────────┘│      6 netns           [netns]                                                  1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
                                                                                                                                                                                                     │┌─┤net├────────────────────────────────────────────────────────────────────┤‹b eno2 n›├─┐│      7 kworker/0:0-eve [kworker/0:0-events]                                     1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA                                    ││10K                                                      ⣀                             ││      8 kworker/0:0H-ev [kworker/0:0H-events_highpri]                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory                                ││                                                         ⣿    ┌─┤Download├───────────┐ ││     10 mm_percpu_wq    [mm_percpu_wq]                                           1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 6) self.between_hosts_config=(1, 1)...                                                                                                    ││                                                       ⣴⣷⣿ ⣷  │▼ Byte:     1.89 KiB/s│ ││     11 rcu_tasks_rude_ [rcu_tasks_rude_]                                        1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:Detected 6 devices in mesh                                                                                                                                                                 ││                                             ⣦   ⣴ ⣄⣤  ⣿⣿⣿ ⣿  │▼ Bit:      15.4 Kibps│ ││     12 rcu_tasks_trace [rcu_tasks_trace]                                        1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
2024-03-19 07:55:00.536833: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver i││                                             ⣿⣶⣦⣶⣿⣤⣿⣿⣶⣤⣿⣿⣿⣾⣿⣾ │▼ Total:       313 GiB│ ││     13 ksoftirqd/0     [ksoftirqd/0]                                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
s older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility pa││                                              ⣿⠻⠟⣿⠻⣿⠻⡿⠻⠻⠻⡿⣿⠻⠻ │                      │ ││     14 rcu_sched       [rcu_sched]                                              1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
ckages.                                                                                                                                                                                              ││                                              ⣿  ⣿ ⣿      ⠻   │▲ Byte:     6.11 KiB/s│ ││     15 migration/0     [migration/0]                                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:partition rules: <bound method LanguageModelConfig.partition_rules of LanguageModelConfig(model=TransformerConfig(emb_size=6144, key_size=128, num_q_heads=48, num_kv_heads=6, num_layers=6││                                              ⠻  ⣿ ⣿          │▲ Bit:      48.5 Kibps│ ││     16 idle_inject/0   [idle_inject/0]                                          1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
4, vocab_size=131072, widening_factor=6, attn_output_multiplier=0.08838834764831845, name=None, num_experts=8, capacity_factor=1.0, num_selected_experts=2, init_scale=1.0, shard_activations=True, d││                                                 ⠈ ⣿          │▲ Total:       182 GiB│ ││     18 cpuhp/0         [cpuhp/0]                                                1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
ata_axis='data', model_axis='model'), vocab_size=131072, pad_token=0, eos_token=2, sequence_len=8192, model_size=6144, embedding_init_scale=1.0, embedding_multiplier_scale=78.38367176906169, output││50K                                                ⣿          └─┤Upload├─────────────┘ ││     19 cpuhp/1         [cpuhp/1]                                                1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
_multiplier_scale=0.5773502691896257, name=None, fprop_dtype=<class 'jax.numpy.bfloat16'>, model_type=None, init_scale_override=None, shard_embeddings=True)>                                        │└───────────────────────────────────────────────────────────────────────────────────────┘└─┤↑ select ↓├─┤info ↲├─┤terminate├─┤kill├─┤interrupt├─────────────────────────────────────────────┤5/665├─┘
INFO:rank:(1, 256, 6144)                                                                                                                                                                             ├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
INFO:rank:(1, 256, 131072)                                                                                                                                                                           │Every 2.0s: nvidia-smi                                                                                                                                                   ai: Tue Mar 19 07:56:37 2024
INFO:rank:State sharding type: <class 'model.TrainingState'>                                                                                                                                         │
INFO:rank:(1, 256, 6144)                                                                                                                                                                             │Tue Mar 19 07:56:37 2024
INFO:rank:(1, 256, 131072)                                                                                                                                                                           │+---------------------------------------------------------------------------------------+
INFO:rank:Loading checkpoint at ./checkpoints/ckpt-0                                                                                                                                                 │| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
Traceback (most recent call last):                                                                                                                                                                   │|-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/run.py", line 72, in <module>                                                                                                                           │| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
    main()                                                                                                                                                                                           │| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
  File "/home/loading/PycharmProjects/grok-1/run.py", line 63, in main                                                                                                                               │|                                         |                      |               MIG M. |
    inference_runner.initialize()                                                                                                                                                                    │|=========================================+======================+======================|
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 294, in initialize                                                                                                                    │|   0  NVIDIA GeForce RTX 4090        On  | 00000000:16:00.0 Off |                  Off |
    params = runner.load_or_init(dummy_data)                                                                                                                                                         │|  0%   31C    P8              23W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 238, in load_or_init                                                                                                                  │|                                         |                      |                  N/A |
    state = xai_checkpoint.restore(                                                                                                                                                                  │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/checkpoint.py", line 218, in restore                                                                                                                    │|   1  NVIDIA GeForce RTX 4090        On  | 00000000:34:00.0 Off |                  Off |
    state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)                                                                                                            │|  0%   30C    P8              25W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 342, in host_local_array_to_global_array                                  │|                                         |                      |                  N/A |
    out_flat = [                                                                                                                                                                                     │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 343, in <listcomp>                                                        │|   2  NVIDIA GeForce RTX 4090        On  | 00000000:52:00.0 Off |                  Off |
    host_local_array_to_global_array_p.bind(inp, global_mesh=global_mesh,                                                                                                                            │|  0%   30C    P8              25W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind                                                                                 │|                                         |                      |                  N/A |
    return self.bind_with_trace(find_top_trace(args), args, params)                                                                                                                                  │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 423, in bind_with_trace                                                                      │|   3  NVIDIA GeForce RTX 4090        On  | 00000000:70:00.0 Off |                  Off |
    out = trace.process_primitive(self, map(trace.full_raise, args), params)                                                                                                                         │|  0%   30C    P8              20W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive                                                                    │|                                         |                      |                  N/A |
    return primitive.impl(*tracers, **params)                                                                                                                                                        │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 250, in host_local_array_to_global_array_impl                             │|   4  NVIDIA GeForce RTX 4090        On  | 00000000:AC:00.0 Off |                  Off |
    for d, index in local_sharding.devices_indices_map(arr.shape).items()]                                                                                                                           │|  0%   32C    P8              29W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 110, in devices_indices_map                                                        │|                                         |                      |                  N/A |
    return common_devices_indices_map(self, global_shape)                                                                                                                                            │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 59, in common_devices_indices_map                                                  │|   5  NVIDIA GeForce RTX 4090        On  | 00000000:CA:00.0 Off |                  Off |
    return gspmd_sharding.devices_indices_map(global_shape)                                                                                                                                          │|  0%   30C    P8              16W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 898, in devices_indices_map                                                        │|                                         |                      |                  N/A |
    return gspmd_sharding_devices_indices_map(self, global_shape)                                                                                                                                    │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 826, in gspmd_sharding_devices_indices_map                                         │
    self.shard_shape(global_shape)  # raises a good error message                                                                                                                                    │+---------------------------------------------------------------------------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 122, in shard_shape                                                                │| Processes:                                                                            |
    return _common_shard_shape(self, global_shape)                                                                                                                                                   │|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 77, in _common_shard_shape                                                         │|        ID   ID                                                             Usage      |
    raise ValueError(                                                                                                                                                                                │|=======================================================================================|
ValueError: Sharding GSPMDSharding({devices=[1,1,6]<=[6]}) implies that array axis 2 is partitioned 6 times, but the dimension size is 32768 (full shape: (8, 6144, 32768), per-dimension tiling fact│|  No running processes found                                                           |
ors: [1, 1, 6] should evenly divide the shape)                                                                                                                                                       │+---------------------------------------------------------------------------------------+
(.venv) loading@ai:~/PycharmProjects/grok-1$  

thisIsLoading avatar Mar 19 '24 07:03 thisIsLoading

  • In the file checkpoint.py I'm changing /dev/shm/ to './dev/shm/'
  • in the terminal mkdir -p ./dev/shm/
  • after that, I run python run.py

But my macbook m1 pro with 16/512gb, freezes after that! because python eats up all my memory) I assume that there is not enough memory for the ssd

malinichev avatar Mar 19 '24 12:03 malinichev

am at 272/300 gb right now. excitement starts to kick in, lets hope this thing runs. only having 6x 4090 (144GB VRAM) and 512GB RAM, if this isnt enough to at least run it, regardless of the speed, then something is off

It is probably not. I have 4 A100 and 512gb per node as well and I am not sure I can run it. It's stuck at loading checkpoints for a while now.

surak avatar Mar 19 '24 16:03 surak

you should install jaxlib for cuda, so that your 8 GPUs can be detected. or you can set local_mesh_config=(1, 1), and grok will run on cpu.

Christmas-Wong avatar Mar 20 '24 06:03 Christmas-Wong