Skip to content

Commit 69ae96b

Browse files
committed
Breakout single and multiple location parsing
1 parent b782cac commit 69ae96b

1 file changed

Lines changed: 34 additions & 19 deletions

File tree

timdex_dataset_api/dataset.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,34 +58,49 @@ def get_s3_filesystem() -> fs.FileSystem:
5858
session_token=credentials.token,
5959
)
6060

61-
@staticmethod
61+
@classmethod
6262
def parse_location(
63+
cls,
6364
location: str | list[str],
6465
) -> tuple[fs.FileSystem, str | list[str]]:
6566
"""Parse and return the filesystem and normalized source location(s).
6667
6768
Handles both single location strings and lists of Parquet file paths.
6869
"""
69-
source: str | list[str]
70-
if isinstance(location, str):
71-
if location.startswith("s3://"):
72-
filesystem = TIMDEXDataset.get_s3_filesystem()
73-
source = location.removeprefix("s3://")
74-
else:
75-
filesystem = fs.LocalFileSystem()
76-
source = location
77-
elif isinstance(location, list):
78-
if all(loc.startswith("s3://") for loc in location):
79-
filesystem = TIMDEXDataset.get_s3_filesystem()
80-
source = [loc.removeprefix("s3://") for loc in location]
81-
elif all(not loc.startswith("s3://") for loc in location):
82-
filesystem = fs.LocalFileSystem()
83-
source = location
84-
else:
85-
raise ValueError("Mixed S3 and local paths are not supported.")
70+
match location:
71+
case str():
72+
return cls._parse_single_location(location)
73+
case list():
74+
return cls._parse_multiple_locations(location)
75+
case _:
76+
raise TypeError("Location type must be str or list[str].")
77+
78+
@classmethod
79+
def _parse_single_location(
80+
cls, location: str
81+
) -> tuple[fs.FileSystem, str | list[str]]:
82+
"""Get filesystem and normalized location for single location."""
83+
if location.startswith("s3://"):
84+
filesystem = TIMDEXDataset.get_s3_filesystem()
85+
source = location.removeprefix("s3://")
8686
else:
87-
raise TypeError("Location type must be str or list[str].")
87+
filesystem = fs.LocalFileSystem()
88+
source = location
89+
return filesystem, source
8890

91+
@classmethod
92+
def _parse_multiple_locations(
93+
cls, location: list[str]
94+
) -> tuple[fs.FileSystem, str | list[str]]:
95+
"""Get filesystem and normalized location for multiple locations."""
96+
if all(loc.startswith("s3://") for loc in location):
97+
filesystem = TIMDEXDataset.get_s3_filesystem()
98+
source = [loc.removeprefix("s3://") for loc in location]
99+
elif all(not loc.startswith("s3://") for loc in location):
100+
filesystem = fs.LocalFileSystem()
101+
source = location
102+
else:
103+
raise ValueError("Mixed S3 and local paths are not supported.")
89104
return filesystem, source
90105

91106
def load_dataset(self) -> ds.Dataset:

0 commit comments

Comments
 (0)