seaborn
seaborn copied to clipboard
sns.regplot regression line fails with large values
With large values, the regression line can be done incorrectly:
The first three lines are good, the last three aren't. (Note also the strange shape of the green shaded region which is the same issue).
This is not strictly a seaborn issue - see https://github.com/statsmodels/statsmodels/issues/9258 where there is some further information about this. However, as other OLS methods implemented by statsmodels (QR) produce an accurate fit in these cases, as does sklearn (dotted black line), I thought it might be worth posting here. Perhaps worth considering whether a more robust algorithm which doesn't lead to these issues can be used to simplify the line of best fit visualisation process.
The code used to generate the above plot is
import numpy as np
import statsmodels.api
import sklearn.linear_model
import matplotlib.pyplot as plt
import seaborn as sns
x_base = np.linspace(4e13, 10e13, 10)
y = np.linspace(1, 0, 10)
for i in range(6):
x = x_base + (i*3e13)
# solve using statsmodels
stats_ols = statsmodels.regression.linear_model.OLS(
endog=y, exog=statsmodels.api.add_constant(x))
stats_ols_fitted = stats_ols.fit() # uses method = "pinv" by default
# stats_ols_fitted = stats_ols.fit(method = "qr") # fits correctly
# solve & predict using sklearn
sklearn_ols = sklearn.linear_model.LinearRegression()
sklearn_ols.fit(x.reshape((-1,1)), y)
x_sklearn = np.linspace(x.min(), x.max())
y_sklearn = sklearn_ols.predict(x_sklearn.reshape((-1,1)))
# compose informative legend label for each set of data/LR model
label ='statsmodels OLS: $r^2=' + str(np.round(stats_ols_fitted._results.rsquared, 3)) + '$'
label += '\nStatsmodels params: ' + ', '.join(['{:0.3}'.format(param) for param in stats_ols_fitted._results.params])
label += '\nSklearn params: ' + ', '.join(['{:0.3}'.format(param) for param in [sklearn_ols.intercept_] + list(sklearn_ols.coef_)])
# plot using seaborn
sns.regplot(x=x, y=y, label=label, ax=plt.gca())
# plot the LR fits (sklearn)
plt.plot(x_sklearn, y_sklearn,
label=('sklearn LinearRegression' if i in [2,5] else ''),
ls=':', lw=1.5, c='k')
plt.legend(fontsize='small', loc='center left', bbox_to_anchor=(1, 0.5), ncols=2)
plt.show()