Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFECV is much slower than Sklearn's implementation #1047

Open
jc639 opened this issue Mar 5, 2020 · 10 comments · May be fixed by #1048
Open

RFECV is much slower than Sklearn's implementation #1047

jc639 opened this issue Mar 5, 2020 · 10 comments · May be fixed by #1048
Labels
priority: high should be done before next release type: technical debt work to optimize or generalize code

Comments

@jc639
Copy link

jc639 commented Mar 5, 2020

I am aware that yellowbrick is using RFE and CV separately to produce the visualiser but the approach is several times slower than sklearn's implementation of RFECV.

Running the following in a jupyter notebook:

import yellowbrick
print('yellowbrick version: ', yellowbrick.__version__)
import sklearn
(print('sklearn version: ', sklearn.__version__))

yellowbrick version: 1.1
sklearn version: 0.22.1

from sklearn.datasets import make_classification
from sklearn.feature_selection import RFECV as skrfecv
from yellowbrick.model_selection import RFECV
from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression

# Build a classification task using 4 out of 50 informative features
X, y = make_classification(n_samples=200, n_features=50, n_informative=4,
                           n_redundant=2, n_repeated=0, n_classes=4,
                           n_clusters_per_class=1, random_state=0)

log_reg = LogisticRegression()


def rfe_time_test(yb=True):
    if yb:
        rfecv = RFECV(log_reg, step=1, cv=StratifiedKFold(5),
                      scoring='accuracy')
    else:
        rfecv = skrfecv(log_reg, step=1, cv=StratifiedKFold(5),
                      scoring='accuracy')
    _ = rfecv.fit(X, y)

%timeit rfe_time_test(yb=True)

1min 23s ± 8.18 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit rfe_time_test(yb=False)

3.73 s ± 430 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

If this is unavoidable due to using CV separately to get the full scores, then it would be nice to note in the documentation, so that you could use sklearn's RFECV to drop the bottom ~50% of features before running the visualiser.

This got me interested so I did some digging into what might affect the difference between sklearn and yellowbricks's RFECV:

import matplotlib.pyplot as plt
import numpy as np

def plot_timings(x_range, yb_timings, sk_timings, x_axis, titles):
    f, ax = plt.subplots(1, 2)
    s_times = np.array([t.average for t in sk_timings])
    y_times = np.array([t.average for t in yb_timings])
    ax[0].plot(x_range, y_times, 'ro-')
    ax[0].plot(x_range, s_times, 'bo-')
    ax[0].legend(['yellowbrick', 'sklearn'])
    ax[0].set_ylabel('Time (seconds)')
    ax[1].set_ylabel('YB time / SK time')

    ratio = y_times/s_times
    ax[1].plot(x_range, ratio, 'og-')
    for i, title in enumerate(titles):
        ax[i].set_title(title)
        ax[i].set_xlabel(x_axis)
    
    f.subplots_adjust(wspace=0.25)
    f.set_size_inches(10, 6)
        
    plt.show()
    
    return f

yb_timings = []
sk_timings = []
n_obs = [i for i in range(200, 1001, 100)]
for i in n_obs:
    # Build a classification task using 4 informative features
    X, y = make_classification(n_samples=i, n_features=10, n_informative=4,
                               n_redundant=2, n_repeated=0, n_classes=4,
                               n_clusters_per_class=1, random_state=0)
    yb_time = %timeit -o rfe_time_test(yb=True)
    yb_timings.append(yb_time)
    
    sk_time = %timeit -o rfe_time_test(yb=False)
    sk_timings.append(sk_time)
    
obs = plot_timings(n_obs, yb_timings, 
                   sk_timings, x_axis='Number of observations', 
                   titles=['Timings', 'Ratio'])

Timings and observations

Ratio of time difference is fairly stable over number of observations.

