JAX-Flax-Tutorial-Image-Classification-with-Linen icon indicating copy to clipboard operation
JAX-Flax-Tutorial-Image-Classification-with-Linen copied to clipboard

Error while running CNN Python notebook with CIFAR10 Images

Open AbhiDu96 opened this issue 1 year ago • 0 comments

Colab_The_annotated_MNIST_image_classification_example_with_Flax_Linen_and_Optax.ipynb

In the above notebook, I ran the same code as it is with CIFAR10 image dataset from tfds. There was no problem with the dataset itself, but in the train_step as shown below.

I had the same error as below when the arguments of the train_step were state and batch (as a dictionary) but even the first epoch did not take complete. Then I changed the train_step argument to instead take (state, batch_labels, batch_images) as all arrays. This time I could finish the first training epoch but then it threw the following error.

""" python

Training -epoch: 1, loss: 1.9994, accuracy: 0.29 Traceback (most recent call last):

File "", line 4, in <cell line: 1> test_loss, test_accuracy = eval_model(state.params, test_ds)

File "", line 2, in eval_model metrics = eval_step(model, test_ds)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 208, in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper(

File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(

File "/usr/local/lib/python3.10/dist-packages/jax/_src/api.py", line 301, in infer_params return pjit.common_infer_params(pjit_info_args, *args, **kwargs)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 460, in common_infer_params avals.append(shaped_abstractify(a))

File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 563, in shaped_abstractify return _shaped_abstractify_handlerstype(x)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 575, in _numpy_array_abstractify dtypes.check_valid_dtype(dtype)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 441, in check_valid_dtype raise TypeError(f"Dtype {dtype} is not a valid JAX array "

TypeError: Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

"""

Everything else is the same.

AbhiDu96 avatar Jul 01 '23 19:07 AbhiDu96