diff --git a/yellowbrick/features/decomposition.py b/yellowbrick/features/decomposition.py index 31afbbd20..d8ab67dd9 100644 --- a/yellowbrick/features/decomposition.py +++ b/yellowbrick/features/decomposition.py @@ -11,6 +11,8 @@ ########################################################################## ## Imports ########################################################################## +import numpy as np +import bisect from yellowbrick.style import palettes from yellowbrick.features.base import FeatureVisualizer @@ -123,15 +125,9 @@ class ExplainedVariance(FeatureVisualizer): """ - def __init__( - self, - ax=None, - scale=True, - center=True, - n_components=None, - colormap=palettes.DEFAULT_SEQUENCE, - **kwargs - ): + def __init__(self, n_components=None, ax=None, scale=True, center=True, + colormap=palettes.DEFAULT_SEQUENCE, cumulative=False, cutoff=95, + **kwargs): super(ExplainedVariance, self).__init__(ax=ax, **kwargs) @@ -146,10 +142,16 @@ def __init__( ] ) self.pca_features = None - + self.cumulative = cumulative + self.cutoff = cutoff + @property def explained_variance_(self): return self.pipeline.steps[-1][1].explained_variance_ + + @property + def explained_variance_ratio_(self): + return self.pipeline.steps[-1][1].explained_variance_ratio_ def fit(self, X, y=None): self.pipeline.fit(X) @@ -161,8 +163,17 @@ def transform(self, X): return self.pca_features def draw(self): - X = self.explained_variance_ - self.ax.plot(X) + X = self.explained_variance_ratio_ + self.ax.plot(X, label = "Explained Variance") + if (self.cumulative): + X = np.cumsum(self.explained_variance_ratio_) + self.ax.plot(X, label = "Cumulative Variance") + + n_comp = bisect.bisect_left(X, self.cutoff/100); + self.ax.vlines(n_comp, 0, X[n_comp], linestyle = "dashed", + label=str(self.cutoff)+"% Variance") + self.ax.hlines(X[n_comp], 0, n_comp, linestyle = "dashed") + return self.ax def finalize(self, **kwargs): @@ -170,5 +181,6 @@ def finalize(self, **kwargs): self.set_title("Explained Variance Plot") # Set the axes labels - self.ax.set_ylabel("Explained Variance") - self.ax.set_xlabel("Number of Components") + self.ax.set_ylabel('Explained Variance') + self.ax.set_xlabel('Number of Components') + self.ax.legend(loc="center right")