yb_timings = []
sk_timings = []
n_feats = [i for i in range(10, 51, 10)]
for i in n_feats:
    # Build a classification task using 4 informative features
    X, y = make_classification(n_samples=200, n_features=i, n_informative=4,
                               n_redundant=2, n_repeated=0, n_classes=4,
                               n_clusters_per_class=1, random_state=0)
    yb_time = %timeit -o rfe_time_test(yb=True)
    yb_timings.append(yb_time)
    
    sk_time = %timeit -o rfe_time_test(yb=False)
    sk_timings.append(sk_time)
    
feats = plot_timings(n_feats, yb_timings, 
                     sk_timings, x_axis='Number of input features', 
                     titles=['Timings', 'Ratio'])

Timings and features

As number of starting features increase YB becomes even slower relative to sklearn.

# Build a classification task using 4 informative features
X, y = make_classification(n_samples=200, n_features=10, n_informative=4,
                           n_redundant=2, n_repeated=0, n_classes=4,
                           n_clusters_per_class=1, random_state=0)

log_reg = LogisticRegression()

yb_timings = []
sk_timings = []
cvs = [i for i in range(2, 11, 2)]
for i in cvs:
    def rfe_time_test(yb=True):
        if yb:
            rfecv = RFECV(log_reg, step=1, cv=StratifiedKFold(i),
                          scoring='accuracy')
        else:
            rfecv = skrfecv(log_reg, step=1, cv=StratifiedKFold(i),
                          scoring='accuracy')
        _ = rfecv.fit(X, y)
        
    yb_time = %timeit -o rfe_time_test(yb=True)
    yb_timings.append(yb_time)
    
    sk_time = %timeit -o rfe_time_test(yb=False)
    sk_timings.append(sk_time)
    
cv = plot_timings(cvs, yb_timings, 
                     sk_timings, x_axis='Number of CV folds', 
                     titles=['Timings', 'Ratio'])

Timings and CV folds

YB becomes slower with increasing number of folds too!

@lwgray
Copy link
Contributor

lwgray commented Mar 6, 2020

hi @jc639
Thank you for using yellowbrick and doing this analysis. It is quite thorough.

I would like to explain a bit about the performance differences between YB and Sklearn. YB model does in fact utilize Sklearn in order to employ the RFE functionality. However, our model wraps sklearn.feature_selection.RFE and not sklearn.feature_selection.RFECV because access to the internals of the CV and RFE estimators is required for the visualization. The visualizer does take similar arguments, however it does not expose the same internal attributes. Furthermore the CV step has to be performed separately. Due to these limitations, performance can't be optimize to the levels demonstrated by RFECV as far as I can tell. If you have any suggestions on how to better implement this, we would grateful for your input.

You can find the code for this visualizer here My response is mostly taken directly from a comment within our code.

Additionally, although not pertinent to your question, the RFE model can be accessed via the rfe_estimator_attribute. Once fitted, the visualizer acts as a wrapper for this estimator and not for the original model passed to the model. This way the visualizer model can be used to make predictions.

If you would like to discuss this more, please drop us a comment. Also, maybe @bbengfort can provide a more nuisance answer if I didn't cover everything.

@jc639
Copy link
Author

jc639 commented Mar 6, 2020

@lwgray thanks for the reply, I thought this might be the case.

As far as I can tell from the code the reason to separate the two is to compute the following:

        sigmas = self.cv_scores_.std(axis=1)

        # Plot one standard deviation above and below the mean
        self.ax.fill_between(x, means - sigmas, means + sigmas, alpha=0.25)

Which allows the plotting of the std (a very good thing!), but might it make sense to have an option of a fast implementation which uses sklearn's RFECV at the expense of the std, but can still utilise the very nice visualisation with mean score.

Also maybe I am missing something but couldn't the sk RFECV fit method be overwritten to not do this one line:
scores = np.sum(scores, axis=0)

Aware this may be a bit hacky, but something like this addition:

        # THIS IS THE NEW ADDITION 
        self.cv_scores_ = scores

In the following:

from joblib import Parallel, delayed, effective_n_jobs
from sklearn.utils import check_X_y, safe_sqr
from sklearn.base import is_classifier, clone
from sklearn.model_selection import check_cv
from sklearn.model_selection._validation import _score
from sklearn.metrics import check_scoring
from sklearn.feature_selection._rfe import RFE, RFECV, _rfe_single_fit
import numpy as np


