LWM
LWM copied to clipboard
out of memory error
bash scripts/run_vision_chat.sh
removed --mesh_dim param
model is LWM-Chat-32K-Jax
out of memory error, how to solve it
my card is nvidia 2080 super 8G
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708500656.672727 10871 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
I0221 15:30:57.202437 140383335174272 xla_bridge.py:513] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0221 15:30:57.202921 140383335174272 xla_bridge.py:513] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-21 15:36:18.340692: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.00GiB (rounded to 2147483648)requested by op
2024-02-21 15:36:18.340908: W external/tsl/tsl/framework/bfc_allocator.cc:497] *________**********************************************************************_____________________
2024-02-21 15:36:18.340944: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 1.00GiB
constant allocation: 0B
maybe_live_out allocation: 2.00GiB
preallocated temp allocation: 0B
total allocation: 3.00GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 2.00GiB
Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
XLA Label: fusion
Shape: f32[32,4096,4096]
==========================
Buffer 2:
Size: 1.00GiB
Entry Parameter Subshape: bf16[32,4096,4096]
==========================
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/data/test/LWM/lwm/vision_chat.py", line 254, in <module>
run(main)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/mnt/data/test/LWM/lwm/vision_chat.py", line 249, in main
sampler = Sampler()
File "/mnt/data/test/LWM/lwm/vision_chat.py", line 51, in __init__
self._load_model()
File "/mnt/data/test/LWM/lwm/vision_chat.py", line 199, in _load_model
self.params = tree_apply(shard_fns, self.params)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in tree_apply
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in <lambda>
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/distributed.py", line 95, in shard_fn
return jax_shard_function(tensor).block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 1.00GiB
constant allocation: 0B
maybe_live_out allocation: 2.00GiB
preallocated temp allocation: 0B
total allocation: 3.00GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 2.00GiB
Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
XLA Label: fusion
Shape: f32[32,4096,4096]
==========================
Buffer 2:
Size: 1.00GiB
Entry Parameter Subshape: bf16[32,4096,4096]
==========================
I0000 00:00:1708500978.900009 10871 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
(lwm) test@test-3:/mnt/data/test/LWM$ nvidia-smi
Wed Feb 21 15:47:00 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 2080 S... Off| 00000000:01:00.0 Off | N/A |
| 0% 40C P0 23W / 250W| 0MiB / 8192MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
can you share me your modified requirements.txt?
can you share me your modified requirements.txt?
I did not modify requirements. txt, I modified run_vision_chat.sh for your reference
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# MODEL_NAME='LWM-Chat-1M-Jax'
# MODEL_NAME='LWM-Chat-128K-Jax'
MODEL_NAME='LWM-Chat-32K-Jax'
export llama_tokenizer_path="/mnt/data/test/LWM/models/${MODEL_NAME}/tokenizer.model"
export vqgan_checkpoint="/mnt/data/t'e's't/LWM/models/${MODEL_NAME}/vqgan"
export lwm_checkpoint="/mnt/data/test/LWM/models/${MODEL_NAME}/params"
export input_file="/mnt/data/test/2020-07-30_pose_test_006.mp4"
python3 -u -m lwm.vision_chat \
--prompt="What is the video about?" \
--input_file="$input_file" \
--vqgan_checkpoint="$vqgan_checkpoint" \
--dtype='fp32' \
--load_llama_config='7b' \
--max_n_frames=8 \
--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)" \
--load_checkpoint="params::$lwm_checkpoint" \
--tokenizer.vocab_file="$llama_tokenizer_path" \
2>&1 | tee ~/output.log
read
I don't think your GPU has enough memory, as by itself a 7B model with fp32
would be 28GB.