Colab not working
Hi guys,
Has anyone encountered this issue in Colab
I am running into while running the any cell below the 3rd one. I did not make any changes to the notebook
TypeError Traceback (most recent call last)
[<ipython-input-22-a54bb9e92c2a>](https://localhost:8080/#) in <cell line: 9>()
7 print(x)
8
----> 9 triton_viz.trace(demo)[(1, 1, 1)](torch.ones(4, 3))
10 triton_viz.launch()
3 frames
[/usr/local/lib/python3.10/dist-packages/triton_viz/interpreter.py](https://localhost:8080/#) in _grid_executor_call(self, *args_dev, **kwargs)
140 if kwargs.pop("warmup", False):
141 return
--> 142 args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
143 # Remaps core language functions to interpreted ones
144 _patch_lang(self.fn)
TypeError: GridExecutor._init_args_hst() takes 2 positional arguments but 3 were given
I also reproduced this issue following the notebook. It seems that the usage of _init_arg_hst and _restore_args_dev in [email protected] does not match their implementation in GridExecutor for triton-3.0.0. Upgrading triton to 3.1.0 solves this issue.
I also reproduced this issue following the notebook. It seems that the usage of
_init_arg_hstand_restore_args_devin [email protected] does not match their implementation inGridExecutorfor triton-3.0.0. Upgrading triton to 3.1.0 solves this issue.
How can I update triton 3.1.0 on colab?
I found a solution
I found a solution
![]()
btw, I find triton-3.1.0 installed via pip does not give correct results using the interpreter, possibly due to some reason mentioned in triton-lang/triton#4274 (comment).
Here is the output of the first demo using triton-3.1.0 installed via pip:
I reinstalled triton@d997364bd617ba91911ecd73070f57f291611203 (it should also be ok to use current triton@main) from source and everything works as expected:
I had to change the install cell as below to get this to work, and change SVG to PNG in the source code of triton-viz as suggested here on my own fork:
%%capture
# Only need to run the first time.
# Works with latest triton. Sorry, this takes a minute to install.
!pip install jaxtyping
!pip install git+https://github.com/triton-lang/triton.git@main
!pip install git+https://github.com/LukeWeidenwalker/triton-viz.git@main
!apt install libcairo2-dev
!pip install pycairo
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia
!apt-get install libcairo2-dev
!pip install pycairo
!pip install git+https://github.com/Deep-Learning-Profiling-Tools/triton-viz.git
!pip install triton==3.1.0
try this
Confirmed @w1ndseeker solution works, thanks! see here: https://github.com/Deep-Learning-Profiling-Tools/triton-viz/issues/48