python3.12+triton=🐛?
@latkins love the triton kernel 🎉
Not sure if Triton supports python3.12 yet - see error below.
Could change pyproject.toml until we're sure.
requires-python = ">=3.10,<3.12"
Apologies if I'm missing something.
File "/home/alexmath/miniconda3/envs/boltz/lib/python3.12/site-packages/boltz/model/layers/triangular_attention/primitives.py", line 517, in forward
o = _trifast_attn(q, k, v, biases)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alexmath/miniconda3/envs/boltz/lib/python3.12/site-packages/boltz/model/layers/triangular_attention/primitives.py", line 689, in _trifast_attn
o = triangle_attention(q, k, v, b, mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alexmath/miniconda3/envs/boltz/lib/python3.12/site-packages/trifast/torch.py", line 227, in triangle_attention
o, _ = _triangle_attention(q, k, v, b, mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alexmath/miniconda3/envs/boltz/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 641, in __call__
return self._opoverload(*args, **kwargs)
...
File "/home/alexmath/miniconda3/envs/boltz/lib/python3.12/site-packages/triton/compiler/compiler.py", line 390, in _init_handles
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
SystemError: PY_SSIZE_T_CLEAN macro must be defined for '#' formats
Predicting DataLoader 0: 0%| | 0/1 [00:06<?, ?it/s]
for others just do this and it'll work
conda create -n boltz_py311 python=3.11
conda activate boltz_py311
pip install boltz
for others just do this and it'll work
conda create -n boltz_py311 python=3.11 conda activate boltz_py311 pip install boltz
I still can't find the good version of python, torch, triton, trifast to make the installation works for CUDA 11.8. I have tried python 3.10/11/12, with torch 2.2+cu118 or 2.7+cu118 and triton 2.x or 3.x but still not successful. The error if triton==3.x is:
RuntimeError: Triton Error [CUDA]: device kernel image is invalid
The error if triton==2.x is:
TypeError: Autotuner.init() takes from 7 to 10 positional arguments but 13 were given
Can anyone help please?