when import jax and tensorflow at the same time, kernel crash
There are jax example in Kaggle tutorial below https://www.kaggle.com/code/nilaychauhan/jigsaw-toxic-comment-classification-using-jax-flax/notebook
I was going to follow the code in Google Cloud TPU VM (3-8) But when I just execute the import part, It crashes my Vscode notebook kernel. so I tried to figure out which code crashes the kernel
When I import tensorflow after jax, it crashes the kernel, And I found out there were issue in 2018, but I think it still happens. is there any solution?
(this kernel works in other ipynb file)
import os
import time
import jax
import flax
import optax
import datasets
import pandas as pd
import numpy as np
from jax import jit
import jax.numpy as jnp
import tensorflow as tf
- [ ] If applicable, include full error messages/tracebacks.
python 3.8.10
flax 0.52
ipykernel 6.15.0
ipython 8.4.0
jax 0.3.14
jaxlib 0.3.14
tensorflow 2.8.0
some suspected part
Log after the code executed
error 1:44:18.491: Disposing session as kernel process died ExitCode: undefined, Reason: /home/dlfrnaos19/.local/lib/python3.8/site-packages/traitlets/traitlets.py:2392: FutureWarning: Supporting extra quotes around strings is deprecated in traitlets 5.0. You can use 'hmac-sha256' instead of '"hmac-sha256"' if you require traitlets >=5.
warn(
/home/dlfrnaos19/.local/lib/python3.8/site-packages/traitlets/traitlets.py:2346: FutureWarning: Supporting extra quotes around Bytes is deprecated in traitlets 5.0. Use '72b99a43-3c72-46c7-bf0d-56bff3c674b6' instead of 'b"72b99a43-3c72-46c7-bf0d-56bff3c674b6"'.
warn(
error 1:44:18.492: Raw kernel process exited code: undefined
error 1:44:18.494: Error in waiting for cell to complete [Error: Canceled future for execute_request message before replies were done
at t.KernelShellFutureHandler.dispose (/home/dlfrnaos19/.vscode-server/extensions/ms-toolsai.jupyter-2022.6.1001902341/out/extension.node.js:2:32353)
at /home/dlfrnaos19/.vscode-server/extensions/ms-toolsai.jupyter-2022.6.1001902341/out/extension.node.js:2:51405
at Map.forEach (<anonymous>)
at y._clearKernelState (/home/dlfrnaos19/.vscode-server/extensions/ms-toolsai.jupyter-2022.6.1001902341/out/extension.node.js:2:51390)
at y.dispose (/home/dlfrnaos19/.vscode-server/extensions/ms-toolsai.jupyter-2022.6.1001902341/out/extension.node.js:2:44872)
at /home/dlfrnaos19/.vscode-server/extensions/ms-toolsai.jupyter-2022.6.1001902341/out/extension.node.js:24:251157
at t.swallowExceptions (/home/dlfrnaos19/.vscode-server/extensions/ms-toolsai.jupyter-2022.6.1001902341/out/extension.node.js:29:120529)
at dispose (/home/dlfrnaos19/.vscode-server/extensions/ms-toolsai.jupyter-2022.6.1001902341/out/extension.node.js:24:251135)
at t.RawSession.dispose (/home/dlfrnaos19/.vscode-server/extensions/ms-toolsai.jupyter-2022.6.1001902341/out/extension.node.js:24:256072)
at runMicrotasks (<anonymous>)
at processTicksAndRejections (node:internal/process/task_queues:96:5)]
warn 1:44:18.495: Cell completed with errors {
message: 'Canceled future for execute_request message before replies were done'
}
info 1:44:18.496: Cancel all remaining cells true || Error || undefined
info 1:44:18.496: Cancel pending cells
info 1:44:18.497: Cell 0 executed with state Error

please let me know if you need more info!
I`m trying to copy this tutorial in tpu-vm-base version cloud-tpu and it works well for training. I was trying to run the code from tensorflow-tpu-vm,, maybe it triggered the bug
I suspect nothing good will happen if you import TensorFlow's TPU support at the same time as JAX's TPU support.
One thing you can try: install a CPU-only version of TensorFlow, e.g., the tensorflow-cpu pip package. It might also work to configure TF not to use the TPU devices.
I have a closely related issue with pytorch instead of tensorflow:
import jax
from torch.utils.data import Dataset
will crash the jupyter kernel on my intel mac (no GPU).