Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(sklearn 0.24) LearningCurve object Ipython REPL prettyprinting raise error #1152

Open
bole1 opened this issue Feb 13, 2021 · 4 comments
Open
Labels
type: bug something isn't working

Comments

@bole1
Copy link

bole1 commented Feb 13, 2021

Describe the bug
on sklearn 0.24 printing LearningCurve object representation after interactive cell raise error. But ValidationCurve object works well, in sklearn0.23 both work well. I think some changes in sklearn/utils/_pprint.py in 0.24 is the reason.

To Reproduce
In Ipython cell(Jupyter notebook):

from sklearn.linear_model import LogisticRegression
from yellowbrick.model_selection import LearningCurve
lc=LearningCurve(LogisticRegression())
lc

after running cell:

# Add the traceback below


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/dev/python38/myenv/lib/python3.8/site-packages/IPython/core/formatters.py in __call__(self, obj, include, exclude)
    968 
    969             if method is not None:
--> 970                 return method(include=include, exclude=exclude)
    971             return None
    972         else:

~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/base.py in _repr_mimebundle_(self, **kwargs)
    462     def _repr_mimebundle_(self, **kwargs):
    463         """Mime bundle used by jupyter kernels to display estimator"""
--> 464         output = {"text/plain": repr(self)}
    465         if get_config()["display"] == 'diagram':
    466             output["text/html"] = estimator_html_repr(self)

~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/base.py in __repr__(self, N_CHAR_MAX)
    258             n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
    259 
--> 260         repr_ = pp.pformat(self)
    261 
    262         # Use bruteforce ellipsis when there are a lot of non-blank characters

/usr/lib/python3.8/pprint.py in pformat(self, object)
    151     def pformat(self, object):
    152         sio = _StringIO()
--> 153         self._format(object, sio, 0, 0, {}, 0)
    154         return sio.getvalue()
    155 

/usr/lib/python3.8/pprint.py in _format(self, object, stream, indent, allowance, context, level)
    168             self._readable = False
    169             return
--> 170         rep = self._repr(object, context, level)
    171         max_width = self._width - indent - allowance
    172         if len(rep) > max_width:

/usr/lib/python3.8/pprint.py in _repr(self, object, context, level)
    402 
    403     def _repr(self, object, context, level):
--> 404         repr, readable, recursive = self.format(object, context.copy(),
    405                                                 self._depth, level)
    406         if not readable:

~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in format(self, object, context, maxlevels, level)
    178 
    179     def format(self, object, context, maxlevels, level):
--> 180         return _safe_repr(object, context, maxlevels, level,
    181                           changed_only=self._changed_only)
    182 

~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
    423         recursive = False
    424         if changed_only:
--> 425             params = _changed_params(object)
    426         else:
    427             params = object.get_params(deep=False)

~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in _changed_params(estimator)
    110         return False
    111 
--> 112     return {k: v for k, v in params.items() if has_changed(k, v)}
    113 
    114 

~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in <dictcomp>(.0)
    110         return False
    111 
--> 112     return {k: v for k, v in params.items() if has_changed(k, v)}
    113 
    114 

~/dev/python38/myenv/lib/python3.8/site-packages/sklearn/utils/_pprint.py in has_changed(k, v)
     98         if k not in init_params:  # happens if k is part of a **kwargs
     99             return True
--> 100         if init_params[k] == inspect._empty:  # k has no default value
    101             return True
    102         # try to avoid calling repr on nested estimators

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Desktop (please complete the following information):

  • OS: LINUX
  • Python Version 3.8
  • scikit-sklearn Version 0.24
  • Yellowbrick Version v1.3
@bbengfort bbengfort added the type: bug something isn't working label Feb 15, 2021
@bbengfort
Copy link
Member

@bole1 Thank you for using Yellowbrick and thank you for reporting this bug! My guess is that the issue is arising from the comparison of the LearningCurve parameter train_sizes=DEFAULT_TRAIN_SIZES; DEFAULT_TRAIN_SIZES is an np.array and this ValueError seems to happen a lot with arrays. Unfortunately, none of the code in the stack trace you provided is yellowbrick code, which makes it really difficult to get in the way of this error or make changes that might modify it. I assume that pprint is using inspect to get the default values of the Visualizer, hence the comparison. This leads me to some questions:

  1. Do no other scikit-learn estimators use numpy arrays as initial arguments? If they do, then this is a bug for them as well, and we could issue an upstream PR to try to fix scikit-learn itself. If not, then we could model what scikit-learn estimators do when they could take array input (e.g. alphas? maybe?)

  2. Is there a way that we could pprint a visualizer in a custom fashion that doesn't require the scikit-learn _pprint.py utility? It's getting tough to maintain our extension of scikit-learn and we'd like to depend on them a bit less.

As always, we're happy for PRs or your thoughts on the matter!

@chungyaoc
Copy link

I was trying Yellowbrick today and ran into the same problem. It's indeed from train_sizes. I got around it by changing the code of LearningCurve class in yellowbrick/model_selection/learning_curve.py:

