JAX-Flax-Tutorial-Image-Classification-with-Linen
JAX-Flax-Tutorial-Image-Classification-with-Linen copied to clipboard
Error while running CNN Python notebook with CIFAR10 Images
Colab_The_annotated_MNIST_image_classification_example_with_Flax_Linen_and_Optax.ipynb
In the above notebook, I ran the same code as it is with CIFAR10 image dataset from tfds. There was no problem with the dataset itself, but in the train_step as shown below.
I had the same error as below when the arguments of the train_step were state and batch (as a dictionary) but even the first epoch did not take complete. Then I changed the train_step argument to instead take (state, batch_labels, batch_images) as all arrays. This time I could finish the first training epoch but then it threw the following error.
""" python
Training -epoch: 1, loss: 1.9994, accuracy: 0.29 Traceback (most recent call last):
File "
File "
File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 208, in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api.py", line 301, in infer_params return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 460, in common_infer_params avals.append(shaped_abstractify(a))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 563, in shaped_abstractify return _shaped_abstractify_handlerstype(x)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 575, in _numpy_array_abstractify dtypes.check_valid_dtype(dtype)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 441, in check_valid_dtype raise TypeError(f"Dtype {dtype} is not a valid JAX array "
TypeError: Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
"""
Everything else is the same.