Skip to content

Commit 2ea6a72

Browse files
authored
add CI evals (#343)
1 parent 0e6da4c commit 2ea6a72

File tree

5 files changed

+372
-4
lines changed

5 files changed

+372
-4
lines changed

.github/workflows/eval.yml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: Eval
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
pull_request:
8+
branches:
9+
- master
10+
workflow_dispatch:
11+
12+
concurrency:
13+
group: eval-${{ github.ref }}
14+
cancel-in-progress: true
15+
16+
jobs:
17+
run_eval:
18+
runs-on: ubuntu-latest
19+
environment: Evaluation
20+
steps:
21+
- name: Checkout repository
22+
uses: actions/checkout@v3
23+
24+
- name: Set up Python + Poetry
25+
uses: "./.github/actions/poetry_setup"
26+
with:
27+
python-version: "3.11"
28+
poetry-version: "1.7.1"
29+
cache-key: lint
30+
31+
- name: Install dependencies
32+
run: poetry install --with dev
33+
34+
- name: Evaluate
35+
env:
36+
LANGCHAIN_API_KEY: ${{ secrets.LANGCHAIN_API_KEY }}
37+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
38+
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
39+
WEAVIATE_URL: ${{ secrets.WEAVIATE_URL }}
40+
WEAVIATE_API_KEY: ${{ secrets.WEAVIATE_API_KEY }}
41+
run: poetry run pytest backend/tests/evals

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
run: poetry lock --check
4545

4646
- name: Install dependencies
47-
run: poetry install --with lint
47+
run: poetry install --with dev
4848

4949
- name: Get .mypy_cache to speed up mypy
5050
uses: actions/cache@v3

backend/tests/evals/test_e2e.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from typing import Any
2+
3+
import pandas as pd
4+
from langchain_anthropic import ChatAnthropic
5+
from langchain_core.documents import Document
6+
from langchain_core.messages import AIMessage
7+
from langchain_core.prompts import ChatPromptTemplate
8+
from langchain_core.pydantic_v1 import BaseModel, Field
9+
from langsmith.evaluation import EvaluationResults, evaluate
10+
from langsmith.schemas import Example, Run
11+
12+
from backend.graph import OPENAI_MODEL_KEY, format_docs, graph
13+
14+
DATASET_NAME = "chat-langchain-qa"
15+
EXPERIMENT_PREFIX = "chat-langchain-ci"
16+
17+
SCORE_RETRIEVAL_RECALL = "retrieval_recall"
18+
SCORE_ANSWER_CORRECTNESS = "answer_correctness_score"
19+
SCORE_ANSWER_VS_CONTEXT_CORRECTNESS = "answer_vs_context_correctness_score"
20+
21+
# claude sonnet / gpt-4o are a bit too expensive
22+
JUDGE_MODEL_NAME = "claude-3-haiku-20240307"
23+
24+
judge_llm = ChatAnthropic(model_name=JUDGE_MODEL_NAME)
25+
26+
27+
# Evaluate retrieval
28+
29+
30+
def evaluate_retrieval_recall(run: Run, example: Example) -> dict:
31+
documents: list[Document] = run.outputs.get("documents") or []
32+
sources = [doc.metadata["source"] for doc in documents]
33+
expected_sources = set(example.outputs.get("sources") or [])
34+
# NOTE: since we're currently assuming only ~1 correct document per question
35+
# this score is equivalent to recall @K where K is number of retrieved documents
36+
score = float(any(source in expected_sources for source in sources))
37+
return {"key": SCORE_RETRIEVAL_RECALL, "score": score}
38+
39+
40+
# QA Evaluation Schema
41+
42+
43+
class GradeAnswer(BaseModel):
44+
"""Evaluate correctness of the answer and assign a continuous score."""
45+
46+
reason: str = Field(
47+
description="1-2 short sentences with the reason why the score was assigned"
48+
)
49+
score: float = Field(
50+
description="Score that shows how correct the answer is. Use 1.0 if completely correct and 0.0 if completely incorrect",
51+
minimum=0.0,
52+
maximum=1.0,
53+
)
54+
55+
56+
# Evaluate the answer based on the reference answers
57+
58+
59+
QA_SYSTEM_PROMPT = """You are an expert programmer and problem-solver, tasked with grading answers to questions about Langchain.
60+
You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.
61+
62+
Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements."""
63+
64+
QA_PROMPT = ChatPromptTemplate.from_messages(
65+
[
66+
("system", QA_SYSTEM_PROMPT),
67+
(
68+
"human",
69+
"QUESTION: \n\n {question} \n\n TRUE ANSWER: {true_answer} \n\n STUDENT ANSWER: {answer}",
70+
),
71+
]
72+
)
73+
74+
qa_chain = QA_PROMPT | judge_llm.with_structured_output(GradeAnswer)
75+
76+
77+
def evaluate_qa(run: Run, example: Example) -> dict:
78+
messages = run.outputs.get("messages") or []
79+
if not messages:
80+
return {"score": 0.0}
81+
82+
last_message = messages[-1]
83+
if not isinstance(last_message, AIMessage):
84+
return {"score": 0.0}
85+
86+
score: GradeAnswer = qa_chain.invoke(
87+
{
88+
"question": example.inputs["question"],
89+
"true_answer": example.outputs["answer"],
90+
"answer": last_message.content,
91+
}
92+
)
93+
return {"key": SCORE_ANSWER_CORRECTNESS, "score": float(score.score)}
94+
95+
96+
# Evaluate the answer based on the provided context
97+
98+
CONTEXT_QA_SYSTEM_PROMPT = """You are an expert programmer and problem-solver, tasked with grading answers to questions about Langchain.
99+
You are given a question, the context for answering the question, and the student's answer. You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context.
100+
101+
Grade the student answer BOTH based on its factual accuracy AND on whether it is supported by the context. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements."""
102+
103+
CONTEXT_QA_PROMPT = ChatPromptTemplate.from_messages(
104+
[
105+
("system", CONTEXT_QA_SYSTEM_PROMPT),
106+
(
107+
"human",
108+
"QUESTION: \n\n {question} \n\n CONTEXT: {context} \n\n STUDENT ANSWER: {answer}",
109+
),
110+
]
111+
)
112+
113+
context_qa_chain = CONTEXT_QA_PROMPT | judge_llm.with_structured_output(GradeAnswer)
114+
115+
116+
def evaluate_qa_context(run: Run, example: Example) -> dict:
117+
messages = run.outputs.get("messages") or []
118+
if not messages:
119+
return {"score": 0.0}
120+
121+
documents = run.outputs.get("documents") or []
122+
if not documents:
123+
return {"score": 0.0}
124+
125+
context = format_docs(documents)
126+
127+
last_message = messages[-1]
128+
if not isinstance(last_message, AIMessage):
129+
return {"score": 0.0}
130+
131+
score: GradeAnswer = context_qa_chain.invoke(
132+
{
133+
"question": example.inputs["question"],
134+
"context": context,
135+
"answer": last_message.content,
136+
}
137+
)
138+
return {"key": SCORE_ANSWER_VS_CONTEXT_CORRECTNESS, "score": float(score.score)}
139+
140+
141+
# Run evaluation
142+
143+
144+
def run_graph(inputs: dict[str, Any], model_name: str) -> dict[str, Any]:
145+
results = graph.invoke(
146+
{"messages": [("human", inputs["question"])]},
147+
config={"configurable": {"model_name": model_name}},
148+
)
149+
return results
150+
151+
152+
def evaluate_model(*, model_name: str):
153+
results = evaluate(
154+
lambda inputs: run_graph(inputs, model_name=model_name),
155+
data=DATASET_NAME,
156+
evaluators=[evaluate_retrieval_recall, evaluate_qa, evaluate_qa_context],
157+
experiment_prefix=EXPERIMENT_PREFIX,
158+
metadata={"model_name": model_name, "judge_model_name": JUDGE_MODEL_NAME},
159+
max_concurrency=4,
160+
)
161+
return results
162+
163+
164+
# Check results
165+
166+
167+
def convert_single_example_results(evaluation_results: EvaluationResults):
168+
converted = {}
169+
for r in evaluation_results["results"]:
170+
converted[r.key] = r.score
171+
return converted
172+
173+
174+
# NOTE: this is more of a regression test
175+
def test_scores_regression():
176+
# test most commonly used model
177+
experiment_results = evaluate_model(model_name=OPENAI_MODEL_KEY)
178+
experiment_result_df = pd.DataFrame(
179+
convert_single_example_results(result["evaluation_results"])
180+
for result in experiment_results._results
181+
)
182+
average_scores = experiment_result_df.mean()
183+
184+
assert average_scores[SCORE_RETRIEVAL_RECALL] >= 0.65
185+
assert average_scores[SCORE_ANSWER_CORRECTNESS] >= 0.9
186+
assert average_scores[SCORE_ANSWER_VS_CONTEXT_CORRECTNESS] >= 0.9

0 commit comments

Comments
 (0)