|
| 1 | +from typing import Any |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | +from langchain_anthropic import ChatAnthropic |
| 5 | +from langchain_core.documents import Document |
| 6 | +from langchain_core.messages import AIMessage |
| 7 | +from langchain_core.prompts import ChatPromptTemplate |
| 8 | +from langchain_core.pydantic_v1 import BaseModel, Field |
| 9 | +from langsmith.evaluation import EvaluationResults, evaluate |
| 10 | +from langsmith.schemas import Example, Run |
| 11 | + |
| 12 | +from backend.graph import OPENAI_MODEL_KEY, format_docs, graph |
| 13 | + |
| 14 | +DATASET_NAME = "chat-langchain-qa" |
| 15 | +EXPERIMENT_PREFIX = "chat-langchain-ci" |
| 16 | + |
| 17 | +SCORE_RETRIEVAL_RECALL = "retrieval_recall" |
| 18 | +SCORE_ANSWER_CORRECTNESS = "answer_correctness_score" |
| 19 | +SCORE_ANSWER_VS_CONTEXT_CORRECTNESS = "answer_vs_context_correctness_score" |
| 20 | + |
| 21 | +# claude sonnet / gpt-4o are a bit too expensive |
| 22 | +JUDGE_MODEL_NAME = "claude-3-haiku-20240307" |
| 23 | + |
| 24 | +judge_llm = ChatAnthropic(model_name=JUDGE_MODEL_NAME) |
| 25 | + |
| 26 | + |
| 27 | +# Evaluate retrieval |
| 28 | + |
| 29 | + |
| 30 | +def evaluate_retrieval_recall(run: Run, example: Example) -> dict: |
| 31 | + documents: list[Document] = run.outputs.get("documents") or [] |
| 32 | + sources = [doc.metadata["source"] for doc in documents] |
| 33 | + expected_sources = set(example.outputs.get("sources") or []) |
| 34 | + # NOTE: since we're currently assuming only ~1 correct document per question |
| 35 | + # this score is equivalent to recall @K where K is number of retrieved documents |
| 36 | + score = float(any(source in expected_sources for source in sources)) |
| 37 | + return {"key": SCORE_RETRIEVAL_RECALL, "score": score} |
| 38 | + |
| 39 | + |
| 40 | +# QA Evaluation Schema |
| 41 | + |
| 42 | + |
| 43 | +class GradeAnswer(BaseModel): |
| 44 | + """Evaluate correctness of the answer and assign a continuous score.""" |
| 45 | + |
| 46 | + reason: str = Field( |
| 47 | + description="1-2 short sentences with the reason why the score was assigned" |
| 48 | + ) |
| 49 | + score: float = Field( |
| 50 | + description="Score that shows how correct the answer is. Use 1.0 if completely correct and 0.0 if completely incorrect", |
| 51 | + minimum=0.0, |
| 52 | + maximum=1.0, |
| 53 | + ) |
| 54 | + |
| 55 | + |
| 56 | +# Evaluate the answer based on the reference answers |
| 57 | + |
| 58 | + |
| 59 | +QA_SYSTEM_PROMPT = """You are an expert programmer and problem-solver, tasked with grading answers to questions about Langchain. |
| 60 | +You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT. |
| 61 | +
|
| 62 | +Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements.""" |
| 63 | + |
| 64 | +QA_PROMPT = ChatPromptTemplate.from_messages( |
| 65 | + [ |
| 66 | + ("system", QA_SYSTEM_PROMPT), |
| 67 | + ( |
| 68 | + "human", |
| 69 | + "QUESTION: \n\n {question} \n\n TRUE ANSWER: {true_answer} \n\n STUDENT ANSWER: {answer}", |
| 70 | + ), |
| 71 | + ] |
| 72 | +) |
| 73 | + |
| 74 | +qa_chain = QA_PROMPT | judge_llm.with_structured_output(GradeAnswer) |
| 75 | + |
| 76 | + |
| 77 | +def evaluate_qa(run: Run, example: Example) -> dict: |
| 78 | + messages = run.outputs.get("messages") or [] |
| 79 | + if not messages: |
| 80 | + return {"score": 0.0} |
| 81 | + |
| 82 | + last_message = messages[-1] |
| 83 | + if not isinstance(last_message, AIMessage): |
| 84 | + return {"score": 0.0} |
| 85 | + |
| 86 | + score: GradeAnswer = qa_chain.invoke( |
| 87 | + { |
| 88 | + "question": example.inputs["question"], |
| 89 | + "true_answer": example.outputs["answer"], |
| 90 | + "answer": last_message.content, |
| 91 | + } |
| 92 | + ) |
| 93 | + return {"key": SCORE_ANSWER_CORRECTNESS, "score": float(score.score)} |
| 94 | + |
| 95 | + |
| 96 | +# Evaluate the answer based on the provided context |
| 97 | + |
| 98 | +CONTEXT_QA_SYSTEM_PROMPT = """You are an expert programmer and problem-solver, tasked with grading answers to questions about Langchain. |
| 99 | +You are given a question, the context for answering the question, and the student's answer. You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context. |
| 100 | +
|
| 101 | +Grade the student answer BOTH based on its factual accuracy AND on whether it is supported by the context. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements.""" |
| 102 | + |
| 103 | +CONTEXT_QA_PROMPT = ChatPromptTemplate.from_messages( |
| 104 | + [ |
| 105 | + ("system", CONTEXT_QA_SYSTEM_PROMPT), |
| 106 | + ( |
| 107 | + "human", |
| 108 | + "QUESTION: \n\n {question} \n\n CONTEXT: {context} \n\n STUDENT ANSWER: {answer}", |
| 109 | + ), |
| 110 | + ] |
| 111 | +) |
| 112 | + |
| 113 | +context_qa_chain = CONTEXT_QA_PROMPT | judge_llm.with_structured_output(GradeAnswer) |
| 114 | + |
| 115 | + |
| 116 | +def evaluate_qa_context(run: Run, example: Example) -> dict: |
| 117 | + messages = run.outputs.get("messages") or [] |
| 118 | + if not messages: |
| 119 | + return {"score": 0.0} |
| 120 | + |
| 121 | + documents = run.outputs.get("documents") or [] |
| 122 | + if not documents: |
| 123 | + return {"score": 0.0} |
| 124 | + |
| 125 | + context = format_docs(documents) |
| 126 | + |
| 127 | + last_message = messages[-1] |
| 128 | + if not isinstance(last_message, AIMessage): |
| 129 | + return {"score": 0.0} |
| 130 | + |
| 131 | + score: GradeAnswer = context_qa_chain.invoke( |
| 132 | + { |
| 133 | + "question": example.inputs["question"], |
| 134 | + "context": context, |
| 135 | + "answer": last_message.content, |
| 136 | + } |
| 137 | + ) |
| 138 | + return {"key": SCORE_ANSWER_VS_CONTEXT_CORRECTNESS, "score": float(score.score)} |
| 139 | + |
| 140 | + |
| 141 | +# Run evaluation |
| 142 | + |
| 143 | + |
| 144 | +def run_graph(inputs: dict[str, Any], model_name: str) -> dict[str, Any]: |
| 145 | + results = graph.invoke( |
| 146 | + {"messages": [("human", inputs["question"])]}, |
| 147 | + config={"configurable": {"model_name": model_name}}, |
| 148 | + ) |
| 149 | + return results |
| 150 | + |
| 151 | + |
| 152 | +def evaluate_model(*, model_name: str): |
| 153 | + results = evaluate( |
| 154 | + lambda inputs: run_graph(inputs, model_name=model_name), |
| 155 | + data=DATASET_NAME, |
| 156 | + evaluators=[evaluate_retrieval_recall, evaluate_qa, evaluate_qa_context], |
| 157 | + experiment_prefix=EXPERIMENT_PREFIX, |
| 158 | + metadata={"model_name": model_name, "judge_model_name": JUDGE_MODEL_NAME}, |
| 159 | + max_concurrency=4, |
| 160 | + ) |
| 161 | + return results |
| 162 | + |
| 163 | + |
| 164 | +# Check results |
| 165 | + |
| 166 | + |
| 167 | +def convert_single_example_results(evaluation_results: EvaluationResults): |
| 168 | + converted = {} |
| 169 | + for r in evaluation_results["results"]: |
| 170 | + converted[r.key] = r.score |
| 171 | + return converted |
| 172 | + |
| 173 | + |
| 174 | +# NOTE: this is more of a regression test |
| 175 | +def test_scores_regression(): |
| 176 | + # test most commonly used model |
| 177 | + experiment_results = evaluate_model(model_name=OPENAI_MODEL_KEY) |
| 178 | + experiment_result_df = pd.DataFrame( |
| 179 | + convert_single_example_results(result["evaluation_results"]) |
| 180 | + for result in experiment_results._results |
| 181 | + ) |
| 182 | + average_scores = experiment_result_df.mean() |
| 183 | + |
| 184 | + assert average_scores[SCORE_RETRIEVAL_RECALL] >= 0.65 |
| 185 | + assert average_scores[SCORE_ANSWER_CORRECTNESS] >= 0.9 |
| 186 | + assert average_scores[SCORE_ANSWER_VS_CONTEXT_CORRECTNESS] >= 0.9 |
0 commit comments