keras icon indicating copy to clipboard operation
keras copied to clipboard

Pytorch Metal Problem: Introduction to Keras for engineers

Open pultar opened this issue 1 year ago • 1 comments

I copy/pasted the tutorial shown here: https://keras.io/getting_started/intro_to_keras_for_engineers/ and tried to run it on my Mac with the Pytorch Metal backend:

os.environ["KERAS_BACKEND"] = "torch"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

I get the error:

File ~/miniconda3/envs/pytorch_keras3_py311/lib/python3.11/site-packages/keras/src/backend/torch/numpy.py:1347, in where(condition, x1, x2)
   1344     x2 = convert_to_tensor(x2)
-> 1347     return torch.where(condition, x1, x2)
   1348 else:
   1349     return torch.where(condition)

RuntimeError: expected scalar type long long but found int

Manual casting the second tensor to int64 works (but is of course only a band-aid for an underlying issue):

def where(condition, x1, x2):
    condition = convert_to_tensor(condition, dtype=bool)
    if x1 is not None and x2 is not None:
        x1 = convert_to_tensor(x1)
        x2 = convert_to_tensor(x2)
        x2 = x2.to(dtype=torch.int64)
        return torch.where(condition, x1, x2)
    else:
        return torch.where(condition)

CPU example works great, Keras and PyTorch versions are 3.0.4 and 2.1.0, respectively.

I also found I had to set the PYTORCH_ENABLE_MPS_FALLBACK environment variable, maybe that should be documented within the first pages of the documentation as well?

I have not had the time to look at the underlying reason for the problem but could give it a try unless you already know what the problem is.

pultar avatar Jan 28 '24 22:01 pultar

Hi @pultar ,

It seems the issue is specific with Pytorch Metal. The tutorial works fine on colab environment without any modifications. Attached gist for reference.

SuryanarayanaY avatar Jan 29 '24 04:01 SuryanarayanaY

The issue is indeed specific to the PyTorch metal. I think it was solved in this pull request https://github.com/pytorch/pytorch/pull/121476

M7Saad avatar Apr 05 '24 05:04 M7Saad