7272 get_roboflow_instant_model_data ,
7373 get_roboflow_model_data ,
7474)
75+ from inference .core .telemetry import set_span_attribute , start_span
7576from inference .core .utils .image_utils import load_image
7677from inference .core .utils .onnx import get_onnxruntime_execution_providers
7778from inference .core .utils .preprocess import letterbox_image , prepare
@@ -282,21 +283,25 @@ def get_all_required_infer_bucket_file(self) -> List[str]:
282283 return [f for f in infer_bucket_files if f is not None ]
283284
284285 def download_model_artefacts_from_s3 (self ) -> None :
285- try :
286- logger .debug ("Downloading model artifacts from S3" )
287- infer_bucket_files = self .get_all_required_infer_bucket_file ()
288- cache_directory = get_cache_dir ()
289- s3_keys = [f"{ self .endpoint } /{ file } " for file in infer_bucket_files ]
290- download_s3_files_to_directory (
291- bucket = self .model_artifact_bucket ,
292- keys = s3_keys ,
293- target_dir = cache_directory ,
294- s3_client = S3_CLIENT ,
295- )
296- except Exception as error :
297- raise ModelArtefactError (
298- f"Could not obtain model artefacts from S3 with keys { s3_keys } . Cause: { error } "
299- ) from error
286+ with start_span (
287+ "model.artifacts.download" ,
288+ {"model.id" : self .endpoint , "model.artifacts.source" : "s3" },
289+ ):
290+ try :
291+ logger .debug ("Downloading model artifacts from S3" )
292+ infer_bucket_files = self .get_all_required_infer_bucket_file ()
293+ cache_directory = get_cache_dir ()
294+ s3_keys = [f"{ self .endpoint } /{ file } " for file in infer_bucket_files ]
295+ download_s3_files_to_directory (
296+ bucket = self .model_artifact_bucket ,
297+ keys = s3_keys ,
298+ target_dir = cache_directory ,
299+ s3_client = S3_CLIENT ,
300+ )
301+ except Exception as error :
302+ raise ModelArtefactError (
303+ f"Could not obtain model artefacts from S3 with keys { s3_keys } . Cause: { error } "
304+ ) from error
300305
301306 @property
302307 def model_artifact_bucket (self ):
@@ -309,174 +314,185 @@ def download_model_artifacts_from_roboflow_api(
309314 ** kwargs ,
310315 ) -> None :
311316 logger .debug ("Downloading model artifacts from Roboflow API" )
312-
313- # Use the same lock file pattern as in clear_cache
314- lock_dir = MODEL_CACHE_DIR + "/_file_locks" # Dedicated lock directory
315- os .makedirs (lock_dir , exist_ok = True ) # Ensure lock directory exists.
316- lock_file = os .path .join (lock_dir , f"{ os .path .basename (self .cache_dir )} .lock" )
317- try :
318- lock = FileLock (lock_file , timeout = 120 ) # 120 second timeout for downloads
319- with lock :
320- if self .version_id is not None :
321- api_data = get_roboflow_model_data (
322- api_key = self .api_key ,
323- model_id = self .endpoint ,
324- endpoint_type = ModelEndpointType .ORT ,
325- device_id = self .device_id ,
326- countinference = countinference ,
327- service_secret = service_secret ,
328- )
329- if "ort" not in api_data .keys ():
330- raise ModelArtefactError (
331- "Could not find `ort` key in roboflow API model description response."
332- )
333- api_data = api_data ["ort" ]
334- if "classes" in api_data :
335- save_text_lines_in_cache (
336- content = api_data ["classes" ],
337- file = "class_names.txt" ,
317+ with start_span (
318+ "model.artifacts.download" ,
319+ {"model.id" : self .endpoint , "model.artifacts.source" : "roboflow_api" },
320+ ):
321+ # Use the same lock file pattern as in clear_cache
322+ lock_dir = MODEL_CACHE_DIR + "/_file_locks" # Dedicated lock directory
323+ os .makedirs (lock_dir , exist_ok = True ) # Ensure lock directory exists.
324+ lock_file = os .path .join (
325+ lock_dir , f"{ os .path .basename (self .cache_dir )} .lock"
326+ )
327+ try :
328+ lock = FileLock (
329+ lock_file , timeout = 120
330+ ) # 120 second timeout for downloads
331+ with lock :
332+ if self .version_id is not None :
333+ api_data = get_roboflow_model_data (
334+ api_key = self .api_key ,
338335 model_id = self .endpoint ,
336+ endpoint_type = ModelEndpointType .ORT ,
337+ device_id = self .device_id ,
338+ countinference = countinference ,
339+ service_secret = service_secret ,
340+ )
341+ if "ort" not in api_data .keys ():
342+ raise ModelArtefactError (
343+ "Could not find `ort` key in roboflow API model description response."
344+ )
345+ api_data = api_data ["ort" ]
346+ if "classes" in api_data :
347+ save_text_lines_in_cache (
348+ content = api_data ["classes" ],
349+ file = "class_names.txt" ,
350+ model_id = self .endpoint ,
351+ )
352+ if "model" not in api_data :
353+ raise ModelArtefactError (
354+ "Could not find `model` key in roboflow API model description response."
355+ )
356+ if "environment" not in api_data :
357+ raise ModelArtefactError (
358+ "Could not find `environment` key in roboflow API model description response."
359+ )
360+ environment = get_from_url (api_data ["environment" ])
361+ model_weights_response = get_from_url (
362+ api_data ["model" ],
363+ json_response = False ,
339364 )
340- if "model" not in api_data :
341- raise ModelArtefactError (
342- "Could not find `model` key in roboflow API model description response."
365+ else :
366+ api_data = get_roboflow_instant_model_data (
367+ api_key = self .api_key ,
368+ model_id = self .endpoint ,
369+ countinference = countinference ,
370+ service_secret = service_secret ,
343371 )
344- if "environment" not in api_data :
345- raise ModelArtefactError (
346- "Could not find `environment` key in roboflow API model description response."
372+ if (
373+ "modelFiles" not in api_data
374+ or "ort" not in api_data ["modelFiles" ]
375+ or "model" not in api_data ["modelFiles" ]["ort" ]
376+ ):
377+ raise ModelArtefactError (
378+ "Could not find `modelFiles` key or `modelFiles`.`ort` or `modelFiles`.`ort`.`model` key in roboflow API model description response."
379+ )
380+ if "environment" not in api_data :
381+ raise ModelArtefactError (
382+ "Could not find `environment` key in roboflow API model description response."
383+ )
384+ model_weights_response = get_from_url (
385+ api_data ["modelFiles" ]["ort" ]["model" ],
386+ json_response = False ,
347387 )
348- environment = get_from_url (api_data ["environment" ])
349- model_weights_response = get_from_url (
350- api_data ["model" ],
351- json_response = False ,
352- )
353- else :
354- api_data = get_roboflow_instant_model_data (
355- api_key = self .api_key ,
388+ environment = api_data ["environment" ]
389+ if "classes" in api_data :
390+ save_text_lines_in_cache (
391+ content = api_data ["classes" ],
392+ file = "class_names.txt" ,
393+ model_id = self .endpoint ,
394+ )
395+
396+ save_bytes_in_cache (
397+ content = model_weights_response .content ,
398+ file = self .weights_file ,
356399 model_id = self .endpoint ,
357- countinference = countinference ,
358- service_secret = service_secret ,
359400 )
360- if (
361- "modelFiles" not in api_data
362- or "ort" not in api_data ["modelFiles" ]
363- or "model" not in api_data ["modelFiles" ]["ort" ]
364- ):
365- raise ModelArtefactError (
366- "Could not find `modelFiles` key or `modelFiles`.`ort` or `modelFiles`.`ort`.`model` key in roboflow API model description response."
367- )
368- if "environment" not in api_data :
369- raise ModelArtefactError (
370- "Could not find `environment` key in roboflow API model description response."
371- )
372- model_weights_response = get_from_url (
373- api_data ["modelFiles" ]["ort" ]["model" ],
374- json_response = False ,
401+ if "colors" in api_data :
402+ environment ["COLORS" ] = api_data ["colors" ]
403+ save_json_in_cache (
404+ content = environment ,
405+ file = "environment.json" ,
406+ model_id = self .endpoint ,
375407 )
376- environment = api_data [ "environment" ]
377- if "classes" in api_data :
378- save_text_lines_in_cache (
379- content = api_data ["classes " ],
380- file = "class_names.txt " ,
408+ if "keypoints_metadata" in api_data :
409+ # TODO: make sure backend provides that
410+ save_json_in_cache (
411+ content = api_data ["keypoints_metadata " ],
412+ file = "keypoints_metadata.json " ,
381413 model_id = self .endpoint ,
382414 )
415+ except Exception as e :
416+ logger .error (f"Error downloading model artifacts: { e } " )
417+ raise
383418
384- save_bytes_in_cache (
385- content = model_weights_response .content ,
386- file = self .weights_file ,
419+ def load_model_artifacts_from_cache (self ) -> None :
420+ logger .debug ("Model artifacts already downloaded, loading model from cache" )
421+ with start_span (
422+ "model.artifacts.load" ,
423+ {"model.id" : self .endpoint , "model.artifacts.source" : "local_cache" },
424+ ):
425+ infer_bucket_files = self .get_all_required_infer_bucket_file ()
426+ if "environment.json" in infer_bucket_files :
427+ self .environment = load_json_from_cache (
428+ file = "environment.json" ,
387429 model_id = self .endpoint ,
430+ object_pairs_hook = OrderedDict ,
388431 )
389- if "colors" in api_data :
390- environment ["COLORS" ] = api_data ["colors" ]
391- save_json_in_cache (
392- content = environment ,
393- file = "environment.json" ,
432+ if "class_names.txt" in infer_bucket_files :
433+ self .class_names = load_text_file_from_cache (
434+ file = "class_names.txt" ,
394435 model_id = self .endpoint ,
436+ split_lines = True ,
437+ strip_white_chars = True ,
395438 )
396- if "keypoints_metadata" in api_data :
397- # TODO: make sure backend provides that
398- save_json_in_cache (
399- content = api_data ["keypoints_metadata" ],
439+ else :
440+ self .class_names = get_class_names_from_environment_file (
441+ environment = self .environment
442+ )
443+ self .colors = get_color_mapping_from_environment (
444+ environment = self .environment ,
445+ class_names = self .class_names ,
446+ )
447+ if "keypoints_metadata.json" in infer_bucket_files :
448+ self .keypoints_metadata = parse_keypoints_metadata (
449+ load_json_from_cache (
400450 file = "keypoints_metadata.json" ,
401451 model_id = self .endpoint ,
452+ object_pairs_hook = OrderedDict ,
402453 )
403- except Exception as e :
404- logger .error (f"Error downloading model artifacts: { e } " )
405- raise
406-
407- def load_model_artifacts_from_cache (self ) -> None :
408- logger .debug ("Model artifacts already downloaded, loading model from cache" )
409- infer_bucket_files = self .get_all_required_infer_bucket_file ()
410- if "environment.json" in infer_bucket_files :
411- self .environment = load_json_from_cache (
412- file = "environment.json" ,
413- model_id = self .endpoint ,
414- object_pairs_hook = OrderedDict ,
415- )
416- if "class_names.txt" in infer_bucket_files :
417- self .class_names = load_text_file_from_cache (
418- file = "class_names.txt" ,
419- model_id = self .endpoint ,
420- split_lines = True ,
421- strip_white_chars = True ,
422- )
423- else :
424- self .class_names = get_class_names_from_environment_file (
425- environment = self .environment
426- )
427- self .colors = get_color_mapping_from_environment (
428- environment = self .environment ,
429- class_names = self .class_names ,
430- )
431- if "keypoints_metadata.json" in infer_bucket_files :
432- self .keypoints_metadata = parse_keypoints_metadata (
433- load_json_from_cache (
434- file = "keypoints_metadata.json" ,
435- model_id = self .endpoint ,
436- object_pairs_hook = OrderedDict ,
437454 )
438- )
439- self .num_classes = len (self .class_names )
440- if "PREPROCESSING" not in self .environment :
441- raise ModelArtefactError (
442- "Could not find `PREPROCESSING` key in environment file."
443- )
444- if issubclass (type (self .environment ["PREPROCESSING" ]), dict ):
445- self .preproc = self .environment ["PREPROCESSING" ]
446- else :
447- self .preproc = json .loads (self .environment ["PREPROCESSING" ])
448- if self .preproc .get ("resize" ):
449- self .resize_method = self .preproc ["resize" ].get ("format" , "Stretch to" )
450- if self .resize_method in [
451- "Fit (reflect edges) in" ,
452- "Fit within" ,
453- "Fill (with center crop) in" ,
454- ]:
455- fallback_resize_method = "Fit (black edges) in"
456- logger .warning (
457- "Unsupported resize method '%s', defaulting to '%s' - this may result in degraded model performance." ,
458- self .resize_method ,
459- fallback_resize_method ,
455+ self .num_classes = len (self .class_names )
456+ if "PREPROCESSING" not in self .environment :
457+ raise ModelArtefactError (
458+ "Could not find `PREPROCESSING` key in environment file."
460459 )
461- self .resize_method = fallback_resize_method
462- if self .resize_method not in [
463- "Stretch to" ,
464- "Fit (black edges) in" ,
465- "Fit (grey edges) in" ,
466- "Fit (white edges) in" ,
467- ]:
460+ if issubclass (type (self .environment ["PREPROCESSING" ]), dict ):
461+ self .preproc = self .environment ["PREPROCESSING" ]
462+ else :
463+ self .preproc = json .loads (self .environment ["PREPROCESSING" ])
464+ if self .preproc .get ("resize" ):
465+ self .resize_method = self .preproc ["resize" ].get ("format" , "Stretch to" )
466+ if self .resize_method in [
467+ "Fit (reflect edges) in" ,
468+ "Fit within" ,
469+ "Fill (with center crop) in" ,
470+ ]:
471+ fallback_resize_method = "Fit (black edges) in"
472+ logger .warning (
473+ "Unsupported resize method '%s', defaulting to '%s' - this may result in degraded model performance." ,
474+ self .resize_method ,
475+ fallback_resize_method ,
476+ )
477+ self .resize_method = fallback_resize_method
478+ if self .resize_method not in [
479+ "Stretch to" ,
480+ "Fit (black edges) in" ,
481+ "Fit (grey edges) in" ,
482+ "Fit (white edges) in" ,
483+ ]:
484+ logger .error (
485+ "Unsupported resize method '%s', defaulting to 'Stretch to' - this may result in degraded model performance." ,
486+ self .resize_method ,
487+ )
488+ self .resize_method = "Stretch to"
489+ else :
468490 logger .error (
469- "Unsupported resize method '%s', defaulting to 'Stretch to' - this may result in degraded model performance." ,
470- self .resize_method ,
491+ "Unknown resize method, defaulting to 'Stretch to' - this may result in degraded model performance."
471492 )
472493 self .resize_method = "Stretch to"
473- else :
474- logger .error (
475- "Unknown resize method, defaulting to 'Stretch to' - this may result in degraded model performance."
476- )
477- self .resize_method = "Stretch to"
478- logger .debug (f"Resize method is '{ self .resize_method } '" )
479- self .multiclass = self .environment .get ("MULTICLASS" , False )
494+ logger .debug (f"Resize method is '{ self .resize_method } '" )
495+ self .multiclass = self .environment .get ("MULTICLASS" , False )
480496
481497 def initialize_model (self , ** kwargs ) -> None :
482498 """Initialize the model.
0 commit comments