mesh icon indicating copy to clipboard operation
mesh copied to clipboard

Adds spectral functions to Mesh TensorFlow

Open EiffL opened this issue 4 years ago • 5 comments

This PR adds spectral operations needed for the flowpm project in Mesh TensorFlow, which was the subject of this blogpost: https://blog.tensorflow.org/2020/03/simulating-universe-in-tensorflow.html .

These operations are useful for lots of applications including N-body simulations and MRI reconstructions. For now, we have only added the implementation of 3D FFTs.

The implementation is based on applying a series of 1D FFTs along the trailing dimensions of the input tensors, then using all2all communications and local transpose operations, to transpose the tensor until all 3 dimensions have been transformed.

The algorithm is illustrated here: image from https://www-user.tu-chemnitz.de/~potts/workgroup/pippig/paper/PFFT_SIAM_88588.pdf

2 things to note:

  • The user needs to specify the Fourier dimensions, which will be referenced in the mesh layout, to make sure the output of the FFT remains distributed.
  • The output of the FFT is transposed, to save on extra all2all operations. These two things could be avoided but require at least 2 more all2all operations to transpose and reshape the output array back to the original memory layout of the input array... this could be provided as a option to the user probably.

A minimal example for would be:

batch_dim = mtf.Dimension("batch", batch_size)
x_dim = mtf.Dimension("nx", nc)
y_dim = mtf.Dimension("ny", nc)
z_dim = mtf.Dimension("nz", nc)

kx_dim = mtf.Dimension("kx", nc)
ky_dim = mtf.Dimension("ky", nc)
kz_dim = mtf.Dimension("kz", nc)

# Create field
field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim])
# Apply FFT
fft_field = mtf.signal.fft3d(mtf.cast(field, tf.complex64), [kx_dim, ky_dim, kz_dim])
# fft_field as shape: [batch, ky, kz, kx]
# Inverse FFT
recfield = mtf.cast(mtf.signal.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32)

# The  following layout would be appropriate
mesh_shape = [("row", nblockx), ("col", nblocky)]
layout_rules = [("nx", "row"), ("ny", "col"),
      ("ky", "row"), ("kz", "col"),]   # Note  that the Fourier dimensions are aslo split differently

EiffL avatar Nov 26 '20 20:11 EiffL

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

:memo: Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

google-cla[bot] avatar Nov 26 '20 20:11 google-cla[bot]

@googlebot I signed it!

EiffL avatar Nov 26 '20 20:11 EiffL

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

google-cla[bot] avatar Nov 26 '20 20:11 google-cla[bot]

@googlebot I consent.

zaccharieramzi avatar Nov 26 '20 20:11 zaccharieramzi

An important note is that this PR in addition to the spectral ops adds complex support for the gradient (otherwise the gradient will not flow through the spectral ops) and complex manipulation operations if you need to navigate between complex-valued tensors and float-valued tensors.

zaccharieramzi avatar Nov 27 '20 14:11 zaccharieramzi