|
7 | 7 | from langchain_core.language_models.chat_models import BaseChatModel
|
8 | 8 | from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
9 | 9 | from langchain_core.outputs import ChatGeneration, ChatResult
|
| 10 | +from langchain_core.runnables import RunnableConfig |
10 | 11 | from langchain_core.tools import BaseTool, tool
|
| 12 | +from langgraph.graph import MessagesState, StateGraph |
11 | 13 | from langgraph.prebuilt import create_react_agent
|
12 | 14 |
|
13 | 15 | from langgraph_supervisor import create_supervisor
|
@@ -545,3 +547,59 @@ def get_tool_calls(msg):
|
545 | 547 | },
|
546 | 548 | ]
|
547 | 549 | assert received == expected
|
| 550 | + |
| 551 | + |
| 552 | +def test_metadata_passed_to_subagent() -> None: |
| 553 | + """Test that metadata from config is passed to sub-agents. |
| 554 | +
|
| 555 | + This test verifies that when a config object with metadata is passed to the supervisor, |
| 556 | + the metadata is correctly passed to the sub-agent when it is invoked. |
| 557 | + """ |
| 558 | + |
| 559 | + # Create a tracking agent to verify metadata is passed |
| 560 | + def test_node(_state: MessagesState, config: RunnableConfig): |
| 561 | + # Assert that the metadata is passed to the sub-agent |
| 562 | + assert config["metadata"]["test_key"] == "test_value" |
| 563 | + assert config["metadata"]["another_key"] == 123 |
| 564 | + # Return a new message if the assertion passes. |
| 565 | + return {"messages": [AIMessage(content="Test response")]} |
| 566 | + |
| 567 | + tracking_agent_workflow = StateGraph(MessagesState) |
| 568 | + tracking_agent_workflow.add_node("test_node", test_node) |
| 569 | + tracking_agent_workflow.set_entry_point("test_node") |
| 570 | + tracking_agent_workflow.set_finish_point("test_node") |
| 571 | + tracking_agent = tracking_agent_workflow.compile() |
| 572 | + tracking_agent.name = "test_agent" |
| 573 | + |
| 574 | + # Create a supervisor with the tracking agent |
| 575 | + supervisor_model = FakeChatModel( |
| 576 | + responses=[ |
| 577 | + AIMessage( |
| 578 | + content="", |
| 579 | + tool_calls=[ |
| 580 | + { |
| 581 | + "name": "transfer_to_test_agent", |
| 582 | + "args": {}, |
| 583 | + "id": "call_123", |
| 584 | + "type": "tool_call", |
| 585 | + } |
| 586 | + ], |
| 587 | + ), |
| 588 | + AIMessage(content="Final response"), |
| 589 | + ] |
| 590 | + ) |
| 591 | + |
| 592 | + supervisor = create_supervisor( |
| 593 | + agents=[tracking_agent], |
| 594 | + model=supervisor_model, |
| 595 | + ).compile() |
| 596 | + |
| 597 | + # Create config with metadata |
| 598 | + test_metadata = {"test_key": "test_value", "another_key": 123} |
| 599 | + config = {"metadata": test_metadata} |
| 600 | + |
| 601 | + # Invoke the supervisor with the config |
| 602 | + result = supervisor.invoke({"messages": [HumanMessage(content="Test message")]}, config=config) |
| 603 | + # Get the last message in the messages list & verify it matches the value |
| 604 | + # returned from the node. |
| 605 | + assert result["messages"][-1].content == "Final response" |
0 commit comments