Skip to content

Commit 973bd70

Browse files
committed
Fix issue DistrictDataLabs#1328: update ContribEstimator wrapper to work with sklearn 1.6+ tags API
Starting in scikit-learn 1.6, the type-checking functions is_classifier(), is_regressor(), is_clusterer(), and is_outlier_detector() were changed to use a new tags-based mechanism. Instead of inspecting the _estimator_type attribute directly, they now call get_tags(estimator), which in turn invokes estimator.__sklearn_tags__() and checks the estimator_type field on the returned Tags dataclass. ContribEstimator did not implement __sklearn_tags__(), so when sklearn called it, the call fell through to ContribEstimator.__getattr__(), which proxied it to the wrapped third-party estimator. Since third-party estimators (the whole reason ContribEstimator exists) typically don't implement __sklearn_tags__() either, this raised an AttributeError, causing all four type checks to fail. The fix adds a __sklearn_tags__() method to ContribEstimator that builds a default Tags object (via BaseEstimator) and replaces the estimator_type field with the value from self._estimator_type when set. This preserves the existing behavior where wrap(est, "classifier") makes the estimator pass is_classifier() checks, while remaining forward-compatible with sklearn's tags infrastructure.
1 parent e156692 commit 973bd70

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

yellowbrick/contrib/wrapper.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
## Imports
2020
##########################################################################
2121

22+
import dataclasses
23+
24+
from sklearn.base import BaseEstimator
2225
from yellowbrick.exceptions import YellowbrickAttributeError
2326

2427

@@ -123,6 +126,13 @@ def __init__(self, estimator, estimator_type=None):
123126
if estimator_type:
124127
self._estimator_type = estimator_type
125128

129+
def __sklearn_tags__(self):
130+
tags = BaseEstimator().__sklearn_tags__()
131+
estimator_type = getattr(self, "_estimator_type", None)
132+
if estimator_type is not None:
133+
tags = dataclasses.replace(tags, estimator_type=estimator_type)
134+
return tags
135+
126136
def __getattr__(self, attr):
127137
# proxy to the wrapped object
128138
try:

0 commit comments

Comments
 (0)