Skip to content

Commit 45ef287

Browse files
authored
Fix classification confidence shape handling (#2262)
* Fix classification confidence shape handling * Cover short classification confidence vectors * Run make style
1 parent 2eb9419 commit 45ef287

2 files changed

Lines changed: 74 additions & 5 deletions

File tree

inference/core/models/inference_models_adapters.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
RFDETR_ONNX_MAX_RESOLUTION,
3737
VALID_INFERENCE_MODELS_BACKENDS,
3838
)
39+
from inference.core.exceptions import PostProcessingError
3940
from inference.core.models.base import Model
4041
from inference.core.roboflow_api import get_extra_weights_provider_headers
4142
from inference.core.utils.image_utils import load_image_bgr, load_image_rgb
@@ -767,12 +768,17 @@ def prepare_multi_label_classification_response(
767768
"""
768769
results = []
769770
for prediction, image_size in zip(post_processed_predictions, image_sizes):
771+
class_confidences = _reshape_classification_confidences(
772+
confidence=prediction.confidence.cpu(),
773+
expected_num_images=1,
774+
class_names=class_names,
775+
)[0].tolist()
770776
image_predictions_dict = {
771777
class_names[class_id]: {
772778
"confidence": confidence,
773779
"class_id": class_id,
774780
}
775-
for class_id, confidence in enumerate(prediction.confidence.cpu().tolist())
781+
for class_id, confidence in enumerate(class_confidences)
776782
}
777783
predicted_classes = [
778784
class_names[class_id] for class_id in prediction.class_ids.tolist()
@@ -795,9 +801,12 @@ def prepare_classification_response(
795801
confidence_threshold: float,
796802
) -> List[ClassificationInferenceResponse]:
797803
responses = []
798-
for classes_confidence, image_size in zip(
799-
post_processed_predictions.confidence.cpu().tolist(), image_sizes
800-
):
804+
batch_confidences = _reshape_classification_confidences(
805+
confidence=post_processed_predictions.confidence.cpu(),
806+
expected_num_images=len(image_sizes),
807+
class_names=class_names,
808+
)
809+
for classes_confidence, image_size in zip(batch_confidences.tolist(), image_sizes):
801810
individual_classes_predictions = []
802811
for i, cls_name in enumerate(class_names):
803812
class_score = float(classes_confidence[i])
@@ -831,6 +840,26 @@ def prepare_classification_response(
831840
return responses
832841

833842

843+
def _reshape_classification_confidences(
844+
confidence: torch.Tensor,
845+
expected_num_images: int,
846+
class_names: List[str],
847+
) -> torch.Tensor:
848+
expected_num_classes = len(class_names)
849+
expected_num_scores = expected_num_images * expected_num_classes
850+
actual_num_scores = confidence.numel()
851+
if actual_num_scores != expected_num_scores:
852+
raise PostProcessingError(
853+
"Classification model output has shape "
854+
f"{tuple(confidence.shape)} containing {actual_num_scores} confidence "
855+
f"score(s), but response metadata expects {expected_num_images} image(s) "
856+
f"x {expected_num_classes} class name(s) = {expected_num_scores} score(s). "
857+
"This usually means the model package class names metadata does not match "
858+
"the classifier head."
859+
)
860+
return confidence.reshape(expected_num_images, expected_num_classes)
861+
862+
834863
def draw_predictions(inference_request, inference_response, class_names: List[str]):
835864
"""Draw prediction visuals on an image.
836865

tests/inference/unit_tests/core/models/test_inference_models_adapters.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
import torch
55

66
from inference.core.models.inference_models_adapters import (
7+
prepare_classification_response,
78
prepare_multi_label_classification_response,
89
)
9-
from inference_models import MultiLabelClassificationPrediction
10+
from inference.core.exceptions import PostProcessingError
11+
from inference_models import (
12+
ClassificationPrediction,
13+
MultiLabelClassificationPrediction,
14+
)
1015

1116

1217
def test_prepare_multi_label_response_uses_class_ids_for_predicted_classes() -> None:
@@ -46,3 +51,38 @@ def test_prepare_multi_label_response_uses_class_ids_for_predicted_classes() ->
4651
assert r.predictions["d"].confidence == pytest.approx(0.9)
4752
# Only the model's filtered class_ids show up in predicted_classes.
4853
assert r.predicted_classes == ["c"]
54+
55+
56+
def test_prepare_classification_response_flattens_singleton_output_dimensions() -> None:
57+
class_names = ["cat", "dog"]
58+
prediction = ClassificationPrediction(
59+
class_id=torch.tensor([[1]], dtype=torch.long),
60+
confidence=torch.tensor([[[0.1, 0.9]]]),
61+
)
62+
63+
results = prepare_classification_response(
64+
post_processed_predictions=prediction,
65+
image_sizes=[(10, 20)],
66+
class_names=class_names,
67+
confidence_threshold=0.0,
68+
)
69+
70+
assert len(results) == 1
71+
assert results[0].top == "dog"
72+
assert results[0].confidence == pytest.approx(0.9)
73+
assert [p.class_name for p in results[0].predictions] == ["dog", "cat"]
74+
75+
76+
def test_prepare_classification_response_fails_on_class_count_mismatch() -> None:
77+
prediction = ClassificationPrediction(
78+
class_id=torch.tensor([0], dtype=torch.long),
79+
confidence=torch.tensor([[0.7]]),
80+
)
81+
82+
with pytest.raises(PostProcessingError, match="class names metadata"):
83+
prepare_classification_response(
84+
post_processed_predictions=prediction,
85+
image_sizes=[(10, 20)],
86+
class_names=["cat", "dog"],
87+
confidence_threshold=0.0,
88+
)

0 commit comments

Comments
 (0)