class LearningCurve(ModelVisualizer): 
...
    def __init__(
        ...
        #trains_sizes=DEFAULT_TRAIN_SIZES,
        train_sizes=DEFAULT_TRAIN_SIZES.tolist(),
        ...
    ):
        ...

Hope this helps.

@SPDA36
Copy link

SPDA36 commented Jul 12, 2023

Python implementation: CPython
Python version       : 3.11.0
IPython version      : 8.12.0
yellowbrick: 1.5
xgboost: 1.7.4
from xgboost import XGBRegressor
import yellowbrick.model_selection as ms

ms.learning_curve(estimator=model1, X=X, y=y, train_sizes=[0.1, 0.325, 0.55, 0.775, 1],scoring='r2', n_jobs=-1, random_state=seed)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\core\formatters.py:974, in MimeBundleFormatter.__call__(self, obj, include, exclude)
    971     method = get_real_method(obj, self.print_method)
    973     if method is not None:
--> 974         return method(include=include, exclude=exclude)
    975     return None
    976 else:

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:631, in BaseEstimator._repr_mimebundle_(self, **kwargs)
    629 def _repr_mimebundle_(self, **kwargs):
    630     """Mime bundle used by jupyter kernels to display estimator"""
--> 631     output = {"text/plain": repr(self)}
    632     if get_config()["display"] == "diagram":
    633         output["text/html"] = estimator_html_repr(self)

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:257, in BaseEstimator.__repr__(self, N_CHAR_MAX)
    249 # use ellipsis for sequences with a lot of elements
    250 pp = _EstimatorPrettyPrinter(
    251     compact=True,
    252     indent=1,
    253     indent_at_name=True,
    254     n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
    255 )
--> 257 repr_ = pp.pformat(self)
    259 # Use bruteforce ellipsis when there are a lot of non-blank characters
    260 n_nonblank = len("".join(repr_.split()))

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:158, in PrettyPrinter.pformat(self, object)
    156 def pformat(self, object):
    157     sio = _StringIO()
--> 158     self._format(object, sio, 0, 0, {}, 0)
    159     return sio.getvalue()

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:175, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
    173     self._readable = False
    174     return
--> 175 rep = self._repr(object, context, level)
    176 max_width = self._width - indent - allowance
    177 if len(rep) > max_width:

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:455, in PrettyPrinter._repr(self, object, context, level)
    454 def _repr(self, object, context, level):
--> 455     repr, readable, recursive = self.format(object, context.copy(),
    456                                             self._depth, level)
    457     if not readable:
    458         self._readable = False

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
    188 def format(self, object, context, maxlevels, level):
--> 189     return _safe_repr(
    190         object, context, maxlevels, level, changed_only=self._changed_only
    191     )

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
    438 recursive = False
    439 if changed_only:
--> 440     params = _changed_params(object)
    441 else:
    442     params = object.get_params(deep=False)

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in _changed_params(estimator)
    110         return True
    111     return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in <dictcomp>(.0)
    110         return True
    111     return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:101, in _changed_params.<locals>.has_changed(k, v)
     99 if k not in init_params:  # happens if k is part of a **kwargs
    100     return True
--> 101 if init_params[k] == inspect._empty:  # k has no default value
    102     return True
    103 # try to avoid calling repr on nested estimators

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\core\formatters.py:708, in PlainTextFormatter.__call__(self, obj)
    701 stream = StringIO()
    702 printer = pretty.RepresentationPrinter(stream, self.verbose,
    703     self.max_width, self.newline,
    704     max_seq_length=self.max_seq_length,
    705     singleton_pprinters=self.singleton_printers,
    706     type_pprinters=self.type_printers,
    707     deferred_pprinters=self.deferred_printers)
--> 708 printer.pretty(obj)
    709 printer.flush()
    710 return stream.getvalue()

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\lib\pretty.py:410, in RepresentationPrinter.pretty(self, obj)
    407                         return meth(obj, self, cycle)
    408                 if cls is not object \
    409                         and callable(cls.__dict__.get('__repr__')):
--> 410                     return _repr_pprint(obj, self, cycle)
    412     return _default_pprint(obj, self, cycle)
    413 finally:

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\lib\pretty.py:778, in _repr_pprint(obj, p, cycle)
    776 """A pprint that just redirects to the normal repr function."""
    777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
    779 lines = output.splitlines()
    780 with p.group():

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:257, in BaseEstimator.__repr__(self, N_CHAR_MAX)
    249 # use ellipsis for sequences with a lot of elements
    250 pp = _EstimatorPrettyPrinter(
    251     compact=True,
    252     indent=1,
    253     indent_at_name=True,
    254     n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
    255 )
--> 257 repr_ = pp.pformat(self)
    259 # Use bruteforce ellipsis when there are a lot of non-blank characters
    260 n_nonblank = len("".join(repr_.split()))

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:158, in PrettyPrinter.pformat(self, object)
    156 def pformat(self, object):
    157     sio = _StringIO()
--> 158     self._format(object, sio, 0, 0, {}, 0)
    159     return sio.getvalue()

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:175, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
    173     self._readable = False
    174     return
