1- # ruff: noqa: S105, S106, SLF001
1+ # ruff: noqa: S105, S106, SLF001, PLR2004
22import os
33from datetime import date
44from unittest .mock import MagicMock , patch
77import pytest
88from pyarrow import fs
99
10- from timdex_dataset_api .dataset import DatasetNotLoadedError , TIMDEXDataset
10+ from timdex_dataset_api .dataset import (
11+ DatasetNotLoadedError ,
12+ TIMDEXDataset ,
13+ TIMDEXDatasetConfig ,
14+ )
1115
1216
1317@pytest .mark .parametrize (
@@ -23,6 +27,24 @@ def test_dataset_init_success(location, expected_file_system, expected_source):
2327 assert timdex_dataset .source == expected_source
2428
2529
30+ def test_dataset_init_env_vars_set_config (monkeypatch , local_dataset_location ):
31+ default_timdex_dataset = TIMDEXDataset (location = local_dataset_location )
32+ default_read_batch_config = default_timdex_dataset .config .read_batch_size
33+ assert default_read_batch_config == 1_000
34+
35+ monkeypatch .setenv ("TDA_READ_BATCH_SIZE" , "100_000" )
36+ env_var_timdex_dataset = TIMDEXDataset (location = local_dataset_location )
37+ env_var_read_batch_config = env_var_timdex_dataset .config .read_batch_size
38+ assert env_var_read_batch_config == 100_000
39+
40+
41+ def test_dataset_init_custom_config_object (monkeypatch , local_dataset_location ):
42+ config = TIMDEXDatasetConfig ()
43+ config .max_rows_per_file = 42
44+ timdex_dataset = TIMDEXDataset (location = local_dataset_location , config = config )
45+ assert timdex_dataset .config .max_rows_per_file == 42
46+
47+
2648@patch ("timdex_dataset_api.dataset.fs.LocalFileSystem" )
2749@patch ("timdex_dataset_api.dataset.ds.dataset" )
2850def test_dataset_load_local_sets_filesystem_and_dataset_success (
@@ -73,28 +95,28 @@ def test_dataset_load_without_filters_success(fixed_local_dataset):
7395 fixed_local_dataset .load ()
7496
7597 assert os .path .exists (fixed_local_dataset .location )
76- assert fixed_local_dataset .row_count == 5_000 # noqa: PLR2004
98+ assert fixed_local_dataset .row_count == 5_000
7799
78100
79101def test_dataset_load_with_run_date_str_filters_success (fixed_local_dataset ):
80102 fixed_local_dataset .load (run_date = "2024-12-01" )
81103
82104 assert os .path .exists (fixed_local_dataset .location )
83- assert fixed_local_dataset .row_count == 5_000 # noqa: PLR2004
105+ assert fixed_local_dataset .row_count == 5_000
84106
85107
86108def test_dataset_load_with_run_date_obj_filters_success (fixed_local_dataset ):
87109 fixed_local_dataset .load (run_date = date (2024 , 12 , 1 ))
88110
89111 assert os .path .exists (fixed_local_dataset .location )
90- assert fixed_local_dataset .row_count == 5_000 # noqa: PLR2004
112+ assert fixed_local_dataset .row_count == 5_000
91113
92114
93115def test_dataset_load_with_ymd_filters_success (fixed_local_dataset ):
94116 fixed_local_dataset .load (year = "2024" , month = "12" , day = "01" )
95117
96118 assert os .path .exists (fixed_local_dataset .location )
97- assert fixed_local_dataset .row_count == 5_000 # noqa: PLR2004
119+ assert fixed_local_dataset .row_count == 5_000
98120
99121
100122def test_dataset_load_with_single_nonpartition_filters_success (fixed_local_dataset ):
@@ -158,7 +180,7 @@ def test_dataset_get_filtered_dataset_with_or_nonpartition_filters_success(
158180 timdex_record_id = ["alma:0" , "alma:1" ]
159181 )
160182 filtered_local_df = filtered_local_dataset .to_table ().to_pandas ()
161- assert len (filtered_local_df ) == 2 # noqa: PLR2004
183+ assert len (filtered_local_df ) == 2
162184 assert filtered_local_df ["timdex_record_id" ].tolist () == ["alma:0" , "alma:1" ]
163185
164186
0 commit comments