Skip to content

Commit 4118592

Browse files
authored
Add OTel spans to model artifact loading (#2244)
* Add OTel spans to model artifact loading to distinguish cache vs remote source Adds child spans under the existing model.load span: - model.artifacts.load (source=local_cache) for cache hits - model.artifacts.download (source=s3 or roboflow_api) for remote downloads Each span carries model.id and model.artifacts.source attributes.
1 parent 93b0eee commit 4118592

1 file changed

Lines changed: 177 additions & 161 deletions

File tree

inference/core/models/roboflow.py

Lines changed: 177 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
get_roboflow_instant_model_data,
7373
get_roboflow_model_data,
7474
)
75+
from inference.core.telemetry import set_span_attribute, start_span
7576
from inference.core.utils.image_utils import load_image
7677
from inference.core.utils.onnx import get_onnxruntime_execution_providers
7778
from inference.core.utils.preprocess import letterbox_image, prepare
@@ -282,21 +283,25 @@ def get_all_required_infer_bucket_file(self) -> List[str]:
282283
return [f for f in infer_bucket_files if f is not None]
283284

284285
def download_model_artefacts_from_s3(self) -> None:
285-
try:
286-
logger.debug("Downloading model artifacts from S3")
287-
infer_bucket_files = self.get_all_required_infer_bucket_file()
288-
cache_directory = get_cache_dir()
289-
s3_keys = [f"{self.endpoint}/{file}" for file in infer_bucket_files]
290-
download_s3_files_to_directory(
291-
bucket=self.model_artifact_bucket,
292-
keys=s3_keys,
293-
target_dir=cache_directory,
294-
s3_client=S3_CLIENT,
295-
)
296-
except Exception as error:
297-
raise ModelArtefactError(
298-
f"Could not obtain model artefacts from S3 with keys {s3_keys}. Cause: {error}"
299-
) from error
286+
with start_span(
287+
"model.artifacts.download",
288+
{"model.id": self.endpoint, "model.artifacts.source": "s3"},
289+
):
290+
try:
291+
logger.debug("Downloading model artifacts from S3")
292+
infer_bucket_files = self.get_all_required_infer_bucket_file()
293+
cache_directory = get_cache_dir()
294+
s3_keys = [f"{self.endpoint}/{file}" for file in infer_bucket_files]
295+
download_s3_files_to_directory(
296+
bucket=self.model_artifact_bucket,
297+
keys=s3_keys,
298+
target_dir=cache_directory,
299+
s3_client=S3_CLIENT,
300+
)
301+
except Exception as error:
302+
raise ModelArtefactError(
303+
f"Could not obtain model artefacts from S3 with keys {s3_keys}. Cause: {error}"
304+
) from error
300305