--> 175 rep = self._repr(object, context, level)
    176 max_width = self._width - indent - allowance
    177 if len(rep) > max_width:

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:455, in PrettyPrinter._repr(self, object, context, level)
    454 def _repr(self, object, context, level):
--> 455     repr, readable, recursive = self.format(object, context.copy(),
    456                                             self._depth, level)
    457     if not readable:
    458         self._readable = False

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
    188 def format(self, object, context, maxlevels, level):
--> 189     return _safe_repr(
    190         object, context, maxlevels, level, changed_only=self._changed_only
    191     )

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
    438 recursive = False
    439 if changed_only:
--> 440     params = _changed_params(object)
    441 else:
    442     params = object.get_params(deep=False)

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in _changed_params(estimator)
    110         return True
    111     return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in <dictcomp>(.0)
    110         return True
    111     return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:101, in _changed_params.<locals>.has_changed(k, v)
     99 if k not in init_params:  # happens if k is part of a **kwargs
    100     return True
--> 101 if init_params[k] == inspect._empty:  # k has no default value
    102     return True
    103 # try to avoid calling repr on nested estimators

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\IPython\core\formatters.py:344, in BaseFormatter.__call__(self, obj)
    342     method = get_real_method(obj, self.print_method)
    343     if method is not None:
--> 344         return method()
    345     return None
    346 else:

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:627, in BaseEstimator._repr_html_inner(self)
    622 def _repr_html_inner(self):
    623     """This function is returned by the @property `_repr_html_` to make
    624     `hasattr(estimator, "_repr_html_") return `True` or `False` depending
    625     on `get_config()["display"]`.
    626     """
--> 627     return estimator_html_repr(self)

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_estimator_html_repr.py:393, in estimator_html_repr(estimator)
    391 style_template = Template(_STYLE)
    392 style_with_id = style_template.substitute(id=container_id)
--> 393 estimator_str = str(estimator)
    395 # The fallback message is shown by default and loading the CSS sets
    396 # div.sk-text-repr-fallback to display: none to hide the fallback message.
    397 #
   (...)
    402 # The reverse logic applies to HTML repr div.sk-container.
    403 # div.sk-container is hidden by default and the loading the CSS displays it.
    404 fallback_msg = (
    405     "In a Jupyter environment, please rerun this cell to show the HTML"
    406     " representation or trust the notebook. <br />On GitHub, the"
    407     " HTML representation is unable to render, please try loading this page"
    408     " with nbviewer.org."
    409 )

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\base.py:257, in BaseEstimator.__repr__(self, N_CHAR_MAX)
    249 # use ellipsis for sequences with a lot of elements
    250 pp = _EstimatorPrettyPrinter(
    251     compact=True,
    252     indent=1,
    253     indent_at_name=True,
    254     n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
    255 )
--> 257 repr_ = pp.pformat(self)
    259 # Use bruteforce ellipsis when there are a lot of non-blank characters
    260 n_nonblank = len("".join(repr_.split()))

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:158, in PrettyPrinter.pformat(self, object)
    156 def pformat(self, object):
    157     sio = _StringIO()
--> 158     self._format(object, sio, 0, 0, {}, 0)
    159     return sio.getvalue()

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:175, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
    173     self._readable = False
    174     return
--> 175 rep = self._repr(object, context, level)
    176 max_width = self._width - indent - allowance
    177 if len(rep) > max_width:

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\pprint.py:455, in PrettyPrinter._repr(self, object, context, level)
    454 def _repr(self, object, context, level):
--> 455     repr, readable, recursive = self.format(object, context.copy(),
    456                                             self._depth, level)
    457     if not readable:
    458         self._readable = False

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
    188 def format(self, object, context, maxlevels, level):
--> 189     return _safe_repr(
    190         object, context, maxlevels, level, changed_only=self._changed_only
    191     )

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
    438 recursive = False
    439 if changed_only:
--> 440     params = _changed_params(object)
    441 else:
    442     params = object.get_params(deep=False)

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in _changed_params(estimator)
    110         return True
    111     return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:113, in <dictcomp>(.0)
    110         return True
    111     return False
--> 113 return {k: v for k, v in params.items() if has_changed(k, v)}

File ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\sklearn\utils\_pprint.py:101, in _changed_params.<locals>.has_changed(k, v)
     99 if k not in init_params:  # happens if k is part of a **kwargs
    100     return True
--> 101 if init_params[k] == inspect._empty:  # k has no default value
    102     return True
    103 # try to avoid calling repr on nested estimators

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

@chungyaoc
Copy link

chungyaoc commented Jul 13, 2023

@SPDA36 , I think you can do the same thing as I did to circumvent this problem.

For your setup, it will probably be:

  1. Go to ~\miniconda3\envs\DS_Python311_JAN2023\Lib\site-packages\yellowbrick\model_selection\learning_curve.py
  2. Find the LearningCurve class.
  3. Change train_sizes parameter's default value to train_sizes=DEFAULT_TRAIN_SIZES.tolist() in __init__() signature.

I guess the cause of the problem is that LearningCurve inherits sklearn.base.BaseEstimator, which does some additional checkings on __init__'s parameters. These additional checkings are probably added in the newer version of sklearn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type: bug something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants