|
1 | | -# ruff: noqa: PLR2004 |
| 1 | +# ruff: noqa: D205, D209, PLR2004 |
2 | 2 |
|
3 | 3 |
|
4 | 4 | import pandas as pd |
5 | 5 | import pyarrow as pa |
6 | 6 | import pytest |
| 7 | +from duckdb import ParserException |
7 | 8 |
|
8 | 9 | from timdex_dataset_api.dataset import TIMDEX_DATASET_SCHEMA |
9 | 10 |
|
@@ -93,6 +94,64 @@ def test_read_transformed_records_yields_parsed_dictionary(timdex_dataset_multi_ |
93 | 94 | assert transformed_record == {"title": ["Hello World."]} |
94 | 95 |
|
95 | 96 |
|
| 97 | +def test_read_batches_where_filters_response(timdex_dataset_multi_source): |
| 98 | + df_all = timdex_dataset_multi_source.read_dataframe() |
| 99 | + total_count = len(df_all) |
| 100 | + |
| 101 | + where = ( |
| 102 | + "source = 'libguides' AND run_date = '2024-12-01' AND " |
| 103 | + "run_type = 'daily' AND action = 'index'" |
| 104 | + ) |
| 105 | + df_where = timdex_dataset_multi_source.read_dataframe(where=where) |
| 106 | + |
| 107 | + assert len(df_where) == 1_000 |
| 108 | + assert len(df_where) < total_count |
| 109 | + |
| 110 | + |
| 111 | +def test_read_batches_where_and_dataset_filters_are_combined(timdex_dataset_multi_source): |
| 112 | + """Test that when key/value DatasetFilters AND a SQL where clause is provided, they |
| 113 | + are combined in the final DuckDB SQL query.""" |
| 114 | + where = "run_date = '2024-12-01' AND run_type = 'daily'" |
| 115 | + df = timdex_dataset_multi_source.read_dataframe( |
| 116 | + where=where, source="libguides", action="index" |
| 117 | + ) |
| 118 | + assert len(df) == 1_000 |
| 119 | + assert set(df["source"].unique().tolist()) == {"libguides"} |
| 120 | + assert set(df["action"].unique().tolist()) == {"index"} |
| 121 | + |
| 122 | + |
| 123 | +@pytest.mark.parametrize( |
| 124 | + "bad_where", |
| 125 | + [ |
| 126 | + "SELECT * FROM current_records WHERE source = 'libguides'", |
| 127 | + "FROM records WHERE source = 'libguides'", |
| 128 | + "source = 'libguides';", |
| 129 | + " run_date = '2024-12-01'; ", |
| 130 | + ], |
| 131 | +) |
| 132 | +def test_read_batches_where_rejects_non_predicate_sql( |
| 133 | + timdex_dataset_multi_source, bad_where |
| 134 | +): |
| 135 | + with pytest.raises(ParserException): |
| 136 | + next(timdex_dataset_multi_source.read_batches_iter(where=bad_where)) |
| 137 | + |
| 138 | + |
| 139 | +def test_read_dataframe_respects_where(timdex_dataset_multi_source): |
| 140 | + where = "source = 'libguides' AND action = 'index'" |
| 141 | + df = timdex_dataset_multi_source.read_dataframe(where=where) |
| 142 | + assert len(df) > 0 |
| 143 | + assert set(df["source"].unique().tolist()) == {"libguides"} |
| 144 | + assert set(df["action"].unique().tolist()) == {"index"} |
| 145 | + |
| 146 | + |
| 147 | +def test_read_dicts_iter_respects_where_and_filters(timdex_dataset_multi_source): |
| 148 | + where = "run_type = 'daily'" |
| 149 | + it = timdex_dataset_multi_source.read_dicts_iter(where=where, source="libguides") |
| 150 | + first = next(it) |
| 151 | + assert first["run_type"] == "daily" |
| 152 | + assert first["source"] == "libguides" |
| 153 | + |
| 154 | + |
96 | 155 | def test_dataset_all_current_records_deduped(timdex_dataset_with_runs_with_metadata): |
97 | 156 | df = timdex_dataset_with_runs_with_metadata.read_dataframe( |
98 | 157 | table="current_records", |
|
0 commit comments