class NewRFECV(RFECV):
    def __init__(self, estimator, step=1, min_features_to_select=1, cv=None,
                 scoring=None, verbose=0, n_jobs=None):
        super(NewRFECV, self).__init__(estimator, step=1, min_features_to_select=1, cv=None,
                 scoring=None, verbose=0, n_jobs=None)
        
    def fit(self, X, y, groups=None):
        """Fit the RFE model and automatically tune the number of selected
           features.
        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            Training vector, where `n_samples` is the number of samples and
            `n_features` is the total number of features.
        y : array-like of shape (n_samples,)
            Target values (integers for classification, real numbers for
            regression).
        groups : array-like of shape (n_samples,) or None
            Group labels for the samples used while splitting the dataset into
            train/test set. Only used in conjunction with a "Group" :term:`cv`
            instance (e.g., :class:`~sklearn.model_selection.GroupKFold`).
        """
        X, y = check_X_y(X, y, "csr", ensure_min_features=2,
                         force_all_finite=False)

        # Initialization
        cv = check_cv(self.cv, y, is_classifier(self.estimator))
        scorer = check_scoring(self.estimator, scoring=self.scoring)
        n_features = X.shape[1]

        if 0.0 < self.step < 1.0:
            step = int(max(1, self.step * n_features))
        else:
            step = int(self.step)
        if step <= 0:
            raise ValueError("Step must be >0")

        # Build an RFE object, which will evaluate and score each possible
        # feature count, down to self.min_features_to_select
        rfe = RFE(estimator=self.estimator,
                  n_features_to_select=self.min_features_to_select,
                  step=self.step, verbose=self.verbose)

        # Determine the number of subsets of features by fitting across
        # the train folds and choosing the "features_to_select" parameter
        # that gives the least averaged error across all folds.

        # Note that joblib raises a non-picklable error for bound methods
        # even if n_jobs is set to 1 with the default multiprocessing
        # backend.
        # This branching is done so that to
        # make sure that user code that sets n_jobs to 1
        # and provides bound methods as scorers is not broken with the
        # addition of n_jobs parameter in version 0.18.

        if effective_n_jobs(self.n_jobs) == 1:
            parallel, func = list, _rfe_single_fit
        else:
            parallel = Parallel(n_jobs=self.n_jobs)
            func = delayed(_rfe_single_fit)

        scores = parallel(
            func(rfe, self.estimator, X, y, train, test, scorer)
            for train, test in cv.split(X, y, groups))

        # THIS IS THE NEW ADDITION 
        self.cv_scores_ = scores

        scores = np.sum(scores, axis=0)
        scores_rev = scores[::-1]
        argmax_idx = len(scores) - np.argmax(scores_rev) - 1
        n_features_to_select = max(
            n_features - (argmax_idx * step),
            self.min_features_to_select)

        # Re-execute an elimination with best_k over the whole set
        rfe = RFE(estimator=self.estimator,
                  n_features_to_select=n_features_to_select, step=self.step,
                  verbose=self.verbose)

        rfe.fit(X, y)

        # Set final attributes
        self.support_ = rfe.support_
        self.n_features_ = rfe.n_features_
        self.ranking_ = rfe.ranking_
        self.estimator_ = clone(self.estimator)
        self.estimator_.fit(self.transform(X), y)

        # Fixing a normalization error, n is equal to get_n_splits(X, y) - 1
        # here, the scores are normalized by get_n_splits(X, y)
        self.grid_scores_ = scores[::-1] / cv.get_n_splits(X, y, groups)
        return self
    
new = NewRFECV(estimator=SGDClassifier(max_iter=1000), cv=5, scoring='accuracy')

# make a classification dataset
X, y = make_classification(n_samples=1000, n_features=30, n_informative=4,
                           n_redundant=2, n_repeated=0, n_classes=4,
                           n_clusters_per_class=1, random_state=0)

%timeit new.fit(X, y)

5.3 s ± 582 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

np.array(new.cv_scores_).shape

