torchquad icon indicating copy to clipboard operation
torchquad copied to clipboard

Gaussian quadrature

Open elastufka opened this issue 2 years ago • 8 comments

Add Gaussian quadrature methods

Base Gaussian quadrature class created

Several Gaussian quadrature methods implemented, for integration over [a,b], [0,inf] and [-inf,inf]

Minor changes to BaseIntegrator to allow for args for input functions

Resolved Issues

  • [X ] fixes #126

How Has This Been Tested?

  • [X] See examples in docstrings

How Has This Not Been Tested?

All except Gauss-legendre have not been tested on the GPU. Most have also not been tested in multiple dimensions.

No backends besides Torch and numpy have been tested

Known issues

Finding the roots of some polynomials (Jacobi, Laguerre) use scipy functions. I don't know if/how this is compatible with autoray.do()

Transforming points and weights for integration can be tricky with the GPU involved - need to be able to check/enforce that all tensors are on the same device

Gauss-Jacobi integration does not currently give the correct solution outside the interval [-1,1]

Not yet implemented

other Gaussian methods including Gauss-Kronrod / QUADPACK

elastufka avatar Apr 21 '22 12:04 elastufka

I changed PR base as new features go into develop first :pray:

gomezzz avatar Apr 22 '22 07:04 gomezzz

Just request a review when I should review this!

gomezzz avatar Apr 22 '22 07:04 gomezzz

Hi,

I can't seem to be able to request a review through the usual method. However, please feel free to take a look at your convenience.

All results for 1-dimensional integrals are correct now. You can see the tests that I did in this Colab.

Probably the integration does not work correctly for higher dimensions. When I first looked at this I was interested in vectorized integration - integrating over vectors of limits in one dimension - which is just the result of treating dims as the number of limits and not summing over all dimensions in the end. This is what quadpy does as well, and since it turns out that doing fixed-point and iterative integration in Torch is very slow even with large numbers of points, I won't be looking into it much more, or into high-dimension integration at all. (Also with the particular integral I wanted to do quickly, it looks like tensor-based backends won't help either since it is bogged down by a high-order cross-section calculation.)

Hopefully someone who has a better knowledge of how high-dimensional integration should work can quickly get it working based on the start that's been made here.

elastufka avatar Apr 28 '22 13:04 elastufka

(switched base to refresh diff)

gomezzz avatar May 06 '22 08:05 gomezzz

@elastufka First off, thank you for your efforts! The code looks very well structured and fits in neatly!

A few thoughts to integrate seemless with CI, current version etc.:

  • [ ] Could you have a look at the merge conflicts / ideally merge current status of develop into this?
  • [ ] Thanks for the colab as well! We will need to add some unit tests for this to make sure the results stay correct etc. could you possibly add a test for the gaussian integrators similar to, e.g., the trapezoid one? It can be very similar. Or you can use some of the results from your colab if you prefer.

All except Gauss-legendre have not been tested on the GPU. Most have also not been tested in multiple dimensions. No backends besides Torch and numpy have been tested I will try this out soon :)

Finding the roots of some polynomials (Jacobi, Laguerre) use scipy functions. I don't know if/how this is compatible with autoray.do()

In general it is fine if not all integrators are compatible with everything as long as we throw meaningful exceptions / make sure we know. The easiest way for testing this would be

  • [ ] Add the new integrators to GradientTests at https://github.com/esa/torchquad/blob/bbbb3782cda4ff56f0e8093102dadaf87517b2a0/torchquad/tests/gradient_test.py#L173
  • [ ] Please add the version of scipy you used to the requirements here, here, here and here (sorry for the redundancy, supporting CI, pip, conda and rtd has become a bit tedious). In my env with scipy 1.8.0 I currently get module 'scipy' has no attribute 'special' trying the gaussian integrators

Transforming points and weights for integration can be tricky with the GPU involved - need to be able to check/enforce that all tensors are on the same device

Don't worry too much about this. I will try it out and we should see errors then.

Gauss-Jacobi integration does not currently give the correct solution outside the interval [-1,1]

  • [x] I would suggest to either thrown an error then for other bounds or transform the integrand into this automatically?