301306
@property
302307
def model_artifact_bucket(self):
@@ -309,174 +314,185 @@ def download_model_artifacts_from_roboflow_api(
309314
**kwargs,
310315
) -> None:
311316
logger.debug("Downloading model artifacts from Roboflow API")
312-
313-
# Use the same lock file pattern as in clear_cache
314-
lock_dir = MODEL_CACHE_DIR + "/_file_locks" # Dedicated lock directory
315-
os.makedirs(lock_dir, exist_ok=True) # Ensure lock directory exists.
316-
lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock")
317-
try:
318-
lock = FileLock(lock_file, timeout=120) # 120 second timeout for downloads
319-
with lock:
320-
if self.version_id is not None:
321-
api_data = get_roboflow_model_data(
322-
api_key=self.api_key,
323-
model_id=self.endpoint,
324-
endpoint_type=ModelEndpointType.ORT,
325-
device_id=self.device_id,
326-
countinference=countinference,
327-
service_secret=service_secret,
328-
)
329-
if "ort" not in api_data.keys():
330-
raise ModelArtefactError(
331-
"Could not find `ort` key in roboflow API model description response."
332-
)
333-
api_data = api_data["ort"]
334-
if "classes" in api_data:
335-
save_text_lines_in_cache(
336-
content=api_data["classes"],
337-
file="class_names.txt",
317+
with start_span(
318+
"model.artifacts.download",
319+
{"model.id": self.endpoint, "model.artifacts.source": "roboflow_api"},
320+
):
321+
# Use the same lock file pattern as in clear_cache
322+
lock_dir = MODEL_CACHE_DIR + "/_file_locks" # Dedicated lock directory
323+
os.makedirs(lock_dir, exist_ok=True) # Ensure lock directory exists.
324+
lock_file = os.path.join(
325+
lock_dir, f"{os.path.basename(self.cache_dir)}.lock"
326+
)
327+
try:
328+
lock = FileLock(
329+
lock_file, timeout=120
330+
) # 120 second timeout for downloads
331+
with lock:
332+
if self.version_id is not None:
333+
api_data = get_roboflow_model_data(
334+
api_key=self.api_key,
338335
model_id=self.endpoint,
336+
endpoint_type=ModelEndpointType.ORT,
337+
device_id=self.device_id,
338+
countinference=countinference,
339+
service_secret=service_secret,
340+
)
341+
if "ort" not in api_data.keys():
342+
raise ModelArtefactError(
343+
"Could not find `ort` key in roboflow API model description response."
344+
)
345+
api_data = api_data["ort"]
346+
if "classes" in api_data:
347+
save_text_lines_in_cache(
348+
content=api_data["classes"],
349+
file="class_names.txt",
350+
model_id=self.endpoint,
351+
)
352+
if "model" not in api_data:
353+
raise ModelArtefactError(
354+
"Could not find `model` key in roboflow API model description response."
355+
)
356+
if "environment" not in api_data:
357+
raise ModelArtefactError(
358+
"Could not find `environment` key in roboflow API model description response."
359+
)
360+
environment = get_from_url(api_data["environment"])
361+
model_weights_response = get_from_url(
362+
api_data["model"],
363+
json_response=False,
339364
)
340-
if "model" not in api_data:
341-
raise ModelArtefactError(
342-
"Could not find `model` key in roboflow API model description response."
365+
else:
366+
api_data = get_roboflow_instant_model_data(
367+
api_key=self.api_key,
368+
model_id=self.endpoint,
369+
countinference=countinference,
370+
service_secret=service_secret,
343371
)
344-
if "environment" not in api_data:
345-
raise ModelArtefactError(
346-
"Could not find `environment` key in roboflow API model description response."
372+
if (
373+
"modelFiles" not in api_data
374+
or "ort" not in api_data["modelFiles"]
375+
or "model" not in api_data["modelFiles"]["ort"]
376+
):
377+
raise ModelArtefactError(
378+
"Could not find `modelFiles` key or `modelFiles`.`ort` or `modelFiles`.`ort`.`model` key in roboflow API model description response."
379+
)
380+
if "environment" not in api_data:
381+
raise ModelArtefactError(
382+
"Could not find `environment` key in roboflow API model description response."
383+
)
384+
model_weights_response = get_from_url(
385+
api_data["modelFiles"]["ort"]["model"],
386+
json_response=False,
347387
)
348-
environment = get_from_url(api_data["environment"])
349-
model_weights_response = get_from_url(
350-
api_data["model"],
351-
json_response=False,
352-
)
353-
else:
354-
api_data = get_roboflow_instant_model_data(
355-
api_key=self.api_key,
388+
environment = api_data["environment"]
389+
if "classes" in api_data:
390+
save_text_lines_in_cache(
391+
content=api_data["classes"],
392+
file="class_names.txt",
393+
model_id=self.endpoint,
394+
)
395+
396+
save_bytes_in_cache(
397+
content=model_weights_response.content,
398+
file=self.weights_file,
356399
model_id=self.endpoint,
357-
countinference=countinference,
358-
service_secret=service_secret,
359400
)
360-
if (
361-
"modelFiles" not in api_data
362-
or "ort" not in api_data["modelFiles"]
363-
or "model" not in api_data["modelFiles"]["ort"]
364-
):
365-
raise ModelArtefactError(
366-
"Could not find `modelFiles` key or `modelFiles`.`ort` or `modelFiles`.`ort`.`model` key in roboflow API model description response."
367-
)
368-
if "environment" not in api_data:
369-
raise ModelArtefactError(
370-
"Could not find `environment` key in roboflow API model description response."
371-
)
372-
model_weights_response = get_from_url(
373-
api_data["modelFiles"]["ort"]["model"],
374-
json_response=False,
401+
if "colors" in api_data:
402+
environment["COLORS"] = api_data["colors"]
403+
save_json_in_cache(
404+
content=environment,
405+
file="environment.json",
406+
model_id=self.endpoint,
375407
)
376-
environment = api_data["environment"]
377-
if "classes" in api_data:
378-
save_text_lines_in_cache(
379-
content=api_data["classes"],
380-
file="class_names.txt",
408+
if "keypoints_metadata" in api_data:
409+
# TODO: make sure backend provides that
410+
save_json_in_cache(
411+
content=api_data["keypoints_metadata"],
412+
file="keypoints_metadata.json",
381413
model_id=self.endpoint,
382414
)
415+
except Exception as e:
416+
logger.error(f"Error downloading model artifacts: {e}")
417+
raise
383418

384-
save_bytes_in_cache(
385-
content=model_weights_response.content,
386-
file=self.weights_file,
419+
def load_model_artifacts_from_cache(self) -> None:
420+
logger.debug("Model artifacts already downloaded, loading model from cache")
421+
with start_span(
422+
"model.artifacts.load",
423+
{"model.id": self.endpoint, "model.artifacts.source": "local_cache"},
424+
):
425+
infer_bucket_files = self.get_all_required_infer_bucket_file()
426+
if "environment.json" in infer_bucket_files:
427+
self.environment = load_json_from_cache(
428+
file="environment.json",
387429
model_id=self.endpoint,
430+
object_pairs_hook=OrderedDict,
388431
)
389-
if "colors" in api_data:
390-
environment["COLORS"] = api_data["colors"]
391-
save_json_in_cache(
392-
content=environment,
393-
file="environment.json",
432+
if "class_names.txt" in infer_bucket_files:
433+
self.class_names = load_text_file_from_cache(
434+
file="class_names.txt",
394435
model_id=self.endpoint,
436+
split_lines=True,
437+
strip_white_chars=True,
395438
)
396-
if "keypoints_metadata" in api_data:
397-
# TODO: make sure backend provides that
398-
save_json_in_cache(
399-
content=api_data["keypoints_metadata"],
439+
else:
440+
self.class_names = get_class_names_from_environment_file(
441+
environment=self.environment
442+
)
443+
self.colors = get_color_mapping_from_environment(
444+
environment=self.environment,
445+
class_names=self.class_names,
446+
)
447+
if "keypoints_metadata.json" in infer_bucket_files:
448+
self.keypoints_metadata = parse_keypoints_metadata(
449+
load_json_from_cache(
400450
file="keypoints_metadata.json",
401451
model_id=self.endpoint,
452+
object_pairs_hook=OrderedDict,
402453
)
403-
except Exception as e:
404-
logger.error(f"Error downloading model artifacts: {e}")
405-
raise
406-
407-
def load_model_artifacts_from_cache(self) -> None:
408-
logger.debug("Model artifacts already downloaded, loading model from cache")
409-
infer_bucket_files = self.get_all_required_infer_bucket_file()
410-
if "environment.json" in infer_bucket_files:
411-
self.environment = load_json_from_cache(
412-
file="environment.json",
413-
model_id=self.endpoint,
414-
object_pairs_hook=OrderedDict,
415-
)
416-
if "class_names.txt" in infer_bucket_files:
417-
self.class_names = load_text_file_from_cache(
418-
file="class_names.txt",
419-
model_id=self.endpoint,
420-
split_lines=True,
421-
strip_white_chars=True,
422-
)
423-
else:
424-
self.class_names = get_class_names_from_environment_file(
425-
environment=self.environment
426-
)
427-
self.colors = get_color_mapping_from_environment(
428-
environment=self.environment,
429-
class_names=self.class_names,
430-
)
431-
if "keypoints_metadata.json" in infer_bucket_files:
432-
self.keypoints_metadata = parse_keypoints_metadata(
433-
load_json_from_cache(
434-
file="keypoints_metadata.json",
435-
model_id=self.endpoint,
436-
object_pairs_hook=OrderedDict,
437454
)
438-
)
439-
self.num_classes = len(self.class_names)
440-
if "PREPROCESSING" not in self.environment:
441-
raise ModelArtefactError(
442-
"Could not find `PREPROCESSING` key in environment file."
443-
)
444-
if issubclass(type(self.environment["PREPROCESSING"]), dict):
445-
self.preproc = self.environment["PREPROCESSING"]
446-
else:
447-
self.preproc = json.loads(self.environment["PREPROCESSING"])
448-
if self.preproc.get("resize"):
449-
self.resize_method = self.preproc["resize"].get("format", "Stretch to")
450-
if self.resize_method in [
451-
"Fit (reflect edges) in",
452-
"Fit within",
453-
"Fill (with center crop) in",
454-
]:
455-
fallback_resize_method = "Fit (black edges) in"
456-
logger.warning(
457-
"Unsupported resize method '%s', defaulting to '%s' - this may result in degraded model performance.",
458-
self.resize_method,
459-
fallback_resize_method,
455+
self.num_classes = len(self.class_names)
456+
if "PREPROCESSING" not in self.environment:
457+
raise ModelArtefactError(
458+
"Could not find `PREPROCESSING` key in environment file."
460459
)
461-
self.resize_method = fallback_resize_method
462-
if self.resize_method not in [
463-
"Stretch to",
464-
"Fit (black edges) in",
465-
"Fit (grey edges) in",
466-
"Fit (white edges) in",
467-
]:
460+
if issubclass(type(self.environment["PREPROCESSING"]), dict):
461+
self.preproc = self.environment["PREPROCESSING"]
462+
else:
463+
self.preproc = json.loads(self.environment["PREPROCESSING"])
464+
if self.preproc.get("resize"):
465+
self.resize_method = self.preproc["resize"].get("format", "Stretch to")
466+
if self.resize_method in [
467+
"Fit (reflect edges) in",
468+
"Fit within",
469+
"Fill (with center crop) in",
470+
]:
471+
fallback_resize_method = "Fit (black edges) in"
472+
logger.warning(
473+
"Unsupported resize method '%s', defaulting to '%s' - this may result in degraded model performance.",
474+
self.resize_method,
475+
fallback_resize_method,
476+
)
477+
self.resize_method = fallback_resize_method
478+
if self.resize_method not in [
479+
"Stretch to",
480+
"Fit (black edges) in",
481+
"Fit (grey edges) in",
482+
"Fit (white edges) in",
483+
]:
484+
logger.error(
485+
"Unsupported resize method '%s', defaulting to 'Stretch to' - this may result in degraded model performance.",
486+
self.resize_method,
487+
)
488+
self.resize_method = "Stretch to"
489+
else:
468490
logger.error(
469-
"Unsupported resize method '%s', defaulting to 'Stretch to' - this may result in degraded model performance.",
470-
self.resize_method,
491+
"Unknown resize method, defaulting to 'Stretch to' - this may result in degraded model performance."
471492
)
472493
self.resize_method = "Stretch to"
473-
else:
474-
logger.error(
475-
"Unknown resize method, defaulting to 'Stretch to' - this may result in degraded model performance."
476-
)
477-
self.resize_method = "Stretch to"
478-
logger.debug(f"Resize method is '{self.resize_method}'")
479-
self.multiclass = self.environment.get("MULTICLASS", False)
494+
logger.debug(f"Resize method is '{self.resize_method}'")
495+
self.multiclass = self.environment.get("MULTICLASS", False)
480496

481497
def initialize_model(self, **kwargs) -> None:
482498
"""Initialize the model.

0 commit comments

Comments
 (0)