Skip to content

Commit d144317

Browse files
committed
add e2e test
1 parent 0e6da4c commit d144317

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed

backend/tests/evals/test_e2e.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from typing import Any
2+
3+
import pandas as pd
4+
from langchain_anthropic import ChatAnthropic
5+
from langchain_core.documents import Document
6+
from langchain_core.messages import AIMessage
7+
from langchain_core.prompts import ChatPromptTemplate
8+
from langchain_core.pydantic_v1 import BaseModel, Field
9+
from langsmith.evaluation import EvaluationResults, evaluate
10+
from langsmith.schemas import Example, Run
11+
12+
from backend.graph import OPENAI_MODEL_KEY, format_docs, graph
13+
14+
DATASET_NAME = "chat-langchain-qa"
15+
EXPERIMENT_PREFIX = "chat-langchain-ci"
16+
17+
SCORE_RETRIEVAL_RECALL = "retrieval_recall"
18+
SCORE_ANSWER_CORRECTNESS = "answer_correctness_score"
19+
SCORE_ANSWER_VS_CONTEXT_CORRETNESS = "answer_vs_context_correctness_score"
20+
21+
# claude sonnet / gpt-4o are a bit too expensive
22+
JUDGE_MODEL_NAME = "claude-3-haiku-20240307"
23+
24+
judge_llm = ChatAnthropic(model_name=JUDGE_MODEL_NAME)
25+
26+
27+
# Evaluate retrieval
28+
29+
30+
def evaluate_retrieval_recall(run: Run, example: Example) -> dict:
31+
documents: list[Document] = run.outputs.get("documents") or []
32+
sources = [doc.metadata["source"] for doc in documents]
33+
expected_sources = set(example.outputs.get("sources") or [])
34+
# NOTE: since we're currently assuming only ~1 correct document per question
35+
# this score is equivalent to recall @K where K is number of retrieved documents
36+
score = float(any(source in expected_sources for source in sources))
37+
return {"key": SCORE_RETRIEVAL_RECALL, "score": score}
38+
39+
40+
# QA Evaluation Schema
41+
42+
43+
class GradeAnswer(BaseModel):
44+
"""Continuous score to assess the correctness of the answer."""
45+
46+
score: float = Field(
47+
description="How correct is the answer? 1.0 if completely correct or 0.0 if completely incorrect",
48+
minimum=0.0,
49+
maximum=1.0,
50+
)
51+
reason: str = Field(
52+
description="1-2 short sentences with the reason why the score was assigned"
53+
)
54+
55+
56+
# Evaluate the answer based on the reference answers
57+
58+
59+
QA_SYSTEM_PROMPT = """You are an expert programmer and problem-solver, tasked with grading answers to questions about Langchain.
60+
You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.
61+
62+
Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements."""
63+
64+
QA_PROMPT = ChatPromptTemplate.from_messages(
65+
[
66+
("system", QA_SYSTEM_PROMPT),
67+
(
68+
"human",
69+
"QUESTION: \n\n {question} \n\n TRUE ANSWER: {true_answer} \n\n STUDENT ANSWER: {answer}",
70+
),
71+
]
72+
)
73+
74+
qa_chain = QA_PROMPT | judge_llm.with_structured_output(GradeAnswer)
75+
76+
77+
def evaluate_qa(run: Run, example: Example) -> dict:
78+
messages = run.outputs.get("messages") or []
79+
if not messages:
80+
return {"score": 0.0}
81+
82+
last_message = messages[-1]
83+
if not isinstance(last_message, AIMessage):
84+
return {"score": 0.0}
85+
86+
score: GradeAnswer = qa_chain.invoke(
87+
{
88+
"question": example.inputs["question"],
89+
"true_answer": example.outputs["answer"],
90+
"answer": last_message.content,
91+
}
92+
)
93+
return {"key": SCORE_ANSWER_CORRECTNESS, "score": float(score.score)}
94+
95+
96+
# Evaluate the answer based on the provided context
97+
98+
CONTEXT_QA_SYSTEM_PROMPT = """You are an expert programmer and problem-solver, tasked with grading answers to questions about Langchain.
99+
You are given a question, the context for answering the question, and the student's answer. You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context.
100+
101+
Grade the student answer BOTH based on its factual accuracy AND on whether it is supported by the context. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements."""
102+
103+
CONTEXT_QA_PROMPT = ChatPromptTemplate.from_messages(
104+
[
105+
("system", CONTEXT_QA_SYSTEM_PROMPT),
106+
(
107+
"human",
108+
"QUESTION: \n\n {question} \n\n CONTEXT: {context} \n\n STUDENT ANSWER: {answer}",
109+
),
110+
]
111+
)
112+
113+
context_qa_chain = CONTEXT_QA_PROMPT | judge_llm.with_structured_output(GradeAnswer)
114+
115+
116+
def evaluate_qa_context(run: Run, example: Example) -> dict:
117+
messages = run.outputs.get("messages") or []
118+
if not messages:
119+
return {"score": 0.0}
120+
121+
documents = run.outputs.get("documents") or []
122+
if not documents:
123+
return {"score": 0.0}
124+
125+
context = format_docs(documents)
126+
127+
last_message = messages[-1]
128+
if not isinstance(last_message, AIMessage):
129+
return {"score": 0.0}
130+
131+
score: GradeAnswer = context_qa_chain.invoke(
132+
{
133+
"question": example.inputs["question"],
134+
"context": context,
135+
"answer": last_message.content,
136+
}
137+
)
138+
return {"key": SCORE_ANSWER_VS_CONTEXT_CORRETNESS, "score": float(score.score)}
139+
140+
141+
# Run evaluation
142+
143+
144+
def run_graph(inputs: dict[str, Any], model_name: str) -> dict[str, Any]:
145+
results = graph.invoke(
146+
{"messages": [("human", inputs["question"])]},
147+
config={"configurable": {"model_name": model_name}},
148+
)
149+
return results
150+
151+
152+
def evaluate_model(*, model_name: str):
153+
results = evaluate(
154+
lambda inputs: run_graph(inputs, model_name=model_name),
155+
data=DATASET_NAME,
156+
evaluators=[evaluate_retrieval_recall, evaluate_qa, evaluate_qa_context],
157+
experiment_prefix=EXPERIMENT_PREFIX,
158+
metadata={"model_name": model_name, "judge_model_name": JUDGE_MODEL_NAME},
159+
max_concurrency=4,
160+
)
161+
return results
162+
163+
164+
# Check results
165+
166+
167+
def convert_single_example_results(evaluation_results: EvaluationResults):
168+
converted = {}
169+
for r in evaluation_results["results"]:
170+
converted[r.key] = r.score
171+
return converted
172+
173+
174+
# NOTE: this is more of a regression test
175+
def test_scores_regression():
176+
# test most commonly used model
177+
experiment_results = evaluate_model(model_name=OPENAI_MODEL_KEY)
178+
experiment_result_df = pd.DataFrame(
179+
convert_single_example_results(result["evaluation_results"])
180+
for result in experiment_results._results
181+
)
182+
average_scores = experiment_result_df.mean()
183+
184+
assert average_scores[SCORE_RETRIEVAL_RECALL] >= 0.65
185+
assert average_scores[SCORE_ANSWER_CORRECTNESS] >= 0.9
186+
assert average_scores[SCORE_ANSWER_VS_CONTEXT_CORRETNESS] >= 0.9

0 commit comments

Comments
 (0)