Probably the integration does not work correctly for higher dimensions. When I first looked at this I was interested in vectorized integration - integrating over vectors of limits in one dimension - which is just the result of treating dims as the number of limits and not summing over all dimensions in the end. This is what quadpy does as well, and since it turns out that doing fixed-point and iterative integration in Torch is very slow even with large numbers of points, I won't be looking into it much more, or into high-dimension integration at all. (Also with the particular integral I wanted to do quickly, it looks like tensor-based backends won't help either since it is bogged down by a high-order cross-section calculation.)

Of course. We can always improve on this later. Ideally, could you maybe add an issue describing what needs improvements later and - if you have an idea already - how that could be done?

If this becomes too annoying / much time-investment for you I can also take care of some of the above things. I just want to make sure we stay future-proof. :)

Thanks!

gomezzz avatar May 06 '22 09:05 gomezzz

(I will wait for your reply before I dig in deeper)

gomezzz avatar May 06 '22 09:05 gomezzz

Thanks for taking a look! Sorry for the radio silence, I was on holiday. I'll get back to you in more detail next week, when I have some time to look at this again, but in the meantime I can address some quick points.

Please add the version of scipy you used to the requirements here, here, here and here (sorry for the redundancy, supporting CI, pip, conda and rtd has become a bit tedious). In my env with scipy 1.8.0 I currently get module 'scipy' has no attribute 'special' trying the gaussian integrators

I added an import line that fixes that

Gauss-Jacobi integration does not currently give the correct solution outside the interval [-1,1] * [ X] I would suggest to either thrown an error then for other bounds or transform the integrand into this automatically?

It works correctly now

elastufka avatar May 12 '22 09:05 elastufka

  • [X] Could you have a look at the merge conflicts / ideally merge current status of develop into this?
  • [X] Thanks for the colab as well! We will need to add some unit tests for this to make sure the results stay correct etc. could you possibly add a test for the gaussian integrators similar to, e.g., the trapezoid one? It can be very similar. Or you can use some of the results from your colab if you prefer.
  • [X] Add the new integrators to GradientTests at https://github.com/esa/torchquad/blob/bbbb3782cda4ff56f0e8093102dadaf87517b2a0/torchquad/tests/gradient_test.py#L173

I added the tests but haven't run any of them yet. Pretty sure that even if the high-dimension tests pass, the results will be incorrect so that will need a closer look.

  • [ ] Please add the version of scipy you used to the requirements here, here, here and here (sorry for the redundancy, supporting CI, pip, conda and rtd has become a bit tedious). In my env with scipy 1.8.0 I currently get module 'scipy' has no attribute 'special' trying the gaussian integrators

I didn't do the last one but looking at the scipy 1.6.0 documentation (minimum version according to setup.py), all the functions I use are already present there. The 'from scipy import special' line should be compatible with all versions

elastufka avatar May 16 '22 09:05 elastufka

woops didn't mean to push here....but also the PR has been untouched and also did not appear to work...

ilan-gold avatar Jan 13 '23 15:01 ilan-gold

Ok so I'm definitely making progress here. All of the 1d integrand tests (at least for numpy) pass now.

ilan-gold avatar Jan 13 '23 17:01 ilan-gold

Now need to the multi-dimensional integrand case, resizing the weights seems like all. maybe an einsum

ilan-gold avatar Jan 13 '23 17:01 ilan-gold

Ok all tests pass now for Gauss-Legendre integration. Not sure what we want to do about the other ones, all of which involve some sort of restriction that the current full suite of tests does not obey (domain or type of function). So we'll need to do one of the following: (a) make new testing files (b) don't merge those integrals (they are after all fairly trivial to implement and extremely restricted) (c) something inbetween

ilan-gold avatar Jan 14 '23 11:01 ilan-gold

Also, the multi-dimensional integrand case is supported, which brings full parity with fixed_quad (my original interest after all: https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.fixed_quad.html). I would finish #160 first since it is merged into this branch and then we can move forward with this one.

ilan-gold avatar Jan 14 '23 12:01 ilan-gold

It looks like this will be at a minimum 5x faster than scipy for reasonably sized systems for me (say ~50x150 at N=60), and caps out at least 45x for massive systems (will need to check this, but I can do a grid that is 128x as large as ~50x150 at a ~9x speed improvement over the gpu at 50x150, which is already 5x faster than scipy). So this is huge for me.

ilan-gold avatar Jan 14 '23 13:01 ilan-gold

@ilan-gold wow, you have been busy! :)

