Skip to content

Commit a19cb78

Browse files
committed
parse and use recommendedParameters
1 parent 5df1795 commit a19cb78

File tree

25 files changed

+1922
-64
lines changed

25 files changed

+1922
-64
lines changed

inference/core/entities/requests/inference.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,14 @@ class ObjectDetectionInferenceRequest(CVInferenceRequest):
146146
description="If provided, only predictions for the listed classes will be returned",
147147
)
148148
confidence: Optional[float] = Field(
149-
default=0.4,
149+
default=None,
150150
examples=[0.5],
151-
description="The confidence threshold used to filter out predictions",
151+
description=(
152+
"The confidence threshold used to filter out predictions. If omitted, "
153+
"the server uses the model's F1-optimal threshold from model evaluation "
154+
"when available, otherwise falls back to 0.4. Pass an explicit value to "
155+
"override both."
156+
),
152157
)
153158
fix_batch_size: Optional[bool] = Field(
154159
default=False,

inference/core/models/inference_models_adapters.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
)
6363
from inference_models.models.base.types import PreprocessingMetadata
6464

65+
6566
DEFAULT_COLOR_PALETTE = [
6667
"#A351FB",
6768
"#FF4040",
@@ -154,7 +155,10 @@ def postprocess(
154155
**kwargs,
155156
) -> List[ObjectDetectionInferenceResponse]:
156157
mapped_kwargs = self.map_inference_kwargs(kwargs)
157-
detections_list = self._model.post_process(
158+
# The model owns the recommendedParameters priority chain (user → per-class
159+
# → global → default) and per-class refinement. The adapter just passes
160+
# the user's confidence kwarg through unchanged.
161+
detections_list = self._model.post_process_with_confidence_filter(
158162
predictions, preprocess_return_metadata, **mapped_kwargs
159163
)
160164

@@ -305,7 +309,8 @@ def postprocess(
305309
**kwargs,
306310
) -> List[InstanceSegmentationInferenceResponse]:
307311
mapped_kwargs = self.map_inference_kwargs(kwargs)
308-
detections_list = self._model.post_process(
312+
# See OD adapter — the model owns the recommendedParameters filter chain.
313+
detections_list = self._model.post_process_with_confidence_filter(
309314
predictions, preprocess_return_metadata, **mapped_kwargs
310315
)
311316

@@ -465,7 +470,8 @@ def postprocess(
465470
**kwargs,
466471
) -> List[KeypointsDetectionInferenceResponse]:
467472
mapped_kwargs = self.map_inference_kwargs(kwargs)
468-
keypoints_list, detections_list = self._model.post_process(
473+
# See OD adapter — the model owns the recommendedParameters filter chain.
474+
keypoints_list, detections_list = self._model.post_process_with_confidence_filter(
469475
predictions, preprocess_return_metadata, **mapped_kwargs
470476
)
471477
if detections_list is None:
@@ -677,25 +683,36 @@ def postprocess(
677683
List[ClassificationInferenceResponse],
678684
]:
679685
mapped_kwargs = self.map_inference_kwargs(kwargs)
680-
post_processed_predictions = self._model.post_process(
681-
predictions, **mapped_kwargs
682-
)
683-
if isinstance(post_processed_predictions, list):
684-
# multi-label classification
685-
return prepare_multi_label_classification_response(
686-
post_processed_predictions,
687-
image_sizes=returned_metadata,
688-
class_names=self.class_names,
689-
confidence_threshold=kwargs.get("confidence", 0.5),
686+
if isinstance(self._model, MultiLabelClassificationModel):
687+
# Model owns the recommendedParameters filter chain — its
688+
# post_process_with_confidence_filter applies per-class refinement
689+
# to `class_ids`. The response builder reads `class_ids` directly
690+
# rather than re-thresholding the full confidence vector, so the
691+
# per-class decision makes it through to the API response.
692+
post_processed_predictions = (
693+
self._model.post_process_with_confidence_filter(
694+
predictions, **mapped_kwargs
695+
)
690696
)
691-
else:
692-
# single-label classification
693-
return prepare_classification_response(
697+
return prepare_multi_label_classification_response(
694698
post_processed_predictions,
695699
image_sizes=returned_metadata,
696700
class_names=self.class_names,
697-
confidence_threshold=kwargs.get("confidence", 0.5),
698701
)
702+
# Single-label classification: top-1 always wins regardless of
703+
# confidence, so per-class refinement isn't meaningful here. The base
704+
# class deliberately opts out of recommendedParameters entirely. The
705+
# response builder still uses kwargs.get("confidence", 0.5) for the
706+
# cutoff that decides which alternative classes show up.
707+
post_processed_predictions = self._model.post_process(
708+
predictions, **mapped_kwargs
709+
)
710+
return prepare_classification_response(
711+
post_processed_predictions,
712+
image_sizes=returned_metadata,
713+
class_names=self.class_names,
714+
confidence_threshold=kwargs.get("confidence", 0.5),
715+
)
699716

700717
def clear_cache(self, delete_from_disk: bool = True) -> None:
701718
"""Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.
@@ -747,20 +764,29 @@ def prepare_multi_label_classification_response(
747764
post_processed_predictions: List[MultiLabelClassificationPrediction],
748765
image_sizes: List[Tuple[int, int]],
749766
class_names: List[str],
750-
confidence_threshold: float,
751767
) -> List[MultiLabelClassificationInferenceResponse]:
768+
"""Build the API response from a model's post-processed predictions.
769+
770+
`prediction.class_ids` is the authoritative list of "passed" classes —
771+
the model's `post_process_with_confidence_filter` already applied the
772+
full priority chain (user → per-class → global → default), so the
773+
response builder doesn't re-threshold here. The full per-class score
774+
vector is still emitted in `image_predictions_dict` for UI display.
775+
"""
752776
results = []
753777
for prediction, image_size in zip(post_processed_predictions, image_sizes):
754-
image_predictions_dict = dict()
755-
predicted_classes = []
756-
for class_id, confidence in enumerate(prediction.confidence.cpu().tolist()):
757-
cls_name = class_names[class_id]
758-
image_predictions_dict[cls_name] = {
778+
image_predictions_dict = {
779+
class_names[class_id]: {
759780
"confidence": confidence,
760781
"class_id": class_id,
761782
}
762-
if confidence > confidence_threshold:
763-
predicted_classes.append(cls_name)
783+
for class_id, confidence in enumerate(
784+
prediction.confidence.cpu().tolist()
785+
)
786+
}
787+
predicted_classes = [
788+
class_names[class_id] for class_id in prediction.class_ids.tolist()
789+
]
764790
results.append(
765791
MultiLabelClassificationInferenceResponse(
766792
predictions=image_predictions_dict,

inference/core/workflows/core_steps/models/roboflow/instance_segmentation/v2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,15 @@ class BlockManifest(WorkflowBlockManifest):
8585
images: Selector(kind=[IMAGE_KIND]) = ImageInputField
8686
model_id: Union[Selector(kind=[ROBOFLOW_MODEL_ID_KIND]), str] = RoboflowModelField
8787
confidence: Union[
88-
FloatZeroToOne,
88+
Optional[FloatZeroToOne],
8989
Selector(kind=[FLOAT_ZERO_TO_ONE_KIND]),
9090
] = Field(
91-
default=0.4,
92-
description="Confidence threshold for predictions.",
91+
default=None,
92+
description=(
93+
"Confidence threshold for predictions. If omitted, the inference "
94+
"server uses the model's F1-optimal threshold from model evaluation "
95+
"when available, otherwise falls back to 0.4."
96+
),
9397
examples=[0.3, "$inputs.confidence_threshold"],
9498
)
9599
class_filter: Union[Optional[List[str]], Selector(kind=[LIST_OF_VALUES_KIND])] = (

inference/core/workflows/core_steps/models/roboflow/keypoint_detection/v2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,18 @@ class BlockManifest(WorkflowBlockManifest):
8484
images: Selector(kind=[IMAGE_KIND]) = ImageInputField
8585
model_id: Union[Selector(kind=[ROBOFLOW_MODEL_ID_KIND]), str] = RoboflowModelField
8686
confidence: Union[
87-
FloatZeroToOne,
87+
Optional[FloatZeroToOne],
8888
Selector(kind=[FLOAT_ZERO_TO_ONE_KIND]),
8989
] = Field(
90-
default=0.4,
91-
description="Confidence threshold for predictions.",
90+
default=None,
91+
description=(
92+
"Per-instance confidence threshold for predictions. If omitted, the "
93+
"inference server uses the model's F1-optimal threshold from model "
94+
"evaluation when available, otherwise falls back to 0.4. Note that "
95+
"this filters which detected instances (e.g. people, animals) are "
96+
"returned at all — separately, `keypoint_confidence` filters which "
97+
"individual joints within an accepted instance are marked visible."
98+
),
9299
examples=[0.3, "$inputs.confidence_threshold"],
93100
)
94101
keypoint_confidence: Union[

inference/core/workflows/core_steps/models/roboflow/object_detection/v2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ class BlockManifest(WorkflowBlockManifest):
8282
images: Selector(kind=[IMAGE_KIND]) = ImageInputField
8383
model_id: Union[Selector(kind=[ROBOFLOW_MODEL_ID_KIND]), str] = RoboflowModelField
8484
confidence: Union[
85-
FloatZeroToOne,
85+
Optional[FloatZeroToOne],
8686
Selector(kind=[FLOAT_ZERO_TO_ONE_KIND]),
8787
] = Field(
88-
default=0.4,
89-
description="Confidence threshold for predictions.",
88+
default=None,
89+
description=(
90+
"Confidence threshold for predictions. If omitted, the inference "
91+
"server uses the model's F1-optimal threshold from model evaluation "
92+
"when available, otherwise falls back to 0.4."
93+
),
9094
examples=[0.3, "$inputs.confidence_threshold"],
9195
)
9296
class_filter: Union[Optional[List[str]], Selector(kind=[LIST_OF_VALUES_KIND])] = (

inference_models/inference_models/models/auto_loaders/auto_resolution_cache.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
TaskType,
1818
)
1919
from inference_models.utils.file_system import dump_json, read_json
20-
from inference_models.weights_providers.entities import ModelDependency
20+
from inference_models.weights_providers.entities import (
21+
ModelDependency,
22+
RecommendedParameters,
23+
)
2124

2225

2326
class AutoResolutionCacheEntry(BaseModel):
@@ -30,6 +33,8 @@ class AutoResolutionCacheEntry(BaseModel):
3033
model_dependencies: Optional[List[ModelDependency]] = Field(default=None)
3134
created_at: datetime
3235
model_features: Optional[dict] = Field(default=None)
36+
# Cached so auto-load cache hits don't need to re-fetch model metadata.
37+
recommended_parameters: Optional[RecommendedParameters] = Field(default=None)
3338

3439

3540
class AutoResolutionCache(ABC):

inference_models/inference_models/models/auto_loaders/core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
ModelDependency,
8282
ModelPackageMetadata,
8383
Quantization,
84+
RecommendedParameters,
8485
)
8586

8687
MODEL_TYPES_TO_LOAD_FROM_CHECKPOINT = {
@@ -926,6 +927,7 @@ def model_directory_pointer(model_dir: str) -> None:
926927
model_dependencies=model_metadata.model_dependencies,
927928
model_dependencies_instances=model_dependencies_instances,
928929
model_dependencies_directories=model_dependencies_directories,
930+
recommended_parameters=model_metadata.recommended_parameters,
929931
max_package_loading_attempts=max_package_loading_attempts,
930932
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
931933
verify_hash_while_download=verify_hash_while_download,
@@ -1081,6 +1083,11 @@ def attempt_loading_model_with_auto_load_cache(
10811083
model = model_class.from_pretrained(
10821084
model_package_cache_dir, **model_init_kwargs
10831085
)
1086+
# See initialize_model() for the hasattr-gated injection rationale.
1087+
if cache_entry.recommended_parameters is not None and hasattr(
1088+
type(model), "recommended_parameters"
1089+
):
1090+
model.recommended_parameters = cache_entry.recommended_parameters
10841091
verbose_info(
10851092
message=f"Successfully loaded model {model_name_or_path} using auto-loading cache.",
10861093
verbose_requested=verbose,
@@ -1113,6 +1120,7 @@ def attempt_loading_matching_model_packages(
11131120
model_dependencies: Optional[List[ModelDependency]],
11141121
model_dependencies_instances: Dict[str, AnyModel],
11151122
model_dependencies_directories: Dict[str, str],
1123+
recommended_parameters: Optional[RecommendedParameters] = None,
11161124
max_package_loading_attempts: Optional[int] = None,
11171125
model_download_file_lock_acquire_timeout: int = FILE_LOCK_ACQUIRE_TIMEOUT,
11181126
verbose: bool = True,
@@ -1153,6 +1161,7 @@ def attempt_loading_matching_model_packages(
11531161
model_dependencies=model_dependencies,
11541162
model_dependencies_instances=model_dependencies_instances,
11551163
model_dependencies_directories=model_dependencies_directories,
1164+
recommended_parameters=recommended_parameters,
11561165
verify_hash_while_download=verify_hash_while_download,
11571166
download_files_without_hash=download_files_without_hash,
11581167
on_file_created=partial(
@@ -1218,6 +1227,7 @@ def initialize_model(
12181227
model_dependencies: Optional[List[ModelDependency]],
12191228
model_dependencies_instances: Dict[str, AnyModel],
12201229
model_dependencies_directories: Dict[str, str],
1230+
recommended_parameters: Optional[RecommendedParameters] = None,
12211231
model_download_file_lock_acquire_timeout: int = FILE_LOCK_ACQUIRE_TIMEOUT,
12221232
verify_hash_while_download: bool = True,
12231233
download_files_without_hash: bool = False,
@@ -1308,6 +1318,14 @@ def initialize_model(
13081318
resolved_files.update(dependencies_resolved_files)
13091319
model_init_kwargs[MODEL_DEPENDENCIES_KEY] = model_dependencies_instances
13101320
model = model_class.from_pretrained(model_package_cache_dir, **model_init_kwargs)
1321+
# Inject recommended parameters onto model classes that opt in by declaring
1322+
# `recommended_parameters` at the class level (default = None). hasattr on
1323+
# `type(model)` checks the class, not instance state — so model types that
1324+
# don't care (single-label classification, embeddings, etc.) silently no-op.
1325+
if recommended_parameters is not None and hasattr(
1326+
type(model), "recommended_parameters"
1327+
):
1328+
model.recommended_parameters = recommended_parameters
13111329
dump_auto_resolution_cache(
13121330
use_auto_resolution_cache=use_auto_resolution_cache,
13131331
auto_resolution_cache=auto_resolution_cache,
@@ -1320,6 +1338,7 @@ def initialize_model(
13201338
resolved_files=resolved_files,
13211339
model_dependencies=model_dependencies,
13221340
model_features=model_package.model_features,
1341+
recommended_parameters=recommended_parameters,
13231342
)
13241343
return model, model_package_cache_dir
13251344

@@ -1484,6 +1503,7 @@ def dump_auto_resolution_cache(
14841503
resolved_files: Set[str],
14851504
model_dependencies: Optional[List[ModelDependency]],
14861505
model_features: Optional[dict],
1506+
recommended_parameters: Optional[RecommendedParameters] = None,
14871507
) -> None:
14881508
if not use_auto_resolution_cache:
14891509
return None
@@ -1497,6 +1517,7 @@ def dump_auto_resolution_cache(
14971517
created_at=datetime.now(),
14981518
model_dependencies=model_dependencies,
14991519
model_features=model_features,
1520+
recommended_parameters=recommended_parameters,
15001521
)
15011522
auto_resolution_cache.register(
15021523
auto_negotiation_hash=auto_negotiation_hash, cache_entry=cache_content

inference_models/inference_models/models/auto_loaders/entities.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass, field
2-
from enum import Enum
32
from typing import Optional, Union
43

4+
from inference_models.models.auto_loaders.types import BackendType
55
from inference_models.models.base.classification import (
66
ClassificationModel,
77
MultiLabelClassificationModel,
@@ -22,17 +22,6 @@
2222
MODEL_CONFIG_FILE_NAME = "model_config.json"
2323

2424

25-
class BackendType(str, Enum):
26-
TORCH = "torch"
27-
TORCH_SCRIPT = "torch-script"
28-
ONNX = "onnx"
29-
TRT = "trt"
30-
HF = "hugging-face"
31-
ULTRALYTICS = "ultralytics"
32-
MEDIAPIPE = "mediapipe"
33-
CUSTOM = "custom"
34-
35-
3625
AnyModel = Union[
3726
ClassificationModel,
3827
MultiLabelClassificationModel,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
Leaf-level auto-loader types with no model-class dependencies. Split out from
3+
`auto_loaders/entities.py` so that `weights_providers/entities.py` can import
4+
`BackendType` without pulling in the model base class tree (which would cause
5+
a cycle — base classes depend on `weights_providers.entities.RecommendedParameters`).
6+
"""
7+
8+
from enum import Enum
9+
10+
11+
class BackendType(str, Enum):
12+
TORCH = "torch"
13+
TORCH_SCRIPT = "torch-script"
14+
ONNX = "onnx"
15+
TRT = "trt"
16+
HF = "hugging-face"
17+
ULTRALYTICS = "ultralytics"
18+
MEDIAPIPE = "mediapipe"
19+
CUSTOM = "custom"

0 commit comments

Comments
 (0)