yellowbrick
yellowbrick copied to clipboard
(sklearn 0.24) LearningCurve object Ipython REPL prettyprinting raise error
Describe the bug on sklearn 0.24 printing LearningCurve object representation after interactive cell raise error. But ValidationCurve object works well, in sklearn0.23 both work well. I think some changes in sklearn/utils/_pprint.py in 0.24 is the reason.
To Reproduce In Ipython cell(Jupyter notebook):
from sklearn.linear_model import LogisticRegression
from yellowbrick.model_selection import LearningCurve
lc=LearningCurve(LogisticRegression())
lc
after running cell:
# Add the traceback below
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~/dev/python38/myenv/lib/python3.8/site-packages/IPython/core/formatters.py in __call__(self, obj, include, exclude)
968
969 if method is not None:
--> 970 return method(include=include, exclude=exclude)
971 return None
972 else:
~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/base.py in _repr_mimebundle_(self, **kwargs)
462 def _repr_mimebundle_(self, **kwargs):
463 """Mime bundle used by jupyter kernels to display estimator"""
--> 464 output = {"text/plain": repr(self)}
465 if get_config()["display"] == 'diagram':
466 output["text/html"] = estimator_html_repr(self)
~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/base.py in __repr__(self, N_CHAR_MAX)
258 n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
259
--> 260 repr_ = pp.pformat(self)
261
262 # Use bruteforce ellipsis when there are a lot of non-blank characters
/usr/lib/python3.8/pprint.py in pformat(self, object)
151 def pformat(self, object):
152 sio = _StringIO()
--> 153 self._format(object, sio, 0, 0, {}, 0)
154 return sio.getvalue()
155
/usr/lib/python3.8/pprint.py in _format(self, object, stream, indent, allowance, context, level)
168 self._readable = False
169 return
--> 170 rep = self._repr(object, context, level)
171 max_width = self._width - indent - allowance
172 if len(rep) > max_width:
/usr/lib/python3.8/pprint.py in _repr(self, object, context, level)
402
403 def _repr(self, object, context, level):
--> 404 repr, readable, recursive = self.format(object, context.copy(),
405 self._depth, level)
406 if not readable:
~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in format(self, object, context, maxlevels, level)
178
179 def format(self, object, context, maxlevels, level):
--> 180 return _safe_repr(object, context, maxlevels, level,
181 changed_only=self._changed_only)
182
~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
423 recursive = False
424 if changed_only:
--> 425 params = _changed_params(object)
426 else:
427 params = object.get_params(deep=False)
~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in _changed_params(estimator)
110 return False
111
--> 112 return {k: v for k, v in params.items() if has_changed(k, v)}
113
114
~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in <dictcomp>(.0)
110 return False
111
--> 112 return {k: v for k, v in params.items() if has_changed(k, v)}
113
114
~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in has_changed(k, v)
98 if k not in init_params: # happens if k is part of a **kwargs
99 return True
--> 100 if init_params[k] == inspect._empty: # k has no default value
101 return True
102 # try to avoid calling repr on nested estimators
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Desktop (please complete the following information):
- OS: LINUX
- Python Version 3.8
- scikit-sklearn Version 0.24
- Yellowbrick Version v1.3
@bole1 Thank you for using Yellowbrick and thank you for reporting this bug! My guess is that the issue is arising from the comparison of the LearningCurve
parameter train_sizes=DEFAULT_TRAIN_SIZES
; DEFAULT_TRAIN_SIZES
is an np.array
and this ValueError
seems to happen a lot with arrays. Unfortunately, none of the code in the stack trace you provided is yellowbrick code, which makes it really difficult to get in the way of this error or make changes that might modify it. I assume that pprint
is using inspect
to get the default values of the Visualizer, hence the comparison. This leads me to some questions:
-
Do no other scikit-learn estimators use numpy arrays as initial arguments? If they do, then this is a bug for them as well, and we could issue an upstream PR to try to fix scikit-learn itself. If not, then we could model what scikit-learn estimators do when they could take array input (e.g. alphas? maybe?)
-
Is there a way that we could pprint a visualizer in a custom fashion that doesn't require the scikit-learn
_pprint.py
utility? It's getting tough to maintain our extension of scikit-learn and we'd like to depend on them a bit less.
As always, we're happy for PRs or your thoughts on the matter!
I was trying Yellowbrick today and ran into the same problem. It's indeed from train_sizes
. I got around it by changing the code of LearningCurve
class in yellowbrick/model_selection/learning_curve.py
:
class LearningCurve(ModelVisualizer):
...
def __init__(
...
#trains_sizes=DEFAULT_TRAIN_SIZES,
train_sizes=DEFAULT_TRAIN_SIZES.tolist(),
...
):
...
Hope this helps.
Python implementation: CPython
Python version : 3.11.0
IPython version : 8.12.0
yellowbrick: 1.5
xgboost: 1.7.4
from xgboost import XGBRegressor
import yellowbrick.model_selection as ms
ms.learning_curve(estimator=model1, X=X, y=y, train_sizes=[0.1, 0.325, 0.55, 0.775, 1],scoring='r2', n_jobs=-1, random_state=seed)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\core\formatters.py:974, in MimeBundleFormatter.__call__(self, obj, include, exclude)
971 method = get_real_method(obj, self.print_method)
973 if method is not None:
--> 974 return method(include=include, exclude=exclude)
975 return None
976 else:
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:631, in BaseEstimator._repr_mimebundle_(self, **kwargs)
629 def _repr_mimebundle_(self, **kwargs):
630 """Mime bundle used by jupyter kernels to display estimator"""
--> 631 output = {"text/plain": repr(self)}
632 if get_config()["display"] == "diagram":
633 output["text/html"] = estimator_html_repr(self)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:257, in BaseEstimator.__repr__(self, N_CHAR_MAX)
249 # use ellipsis for sequences with a lot of elements
250 pp = _EstimatorPrettyPrinter(
251 compact=True,
252 indent=1,
253 indent_at_name=True,
254 n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
255 )
--> 257 repr_ = pp.pformat(self)
259 # Use bruteforce ellipsis when there are a lot of non-blank characters
260 n_nonblank = len("".join(repr_.split()))
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:158, in PrettyPrinter.pformat(self, object)
156 def pformat(self, object):
157 sio = _StringIO()
--> 158 self._format(object, sio, 0, 0, {}, 0)
159 return sio.getvalue()
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:175, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
173 self._readable = False
174 return
--> 175 rep = self._repr(object, context, level)
176 max_width = self._width - indent - allowance
177 if len(rep) > max_width:
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:455, in PrettyPrinter._repr(self, object, context, level)
454 def _repr(self, object, context, level):
--> 455 repr, readable, recursive = self.format(object, context.copy(),
456 self._depth, level)
457 if not readable:
458 self._readable = False
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
188 def format(self, object, context, maxlevels, level):
--> 189 return _safe_repr(
190 object, context, maxlevels, level, changed_only=self._changed_only
191 )
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
438 recursive = False
439 if changed_only:
--> 440 params = _changed_params(object)
441 else:
442 params = object.get_params(deep=False)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in _changed_params(estimator)
110 return True
111 return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in <dictcomp>(.0)
110 return True
111 return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:101, in _changed_params.<locals>.has_changed(k, v)
99 if k not in init_params: # happens if k is part of a **kwargs
100 return True
--> 101 if init_params[k] == inspect._empty: # k has no default value
102 return True
103 # try to avoid calling repr on nested estimators
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\core\formatters.py:708, in PlainTextFormatter.__call__(self, obj)
701 stream = StringIO()
702 printer = pretty.RepresentationPrinter(stream, self.verbose,
703 self.max_width, self.newline,
704 max_seq_length=self.max_seq_length,
705 singleton_pprinters=self.singleton_printers,
706 type_pprinters=self.type_printers,
707 deferred_pprinters=self.deferred_printers)
--> 708 printer.pretty(obj)
709 printer.flush()
710 return stream.getvalue()
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\lib\pretty.py:410, in RepresentationPrinter.pretty(self, obj)
407 return meth(obj, self, cycle)
408 if cls is not object \
409 and callable(cls.__dict__.get('__repr__')):
--> 410 return _repr_pprint(obj, self, cycle)
412 return _default_pprint(obj, self, cycle)
413 finally:
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\lib\pretty.py:778, in _repr_pprint(obj, p, cycle)
776 """A pprint that just redirects to the normal repr function."""
777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
779 lines = output.splitlines()
780 with p.group():
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:257, in BaseEstimator.__repr__(self, N_CHAR_MAX)
249 # use ellipsis for sequences with a lot of elements
250 pp = _EstimatorPrettyPrinter(
251 compact=True,
252 indent=1,
253 indent_at_name=True,
254 n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
255 )
--> 257 repr_ = pp.pformat(self)
259 # Use bruteforce ellipsis when there are a lot of non-blank characters
260 n_nonblank = len("".join(repr_.split()))
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:158, in PrettyPrinter.pformat(self, object)
156 def pformat(self, object):
157 sio = _StringIO()
--> 158 self._format(object, sio, 0, 0, {}, 0)
159 return sio.getvalue()
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:175, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
173 self._readable = False
174 return
--> 175 rep = self._repr(object, context, level)
176 max_width = self._width - indent - allowance
177 if len(rep) > max_width:
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:455, in PrettyPrinter._repr(self, object, context, level)
454 def _repr(self, object, context, level):
--> 455 repr, readable, recursive = self.format(object, context.copy(),
456 self._depth, level)
457 if not readable:
458 self._readable = False
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
188 def format(self, object, context, maxlevels, level):
--> 189 return _safe_repr(
190 object, context, maxlevels, level, changed_only=self._changed_only
191 )
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
438 recursive = False
439 if changed_only:
--> 440 params = _changed_params(object)
441 else:
442 params = object.get_params(deep=False)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in _changed_params(estimator)
110 return True
111 return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in <dictcomp>(.0)
110 return True
111 return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:101, in _changed_params.<locals>.has_changed(k, v)
99 if k not in init_params: # happens if k is part of a **kwargs
100 return True
--> 101 if init_params[k] == inspect._empty: # k has no default value
102 return True
103 # try to avoid calling repr on nested estimators
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\core\formatters.py:344, in BaseFormatter.__call__(self, obj)
342 method = get_real_method(obj, self.print_method)
343 if method is not None:
--> 344 return method()
345 return None
346 else:
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:627, in BaseEstimator._repr_html_inner(self)
622 def _repr_html_inner(self):
623 """This function is returned by the @property `_repr_html_` to make
624 `hasattr(estimator, "_repr_html_") return `True` or `False` depending
625 on `get_config()["display"]`.
626 """
--> 627 return estimator_html_repr(self)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_estimator_html_repr.py:393, in estimator_html_repr(estimator)
391 style_template = Template(_STYLE)
392 style_with_id = style_template.substitute(id=container_id)
--> 393 estimator_str = str(estimator)
395 # The fallback message is shown by default and loading the CSS sets
396 # div.sk-text-repr-fallback to display: none to hide the fallback message.
397 #
(...)
402 # The reverse logic applies to HTML repr div.sk-container.
403 # div.sk-container is hidden by default and the loading the CSS displays it.
404 fallback_msg = (
405 "In a Jupyter environment, please rerun this cell to show the HTML"
406 " representation or trust the notebook. <br />On GitHub, the"
407 " HTML representation is unable to render, please try loading this page"
408 " with nbviewer.org."
409 )
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:257, in BaseEstimator.__repr__(self, N_CHAR_MAX)
249 # use ellipsis for sequences with a lot of elements
250 pp = _EstimatorPrettyPrinter(
251 compact=True,
252 indent=1,
253 indent_at_name=True,
254 n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
255 )
--> 257 repr_ = pp.pformat(self)
259 # Use bruteforce ellipsis when there are a lot of non-blank characters
260 n_nonblank = len("".join(repr_.split()))
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:158, in PrettyPrinter.pformat(self, object)
156 def pformat(self, object):
157 sio = _StringIO()
--> 158 self._format(object, sio, 0, 0, {}, 0)
159 return sio.getvalue()
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:175, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
173 self._readable = False
174 return
--> 175 rep = self._repr(object, context, level)
176 max_width = self._width - indent - allowance
177 if len(rep) > max_width:
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:455, in PrettyPrinter._repr(self, object, context, level)
454 def _repr(self, object, context, level):
--> 455 repr, readable, recursive = self.format(object, context.copy(),
456 self._depth, level)
457 if not readable:
458 self._readable = False
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
188 def format(self, object, context, maxlevels, level):
--> 189 return _safe_repr(
190 object, context, maxlevels, level, changed_only=self._changed_only
191 )
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
438 recursive = False
439 if changed_only:
--> 440 params = _changed_params(object)
441 else:
442 params = object.get_params(deep=False)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in _changed_params(estimator)
110 return True
111 return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in <dictcomp>(.0)
110 return True
111 return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:101, in _changed_params.<locals>.has_changed(k, v)
99 if k not in init_params: # happens if k is part of a **kwargs
100 return True
--> 101 if init_params[k] == inspect._empty: # k has no default value
102 return True
103 # try to avoid calling repr on nested estimators
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
@SPDA36 , I think you can do the same thing as I did to circumvent this problem.
For your setup, it will probably be:
- Go to
~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\yellowbrick\model_selection\learning_curve.py
- Find the
LearningCurve
class. - Change
train_sizes
parameter's default value totrain_sizes=DEFAULT_TRAIN_SIZES.tolist()
in__init__()
signature.
I guess the cause of the problem is that LearningCurve
inherits sklearn.base.BaseEstimator
, which does some additional checkings on __init__
's parameters. These additional checkings are probably added in the newer version of sklearn
.