scikit-optimize
scikit-optimize copied to clipboard
[Feature Request] Add 'ax' argument to skopt.plots functions
Currently, I can't find a way to put the produced plots from plot_evaluations
and plot_objective
in a subplot, they just create a new figure and show it.
Ideally something like sklearn.inspection.plot_partial_dependence
's ax argument (in 0.24.2) would serve all use cases:
ax : Matplotlib axes or array-like of Matplotlib axes, default=None
- If a single axis is passed in, it is treated as a bounding axes and a grid of partial dependence plots will be drawn within these bounds. The n_cols parameter controls the number of columns in the grid.
- If an array-like of axes are passed in, the partial dependence plots will be drawn directly into these axes.
- If None, a figure and a bounding axes is created and treated as the single axes case.
an example of why I want this is that I have a lot of models and a want to put for each model the results of plot_evaluations
and plot_objective
side by side, to save space and present the plots in a nice format.
plot_evaluations
doesn't plot on a single Axes
object, as each plot in the final result is an independent Axes
. I don't see how to pass an ax: Axes
parameter in that context.
I'm not aware of how sklearn
does it, but they can draw a grid of axes in the specified ax
.
I think implementing a similar behavior in the skopt.plots
functions can offer a higher flexibility.
If this is not doable, please let me know so I'll close this.
Oh I'll look into it then !
@Abdelgha-4 I am working on #1081 to add this argument. I modified plot_objective
for now, which seems to work properly. Does that cover your feature request ?
Tested using:
import skopt
from skopt import space, utils, plots
import numpy as np
import matplotlib.pyplot as plt
dims = [
space.Real(0, 2 * np.pi, name="x"),
space.Real(0, 2 * np.pi, name="y"),
]
@utils.use_named_args(dims)
def func(x, y):
return (np.sin(x/3) + np.cos(1.5*y) + np.cos(0.5*x*y)) * np.sin(x*y)
res = skopt.gp_minimize(func, dims, n_points=30)
fig, axes = plt.subplots(1, 2)
X = np.linspace(0, 2 * np.pi)
Y = np.linspace(0, 2 * np.pi)
XX, YY = np.meshgrid(X, Y)
axes[0].contourf(XX, YY, func([XX, YY]), locator=None, cmap="viridis_r")
plots.plot_objective(res, ax=axes[1])
plt.show()
@QuentinSoubeyran Yes! this works as expected, many thanks!
plot_evaluations
now also has the ax
arguments (as well as some others that used to be hard-coded).