prophet
prophet copied to clipboard
Speed Up Uncertainty Predictions
Addresses #2030 with the following approach: maintains functionality of predict_uncertainty and sample_posterior_predictive while making as obsolete sample_model and sample_predictive_trend.
However, this allows the full original approach to uncertainty when mcmc_sampling > 0.
Current submission is a draft, the following is required:
- [x] check approach is functional
- [x] fix sample_posterior_predictive
- [x] enable set random seed
This branch: 605 ms ± 5.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (of which 402 ms is everything except .predict()) Current release: 4.05 s ± 222 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) Current main branch: 1.56 s ± 5.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
without holiday and with include_history=False This branch: 412 ms ± 966 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) Current release: 2.91 s ± 73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) Current main branch: 1.06 s ± 7.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Main Version:

Improved Version:

Here where Y is main branch and X is fork:
| ds | yhat_x | yhat_upper_x | yhat_lower_x | yhat_y | yhat_upper_y | yhat_lower_y | |
|---|---|---|---|---|---|---|---|
| 0 | 2016-01-21 00:00:00 | 8.50433 | 9.18666 | 7.87826 | 8.50433 | 9.09408 | 7.83959 |
| 1 | 2016-01-22 00:00:00 | 8.51883 | 9.1459 | 7.91799 | 8.51883 | 9.12654 | 7.8763 |
| 2 | 2016-01-23 00:00:00 | 8.2848 | 8.91713 | 7.71767 | 8.2848 | 8.94064 | 7.67322 |
| 3 | 2016-01-24 00:00:00 | 8.65445 | 9.2566 | 8.04 | 8.65445 | 9.21576 | 8.06857 |
| 4 | 2016-01-25 00:00:00 | 8.97835 | 9.60644 | 8.36301 | 8.97835 | 9.64668 | 8.42581 |
| 5 | 2016-01-26 00:00:00 | 8.74186 | 9.3522 | 8.1396 | 8.74186 | 9.37046 | 8.13031 |
| 6 | 2016-01-27 00:00:00 | 8.55594 | 9.18995 | 7.92813 | 8.55594 | 9.16879 | 7.98859 |
| 7 | 2016-01-28 00:00:00 | 8.55536 | 9.14276 | 7.95855 | 8.55536 | 9.16812 | 7.95943 |
| 8 | 2016-01-29 00:00:00 | 8.54914 | 9.16933 | 7.92421 | 8.54914 | 9.18301 | 7.93726 |
| 9 | 2016-01-30 00:00:00 | 8.29096 | 8.87615 | 7.65422 | 8.29096 | 8.87245 | 7.67648 |
| 10 | 2016-01-31 00:00:00 | 8.63309 | 9.2359 | 8.02737 | 8.63309 | 9.21735 | 8.00273 |
| 11 | 2016-02-01 00:00:00 | 8.92622 | 9.50353 | 8.35175 | 8.92622 | 9.57686 | 8.31946 |
| 12 | 2016-02-02 00:00:00 | 8.65596 | 9.27162 | 8.05305 | 8.65596 | 9.31287 | 8.05825 |
| 13 | 2016-02-03 00:00:00 | 8.43361 | 9.01963 | 7.80646 | 8.43361 | 9.00752 | 7.8454 |
| 14 | 2016-02-04 00:00:00 | 8.3944 | 9.01991 | 7.77109 | 8.3944 | 9.0127 | 7.77019 |
| 15 | 2016-02-05 00:00:00 | 8.34786 | 8.95809 | 7.67212 | 8.34786 | 8.98904 | 7.72464 |
| 16 | 2016-02-06 00:00:00 | 8.04831 | 8.66777 | 7.41942 | 8.04831 | 8.64124 | 7.43734 |
| 17 | 2016-02-07 00:00:00 | 8.3487 | 8.92544 | 7.75442 | 8.3487 | 8.9109 | 7.73461 |
| 18 | 2016-02-08 00:00:00 | 8.60046 | 9.22172 | 7.96822 | 8.60046 | 9.25325 | 7.93998 |
| 19 | 2016-02-09 00:00:00 | 8.28998 | 8.87171 | 7.68482 | 8.28998 | 8.92518 | 7.63397 |
| 20 | 2016-02-10 00:00:00 | 8.02937 | 8.63235 | 7.47261 | 8.02937 | 8.69591 | 7.42786 |
| 21 | 2016-02-11 00:00:00 | 7.95464 | 8.5395 | 7.33431 | 7.95464 | 8.53878 | 7.33922 |
| 22 | 2016-02-12 00:00:00 | 7.87612 | 8.49491 | 7.23573 | 7.87612 | 8.48426 | 7.30573 |
| 23 | 2016-02-13 00:00:00 | 7.54884 | 8.18376 | 6.95785 | 7.54884 | 8.13378 | 6.93322 |
| 24 | 2016-02-14 00:00:00 | 7.82643 | 8.44139 | 7.20395 | 7.82643 | 8.40929 | 7.20042 |
| 25 | 2016-02-15 00:00:00 | 7.64645 | 8.28459 | 7.05067 | 7.64645 | 8.21321 | 7.0273 |
| 26 | 2016-02-16 00:00:00 | 7.73917 | 8.36804 | 7.16921 | 7.73917 | 8.40742 | 7.18914 |
| 27 | 2016-02-17 00:00:00 | 7.47366 | 8.07386 | 6.88434 | 7.47366 | 8.06657 | 6.83621 |
| 28 | 2016-02-18 00:00:00 | 7.40067 | 7.99196 | 6.73279 | 7.40067 | 8.03888 | 6.79691 |
| 29 | 2016-02-19 00:00:00 | 7.33061 | 7.96718 | 6.71675 | 7.33061 | 7.9433 | 6.69069 |
import pandas as pd
from importlib import reload
import prophet
df_in = pd.read_csv('https://raw.githubusercontent.com/facebook/prophet/main/examples/example_wp_log_peyton_manning.csv')
reload(prophet)
def run_prophet_speedtest(periods=30, include_history=True, holiday=True):
m = prophet.Prophet(uncertainty_samples=1000)
if holiday:
m.add_country_holidays(country_name='US')
m.fit(df_in)
# Python
future = m.make_future_dataframe(periods=periods, include_history=include_history)
forecast = m.predict(future)
return m, forecast
%timeit run_prophet_speedtest()
Ready for review for #2030 @tcuongd @nicolaerosia @orenmatar
@tcuongd , this looks ready to test / approve / merge
Hi, @tcuongd, @winedarksea, are there any updates or expected timeline for merging and releasing this update? Is there anything I can do to help speed it up?
Hi, @tcuongd, @winedarksea, are there any updates or expected timeline for merging and releasing this update? Is there anything I can do to help speed it up?
Sorry, totally hadn't realized he had left comments. I will get on those now
Script for side by side comparison of new and pypi version results
# -*- coding: utf-8 -*-
"""
If you get an error for plotting "no keyword include_legend" or some such,
update your environment to the latest version of Prophet from pypi
"""
import time
import timeit
import sys
import requests
import importlib
import numpy as np
import pandas as pd
from prophet import Prophet # import the default site-packages prophet
# CHANGE THIS
dev_prophet_location = "C:/Users/Colin/Documents/prophet/python/prophet/forecaster.py"
# AND DON'T run this script from the dev prophet source location, or you'll load the same prophet twice
# note this only imports forecaster file, so any other modules for this will be pulled from site_packages Prophet
# borrowed from https://stackoverflow.com/questions/67631/how-do-i-import-a-module-given-the-full-path
spec = importlib.util.spec_from_file_location("forecaster", dev_prophet_location)
forecaster = importlib.util.module_from_spec(spec)
sys.modules["forecaster"] = forecaster
spec.loader.exec_module(forecaster)
# unfortunately doesn't fully fix base versions results per run, but should for new version
np.random.seed(45)
# load a dataset
wiki_pages = [
"all",
"Standard_deviation",
"Christmas",
"William_Shakespeare",
]
wiki_language = "en"
sleep_seconds = 5
timeout = 60
s = requests.Session()
str_start = pd.to_datetime("2018-01-01", infer_datetime_format=True).strftime(
"%Y%m%d00"
)
str_end = pd.Timestamp.now().strftime("%Y%m%d00")
headers = {
"User-Agent": "prophet testing",
}
dataset_lists = []
for page in wiki_pages:
if page == "all":
url = f"https://wikimedia.org/api/rest_v1/metrics/pageviews/aggregate/all-projects/all-access/all-agents/daily/{str_start}/{str_end}?maxlag=5"
else:
url = f"https://wikimedia.org/api/rest_v1/metrics/pageviews/per-article/{wiki_language}.wikipedia/all-access/all-agents/{page}/daily/{str_start}/{str_end}?maxlag=5"
data = s.get(url, timeout=timeout, headers=headers)
data_js = data.json()
if "items" not in data_js.keys():
print(data_js)
gdf = pd.DataFrame(data_js["items"])
gdf["date"] = pd.to_datetime(gdf["timestamp"], format="%Y%m%d00")
gresult = gdf.set_index("date")["views"].fillna(0)
gresult.name = "wiki_" + str(page)[0:80]
dataset_lists.append(gresult.to_frame())
time.sleep(sleep_seconds)
df = pd.concat(dataset_lists, axis=1)
# begin modelling
mapes = {}
lower_mapes = {}
upper_mapes = {}
for col in df:
new_df = pd.DataFrame({"y": df[col], "ds": df.index})
m = Prophet()
start_time = timeit.default_timer()
m.fit(new_df)
future = m.make_future_dataframe(periods=180)
forecast = m.predict(future)
print(f"Base runtime {timeit.default_timer() - start_time} for {col}")
# I am using EWM to remove some weekly seasonality to make the upper and lower separation clearer
forecast[["yhat", "yhat_lower", "yhat_upper"]].ewm(span=7).mean().plot(
linewidth=0.7, title=f"{col} Base Version"
)
# plt.savefig(f"{col}_base_version.png", dpi=300)
# note that this is pulling the other version from the forecaster module directly
m2 = forecaster.Prophet()
start_time = timeit.default_timer()
m2.fit(new_df)
future = m2.make_future_dataframe(periods=180)
forecast2 = m2.predict(future)
print(f"Dev runtime {timeit.default_timer() - start_time} for {col}")
# I am using EWM to remove some weekly seasonality to make the upper and lower separation clearer
forecast2[["yhat", "yhat_lower", "yhat_upper"]].ewm(span=7).mean().plot(
linewidth=0.7, title=f"{col} Dev Version"
)
# plt.savefig(f"{col}_dev_version.png", dpi=300)
# plot components
m.plot_components(forecast).suptitle(f"{col} Base Version")
m2.plot_components(forecast).suptitle(f"{col} Dev Version")
mapes[col] = round(
((forecast["yhat"] - forecast2["yhat"]) / forecast["yhat"]).abs().mean(), 5
)
lower_mapes[col] = round(
((forecast["yhat_lower"] - forecast2["yhat_lower"]) / forecast["yhat_lower"])
.abs()
.mean(),
5,
)
upper_mapes[col] = round(
((forecast["yhat_upper"] - forecast2["yhat_upper"]) / forecast["yhat_upper"])
.abs()
.mean(),
5,
)
# we expect 0 error for yhat (unchanged) and some small error (~1%) for uncertainty due to different sampling of random
print(mapes)
# MAPE values are artifically high for series that cross near zero (here lower for Christmas)
print(lower_mapes)
print(upper_mapes)
yhats matched identically, as expected as that is unchanged.
Here the best mape was 0.001 and the worst was 0.03. When absolute value was removed, error was near 0, suggesting that difference is nearly perfectly balanced above and below the main branch comparison.

