Skip to content

Commit 9fd525a

Browse files
committed
Refactor evals.py for clarity, formatting, and maintainability
1 parent 1980105 commit 9fd525a

1 file changed

Lines changed: 61 additions & 34 deletions

File tree

server/api/services/evals.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Evaluate LLM outputs using multiple metrics and compute associated costs
33
"""
44

5-
#TODO: Add tests on a small dummy dataset to confirm it handles errors gracefully and produces expected outputs
5+
# TODO: Add tests on a small dummy dataset to confirm it handles errors gracefully and produces expected outputs
66

77
import argparse
88
import logging
@@ -13,11 +13,14 @@
1313

1414
from services import ModelFactory
1515

16-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16+
logging.basicConfig(
17+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
18+
)
1719

1820

19-
#TODO: Rename model to model name and query to instructions for clarity
20-
def evaluate_response(model_name: str, query: str, context: str, reference: str) -> pd.DataFrame:
21+
def evaluate_response(
22+
model_name: str, query: str, context: str, reference: str
23+
) -> pd.DataFrame:
2124
"""
2225
Evaluates the response of a model to a given query and context, computes extractiveness metrics, token usage, and cost
2326
@@ -33,35 +36,47 @@ def evaluate_response(model_name: str, query: str, context: str, reference: str)
3336

3437
handler = ModelFactory.get_handler(model_name)
3538

36-
#TODO: Add error handling for unsupported models
37-
39+
# TODO: Add error handling for unsupported models
40+
3841
output_text, token_usage, pricing, duration = handler.handle_request(query, context)
3942

4043
doc = Doc(query="", choices=[], gold_index=0, specific={"text": context})
41-
extractiveness = Extractiveness().compute(formatted_doc=doc, predictions=[output_text])
44+
extractiveness = Extractiveness().compute(
45+
formatted_doc=doc, predictions=[output_text]
46+
)
4247

43-
input_cost_dollars = (pricing['input'] / 1000000) * token_usage.input_tokens
44-
output_cost_dollars = (pricing['output'] / 1000000) * token_usage.output_tokens
48+
input_cost_dollars = (pricing["input"] / 1000000) * token_usage.input_tokens
49+
output_cost_dollars = (pricing["output"] / 1000000) * token_usage.output_tokens
4550

4651
total_cost_dollars = input_cost_dollars + output_cost_dollars
4752

48-
return pd.DataFrame([{
49-
"Output Text": output_text,
50-
"Extractiveness Coverage": extractiveness['summarization_coverage'],
51-
"Extractiveness Density": extractiveness['summarization_density'],
52-
"Extractiveness Compression": extractiveness['summarization_compression'],
53-
"Input Token Usage": token_usage.input_tokens,
54-
"Output Token Usage": token_usage.output_tokens,
55-
"Cost (USD)": total_cost_dollars,
56-
"Duration (s)": duration
57-
}])
53+
return pd.DataFrame(
54+
[
55+
{
56+
"Output Text": output_text,
57+
"Extractiveness Coverage": extractiveness["summarization_coverage"],
58+
"Extractiveness Density": extractiveness["summarization_density"],
59+
"Extractiveness Compression": extractiveness[
60+
"summarization_compression"
61+
],
62+
"Input Token Usage": token_usage.input_tokens,
63+
"Output Token Usage": token_usage.output_tokens,
64+
"Cost (USD)": total_cost_dollars,
65+
"Duration (s)": duration,
66+
}
67+
]
68+
)
5869

5970

6071
if __name__ == "__main__":
61-
62-
parser = argparse.ArgumentParser(description="Evaluate LLM outputs using multiple metrics and compute associated costs")
72+
# TODO: Add CLI argument to specify the metrics to be computed
73+
parser = argparse.ArgumentParser(
74+
description="Evaluate LLM outputs using multiple metrics and compute associated costs"
75+
)
6376
parser.add_argument("--config", "-c", required=True, help="Path to config CSV file")
64-
parser.add_argument("--reference", "-r", required=True, help="Path to reference CSV file")
77+
parser.add_argument(
78+
"--reference", "-r", required=True, help="Path to reference CSV file"
79+
)
6580
parser.add_argument("--output", "-o", required=True, help="Path to output CSV file")
6681

6782
args = parser.parse_args()
@@ -73,34 +88,46 @@ def evaluate_response(model_name: str, query: str, context: str, reference: str)
7388
# Remove the trailing whitespace from column names
7489
df_config.columns = df_config.columns.str.strip()
7590

76-
#TODO: Check if the required columns are present
91+
# TODO: Check if the required columns are present
7792

7893
# Check if all models in the config are supported by ModelFactory
79-
if not all(model in ModelFactory.HANDLERS.keys() for model in df_config['Model'].unique()):
80-
raise ValueError(f"Unsupported model(s) found in config: {set(df_config['Model'].unique()) - set(ModelFactory.HANDLERS.keys())}")
81-
94+
if not all(
95+
model in ModelFactory.HANDLERS.keys()
96+
for model in df_config["Model Name"].unique()
97+
):
98+
raise ValueError(
99+
f"Unsupported model(s) found in config: {set(df_config['Model Name'].unique()) - set(ModelFactory.HANDLERS.keys())}"
100+
)
101+
82102
df_reference = pd.read_csv(args.reference)
83103
logging.info(f"Reference DataFrame shape: {df_reference.shape}")
84104
logging.info(f"Reference DataFrame columns: {df_reference.columns.tolist()}")
85-
105+
86106
# Cross join the config and reference DataFrames
87-
df_in = df_config.merge(df_reference, how='cross')
107+
df_in = df_config.merge(df_reference, how="cross")
88108

89109
# TODO: Parallelize the evaluation process for each row in df_in using concurrent.futures or similar libraries
90110
df_evals = pd.DataFrame()
91111
for index, row in df_in.iterrows():
112+
df_evals = pd.concat(
113+
[
114+
df_evals,
115+
evaluate_response(
116+
row["Model Name"], row["Query"], row["Context"], row["Reference"]
117+
),
118+
],
119+
axis=0,
120+
)
92121

93-
#TODO: Rename Model to Model name for clarity
94-
df_evals = pd.concat([df_evals, evaluate_response(row['Model'], row['Query'], row['Context'], row['Reference'])], axis=0)
95-
96122
logging.info(f"Processed row {index + 1}/{len(df_in)}")
97123

98-
99124
# Concatenate the input and evaluations DataFrames
100125

101-
df_out = pd.concat([df_in.reset_index(drop=True), df_evals.reset_index(drop=True)], axis=1)
126+
df_out = pd.concat(
127+
[df_in.reset_index(drop=True), df_evals.reset_index(drop=True)], axis=1
128+
)
102129

103130
df_out.to_csv(args.output, index=False)
104131
logging.info(f"Output DataFrame shape: {df_out.shape}")
105132
logging.info(f"Results saved to {args.output}")
106-
logging.info("Evaluation completed successfully.")
133+
logging.info("Evaluation completed successfully.")

0 commit comments

Comments
 (0)