@@ -58,34 +58,49 @@ def get_s3_filesystem() -> fs.FileSystem:
5858 session_token = credentials .token ,
5959 )
6060
61- @staticmethod
61+ @classmethod
6262 def parse_location (
63+ cls ,
6364 location : str | list [str ],
6465 ) -> tuple [fs .FileSystem , str | list [str ]]:
6566 """Parse and return the filesystem and normalized source location(s).
6667
6768 Handles both single location strings and lists of Parquet file paths.
6869 """
69- source : str | list [str ]
70- if isinstance (location , str ):
71- if location .startswith ("s3://" ):
72- filesystem = TIMDEXDataset .get_s3_filesystem ()
73- source = location .removeprefix ("s3://" )
74- else :
75- filesystem = fs .LocalFileSystem ()
76- source = location
77- elif isinstance (location , list ):
78- if all (loc .startswith ("s3://" ) for loc in location ):
79- filesystem = TIMDEXDataset .get_s3_filesystem ()
80- source = [loc .removeprefix ("s3://" ) for loc in location ]
81- elif all (not loc .startswith ("s3://" ) for loc in location ):
82- filesystem = fs .LocalFileSystem ()
83- source = location
84- else :
85- raise ValueError ("Mixed S3 and local paths are not supported." )
70+ match location :
71+ case str ():
72+ return cls ._parse_single_location (location )
73+ case list ():
74+ return cls ._parse_multiple_locations (location )
75+ case _:
76+ raise TypeError ("Location type must be str or list[str]." )
77+
78+ @classmethod
79+ def _parse_single_location (
80+ cls , location : str
81+ ) -> tuple [fs .FileSystem , str | list [str ]]:
82+ """Get filesystem and normalized location for single location."""
83+ if location .startswith ("s3://" ):
84+ filesystem = TIMDEXDataset .get_s3_filesystem ()
85+ source = location .removeprefix ("s3://" )
8686 else :
87- raise TypeError ("Location type must be str or list[str]." )
87+ filesystem = fs .LocalFileSystem ()
88+ source = location
89+ return filesystem , source
8890
91+ @classmethod
92+ def _parse_multiple_locations (
93+ cls , location : list [str ]
94+ ) -> tuple [fs .FileSystem , str | list [str ]]:
95+ """Get filesystem and normalized location for multiple locations."""
96+ if all (loc .startswith ("s3://" ) for loc in location ):
97+ filesystem = TIMDEXDataset .get_s3_filesystem ()
98+ source = [loc .removeprefix ("s3://" ) for loc in location ]
99+ elif all (not loc .startswith ("s3://" ) for loc in location ):
100+ filesystem = fs .LocalFileSystem ()
101+ source = location
102+ else :
103+ raise ValueError ("Mixed S3 and local paths are not supported." )
89104 return filesystem , source
90105
91106 def load_dataset (self ) -> ds .Dataset :
0 commit comments