Updated commit addressing all comments as requested submitted
ping for reviewers :)
@winedarksea thank you very much, I see a 2x speed improvement!
@nicolaerosia @winedarksea that's strange, I'd expect a much greater speed improvement based on: https://towardsdatascience.com/how-to-run-facebook-prophet-predict-x100-faster-cce0282ca77d
@orenmatar your timeit only looks at the speed of the probalistic uncertainty generation, while that 2x estimate is the entire prophet fit/predict. My .fit() takes significantly longer than the 91.5 ms you had in that article.
For me
Prophet().fit(df)
1.05 s ± 7.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Hi everyone, this speedup will help a lot with my own use case. Hence, just wanted to ping here, to see if this is ready to be merged. Thanks @winedarksea for the awesome work to get 2x speedup.
Sorry all for the delay on this, had the chance to look at it more deeply this weekend though and after making some tweaks, I think it's ready to merge!
@winedarksea the bulk of the changes I've made are to the function structure (the core of the vectorization code that you wrote are unchanged). I wanted to have most of the changes in sample_predictive_trend(), sample_model(), and other lower level functions, rather than the higher level functions like predict(). You can see the changes in this commit, but in summary:
- Created
_sample_predictive_trend(), a vectorized equivalent forsample_predictive_trend(). This adds the expectation to the uncertainty draws so does one extra pass of the data through the trend function, but this isn't a huge speed penalty and means we can utilise the code / outputs in MCMC sampling as well. - Instead of keeping the original code path for
sample_model()whenmcmc_samples > 0, we can generate samples for each MCMC iteration in a vectorized manner as well. This helps a lot when uncertainty_samples > mcmc_samples, e.g. if we want 1,000 MCMC draws and 5,000 total uncertainty samples, we can reduce the number of iterations by 5x. - There seemed to be some bugs / typos in the
_logistic_uncertainty()function but I think I've fixed them; cross-checked results against the main branch and everything matches.
I've tested the vectorized vs. non-vectorized versions over a few parameters: linear v logistic v flat trend, and history-only v history+future prediction dataframe. The code seems to work and the outputs match well (the script I ran is here). In terms of performance:
- MAP (i.e. no MCMC sampling) predictions are faster by about 3-4x for flat / linear growth models, and by about 5-7x for logistic growth models.
- MCMC predictions are faster by about 2-3x for flat / linear growth models.
- Unfortunately MCMC predictions can be slower by about 1.5x for logistic growth models. I'm guessing this varies based on how much bigger
uncertainty_samplesis compared tomcmc_samples; the greater the samples per MCMC iteration, the faster the vectorization method is, otherwise the overhead of thelogistic_uncertainty()function causes an overall slowdown.
Because of the mixed results I've added a parameter vectorized: bool = True to the predict() method so users can toggle between the backends based on their needs.
I've also re-ordered the prediction functions in the code so that the link between different methods is clearer, but eventually I think we'll need to refactor forecaster.py out into different files if we want to make any other big changes to it.
Next steps
- @winedarksea @orenmatar @nicolaerosia if you guys have some time to take a quick look at these changes over the next few days that would be great, but otherwise I will likely merge this after the tests pass in order to get it into the next release.
- I'd imagine this is possible for R as well, although I haven't checked what we're currently doing there and how hard it would be to refactor.
@tcuongd I went over the code and it looks really well integrated. My only comment is the names of the vectorized functions are not very indicative - the same name with an underscore. perhaps change them to "sample_model_vectorized" etc? but really not crucial... I'm also confused as to why it is saving so much more time when I ran it (see the blog post). But I don't know what to do about that... Anyway, this is terrific! I'm really excited about this release. Thank you all.
@tcuongd @winedarksea @nicolaerosia Finally had some time to check why this version was still significantly slower than what I got in my original blog post. The culprit: DataFrames. Their init method and getitem are notoriously slow, which can be significant when creating new dfs or getting items 1000s of times. The fix is super easy - just replace them with a dict. I created a new pr for this, and testing yielded ~10X faster runtime.
The new PR, with profiling details: https://github.com/facebook/prophet/pull/2299