Skip to content

Commit 9cd1225

Browse files
authored
Add more tests cases for califorest model
1 parent 2e3deac commit 9cd1225

File tree

6 files changed

+72
-269
lines changed

6 files changed

+72
-269
lines changed

califorest_tests/__init__.py

Whitespace-only changes.

califorest_tests/test_datasets.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

califorest_tests/test_models.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

califorest_tests/test_tasks.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

califorest_tests/utils.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

tests/test_califorest.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,75 @@ def test_same_random_state_reproducible(self, minimal_data):
170170
probas2 = model2.predict_proba(X_test)
171171

172172
assert np.allclose(probas1, probas2)
173+
174+
class TestCaliForestPrediction:
175+
"""Tests for model prediction - verifies forward pass equivalent."""
176+
177+
def test_predict_output_shape(self, fitted_model, minimal_data):
178+
"""Test that predict returns correct shape."""
179+
X, y = minimal_data
180+
predictions = fitted_model.predict(X)
181+
assert predictions.shape == (X.shape[0],)
182+
183+
def test_predict_proba_output_shape(self, fitted_model, minimal_data):
184+
"""Test that predict_proba returns correct shape."""
185+
X, y = minimal_data
186+
probas = fitted_model.predict_proba(X)
187+
assert probas.shape == (X.shape[0], 2) # Binary classification
188+
189+
def test_predict_proba_valid_probabilities(self, fitted_model, minimal_data):
190+
"""Test that probabilities are valid (sum to 1, in [0,1])."""
191+
X, y = minimal_data
192+
probas = fitted_model.predict_proba(X)
193+
assert np.all(probas >= 0.0)
194+
assert np.all(probas <= 1.0)
195+
assert np.allclose(probas.sum(axis=1), 1.0)
196+
197+
def test_predict_returns_valid_classes(self, fitted_model, minimal_data):
198+
"""Test that predictions are valid class labels."""
199+
X, y = minimal_data
200+
predictions = fitted_model.predict(X)
201+
assert set(predictions).issubset({0, 1})
202+
203+
204+
class TestCaliForestCalibration:
205+
"""Tests for calibration functionality."""
206+
207+
def test_calibration_applied(self, fitted_model):
208+
"""Test that calibrator was fitted during training."""
209+
assert fitted_model.calibrator_ is not None
210+
211+
def test_isotonic_vs_platt(self, minimal_data):
212+
"""Test both calibration methods produce valid outputs."""
213+
X, y = minimal_data
214+
215+
model_iso = CaliForest(n_estimators=5, calibration_method="isotonic", random_state=42)
216+
model_platt = CaliForest(n_estimators=5, calibration_method="platt", random_state=42)
217+
218+
model_iso.fit(X, y)
219+
model_platt.fit(X, y)
220+
221+
probas_iso = model_iso.predict_proba(X)
222+
probas_platt = model_platt.predict_proba(X)
223+
224+
# Both should produce valid probabilities
225+
assert probas_iso.shape == probas_platt.shape
226+
assert np.all(probas_iso >= 0) and np.all(probas_iso <= 1)
227+
assert np.all(probas_platt >= 0) and np.all(probas_platt <= 1)
228+
229+
230+
class TestCaliForestEdgeCases:
231+
"""Tests for edge cases."""
232+
233+
def test_predict_before_fit_raises(self):
234+
"""Test that predicting before fitting raises error."""
235+
model = CaliForest()
236+
X_test = np.random.randn(5, 3)
237+
with pytest.raises(NotFittedError):
238+
model.predict(X_test)
239+
240+
def test_single_sample_prediction(self, fitted_model):
241+
"""Test prediction on a single sample."""
242+
X_single = np.random.randn(1, 3)
243+
probas = fitted_model.predict_proba(X_single)
244+
assert probas.shape == (1, 2)

0 commit comments

Comments
 (0)