Skip to content

Commit 402e9ae

Browse files
Add support for 'or' conditions to filtering method (#99)
1 parent 9c6bdc4 commit 402e9ae

2 files changed

Lines changed: 25 additions & 11 deletions

File tree

tests/test_dataset.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,22 @@ def test_dataset_load_with_multi_nonpartition_filters_success(fixed_local_datase
115115
assert fixed_local_dataset.row_count == 1
116116

117117

118+
def test_dataset_get_filtered_dataset_with_single_nonpartition_success(
119+
fixed_local_dataset,
120+
):
121+
fixed_local_dataset.load() # initial load dataset, no filters passed
122+
123+
filtered_local_dataset = fixed_local_dataset._get_filtered_dataset(
124+
run_id="abc123",
125+
)
126+
filtered_local_df = filtered_local_dataset.to_table().to_pandas()
127+
128+
# fixed_local_dataset consists of single 'run_id' value
129+
# therefore, filtered_local_dataset includes all records
130+
assert len(filtered_local_df) == filtered_local_dataset.count_rows()
131+
assert filtered_local_df["run_id"].unique() == ["abc123"]
132+
133+
118134
def test_dataset_get_filtered_dataset_with_multi_nonpartition_filters_success(
119135
fixed_local_dataset,
120136
):
@@ -133,20 +149,17 @@ def test_dataset_get_filtered_dataset_with_multi_nonpartition_filters_success(
133149
assert filtered_local_df["timdex_record_id"].iloc[0] == "alma:0"
134150

135151

136-
def test_dataset_get_filtered_dataset_with_single_nonpartition_success(
152+
def test_dataset_get_filtered_dataset_with_or_nonpartition_filters_success(
137153
fixed_local_dataset,
138154
):
139-
fixed_local_dataset.load() # initial load dataset, no filters passed
155+
fixed_local_dataset.load()
140156

141157
filtered_local_dataset = fixed_local_dataset._get_filtered_dataset(
142-
run_id="abc123",
158+
timdex_record_id=["alma:0", "alma:1"]
143159
)
144160
filtered_local_df = filtered_local_dataset.to_table().to_pandas()
145-
146-
# fixed_local_dataset consists of single 'run_id' value
147-
# therefore, filtered_local_dataset includes all records
148-
assert len(filtered_local_df) == filtered_local_dataset.count_rows()
149-
assert filtered_local_df["run_id"].unique() == ["abc123"]
161+
assert len(filtered_local_df) == 2 # noqa: PLR2004
162+
assert filtered_local_df["timdex_record_id"].tolist() == ["alma:0", "alma:1"]
150163

151164

152165
def test_dataset_get_filtered_dataset_with_run_date_str_successs(fixed_local_dataset):

timdex_dataset_api/dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import boto3
1414
import pandas as pd
1515
import pyarrow as pa
16-
import pyarrow.compute as pc
1716
import pyarrow.dataset as ds
1817
from pyarrow import fs
1918

@@ -171,8 +170,10 @@ def _get_filtered_dataset(
171170
# create filter expressions for element-wise equality comparisons
172171
expressions = []
173172
for field, value in filters.items():
174-
if value:
175-
expressions.append(pc.equal(pc.field(field), value))
173+
if isinstance(value, list):
174+
expressions.append(ds.field(field).isin(value))
175+
else:
176+
expressions.append(ds.field(field) == value)
176177

177178
# if filter expressions not found, return original dataset
178179
if not expressions:

0 commit comments

Comments
 (0)