Skip to content

Commit 49771c5

Browse files
committed
feat: update package repository to accept new native extended type
Enforcing an MSSDK type will raise validation errors, so default to the native extended type. Also update the soon-to-be-deprecated filesystem repository to deal with paths in the new model.
1 parent d950637 commit 49771c5

1 file changed

Lines changed: 49 additions & 35 deletions

File tree

src/ted_sws/data_manager/adapters/mapping_package_repository.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import pathlib
44
import shutil
5-
from datetime import datetime
65
from typing import Iterator, List, Optional
76

87
from pymongo import MongoClient
@@ -64,7 +63,7 @@ def __init__(self, mongodb_client: MongoClient, database_name: str = None):
6463
self.database_name = database_name or config.MONGO_DB_AGGREGATES_DATABASE_NAME
6564
self.mongodb_client = mongodb_client
6665

67-
# Repositories for each MSSDK package type
66+
# Repositories for each package type
6867
self._repo_v1 = MongoDBRepository(
6968
model_class=MappingPackageV1,
7069
mongo_client=mongodb_client,
@@ -89,6 +88,12 @@ def __init__(self, mongodb_client: MongoClient, database_name: str = None):
8988
database_name=self.database_name,
9089
collection_name=self._collection_name
9190
)
91+
self._repo_legacy = MongoDBRepository(
92+
model_class=MappingPackage,
93+
mongo_client=mongodb_client,
94+
database_name=self.database_name,
95+
collection_name=self._collection_name
96+
)
9297

9398
def _get_repository(self, package: MappingPackage) -> MongoDBRepository:
9499
"""Get the appropriate repository based on package type."""
@@ -100,58 +105,64 @@ def _get_repository(self, package: MappingPackage) -> MongoDBRepository:
100105
return self._repo_v2
101106
elif isinstance(package, MappingPackageV1):
102107
return self._repo_v1
108+
elif isinstance(package, MappingPackage):
109+
return self._repo_legacy
103110
else:
104111
raise ValueError(f"Unsupported package type: {type(package).__name__}")
105112

106-
def add(self, package: MappingPackage) -> MappingPackage:
113+
def get_repository_by_class(self, package_class):
114+
if package_class == MappingPackageV1:
115+
return self._repo_v1
116+
elif package_class == MappingPackageV2:
117+
return self._repo_v2
118+
elif package_class == MappingPackageV3:
119+
return self._repo_v3
120+
elif package_class == MappingPackageV3Lightweight:
121+
return self._repo_v3_lightweight
122+
elif package_class == MappingPackage:
123+
return self._repo_legacy
124+
else:
125+
raise ValueError(f"Unsupported package class: {package_class.__name__}")
126+
127+
def add(self, mapping_package: MappingPackage) -> MappingPackage:
107128
"""Save a mapping package to MongoDB.
108129
109130
Args:
110-
package: The mapping package (legacy or MSSDK model)
131+
mapping_package: The mapping package (legacy or MSSDK model)
111132
112133
Returns:
113134
The saved package
114135
"""
115-
repo = self._get_repository(package)
116-
return repo.create(package)
136+
repo = self._get_repository(mapping_package)
137+
return repo.create(mapping_package)
117138

118-
def get(self, reference: str, package_class: MappingPackage) -> MappingPackage:
139+
def get(self, reference: str, package_class: type = MappingPackage) -> MappingPackage:
119140
"""Retrieve a mapping package from MongoDB.
120141
121142
Args:
122143
reference: The package identifier
123-
package_class: The expected package model class (defaults to V2)
144+
package_class: The expected package model class (defaults to MappingPackage)
124145
125146
Returns:
126147
The retrieved package
127148
128149
Raises:
129150
ModelNotFoundError: If package not found
130151
"""
131-
if package_class == MappingPackageV1:
132-
repo = self._repo_v1
133-
elif package_class == MappingPackageV2:
134-
repo = self._repo_v2
135-
elif package_class == MappingPackageV3:
136-
repo = self._repo_v3
137-
elif package_class == MappingPackageV3Lightweight:
138-
repo = self._repo_v3_lightweight
139-
else:
140-
raise ValueError(f"Unsupported package class: {package_class.__name__}")
141-
152+
repo = self.get_repository_by_class(package_class)
142153
return repo.read(reference)
143154

144-
def update(self, package: MappingPackage) -> MappingPackage:
155+
def update(self, mapping_package: MappingPackage) -> MappingPackage:
145156
"""Update a mapping package in MongoDB.
146157
147158
Args:
148-
package: The package to update
159+
mapping_package: The package to update
149160
150161
Returns:
151162
The updated package
152163
"""
153-
repo = self._get_repository(package)
154-
return repo.update(package)
164+
repo = self._get_repository(mapping_package)
165+
return repo.update(mapping_package)
155166

156167
def delete(self, reference: str) -> None:
157168
"""Delete a mapping package from MongoDB.
@@ -165,7 +176,7 @@ def delete(self, reference: str) -> None:
165176
if result.deleted_count < 1:
166177
raise ModelNotFoundError(f"Package with ID {reference} not found")
167178

168-
def list(self, package_class: MappingPackage) -> List[MappingPackage]:
179+
def list(self, package_class: type = MappingPackage) -> List[MappingPackage]:
169180
"""List mapping packages from MongoDB.
170181
171182
Args:
@@ -174,17 +185,7 @@ def list(self, package_class: MappingPackage) -> List[MappingPackage]:
174185
Returns:
175186
List of packages
176187
"""
177-
if package_class == MappingPackageV1:
178-
repo = self._repo_v1
179-
elif package_class == MappingPackageV2:
180-
repo = self._repo_v2
181-
elif package_class == MappingPackageV3:
182-
repo = self._repo_v3
183-
elif package_class == MappingPackageV3Lightweight:
184-
repo = self._repo_v3_lightweight
185-
else:
186-
raise ValueError(f"Unsupported package class: {package_class.__name__}")
187-
188+
repo = self.get_repository_by_class(package_class)
188189
return repo.read_many()
189190

190191

@@ -278,12 +279,25 @@ def _write_package_metadata(self, mapping_package: MappingPackage):
278279
:param mapping_package:
279280
:return:
280281
"""
282+
def convert_paths(obj):
283+
if isinstance(obj, pathlib.Path):
284+
return str(obj)
285+
elif isinstance(obj, dict):
286+
return {k: convert_paths(v) for k, v in obj.items()}
287+
elif isinstance(obj, list):
288+
return [convert_paths(i) for i in obj]
289+
elif isinstance(obj, tuple):
290+
return tuple(convert_paths(i) for i in obj)
291+
else:
292+
return obj
293+
281294
package_path = self.repository_path / mapping_package.identifier
282295
package_path.mkdir(parents=True, exist_ok=True)
283296
metadata_path = package_path / MS_METADATA_FILE_NAME
284297
package_metadata = mapping_package.model_dump()
285298
[package_metadata.pop(key, None) for key in
286299
["transformation_rule_set", "shacl_test_suites", "sparql_test_suites"]]
300+
package_metadata = convert_paths(package_metadata)
287301
with metadata_path.open("w", encoding="utf-8") as f:
288302
f.write(json.dumps(package_metadata))
289303

0 commit comments

Comments
 (0)