Skip to content

Commit 9f9477a

Browse files
authored
Merge pull request #2 from MITLibraries/TIMX-415-load-dataset
TIMX 415 - load dataset
2 parents d5f3549 + 69ae96b commit 9f9477a

14 files changed

Lines changed: 844 additions & 208 deletions

File tree

Pipfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ pytest = "*"
2121
ruff = "*"
2222
setuptools = "*"
2323
pandas-stubs = "*"
24+
moto = "*"
25+
pytest-mock = "*"
2426

2527
[requires]
2628
python_version = "3.12"

Pipfile.lock

Lines changed: 483 additions & 208 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ timdex_dataset_api = {git = "https://github.com/MITLibraries/timdex-dataset-api.
2626
... other dependencies...
2727
```
2828

29+
## Environment Variables
30+
31+
### Required
32+
33+
### Optional
34+
```shell
35+
TDA_LOG_LEVEL=# log level for timdex-dataset-api, accepts [DEBUG, INFO, WARNING, ERROR], default INFO
36+
```
37+
2938
## Usage
3039

3140
_TODO..._

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ ignore = [
9494
"S320",
9595
"S321",
9696
"S608",
97+
"TRY003"
9798
]
9899

99100
fixable = ["E", "F", "I", "Q"]

tests/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""tests/conftest.py"""
2+
3+
import pytest
4+
5+
from timdex_dataset_api import TIMDEXDataset
6+
7+
8+
@pytest.fixture(autouse=True)
9+
def _test_env(monkeypatch):
10+
monkeypatch.setenv("TDA_LOG_LEVEL", "INFO")
11+
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "fake_access_key")
12+
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "fake_secret_key")
13+
monkeypatch.setenv("AWS_SESSION_TOKEN", "fake_session_token")
14+
monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1")
15+
16+
17+
@pytest.fixture
18+
def local_dataset_location():
19+
return "tests/fixtures/local_datasets/dataset"
20+
21+
22+
@pytest.fixture
23+
def local_dataset(local_dataset_location):
24+
return TIMDEXDataset.load(local_dataset_location)

tests/test_config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from timdex_dataset_api.config import configure_logger
2+
3+
4+
def test_configure_logger_default_info_level(monkeypatch, caplog):
5+
caplog.set_level("DEBUG")
6+
logger = configure_logger(__name__)
7+
8+
info_msg = "hello INFO world"
9+
logger.info(info_msg)
10+
assert info_msg in caplog.text
11+
12+
debug_msg = "hello DEBUG world"
13+
logger.debug(debug_msg)
14+
assert debug_msg not in caplog.text # NOT captured
15+
16+
17+
def test_configure_logger_env_var_sets_debug_level(monkeypatch, caplog):
18+
caplog.set_level("DEBUG")
19+
monkeypatch.setenv("TDA_LOG_LEVEL", "DEBUG")
20+
logger = configure_logger(__name__)
21+
22+
info_msg = "hello INFO world"
23+
logger.info(info_msg)
24+
assert info_msg in caplog.text
25+
26+
debug_msg = "hello DEBUG world"
27+
logger.debug(debug_msg)
28+
assert debug_msg in caplog.text # IS captured

tests/test_dataset.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# ruff: noqa: S105, S106, SLF001
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
import pyarrow as pa
6+
import pytest
7+
from pyarrow import fs
8+
9+
from timdex_dataset_api.dataset import DatasetNotLoadedError, TIMDEXDataset
10+
11+
12+
@pytest.mark.parametrize(
13+
("location", "expected_filesystem", "expected_source"),
14+
[
15+
("/path/to/dataset", fs.LocalFileSystem, "/path/to/dataset"),
16+
(
17+
["/path/to/records1.parquet", "/path/to/records2.parquet"],
18+
fs.LocalFileSystem,
19+
["/path/to/records1.parquet", "/path/to/records2.parquet"],
20+
),
21+
("s3://bucket/path/to/dataset", fs.S3FileSystem, "bucket/path/to/dataset"),
22+
(
23+
[
24+
"s3://bucket/path/to/dataset/records1.parquet",
25+
"s3://bucket/path/to/dataset/records2.parquet",
26+
],
27+
fs.S3FileSystem,
28+
[
29+
"bucket/path/to/dataset/records1.parquet",
30+
"bucket/path/to/dataset/records2.parquet",
31+
],
32+
),
33+
],
34+
)
35+
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
36+
def test_parse_location_single_local_directory(
37+
get_s3_filesystem,
38+
location,
39+
expected_filesystem,
40+
expected_source,
41+
):
42+
get_s3_filesystem.return_value = fs.S3FileSystem()
43+
filesystem, source = TIMDEXDataset.parse_location(location)
44+
assert isinstance(filesystem, expected_filesystem)
45+
assert source == expected_source
46+
47+
48+
def test_get_s3_filesystem_success(mocker):
49+
mocked_s3_filesystem = mocker.spy(fs, "S3FileSystem")
50+
s3_filesystem = TIMDEXDataset.get_s3_filesystem()
51+
52+
assert mocked_s3_filesystem.call_args[1] == {
53+
"secret_key": "fake_secret_key",
54+
"access_key": "fake_access_key",
55+
"region": "us-east-1",
56+
"session_token": "fake_session_token",
57+
}
58+
assert isinstance(s3_filesystem, pa._s3fs.S3FileSystem)
59+
60+
61+
@patch("timdex_dataset_api.dataset.fs.LocalFileSystem")
62+
@patch("timdex_dataset_api.dataset.ds.dataset")
63+
def test_load_local_dataset_correct_filesystem_and_source(mock_pyarrow_ds, mock_local_fs):
64+
mock_local_fs.return_value = MagicMock()
65+
mock_pyarrow_ds.return_value = MagicMock()
66+
67+
timdex_dataset = TIMDEXDataset(location="local/path/to/dataset")
68+
loaded_dataset = timdex_dataset.load_dataset()
69+
70+
mock_pyarrow_ds.assert_called_once_with(
71+
"local/path/to/dataset",
72+
schema=timdex_dataset.schema,
73+
format="parquet",
74+
partitioning="hive",
75+
filesystem=mock_local_fs.return_value,
76+
)
77+
assert loaded_dataset == mock_pyarrow_ds.return_value
78+
79+
80+
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
81+
@patch("timdex_dataset_api.dataset.ds.dataset")
82+
def test_load_s3_dataset_correct_filesystem_and_source(mock_pyarrow_ds, mock_get_s3_fs):
83+
mock_get_s3_fs.return_value = MagicMock()
84+
mock_pyarrow_ds.return_value = MagicMock()
85+
86+
timdex_dataset = TIMDEXDataset(location="s3://bucket/path/to/dataset")
87+
loaded_dataset = timdex_dataset.load_dataset()
88+
89+
mock_get_s3_fs.assert_called_once()
90+
mock_pyarrow_ds.assert_called_once_with(
91+
"bucket/path/to/dataset",
92+
schema=timdex_dataset.schema,
93+
format="parquet",
94+
partitioning="hive",
95+
filesystem=mock_get_s3_fs.return_value,
96+
)
97+
assert loaded_dataset == mock_pyarrow_ds.return_value
98+
99+
100+
@patch("timdex_dataset_api.dataset.TIMDEXDataset.load_dataset")
101+
def test_load_method_loads_dataset_and_returns_timdexdataset_instance(mock_load_dataset):
102+
mock_load_dataset.return_value = MagicMock()
103+
104+
timdex_dataset = TIMDEXDataset.load("s3://bucket/path/to/dataset")
105+
106+
assert isinstance(timdex_dataset, TIMDEXDataset)
107+
assert timdex_dataset.location == "s3://bucket/path/to/dataset"
108+
mock_load_dataset.assert_called_once()
109+
110+
111+
def test_local_dataset_is_valid(local_dataset):
112+
assert local_dataset.dataset.to_table().validate() is None # where None is valid
113+
114+
115+
def test_local_dataset_row_count_success(local_dataset):
116+
assert local_dataset.dataset.count_rows() == local_dataset.row_count
117+
118+
119+
def test_local_dataset_row_count_missing_dataset_exception(local_dataset):
120+
td = TIMDEXDataset(location="path/to/nowhere")
121+
with pytest.raises(DatasetNotLoadedError):
122+
_ = td.row_count

0 commit comments

Comments
 (0)