3636 RFDETR_ONNX_MAX_RESOLUTION ,
3737 VALID_INFERENCE_MODELS_BACKENDS ,
3838)
39+ from inference .core .exceptions import PostProcessingError
3940from inference .core .models .base import Model
4041from inference .core .roboflow_api import get_extra_weights_provider_headers
4142from inference .core .utils .image_utils import load_image_bgr , load_image_rgb
@@ -767,12 +768,17 @@ def prepare_multi_label_classification_response(
767768 """
768769 results = []
769770 for prediction , image_size in zip (post_processed_predictions , image_sizes ):
771+ class_confidences = _reshape_classification_confidences (
772+ confidence = prediction .confidence .cpu (),
773+ expected_num_images = 1 ,
774+ class_names = class_names ,
775+ )[0 ].tolist ()
770776 image_predictions_dict = {
771777 class_names [class_id ]: {
772778 "confidence" : confidence ,
773779 "class_id" : class_id ,
774780 }
775- for class_id , confidence in enumerate (prediction . confidence . cpu (). tolist () )
781+ for class_id , confidence in enumerate (class_confidences )
776782 }
777783 predicted_classes = [
778784 class_names [class_id ] for class_id in prediction .class_ids .tolist ()
@@ -795,9 +801,12 @@ def prepare_classification_response(
795801 confidence_threshold : float ,
796802) -> List [ClassificationInferenceResponse ]:
797803 responses = []
798- for classes_confidence , image_size in zip (
799- post_processed_predictions .confidence .cpu ().tolist (), image_sizes
800- ):
804+ batch_confidences = _reshape_classification_confidences (
805+ confidence = post_processed_predictions .confidence .cpu (),
806+ expected_num_images = len (image_sizes ),
807+ class_names = class_names ,
808+ )
809+ for classes_confidence , image_size in zip (batch_confidences .tolist (), image_sizes ):
801810 individual_classes_predictions = []
802811 for i , cls_name in enumerate (class_names ):
803812 class_score = float (classes_confidence [i ])
@@ -831,6 +840,26 @@ def prepare_classification_response(
831840 return responses
832841
833842
843+ def _reshape_classification_confidences (
844+ confidence : torch .Tensor ,
845+ expected_num_images : int ,
846+ class_names : List [str ],
847+ ) -> torch .Tensor :
848+ expected_num_classes = len (class_names )
849+ expected_num_scores = expected_num_images * expected_num_classes
850+ actual_num_scores = confidence .numel ()
851+ if actual_num_scores != expected_num_scores :
852+ raise PostProcessingError (
853+ "Classification model output has shape "
854+ f"{ tuple (confidence .shape )} containing { actual_num_scores } confidence "
855+ f"score(s), but response metadata expects { expected_num_images } image(s) "
856+ f"x { expected_num_classes } class name(s) = { expected_num_scores } score(s). "
857+ "This usually means the model package class names metadata does not match "
858+ "the classifier head."
859+ )
860+ return confidence .reshape (expected_num_images , expected_num_classes )
861+
862+
834863def draw_predictions (inference_request , inference_response , class_names : List [str ]):
835864 """Draw prediction visuals on an image.
836865
0 commit comments