Ok all tests pass now for Gauss-Legendre integration. Not sure what we want to do about the other ones, all of which involve some sort of restriction that the current full suite of tests does not obey (domain or type of function). So we'll need to do one of the following: (a) make new testing files (b) don't merge those integrals (they are after all fairly trivial to implement and extremely restricted) (c) something inbetween

Can you elaborate a bit what the problem is, I am not sure I get it?

gomezzz avatar Jan 16 '23 09:01 gomezzz

Added JIT and ensured gradients are actually preserved!

ilan-gold avatar Jan 19 '23 11:01 ilan-gold

@ilan-gold wow, you have been busy! :)

Ok all tests pass now for Gauss-Legendre integration. Not sure what we want to do about the other ones, all of which involve some sort of restriction that the current full suite of tests does not obey (domain or type of function). So we'll need to do one of the following: (a) make new testing files (b) don't merge those integrals (they are after all fairly trivial to implement and extremely restricted) (c) something inbetween

Can you elaborate a bit what the problem is, I am not sure I get it?

So the other integrators all have restrictions on how you can use them that the current test suite does not obey, like domain restrictions or integrand restrictions (they must be of a certain form). So we'd have to write new tests just for them. I personally do not think it's worth it.

ilan-gold avatar Jan 19 '23 11:01 ilan-gold

@ilan-gold wow, you have been busy! :)

Ok all tests pass now for Gauss-Legendre integration. Not sure what we want to do about the other ones, all of which involve some sort of restriction that the current full suite of tests does not obey (domain or type of function). So we'll need to do one of the following: (a) make new testing files (b) don't merge those integrals (they are after all fairly trivial to implement and extremely restricted) (c) something inbetween

Can you elaborate a bit what the problem is, I am not sure I get it?

So the other integrators all have restrictions on how you can use them that the current test suite does not obey, like domain restrictions or integrand restrictions (they must be of a certain form). So we'd have to write new tests just for them. I personally do not think it's worth it.

Ah, you mean Gauss-Jacobi, Gauss-Hermite etc? Ok for me to leave that out, I think

gomezzz avatar Jan 19 '23 12:01 gomezzz

Yup exactly those. I think it could be a nice demo in the docs

ilan-gold avatar Jan 19 '23 12:01 ilan-gold

@gomezzz I have not forgotten about this! I just finished my exam so I am circling back to this. I will put this in a stable place by the end of next week if not sooner!

ilan-gold avatar Feb 22 '23 09:02 ilan-gold

@gomezzz I have not forgotten about this! I just finished my exam so I am circling back to this. I will put this in a stable place by the end of next week if not sooner!

@ilan-gold Great to hear! Hope the exams went well! :) Take your time, no rush. I will look at creating a release after that. :v:

Btw. if you are interested in continuing with torchquad and joining me as a maintainer, I would be very open to that! :)

gomezzz avatar Feb 22 '23 10:02 gomezzz

@gomezzz so I've run into a bit of an issue - I think that JIT is broken here for JAX. Do you have any pointers/advice on getting JAX and autoray to work well together in this case? I'm trying to get this part working:

File ~/torchquad/torchquad/integration/gaussian.py:59, in Gaussian._weights(self, N, dim, backend, requires_grad)
     52     return anp.prod(
     53         anp.stack(
     54             list(anp.meshgrid(*([weights] * dim))), like=backend, dim=0
     55         ),
     56         axis=0,
     57     ).ravel()
     58 else:
---> 59     return anp.prod(
     60         anp.meshgrid(*([weights] * dim), like=backend), axis=0
     61     ).ravel()

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3,3,3])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function _weights at /home/ig62/torchquad/torchquad/integration/gaussian.py:37 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f32[3] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line /home/ig62/masters-thesis/venv/lib/python3.8/site-packages/autoray/autoray.py:79 (do)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

I added jit decorators to non-static functions for jax but this breaks other things - it seemed like torch worked fine but maybe I wasn't actually jit'ing it. Thanks!

ilan-gold avatar Mar 03 '23 16:03 ilan-gold

I think the problem goes beyond this PR - I think jax broke with the last PR because of the changes to calculate_result. I think it should probably be refactored. I can make a PR into this one to fix this.

ilan-gold avatar Mar 03 '23 16:03 ilan-gold

I think the problem goes beyond this PR - I think jax broke with the last PR because of the changes to calculate_result. I think it should probably be refactored. I can make a PR into this one to fix this.

Running test CI again to confirm, will update in a sec :v:

