pyro
pyro copied to clipboard
Torchscript error in JitTraceEnum_ELBO Torch Version 2.2.1, CUDA Version: 12.3
Hi there, Noticed a bug in JitTraceEnum_ELBO. My code runs fine with a previous version of pytorch or with JitTrace_ELBO (I can use RelaxedOneHotCategorical instead of OneHotCategorical for what I was enumerating). I don't personally need this bug fixed at this time, and this bug is out of my depth to understand but figured I'd report it in case someone else notices the same problem:
The error seems to come from a torchscript issue in calculating the Enumerate ELBO in pyro.infer.SVI:
315 def step(self, *args, **kwargs):
316 # Compute loss and gradients
317 with poutine.trace(param_only=True) as param_capture:
--> 318 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
320 loss_val = torch_item(loss)
321 self.losses.append(loss_val)
File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/infer/traceenum_elbo.py:564, in JitTraceEnum_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
563 def loss_and_grads(self, model, guide, *args, **kwargs):
--> 564 differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
565 differentiable_loss.backward() # this line triggers jit compilation
566 loss = differentiable_loss.item()
File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/infer/traceenum_elbo.py:561, in JitTraceEnum_ELBO.differentiable_loss(self, model, guide, *args, **kwargs)
557 return elbo * (-1.0 / self.num_particles)
559 self._differentiable_loss = differentiable_loss
--> 561 return self._differentiable_loss(*args, **kwargs)
File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/ops/jit.py:120, in CompiledFunction.__call__(self, *args, **kwargs)
118 with poutine.block(hide=self._param_names):
119 with poutine.trace(param_only=True) as param_capture:
--> 120 ret = self.compiled[key](*params_and_args)
122 for name in param_capture.trace.nodes.keys():
123 if name not in self._param_names:
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: default_program(23): error: extra text after expected end of number
aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
^
default_program(23): error: extra text after expected end of number
aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
^
2 errors detected in the compilation of "default_program".
nvrtc compilation failed:
#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)
template<typename T>
__device__ T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}
template<typename T>
__device__ T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}
extern "C" __global__
void fused_clamp_sub_exp(float* tt_3, float* tshift_1, float* aten_exp) {
{
if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<45150ll ? 1 : 0) {
float tshift_1_1 = __ldg(tshift_1 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
float v = __ldg(tt_3 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
}}
}
My environment is as follows:
absl-py==2.1.0
aiohttp==3.9.1
aiosignal==1.3.1
anndata==0.10.4
annotated-types==0.6.0
anyio==4.2.0
array_api_compat==1.4.1
arrow==1.3.0
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
attrs==23.2.0
backoff==2.2.1
beautifulsoup4==4.12.3
blessed==1.20.0
boto3==1.34.28
botocore==1.34.28
certifi==2023.11.17
charset-normalizer==3.3.2
chex==0.1.7
click==8.1.7
comm @ file:///work/ci_py311/comm_1677709131612/work
contextlib2==21.6.0
contourpy==1.2.0
croniter==1.4.1
cycler==0.12.1
dateutils==0.6.12
debugpy @ file:///croot/debugpy_1690905042057/work
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
deepdiff==6.7.1
dm-tree==0.1.8
docrep==0.3.2
editor==1.6.6
etils==1.6.0
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
fastapi==0.109.0
filelock @ file:///croot/filelock_1700591183607/work
flax==0.8.0
fonttools==4.47.2
frozenlist==1.4.1
fsspec==2023.12.2
gmpy2 @ file:///work/ci_py311/gmpy2_1676839849213/work
h11==0.14.0
h5py==3.10.0
idna==3.6
igraph==0.11.3
importlib-resources==6.1.1
inquirer==3.2.1
ipykernel @ file:///croot/ipykernel_1705933831282/work
ipython @ file:///croot/ipython_1704833016303/work
itsdangerous==2.1.2
jax==0.4.23
jaxlib==0.4.23
jedi @ file:///work/ci_py311_2/jedi_1679336495545/work
Jinja2 @ file:///work/ci_py311/jinja2_1676823587943/work
jmespath==1.0.1
joblib==1.3.2
jupyter_client @ file:///croot/jupyter_client_1699455897726/work
jupyter_core @ file:///croot/jupyter_core_1698937308754/work
kiwisolver==1.4.5
leidenalg==0.10.2
lightning==2.0.9.post0
lightning-cloud==0.5.61
lightning-utilities==0.10.1
llvmlite==0.41.1
markdown-it-py==3.0.0
MarkupSafe @ file:///croot/markupsafe_1704205993651/work
matplotlib==3.8.2
matplotlib-inline @ file:///work/ci_py311/matplotlib-inline_1676823841154/work
mdurl==0.1.2
mkl-fft @ file:///croot/mkl_fft_1695058164594/work
mkl-random @ file:///croot/mkl_random_1695059800811/work
mkl-service==2.4.0
ml-collections==0.1.1
ml-dtypes @ file:///croot/ml_dtypes_1702691022032/work
mpmath @ file:///croot/mpmath_1690848262763/work
msgpack==1.0.7
mudata==0.2.3
multidict==6.0.4
multipledispatch==1.0.0
natsort==8.4.0
nest-asyncio @ file:///work/ci_py311/nest-asyncio_1676823382924/work
networkx==3.2.1
numba==0.58.1
numpy==1.26.1
numpyro==0.13.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
opt-einsum @ file:///home/conda/feedstock_root/build_artifacts/opt_einsum_1696448916724/work
optax==0.1.8
orbax-checkpoint==0.5.1
ordered-set==4.1.0
packaging @ file:///croot/packaging_1693575174725/work
pandas==2.2.0
parso @ file:///opt/conda/conda-bld/parso_1641458642106/work
patsy==0.5.6
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
pillow==10.2.0
platformdirs @ file:///croot/platformdirs_1692205439124/work
prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work
protobuf==4.25.2
psutil @ file:///work/ci_py311_2/psutil_1679337388738/work
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
pydantic==2.1.1
pydantic_core==2.4.0
Pygments @ file:///croot/pygments_1684279966437/work
PyJWT==2.8.0
pymde==0.1.18
pynndescent==0.5.11
pyparsing==3.1.1
pyro-api==0.1.2
pyro-ppl==1.8.6
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
python-multipart==0.0.6
pytorch-lightning==2.1.3
pytz==2023.3.post1
PyYAML @ file:///croot/pyyaml_1698096049011/work
pyzmq @ file:///croot/pyzmq_1705605076900/work
readchar==4.0.5
requests==2.31.0
rich==13.7.0
runs==1.2.2
s3transfer==0.10.0
scanpy==1.9.6
scikit-learn==1.3.2
scipy==1.11.4
scvi-tools==1.0.4
seaborn==0.13.1
session-info==1.0.0
six @ file:///tmp/build/80754af9/six_1644875935023/work
sniffio==1.3.0
soupsieve==2.5
sparse==0.15.1
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
starlette==0.35.1
starsessions==1.3.0
statsmodels==0.14.1
stdlib-list==0.10.0
sympy @ file:///croot/sympy_1701397643339/work
tensorstore==0.1.52
texttable==1.7.0
threadpoolctl==3.2.0
toolz==0.12.1
torch==2.2.1
torchmetrics==1.3.0.post0
torchvision==0.17.1
tornado @ file:///croot/tornado_1696936946304/work
tqdm==4.66.1
traitlets @ file:///work/ci_py311/traitlets_1676823305040/work
triton==2.2.0
types-python-dateutil==2.8.19.20240106
typing_extensions @ file:///croot/typing_extensions_1705599297034/work
tzdata==2023.4
umap-learn==0.5.5
urllib3==2.0.7
uvicorn==0.27.0
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
websocket-client==1.7.0
websockets==12.0
xarray==2024.1.1
xgboost==2.0.1
xmod==1.8.1
yarl==1.9.4
zipp==3.17.0
Thanks for all the development work, pyro rules!
Thanks for the bug report. My guess is that this is an upstream bug in pytorch code generation where they are writing two decimal points in a floating point constant. I'm not sure what we can do but wait for an upstream fix.