ChatRWKV
ChatRWKV copied to clipboard
[pip package] Make loading aware that os.environ can change
Expected behavior:
Running something like autopep8 mycode.py
on the code) that uses RWKV should not break the code.
Pressing "ctrl-s" in vs code with "autoformat on save" should not break the code.
It does: linter moves imports where they belong, to the very top of the file, above any statements like os.environ[]=1
.
Actual behavior:
from torch.utils.cpp_extension import load is not executed if environ is not set up during import.
Therefore simple
import os
from rwkv.model import RWKV
os.environ['RWKV_CUDA_ON'] = '1'
RWKV(os.path.expanduser("~/models/recursal_EagleX_1-7T/EagleX-1_7T.pth"), "cuda fp16i8")
is unable to load the model
RWKV_JIT_ON 1 RWKV_CUDA_ON 1 RESCALE_LAYER 6
Loading /home/fella/models/recursal_EagleX_1-7T/EagleX-1_7T.pth ...
Model detected: v5.2
Strategy: (total 32+1=33 layers)
* cuda [float16, uint8], store 33 layers
0-cuda-float16-uint8 1-cuda-float16-uint8 2-cuda-float16-uint8 3-cuda-float16-uint8 4-cuda-float16-uint8 5-cuda-float16-uint8 6-cuda-float16-uint8 7-cuda-float16-uint8 8-cuda-float16-uint8 9-cuda-float16-uint8 10-cuda-float16-uint8 11-cuda-float16-uint8 12-cuda-float16-uint8 13-cuda-float16-uint8 14-cuda-float16-uint8 15-cuda-float16-uint8 16-cuda-float16-uint8 17-cuda-float16-uint8 18-cuda-float16-uint8 19-cuda-float16-uint8 20-cuda-float16-uint8 21-cuda-float16-uint8 22-cuda-float16-uint8 23-cuda-float16-uint8 24-cuda-float16-uint8 25-cuda-float16-uint8 26-cuda-float16-uint8 27-cuda-float16-uint8 28-cuda-float16-uint8 29-cuda-float16-uint8 30-cuda-float16-uint8 31-cuda-float16-uint8 32-cuda-float16-uint8
emb.weight f16 cpu 65536 4096
blocks.0.ln1.weight f16 cuda:0 4096
blocks.0.ln1.bias f16 cuda:0 4096
blocks.0.ln2.weight f16 cuda:0 4096
blocks.0.ln2.bias f16 cuda:0 4096
blocks.0.att.time_mix_k f16 cuda:0 4096
blocks.0.att.time_mix_v f16 cuda:0 4096
blocks.0.att.time_mix_r f16 cuda:0 4096
blocks.0.att.time_mix_g f16 cuda:0 4096
blocks.0.att.time_decay f32 cuda:0 64 64
blocks.0.att.time_first f32 cuda:0 64 64
blocks.0.att.receptance.weight i8 cuda:0 4096 4096
blocks.0.att.key.weight i8 cuda:0 4096 4096
blocks.0.att.value.weight i8 cuda:0 4096 4096
blocks.0.att.output.weight i8 cuda:0 4096 4096
blocks.0.att.gate.weight i8 cuda:0 4096 4096
blocks.0.att.ln_x.weight f32 cuda:0 4096
blocks.0.att.ln_x.bias f32 cuda:0 4096
blocks.0.ffn.time_mix_k f16 cuda:0 4096
blocks.0.ffn.time_mix_r f16 cuda:0 4096
blocks.0.ffn.key.weight i8 cuda:0 4096 14336
blocks.0.ffn.receptance.weight i8 cuda:0 4096 4096
blocks.0.ffn.value.weight i8 cuda:0 14336 4096
....................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................
blocks.31.ln1.weight f16 cuda:0 4096
blocks.31.ln1.bias f16 cuda:0 4096
blocks.31.ln2.weight f16 cuda:0 4096
blocks.31.ln2.bias f16 cuda:0 4096
blocks.31.att.time_mix_k f16 cuda:0 4096
blocks.31.att.time_mix_v f16 cuda:0 4096
blocks.31.att.time_mix_r f16 cuda:0 4096
blocks.31.att.time_mix_g f16 cuda:0 4096
blocks.31.att.time_decay f32 cuda:0 64 64
blocks.31.att.time_first f32 cuda:0 64 64
blocks.31.att.receptance.weight i8 cuda:0 4096 4096
blocks.31.att.key.weight i8 cuda:0 4096 4096
blocks.31.att.value.weight i8 cuda:0 4096 4096
blocks.31.att.output.weight i8 cuda:0 4096 4096
blocks.31.att.gate.weight i8 cuda:0 4096 4096
blocks.31.att.ln_x.weight f32 cuda:0 4096
blocks.31.att.ln_x.bias f32 cuda:0 4096
blocks.31.ffn.time_mix_k f16 cuda:0 4096
blocks.31.ffn.time_mix_r f16 cuda:0 4096
blocks.31.ffn.key.weight i8 cuda:0 4096 14336
blocks.31.ffn.receptance.weight i8 cuda:0 4096 4096
blocks.31.ffn.value.weight i8 cuda:0 14336 4096
ln_out.weight f16 cuda:0 4096
ln_out.bias f16 cuda:0 4096
head.weight i8 cuda:0 4096 65536
Traceback (most recent call last):
File "/tmp/a.py", line 5, in <module>
RWKV(os.path.expanduser("~/models/recursal_EagleX_1-7T/EagleX-1_7T.pth"), "cuda fp16i8")
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/jit/_script.py", line 303, in init_then_script
original_init(self, *args, **kwargs)
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/rwkv/model.py", line 467, in __init__
rwkv5 = load(name="rwkv5", sources=[f"{current_path}/cuda/rwkv5_op.cpp", f"{current_path}/cuda/rwkv5.cu"],
^^^^
NameError: name 'load' is not defined. Did you mean: 'float'?
yeah because you need to do os.environ['RWKV_CUDA_ON'] = '1' before import rwkv
yeah because you need to do os.environ['RWKV_CUDA_ON'] = '1' before import rwkv
Yes, exactly! This is exactly is the issue. Normal imports look like this:
(BEGINNING OF FILE)
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
which is exactly the case with transformers, exllama, llama.cpp, HQQ(which just like RWKV has several backends but only one can be used at the time)
Not like this
import os
os.environ["RWKV_CUDA_ON"] = f"1"
if "I don't want standard formatter to break my code":
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
Running autopep8 should not break the code. Which it does for no reasonable reason unless something like this if
fence is used.
At the very least it's possible to make couple of QoL files like rwkv/cuda_jit
like
import os
os.environ["RWKV_CUDA_ON"] = "1"
os.environ["RWKV_JIT_ON"] = "1"
if "imports come after os":
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
so user code can call from rwkv.cuda_jit import RWKV, PIPELINE
without caring if it's environ["RWKV_CUDA_ON"]="1"
or environ["RWKV_CUDA"]="ON"
which is impossible to tell using LSP.
Autopep8 is a friend! It shouldn't be worked around