(5, 30)

# comparison to sklearn original
orig = RFECV(estimator=LogisticRegression(), cv=5, scoring='accuracy')
%timeit orig.fit(X, y)

5.55 s ± 843 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# comparison to yb
from yellowbrick.model_selection import RFECV as ybRFECV
yb = ybRFECV(model=LogisticRegression(), cv=5, scoring='accuracy')
%timeit yb.fit(X, y)

1min 10s ± 2.63 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

Again maybe I am missing something here...

Happy to have a go at contributing if you think this might be a good idea - not done it before but willing to learn!

@lwgray
Copy link
Contributor

lwgray commented Mar 6, 2020

@bbengfort can you chime in here?.... I don't know of any instance where we altered a sklearn estimator

@bbengfort bbengfort added the type: technical debt work to optimize or generalize code label Mar 10, 2020
@bbengfort
Copy link
Member

@jc639 first off, I have to say - this is an excellent issue - thank you so much for taking the time to write benchmarks, plot them, and dive into both the yb and sklearn code! It's a lot easier for us to dive into these things when we get such excellent write-ups and investigations from contributors!

I have also been noticing a slow down in RFECV, and I do think it has to do with a divergence between the sklearn implementation and our internal implementation. You're right that we need the reimplementation in order to get data for the visualization that's not provided from the sklearn implementation. Perhaps we should open a scikit-learn issue with your proposed RFE change so that we can unify the API, and if sklearn updates RFE, then it will be automatically updated for us as well.

In the meantime, I'm going to create a PR to investigate your proposal and loop you into it so we can modify it together, does that sound ok?

@lwgray
Copy link
Contributor

lwgray commented Mar 10, 2020

@bbengfort Thanks for the help. I agree that opening a PR with sklearn might be the way forward

@bbengfort bbengfort linked a pull request Mar 10, 2020 that will close this issue
17 tasks
@bbengfort
Copy link
Member

@jc639 ok, checkout #1048 -- would you mind running your timings on that code? Hopefully, you should be able to clone my branch and push to it; if not let me know and I'll see what I can do permissions-wise. You may have to fork and PR into my fork's branch.

@jc639
Copy link
Author

jc639 commented Mar 13, 2020

@bbengfort @lwgray sorry for the radio silence, been a bit busy this week. I have got some time this weekend and next week to have a look into this.

Perhaps we should open a scikit-learn issue with your proposed RFE change so that we can unify the API, and if sklearn updates RFE, then it will be automatically updated for us as well.

Agree this would be best approach, to save importing the presumably "non-public" methods from the the internal _rfe.py.

@bbengfort
Copy link
Member

@jc639 any word on this? It would be great if we could update the performance of RFECV before the next release.

@bbengfort bbengfort added the priority: high should be done before next release label Apr 9, 2020
@jc639
Copy link
Author

jc639 commented Apr 13, 2020

@bbengfort I have forked your branch, and made a few edits and re-ran the test notebook again. Speed is now comparable to Sklearn. Just PR into your fork's branch.

import warnings
warnings.filterwarnings(action='ignore')
import yellowbrick
print('yellowbrick version: ', yellowbrick.__version__)
import sklearn
print('sklearn version: ', sklearn.__version__)

yellowbrick version: 1.1
sklearn version: 0.22.1

from sklearn.datasets import make_classification
from sklearn.feature_selection import RFECV as skrfecv
from yellowbrick.model_selection import RFECV
from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression

# Build a classification task using 4 out of 50 informative features
X, y = make_classification(n_samples=200, n_features=10, n_informative=4,
                           n_redundant=2, n_repeated=0, n_classes=4,
                           n_clusters_per_class=1, random_state=0)

log_reg = LogisticRegression()


def rfe_time_test(yb=True):
    if yb:
        rfecv = RFECV(log_reg, step=1, cv=5,
                      scoring='accuracy')
    else:
        rfecv = skrfecv(log_reg, step=1, cv=5,
                      scoring='accuracy')
    _ = rfecv.fit(X, y)

%timeit rfe_time_test(yb=True)

