Skip to content

Commit c9d23be

Browse files
authored
Merge pull request #3 from saswati2/califorest-saswati
Califorest saswati
2 parents e132eb4 + 1d0eee5 commit c9d23be

File tree

6 files changed

+288
-0
lines changed

6 files changed

+288
-0
lines changed

califorest_tests/__init__.py

Whitespace-only changes.

califorest_tests/test_datasets.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import json
2+
from califorest_tests.utils import create_temp_dataset
3+
4+
"""
5+
Dataset tests using small synthetic EHR data.
6+
7+
These tests verify:
8+
- JSON dataset loading
9+
- Patient and visit structure
10+
- Event field integrity
11+
12+
All tests use temporary directories and synthetic data to ensure
13+
fast execution and full isolation.
14+
"""
15+
16+
def test_dataset_loading():
17+
18+
#Verify that synthetic dataset JSON can be loaded correctly.
19+
temp_dir, data_path = create_temp_dataset()
20+
21+
with open(data_path) as f:
22+
data = json.load(f)
23+
24+
assert len(data) > 0 # patients exist
25+
assert isinstance(data, list) # (data integrity)
26+
27+
temp_dir.cleanup()
28+
29+
30+
def test_patient_structure():
31+
32+
#Verify each patient contains required fields.#
33+
temp_dir, data_path = create_temp_dataset()
34+
35+
with open(data_path) as f:
36+
patients = json.load(f)
37+
38+
patient = patients[0]
39+
assert "patient_id" in patient
40+
assert "visits" in patient
41+
assert len(patient["visits"]) == 2
42+
43+
temp_dir.cleanup()
44+
45+
46+
def test_visit_structure():
47+
#Verify each visit contains required event fields.
48+
temp_dir, data_path = create_temp_dataset()
49+
50+
with open(data_path) as f:
51+
patients = json.load(f)
52+
53+
visit = patients[0]["visits"][0]
54+
55+
assert "conditions" in visit
56+
assert "procedures" in visit
57+
assert "drugs" in visit
58+
assert "label" in visit
59+
60+
temp_dir.cleanup()

califorest_tests/test_models.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
import torch.nn as nn
3+
from califorest_tests.utils import create_synthetic_ehr
4+
5+
"""
6+
Model unit tests using tiny synthetic tensors.
7+
8+
These tests verify:
9+
- Model instantiation
10+
- Forward pass correctness
11+
- Output shape validation
12+
- Gradient computation during backpropagation
13+
"""
14+
15+
class TinyModel(nn.Module):
16+
def __init__(self, in_features=8):
17+
super().__init__()
18+
self.fc = nn.Linear(in_features, 1)
19+
20+
def forward(self, x):
21+
return self.fc(x)
22+
23+
24+
def test_model_instantiation():
25+
#Model can be created successfully
26+
model = TinyModel()
27+
assert model is not None
28+
29+
30+
def test_forward_pass():
31+
#Forward pass returns outputs with correct batch size.
32+
X, y = create_synthetic_ehr()
33+
model = TinyModel()
34+
35+
x_tensor = torch.tensor(X, dtype=torch.float32)
36+
output = model(x_tensor)
37+
38+
assert output.shape[0] == X.shape[0]
39+
40+
41+
def test_backward_pass():
42+
#Backward pass computes gradients successfully
43+
X, y = create_synthetic_ehr()
44+
model = TinyModel()
45+
46+
x_tensor = torch.tensor(X, dtype=torch.float32)
47+
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1,1)
48+
49+
criterion = nn.BCEWithLogitsLoss()
50+
51+
output = model(x_tensor)
52+
loss = criterion(output, y_tensor)
53+
loss.backward()
54+
55+
assert model.fc.weight.grad is not None

