grok-1
                                
                                 grok-1 copied to clipboard
                                
                                    grok-1 copied to clipboard
                            
                            
                            
                        Segmentation fault
ok, so after increasing the size of /dev/shm (to solve the "No space left on device" problem), and replacing python 3.10 with 3.11 (to solve the "_pickle.UnpicklingError: invalid load key, '\x00'." problem), I arrived at a new error message:
bluevisor@AMD3090:/mnt/c/grok-1$ python3 run.py
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 1) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 1 devices in mesh
INFO:rank:partition rules: <bound method LanguageModelConfig.partition_rules of LanguageModelConfig(model=TransformerConfig(emb_size=6144, key_size=128, num_q_heads=48, num_kv_heads=8, num_layers=64, vocab_size=131072, widening_factor=8, attn_output_multiplier=0.08838834764831845, name=None, num_experts=8, capacity_factor=1.0, num_selected_experts=2, init_scale=1.0, shard_activations=True, data_axis='data', model_axis='model'), vocab_size=131072, pad_token=0, eos_token=2, sequence_len=8192, model_size=6144, embedding_init_scale=1.0, embedding_multiplier_scale=78.38367176906169, output_multiplier_scale=0.5773502691896257, name=None, fprop_dtype=<class 'jax.numpy.bfloat16'>, model_type=None, init_scale_override=None, shard_embeddings=True)>
Segmentation fault
Do you have enough RAM on your machine to load the weights?
I have 64G, I know it's probably not enough, just wanted to see how far I can get. btw, I'm now getting "killed" instead of "Segmentation fault".
@bluevisor what Nvidia driver are you using? And is that a WSL instance? On Ubuntu with driver version 545.23.8 and CUDA vesion 12.3 using Python 3.11 it starts okay, A4000 16GB adapter and 32GB system RAM:
(py311) elsaco@texas:~/grok-1$ python run.py
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 1) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 1 devices in mesh
2024-03-18 13:15:06.204845: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
INFO:rank:partition rules: <bound method LanguageModelConfig.partition_rules of LanguageModelConfig(model=TransformerConfig(emb_size=6144, key_size=128, num_q_heads=48, num_kv_heads=8, num_layers=64, vocab_size=131072, widening_factor=8, attn_output_multiplier=0.08838834764831845, name=None, num_experts=8, capacity_factor=1.0, num_selected_experts=2, init_scale=1.0, shard_activations=True, data_axis='data', model_axis='model'), vocab_size=131072, pad_token=0, eos_token=2, sequence_len=8192, model_size=6144, embedding_init_scale=1.0, embedding_multiplier_scale=78.38367176906169, output_multiplier_scale=0.5773502691896257, name=None, fprop_dtype=<class 'jax.numpy.bfloat16'>, model_type=None, init_scale_override=None, shard_embeddings=True)>
INFO:rank:(1, 256, 6144)
INFO:rank:(1, 256, 131072)
INFO:rank:State sharding type: <class 'model.TrainingState'>
INFO:rank:(1, 256, 6144)
INFO:rank:(1, 256, 131072)
INFO:rank:Loading checkpoint at ./checkpoints/ckpt-0
FYI: also segfaults for me https://github.com/xai-org/grok-1/issues/164#issuecomment-2004922821