gomezzz avatar Mar 06 '23 09:03 gomezzz

On the branch for this PR test seems to be passing but have a lot of warnings 🤔

=============================== warnings summary ===============================
torchquad/tests/boole_test.py: 4 warnings
torchquad/tests/gauss_test.py: 4 warnings
torchquad/tests/gradient_test.py: 3 warnings
torchquad/tests/integrator_types_test.py: 4 warnings
torchquad/tests/monte_carlo_test.py: 4 warnings
torchquad/tests/simpson_test.py: 7 warnings
torchquad/tests/trapezoid_test.py: 4 warnings
  /home/runner/work/torchquad/torchquad/torchquad/tests/../integration/utils.py:255: UserWarning: DEPRECATION WARNING: In future versions of torchquad, an array-like object will be returned.
    warnings.warn(

torchquad/tests/boole_test.py::test_integrate_torch
  /home/runner/micromamba-root/envs/torchquad/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1675740247391/work/aten/src/ATen/native/TensorShape.cpp:3190.)
    return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

torchquad/tests/gauss_test.py::test_integrate_torch
  /home/runner/micromamba-root/envs/torchquad/lib/python3.10/site-packages/autoray/autoray.py:79: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
    return get_lib_fn(backend, fn)(*args, **kwargs)

torchquad/tests/monte_carlo_test.py::test_integrate_jax
  /home/runner/micromamba-root/envs/torchquad/lib/python3.10/site-packages/autoray/autoray.py:79: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return get_lib_fn(backend, fn)(*args, **kwargs)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================= 52 passed, 33 warnings in 106.23s (0:01:46) ==================

Also pass on develop albeit with a bit fewer but still a lot of warnings

 =============================== warnings summary ===============================
torchquad/tests/boole_test.py: 4 warnings
torchquad/tests/gradient_test.py: 3 warnings
torchquad/tests/integrator_types_test.py: 4 warnings
torchquad/tests/monte_carlo_test.py: 4 warnings
torchquad/tests/simpson_test.py: 7 warnings
torchquad/tests/trapezoid_test.py: 4 warnings
  /home/runner/work/torchquad/torchquad/torchquad/tests/../integration/utils.py:255: UserWarning: DEPRECATION WARNING: In future versions of torchquad, an array-like object will be returned.
    warnings.warn(

torchquad/tests/boole_test.py::test_integrate_torch
  /home/runner/micromamba-root/envs/torchquad/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1675740247391/work/aten/src/ATen/native/TensorShape.cpp:3190.)
    return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

torchquad/tests/monte_carlo_test.py::test_integrate_jax
  /home/runner/micromamba-root/envs/torchquad/lib/python3.10/site-packages/autoray/autoray.py:79: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return get_lib_fn(backend, fn)(*args, **kwargs)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================== 48 passed, 28 warnings in 84.06s (0:01:24) ==================

Are the warnings related to your problems, @ilan-gold ? I created a separate issue for them #163

Or is a there a separate minimal example for the error you posted? Then we can create another issue for that. :)

gomezzz avatar Mar 06 '23 09:03 gomezzz

@gomezzz Oh my gosh I'm so sorry, I forgot the word jit, which we don't test.

ilan-gold avatar Mar 06 '23 09:03 ilan-gold

Wait no, I didn't forget it. But we don't test jit, so this makes sense. Things should be passing.

ilan-gold avatar Mar 06 '23 09:03 ilan-gold

I just looked a bit back in time, before #160 we had

=================== 48 passed, 1 warning in 68.67s (0:01:08) =================== ( https://github.com/esa/torchquad/actions/runs/3830307450/jobs/6606968775 )

and then

================== 48 passed, 28 warnings in 73.58s (0:01:13) ================== ( https://github.com/esa/torchquad/actions/runs/3995901069/jobs/6855308819 )

So at the least the warnings seem to be related. Seems I should have complained more :D Sorry I only looked for the ✅ .... 🙈

gomezzz avatar Mar 06 '23 09:03 gomezzz

(as an afterthought, this could also be related to some updates in other packages, doesn't necessarily have to be in #160 , I noticed e.g. autoray got some updates recently.)

gomezzz avatar Mar 06 '23 09:03 gomezzz

Ok jit stopped working because of the decorator. I think the answer is to just not use *args and **kwargs

ilan-gold avatar Mar 06 '23 10:03 ilan-gold