Skip to content

Commit eb81f4a

Browse files
committed
Update codes for some Copilot comments
1 parent 547672d commit eb81f4a

File tree

1 file changed

+125
-52
lines changed

1 file changed

+125
-52
lines changed

lisa/ai/log_agent.py

Lines changed: 125 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,22 @@ def _clean_json_markers(text: str) -> str:
351351
return text.strip()
352352

353353

354+
def _get_keywords(answer: Union[Dict[str, List[str]], List[str], str]) -> str:
355+
"""Extract keywords from ground truth data."""
356+
if isinstance(answer, dict):
357+
keywords: List[str] = answer.get("problem_keywords", [""])
358+
elif isinstance(answer, list):
359+
keywords = answer
360+
else:
361+
# ground_truth is a string
362+
keywords = [answer]
363+
assert isinstance(keywords, list), f"Expected list, got {type(keywords)}"
364+
# Sort alphabetically and join.
365+
keywords_str = ", ".join(sorted(keywords))
366+
367+
return keywords_str
368+
369+
354370
def _parse_lisa_storage_log_link(log_link: str) -> StorageLogLinkInfo:
355371
"""
356372
Parse a LISA-generated Azure Portal storage container link.
@@ -425,21 +441,122 @@ def _get_storage_credential() -> Any:
425441
"""
426442
Get Azure Storage credential using AAD only.
427443
"""
428-
429-
from azure.identity import DefaultAzureCredential
444+
try:
445+
from azure.identity import DefaultAzureCredential
446+
except ModuleNotFoundError as e:
447+
_raise_missing_azure_dependency_error("azure.identity", e)
430448

431449
logger.info("Using DefaultAzureCredential (AAD) for storage access.")
432450
return DefaultAzureCredential()
433451

434452

453+
def _raise_missing_azure_dependency_error(
454+
module_name: str, error: ModuleNotFoundError
455+
) -> None:
456+
raise ModuleNotFoundError(
457+
"Missing Azure SDK dependencies required for --log-link. "
458+
f"Failed to import '{module_name}'. "
459+
"Install the repo Azure optional dependencies and retry. "
460+
"From the lisa package root, run: pip install -e .[azure]"
461+
) from error
462+
463+
464+
def _to_local_fs_path(path: str) -> str:
465+
absolute_path = os.path.abspath(path)
466+
if os.name == "nt" and not absolute_path.startswith("\\\\?\\"):
467+
return "\\\\?\\" + absolute_path
468+
return absolute_path
469+
470+
471+
def _resolve_safe_local_blob_path(
472+
root_path: str, relative_blob_path: str, source_blob_name: str
473+
) -> str:
474+
# Blob names should be relative and use '/' as separators.
475+
if not relative_blob_path:
476+
raise ValueError(
477+
f"Invalid blob path: empty relative path for '{source_blob_name}'"
478+
)
479+
if relative_blob_path.startswith(("/", "\\")):
480+
raise ValueError(
481+
f"Invalid blob path: absolute path is not allowed: '{source_blob_name}'"
482+
)
483+
if re.match(r"^[a-zA-Z]:", relative_blob_path):
484+
raise ValueError(
485+
f"Invalid blob path: drive letter is not allowed: '{source_blob_name}'"
486+
)
487+
if os.name == "nt" and "\\" in relative_blob_path:
488+
raise ValueError(
489+
f"Invalid blob path: backslash is not allowed on Windows: "
490+
f"'{source_blob_name}'."
491+
)
492+
493+
safe_parts = [
494+
part for part in relative_blob_path.split("/") if part and part != "."
495+
]
496+
if any(part == ".." for part in safe_parts):
497+
raise ValueError(
498+
f"Invalid blob path: parent directory traversal is not allowed: "
499+
f"'{source_blob_name}'."
500+
)
501+
502+
if not safe_parts:
503+
raise ValueError(
504+
f"Invalid blob path: no valid path segments in '{source_blob_name}'."
505+
)
506+
507+
root_abs = os.path.abspath(root_path)
508+
destination_abs = os.path.abspath(os.path.join(root_abs, *safe_parts))
509+
if os.path.commonpath([root_abs, destination_abs]) != root_abs:
510+
raise ValueError(
511+
f"Invalid blob path: resolved outside destination root: "
512+
f"'{source_blob_name}'."
513+
)
514+
515+
return destination_abs
516+
517+
518+
def _download_blobs_to_local(
519+
container_client: Any,
520+
blobs: list,
521+
prefix_with_sep: str,
522+
local_selected_root: str,
523+
) -> int:
524+
downloaded_count = 0
525+
for blob in blobs:
526+
blob_name = blob.name
527+
relative_blob_name = blob_name
528+
if prefix_with_sep and blob_name.startswith(prefix_with_sep):
529+
relative_blob_name = blob_name[len(prefix_with_sep) :]
530+
if not relative_blob_name:
531+
continue
532+
533+
local_blob_path = _resolve_safe_local_blob_path(
534+
root_path=local_selected_root,
535+
relative_blob_path=relative_blob_name,
536+
source_blob_name=blob_name,
537+
)
538+
local_blob_path_fs = _to_local_fs_path(local_blob_path)
539+
local_parent_fs = os.path.dirname(local_blob_path_fs)
540+
os.makedirs(local_parent_fs, exist_ok=True)
541+
542+
with open(local_blob_path_fs, "wb") as output_file:
543+
stream = container_client.download_blob(blob_name)
544+
output_file.write(stream.readall())
545+
downloaded_count += 1
546+
return downloaded_count
547+
548+
435549
def _download_logs_from_link(log_link: str) -> str:
436550
"""
437551
Download logs from an Azure Storage portal link to lisa/ai/logs.
438552
439553
Returns the local folder path used for analysis.
440554
"""
441555

442-
from azure.storage.blob import BlobServiceClient
556+
try:
557+
from azure.storage.blob import BlobServiceClient
558+
except ModuleNotFoundError as e:
559+
_raise_missing_azure_dependency_error("azure.storage.blob", e)
443560

444561
link_info = _parse_lisa_storage_log_link(log_link)
445562
logger.info(
@@ -465,13 +582,6 @@ def _download_logs_from_link(log_link: str) -> str:
465582
)
466583

467584
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
468-
469-
def _to_local_fs_path(path: str) -> str:
470-
absolute_path = os.path.abspath(path)
471-
if os.name == "nt" and not absolute_path.startswith("\\\\?\\"):
472-
return "\\\\?\\" + absolute_path
473-
return absolute_path
474-
475585
local_root = os.path.join(
476586
get_current_directory(),
477587
"logs",
@@ -492,59 +602,22 @@ def _to_local_fs_path(path: str) -> str:
492602
normalized_prefix = link_info.blob_prefix.strip("/")
493603
prefix_with_sep = f"{normalized_prefix}/" if normalized_prefix else ""
494604

495-
# Keep only the selected case folder name locally, and preserve subfolders below it.
496605
selected_root_name = (
497606
normalized_prefix.rsplit("/", maxsplit=1)[-1]
498607
if normalized_prefix
499608
else link_info.container
500609
)
501610
local_selected_root = os.path.join(local_root, selected_root_name)
502611

503-
downloaded_count = 0
504-
for blob in blobs:
505-
blob_name = blob.name
506-
relative_blob_name = blob_name
507-
if prefix_with_sep and blob_name.startswith(prefix_with_sep):
508-
relative_blob_name = blob_name[len(prefix_with_sep) :]
509-
if not relative_blob_name:
510-
continue
511-
512-
local_blob_path = os.path.join(
513-
local_selected_root, *relative_blob_name.split("/")
514-
)
515-
local_blob_path_fs = _to_local_fs_path(local_blob_path)
516-
local_parent_fs = os.path.dirname(local_blob_path_fs)
517-
os.makedirs(local_parent_fs, exist_ok=True)
518-
519-
with open(local_blob_path_fs, "wb") as output_file:
520-
stream = container_client.download_blob(blob_name)
521-
output_file.write(stream.readall())
522-
downloaded_count += 1
523-
524-
analysis_root = local_selected_root
612+
downloaded_count = _download_blobs_to_local(
613+
container_client, blobs, prefix_with_sep, local_selected_root
614+
)
525615

526616
logger.info(
527617
f"Downloaded {downloaded_count} blob(s) from "
528-
f"'{link_info.container}/{link_info.blob_prefix}' to: {analysis_root}"
618+
f"'{link_info.container}/{link_info.blob_prefix}' to: {local_selected_root}"
529619
)
530-
return analysis_root
531-
532-
533-
def _get_keywords(answer: Union[Dict[str, List[str]], List[str], str]) -> str:
534-
"""Extract keywords from ground truth data."""
535-
if isinstance(answer, dict):
536-
keywords: List[str] = answer.get("problem_keywords", [""])
537-
elif isinstance(answer, list):
538-
keywords = answer
539-
else:
540-
# ground_truth is a string
541-
keywords = [answer]
542-
543-
assert isinstance(keywords, list), f"Expected list, got {type(keywords)}"
544-
# Sort alphabetically and join.
545-
keywords_str = ", ".join(sorted(keywords))
546-
547-
return keywords_str
620+
return local_selected_root
548621

549622

550623
@retry(tries=3, delay=2) # type: ignore

0 commit comments

Comments
 (0)