328 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit rfe_time_test(yb=False)

326 ms ± 23.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

import numpy as np

def plot_timings(x_range, yb_timings, sk_timings, x_axis, titles):
    f, ax = plt.subplots(1, 2)
    s_times = np.array([t.average for t in sk_timings])
    y_times = np.array([t.average for t in yb_timings])
    ax[0].plot(x_range, y_times, 'ro-')
    ax[0].plot(x_range, s_times, 'bo-')
    ax[0].legend(['yellowbrick', 'sklearn'])
    ax[0].set_ylabel('Time (seconds)')
    ax[1].set_ylabel('YB time / SK time')

    ratio = y_times/s_times
    ax[1].plot(x_range, ratio, 'og-')
    for i, title in enumerate(titles):
        ax[i].set_title(title)
        ax[i].set_xlabel(x_axis)
    
    f.subplots_adjust(wspace=0.25)
    f.set_size_inches(10, 6)
        
    plt.show()
    
    return f

yb_timings = []
sk_timings = []
n_obs = [i for i in range(200, 1001, 100)]
for i in n_obs:
    # Build a classification task using 4 out of 50 informative features
    X, y = make_classification(n_samples=i, n_features=10, n_informative=4,
                               n_redundant=2, n_repeated=0, n_classes=4,
                               n_clusters_per_class=1, random_state=0)
    yb_time = %timeit -o rfe_time_test(yb=True)
    yb_timings.append(yb_time)
    
    sk_time = %timeit -o rfe_time_test(yb=False)
    sk_timings.append(sk_time)
    
obs = plot_timings(n_obs, yb_timings, 
                   sk_timings, x_axis='Number of observations', 
                   titles=['Timings', 'Ratio'])

Timings and observations

yb_timings = []
sk_timings = []
n_feats = [i for i in range(10, 51, 10)]
for i in n_feats:
    # Build a classification task using 4 out of 50 informative features
    X, y = make_classification(n_samples=200, n_features=i, n_informative=4,
                               n_redundant=2, n_repeated=0, n_classes=4,
                               n_clusters_per_class=1, random_state=0)
    yb_time = %timeit -o rfe_time_test(yb=True)
    yb_timings.append(yb_time)
    
    sk_time = %timeit -o rfe_time_test(yb=False)
    sk_timings.append(sk_time)
    
feats = plot_timings(n_feats, yb_timings, 
                     sk_timings, x_axis='Number of input features', 
                     titles=['Timings', 'Ratio'])
feats.savefig('Timings and features.png')

Timings and features

# Build a classification task using 4 out of 50 informative features
X, y = make_classification(n_samples=200, n_features=10, n_informative=4,
                           n_redundant=2, n_repeated=0, n_classes=4,
                           n_clusters_per_class=1, random_state=0)

log_reg = LogisticRegression()

yb_timings = []
sk_timings = []
cvs = [i for i in range(2, 11, 2)]
for i in cvs:
    def rfe_time_test(yb=True):
        if yb:
            rfecv = RFECV(log_reg, step=1, cv=StratifiedKFold(i),
                          scoring='accuracy')
        else:
            rfecv = skrfecv(log_reg, step=1, cv=StratifiedKFold(i),
                          scoring='accuracy')
        _ = rfecv.fit(X, y)
        
    yb_time = %timeit -o rfe_time_test(yb=True)
    yb_timings.append(yb_time)
    
    sk_time = %timeit -o rfe_time_test(yb=False)
    sk_timings.append(sk_time)
    
cv = plot_timings(cvs, yb_timings, 
                     sk_timings, x_axis='Number of CV folds', 
                     titles=['Timings', 'Ratio'])

cv.savefig('Timings and CV folds.png')

Timings and CV folds

@Guilherme-Garcia
Copy link

I'm running the yellowbrick implementation of RFECV for 738 features and 2843 samples.

While the sklearn implementation run in ~ 20 min, the yellowbrick is running for days (I'm still waiting the outcome).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
priority: high should be done before next release type: technical debt work to optimize or generalize code
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants