jax icon indicating copy to clipboard operation
jax copied to clipboard

Metal: add support for complex-valued math

Open hawkinsp opened this issue 2 years ago • 14 comments

Description

Repro:

In [1]: import jax

In [2]: jax.lax.add(1+2j, 3+4j)
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

loc("-":3:10): error: 'mps.add' op operand #0 must be tensor of MPS type values, but got 'tensor<complex<f32>>'
Segmentation fault: 11

It's fine that the plugin doesn't support complex numbers, but it shouldn't segfault.

What jax/jaxlib version are you using?

jax 0.4.11, jaxlib 0.4.10, jax-metal 0.0.2

Which accelerator(s) are you using?

Apple GPU

Additional system info

No response

NVIDIA GPU info

No response

hawkinsp avatar Jun 15 '23 01:06 hawkinsp

Following this one. @hawkinsp any news?

Saladino93 avatar Aug 20 '23 13:08 Saladino93

@Saladino93 This is an issue in the Metal plugin, so it can only be fixed by Apple. The right people to take a look are already assigned to the bug.

hawkinsp avatar Aug 21 '23 13:08 hawkinsp

Metal plugin doesn't support complex yet, but we will support erroring out more gracefully.

shuhand0 avatar Aug 21 '23 16:08 shuhand0

Any plans to support complex numbers in the near future?

FilipeMaia avatar Oct 04 '23 19:10 FilipeMaia

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

shuhand0 avatar Oct 05 '23 17:10 shuhand0

I'm doing X-ray scattering simulations, which require wave interference calculations using complex numbers. This should be common to all electromagnetic simulations. Another extremely important function to support requiring complex numbers are FFTs.

FilipeMaia avatar Oct 06 '23 12:10 FilipeMaia

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

I'd like to chime in here and say that this would also be a huge plus for my work in particle physics, I do a lot of amplitude analysis and quantum amplitudes are typically complex (in particular, lots of spherical harmonics and Breit-Wigners). Universal support of complex numbers on Metal would help me pitch this library more to my collaboration. Although we typically use large computer clusters which don't run Apple Silicon, there is certainly demand for being able to run some of these things on laptops, and almost everyone I know in this field uses M1/M2 Macs for their personal computers for some reason.

Is somewhere I can follow the progress on this? Is jax-metal being developed in an open-source format? I can't find much through the Apple developer forums.

denehoffman avatar Oct 19 '23 17:10 denehoffman

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

If it's helpful - I ran into this today as well, also as a physicist. Writing some code for lattice-QCD applications, where gauge-fields and all derived quantities are complex-valued (all production code is run on clusters, but local testing on apple silicon with jax-metal would be nice)

joshuazlin avatar Jan 05 '24 01:01 joshuazlin

For the complex element types, do you require fp64, or fp32 would be good for your jax applications?

shuhand0 avatar Jan 05 '24 21:01 shuhand0

Single precision would be good enough for me.

FilipeMaia avatar Jan 05 '24 21:01 FilipeMaia

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

Hello, I'm a member of the development team for dynamiqs (GitHub - dynamiqs), a python package designed for quantum mechanics simulations. Our package heavily depends on complex number computations. Specifically, we mostly perform complex matrix-matrix multiplications and utilize various linear algebra routines like eigh, expm and schur. We are in the process of switching from PyTorch to JAX and the support of complex numbers on Apple Silicon for JAX would be a huge plus.

Most simulation can be run in single precision but double precision would be required to run high accuracy simulations as we sometimes need.

EDIT: updated link

abocquet avatar Jan 12 '24 13:01 abocquet

Hi @hawkinsp

Using jax-metal 0.0.7 with jax 0.4.28 and jaxlib 0.4.28, the mentioned code now encounters an XLARuntimeError instead of segmentation fault: 11.

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<complex<f32>>):
  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>
  "func.return"(%0) : (tensor<complex<f32>>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<complex<f32>>):
  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>
  "func.return"(%0) : (tensor<complex<f32>>) -> ()
}) : () -> ()

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

jax.print_environment_info():

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct  2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

Thank you.

rajasekharporeddy avatar Jun 03 '24 07:06 rajasekharporeddy

Thanks @rajasekharporeddy . I think we can now redesignate this as an "enhancement", not a "bug". It's a bug if the plugin crashes or produces wrong output. It's now just a missing feature.

hawkinsp avatar Jun 03 '24 13:06 hawkinsp

Chiming in to add that I would also greatly appreciate support for complex numbers. I also do lattice QCD (physics simulations), most of which runs on clusters, but local testing would be a huge plus. In my use case, FP64 support is crucial.

akelman avatar Aug 26 '24 23:08 akelman

It's essential for physics & ML-in-physics applications to have complex support at both complex64 and complex128 levels. What is the timescale on this being implemented in metal?

benjaminpope avatar Sep 27 '24 00:09 benjaminpope

https://github.com/rafael-fuente/diffractsim/ uses np.float (float64) and np.complex (complex128) dtypes, neither of which are supported for the JAX GPU backend via jax-metal.

Even without jax-metal, setting the backend to jax.numpy still provides a substantial speedup over numpy.

phansel avatar Oct 29 '24 00:10 phansel

Yep - anything involving optics or wave physics more or less requires complex dtypes.

benjaminpope avatar Oct 29 '24 04:10 benjaminpope

Just chiming in to note that popular linear complexity transformer alternatives like Mamba and LRU also use complex numbers.

smorad avatar Nov 09 '24 14:11 smorad

Checking in on this thread - is there any enthusiasm at all for complex dtypes? Almost any area of hardware engineering & physics will want this.

benjaminpope avatar Mar 15 '25 06:03 benjaminpope

Some MCMC applications also require float64 precision. I also strongly appeal for support of float64 and complex64, event complex128, they are crucial for research and engineering.

xuhengdada avatar Apr 06 '25 12:04 xuhengdada

The problem here I think is that mps does not support float64 (see https://developer.apple.com/documentation/metalperformanceshaders/mpsdatatype?language=objc), so official support for float64 will make it easier to for jax to use float64. But the problem is how to contact those apple people and ask their help.

xuhengdada avatar Apr 06 '25 13:04 xuhengdada

I heavily use jax on CPUs for Quantum Machine Learning simulations with https://github.com/PennyLaneAI/pennylane . I was kind of hoping to exploit my Apple GPUs as well now. This feature will certainly be of great help.

Any news?

fran-scala avatar Apr 14 '25 15:04 fran-scala

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

Hi @shuhand0, any update on this? Major blocker for the adoption of MPS for many applications. Thanks!

callumtilbury avatar Aug 28 '25 13:08 callumtilbury

This would be extremely useful for me

neurosamtle avatar Sep 09 '25 01:09 neurosamtle

Just bumping this thread @shuhand0 - if complex support is in the backend stack, what would it take for us to integrate it into jax-metal?

benjaminpope avatar Oct 17 '25 00:10 benjaminpope

I do a lot of cellular automata and particle system simulations, JAX is the best framework for this and I wish to do that in my MacBook. Complex numbers not supported and thus unable to run jnp.fft is the major blocker, for now I can only run experiments in Colab notebooks...

Chakazul avatar Oct 30 '25 12:10 Chakazul