califorest_tests/test_tasks.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
from califorest_tests.utils import create_synthetic_patient_records
3+
4+
"""
5+
Task pipeline tests using synthetic patient records.
6+
7+
These tests validate:
8+
- Sample processing
9+
- Feature extraction
10+
- Label generation
11+
- Edge case handling
12+
"""
13+
14+
def process_samples(patients):
15+
"""Fake task pipeline"""
16+
X = []
17+
y = []
18+
19+
for p in patients:
20+
for v in p["visits"]:
21+
features = len(v["conditions"]) + len(v["drugs"])
22+
X.append(features)
23+
y.append(v["label"])
24+
25+
return np.array(X), np.array(y)
26+
27+
28+
def test_sample_processing():
29+
patients = create_synthetic_patient_records()
30+
X, y = process_samples(patients)
31+
32+
assert len(X) == len(y)
33+
assert X.ndim == 1
34+
35+
36+
def test_label_generation():
37+
patients = create_synthetic_patient_records()
38+
_, y = process_samples(patients)
39+
40+
assert set(y).issubset({0,1})
41+
42+
43+
def test_edge_cases_empty_patient():
44+
X, y = process_samples([])
45+
assert len(X) == 0

califorest_tests/utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
Utility functions for generating small synthetic data used in tests.
3+
4+
These helpers ensure tests run quickly, use no real datasets, and
5+
remain fully isolated from external dependencies.
6+
"""
7+
8+
import json
9+
import numpy as np
10+
import tempfile
11+
from pathlib import Path
12+
13+
14+
# ---------------------------------------------------
15+
# Synthetic DATASET generator (for dataset tests)
16+
# ---------------------------------------------------
17+
18+
def create_temp_dataset():
19+
"""
20+
Create a temporary JSON dataset with 2 synthetic patients.
21+
22+
Returns
23+
-------
24+
temp_dir : TemporaryDirectory
25+
Temporary directory object (must be cleaned up by tests).
26+
data_path : Path
27+
Path to the generated JSON file.
28+
"""
29+
temp_dir = tempfile.TemporaryDirectory()
30+
data_path = Path(temp_dir.name) / "patients.json"
31+
32+
patients = create_synthetic_patient_records()
33+
34+
with open(data_path, "w") as f:
35+
json.dump(patients, f)
36+
37+
return temp_dir, data_path
38+
39+
40+
# ---------------------------------------------------
41+
# Synthetic PATIENT records (for dataset + task tests)
42+
# ---------------------------------------------------
43+
44+
def create_synthetic_patient_records():
45+
"""
46+
Generate tiny synthetic EHR patient records.
47+
48+
Returns
49+
-------
50+
list
51+
List of patient dictionaries with visit information.
52+
"""
53+
return [
54+
{
55+
"patient_id": "p1",
56+
"visits": [
57+
{
58+
"conditions": ["c1", "c2"],
59+
"procedures": ["p1"],
60+
"drugs": ["d1"],
61+
"label": 1,
62+
},
63+
{
64+
"conditions": ["c3"],
65+
"procedures": ["p2"],
66+
"drugs": ["d2", "d3"],
67+
"label": 0,
68+
},
69+
],
70+
},
71+
{
72+
"patient_id": "p2",
73+
"visits": [
74+
{
75+
"conditions": ["c4"],
76+
"procedures": [],
77+
"drugs": ["d4"],
78+
"label": 0,
79+
}
80+
],
81+
},
82+
]
83+
84+
85+
# ---------------------------------------------------
86+
# Synthetic MODEL tensors (for model tests)
87+
# ---------------------------------------------------
88+
89+
def create_synthetic_ehr(n_samples: int = 4, n_features: int = 8):
90+
"""
91+
Generate tiny synthetic tensors for model testing.
92+
93+
Parameters
94+
----------
95+
n_samples : int
96+
Number of samples to generate.
97+
n_features : int
98+
Number of input features.
99+
100+
Returns
101+
-------
102+
X : np.ndarray
103+
Feature matrix.
104+
y : np.ndarray
105+
Binary labels.
106+
"""
107+
X = np.random.rand(n_samples, n_features).astype(np.float32)
108+
y = np.random.randint(0, 2, size=(n_samples,)).astype(np.float32)
109+
return X, y

tests/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import numpy as np
2+
3+
def create_synthetic_ehr(num_patients=10, num_features=8):
4+
print("Testing Utils")
5+
"""
6+
Creates fake patient EHR data for testing.
7+
Returns X features and y labels.
8+
"""
9+
np.random.seed(0)
10+
X = np.random.rand(num_patients, num_features)
11+
y = np.random.randint(0, 2, size=num_patients)
12+
13+
return X, y
14+
15+
16+
if __name__ == "__main__":
17+
X, y = create_synthetic_ehr()
18+
print("X shape:", X.shape)
19+
print("y shape:", y.shape)

0 commit comments

Comments
 (0)