prophet icon indicating copy to clipboard operation
prophet copied to clipboard

Adding numpyro backend

Open freddyaboulton opened this issue 3 years ago • 14 comments

POC for adding a non-stan backend to Prophet, based on numpyro.

The benefits of using a non-stan backend are that users can avoid installing pystan2 and introducing a GNU-licensed dependency into their project.

Since numpyro can also be easily pip installed, it lets users install prophet and all of its dependencies with pip install prophet, or pip install prophet[numpyro] if we decide to make numpyro optional. From experience, this is much easier than using cmdstan and cmdstanpy.

There are also some performance benefits to using numpyro that I've noticed. The sampling runs faster and gets as good results as stan.

Timing on Peyton Manning

image

Numpyro plot on Peyton Manning

image

Stan plot on Peyton Manning

image

IMO, the numpyro sampling is also nicer to work with. Sometimes I have to change my environment in non-intuitive ways to to get the stan sampling to work, i.e. https://github.com/facebook/prophet/issues/1889 . It should also be possible to run prophet on the gpu with numpyro but I have not tested that.

In short, I think this gives users more options of how to run Prophet in python. The implementation still needs work but I want to check in with the maintainers to see if this is something you would welcome to Prophet before proceeding.

Thanks!

freddyaboulton avatar Jul 27 '21 01:07 freddyaboulton

Hi @freddyaboulton!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

facebook-github-bot avatar Jul 27 '21 01:07 facebook-github-bot

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

facebook-github-bot avatar Jul 27 '21 01:07 facebook-github-bot

This is awesome! I'm keen to try it out. It seems that the confidence intervals using NumPyro are wider compared to using Stan - do you know why this is? Does this also happen if we try a larger mcmc_samples value?

Otherwise my thoughts on using Numpyro as the default engine:

  • Agree with the fact that it's easier to install compared to cmdstan, and with Pystan2 no longer getting updated it seems like a sensible option.
  • One thing to consider is that it might be slightly more work to maintain parity between the R and Python models, since they're not both in Stan. It might also be confusing to end users If the confidence interval results are materially different between the stan and numpyro backends. That said, this wouldn't stop us from adding numpyro (and jax) as an optional dependency.

Would need to wait for @bletham's thoughts though :)

tcuongd avatar Jul 31 '21 13:07 tcuongd

Thanks @tcuongd ! One reason the confidence intervals with Numpyro may be wider is that I have not yet used the values in stan_init to initialize the sampling. I need to do more investigation though!

I think it's totally reasonable to treat the numpyro backend as optional. The concerns you raise over parity with R and the results from the stan backend are valid and my intention for right now is to provide an option to python users who don't want to rely on stan.

Hopefully, we can narrow the gap between the backends in the future but I think providing more backends to users is good even if they don't provide 100% fidelity to stan! It moves Prophet from being a particular implementation to a "standard" that can be implemented in different probabilistic languages. I think that level of flexibility will be useful for users.

freddyaboulton avatar Aug 02 '21 14:08 freddyaboulton

@bletham Do you have any thoughts on this? Thank you!

freddyaboulton avatar Aug 20 '21 16:08 freddyaboulton

Just an update: the CI checks all passed (installing the numpyro dependency and running the numpyro test script), so I'm pretty confident that this works properly.

I'll try to do a more thorough review of the code soon.

tcuongd avatar Aug 28 '21 09:08 tcuongd

Thank you for the thoughtful review @tcuongd ! I really appreciate it. I will address your comments shortly!

freddyaboulton avatar Sep 07 '21 22:09 freddyaboulton

@bletham Do you have any thoughts on this? I want to know if this kind of change is mergeable before continuing to work on it. Thank you!

freddyaboulton avatar Oct 18 '21 00:10 freddyaboulton

I'm also interested in this, any thoughts @bletham ?

nicolaerosia avatar Oct 24 '21 17:10 nicolaerosia

@bletham Have you had a chance to review this?

freddyaboulton avatar Dec 07 '21 17:12 freddyaboulton

@tcuongd glad to see you're quite active :) any chance we could get this awsome work by @freddyaboulton merged?

nicolaerosia avatar May 26 '22 06:05 nicolaerosia

Kind ping :)

nicolaerosia avatar Jun 24 '22 11:06 nicolaerosia

Hey @freddyaboulton @nicolaerosia ! Finally getting back to this one. We're at a place now where stan installation isn't as much of an issue, but I think the benefits that numpyro brings in terms of scale -- being able to MCMC sample much faster when N is large -- makes it worth adding as a backend. I think we should continue with this work as an enhancement to the Python package and I'm happy to merge it in once it's ready.

Next steps on this PR (let me know if you guys want to take this on or I can try to find time for it):

  • [x] Rebase with main
  • [x] Add initial rates to MCMC sampling and MAP estimation so that uncertainty intervals match stan's
  • [x] Add numpyro as an optional dependency via setup.py, and only import numpyro related modules when stan_backend='NUMPYRO' is called.
  • [x] (Another PR) Incorporate the testing for the numpyro model into the test scripts, using either subTest() (https://docs.python.org/3/library/unittest.html#distinguishing-test-iterations-using-subtests) or pytest.fixture parameterizations (https://docs.pytest.org/en/6.2.x/fixture.html#parametrizing-fixtures).

tcuongd avatar Jul 03 '22 04:07 tcuongd

Very interested on this, looking forward to the merge! :D

bernardoiconpro avatar Jul 19 '22 13:07 bernardoiconpro

Looking forward for merge. Installing pystan is a pain

NBekmuratov avatar Sep 22 '22 16:09 NBekmuratov

Heya @WardBrian I've revived this PR and had a go at integrating the numpyro backend. The main changes are a new module numpyro_model.py and a new class in models.py. Could you review the design? (I'm not as fussed about the numpyro / jax code style).

~Everything seems to run fine (see code block below). However I'm getting really slow runtimes for the logistic trend -- I hadn't tested logistic trend in the previous benchmark -- I might ask for advice from the numpyro forum unless anyone else in this thread has ideas!~

from prophet import Prophet
import pandas as pd
from datetime import datetime, timezone

N_TRAIN = 750
df = pd.read_csv('https://raw.githubusercontent.com/facebook/prophet/main/examples/example_wp_log_peyton_manning.csv').iloc[:N_TRAIN]
df["cap"] = 50.0

models = {}
for backend in ["CMDSTANPY", "NUMPYRO"]:
    for growth in ["flat", "linear", "logistic"]:
        for mcmc_samples in [0, 200]:
            
            import logging
            cmdstanpy_logger = logging.getLogger('cmdstanpy')
            cmdstanpy_logger.setLevel(logging.ERROR)
            
            model_name = f"{backend=}-{growth=}-{mcmc_samples=}"
            tic = datetime.now(timezone.utc)
            m = Prophet(growth=growth, mcmc_samples=mcmc_samples, stan_backend=backend)
            if mcmc_samples > 0 and backend == "NUMPYRO":
                m.fit(df, progress_bar=False)
            elif mcmc_samples > 0 and backend == "CMDSTANPY":
                m.fit(df, show_progress=False)
            else:
                m.fit(df)
            toc = datetime.now(timezone.utc)
            print(f"{model_name}: {(toc - tic).total_seconds():.4f} seconds")

Results:

N = 750
backend='CMDSTANPY'-growth='flat'-mcmc_samples=0: 0.1038 seconds
backend='CMDSTANPY'-growth='flat'-mcmc_samples=200: 2.8249 seconds
backend='CMDSTANPY'-growth='linear'-mcmc_samples=0: 0.1040 seconds
backend='CMDSTANPY'-growth='linear'-mcmc_samples=200: 8.6726 seconds
backend='CMDSTANPY'-growth='logistic'-mcmc_samples=0: 0.3011 seconds
backend='CMDSTANPY'-growth='logistic'-mcmc_samples=200: 23.7867 seconds
backend='NUMPYRO'-growth='flat'-mcmc_samples=0: 3.7928 seconds
backend='NUMPYRO'-growth='flat'-mcmc_samples=200: 5.8178 seconds
backend='NUMPYRO'-growth='linear'-mcmc_samples=0: 1.2319 seconds
backend='NUMPYRO'-growth='linear'-mcmc_samples=200: 8.5912 seconds
backend='NUMPYRO'-growth='logistic'-mcmc_samples=0: 1.9608 seconds
backend='NUMPYRO'-growth='logistic'-mcmc_samples=200: 27.5186 seconds

tcuongd avatar Jul 12 '23 13:07 tcuongd

Just an update on this:

  • I did further checks on the numpyro model's predictions and uncertainty widths
  • It produces roughly the same results as the cmdstanpy backend for the flat and linear trends.
  • However I'm still struggling to get it to work as expected for the logistic trend function. MCMC sampling also gives wider uncertainty intervals for the linear trend function, and times out for the logistic trend.
  • Parallelization with processes doesn't play nicely with JAX either (I think it reads / writes to the device, which is shared across processes), so we'd need to warn / prevent people from using cross_validate with the NUMPYRO backend

Given these limitations I'm not sure if it's worth the added complexity in the code + maintenance required to add this backend. Of course, if anyone can solve all the issues above then I'd be happy to proceed.

I also wanted to get a sense check of how many people would benefit from this backend now? Seems like the initial frustrations were with having to install pystan, which we've already resolved for a few versions now with cmdstanpy.

tcuongd avatar Sep 10 '23 22:09 tcuongd

Closing, happy to reconsider if someone is keen to tackle the issues mentioned above.

tcuongd avatar Oct 04 '23 22:10 tcuongd