Skip to content

Commit 8eb227a

Browse files
committed
added tests to verify json feature
1 parent 13184a1 commit 8eb227a

File tree

1 file changed

+65
-62
lines changed

1 file changed

+65
-62
lines changed

libs/genai/tests/integration_tests/test_chat_models.py

Lines changed: 65 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -516,85 +516,88 @@ async def model_astream(context: str) -> List[BaseMessageChunk]:
516516
assert isinstance(result[0], AIMessageChunk)
517517

518518

519-
def test_json_formatted_output() -> None:
520-
"""Test that json_mode works as expected with a json_schema."""
521-
522-
class MyModel(typing_extensions.TypedDict):
523-
item: str
524-
price: float
519+
def test_output_matches_prompt_keys() -> None:
520+
"""
521+
Validate that when response_mime_type="application/json" is specified,
522+
the output is a valid JSON format and contains the expected structure
523+
based on the prompt.
524+
"""
525525

526526
llm = ChatGoogleGenerativeAI(
527527
model=_VISION_MODEL,
528528
response_mime_type="application/json",
529-
response_schema=list[MyModel],
530529
)
531530

531+
prompt_key_names = {
532+
"list_key": "grocery_items",
533+
"item_key": "item",
534+
"price_key": "price",
535+
}
536+
532537
messages = [
533538
("system", "You are a helpful assistant"),
534-
("human", "List the prices of common grocery items"),
539+
(
540+
"human",
541+
"Provide a list of grocery items with key 'grocery_items', "
542+
"and for each item, use 'item' for the name and 'price' for the cost.",
543+
),
535544
]
536545

537546
response = llm.invoke(messages)
538-
response_data = json.loads(response.content)
539-
assert isinstance(response_data, list)
540-
assert len(response.content) > 1
541-
assert isinstance(response.content[0], MyModel)
542-
for item in response.content:
543-
assert isinstance(item, MyModel)
544-
545-
546-
def test_json_formatted_output_with_nested_schema() -> None:
547-
"""Test that json_mode works as expected with a nested json_schema."""
548-
549-
class PriceDetail(typing_extensions.TypedDict):
550-
amount: float
551-
currency: str
552-
553-
class MyModel(TypedDict):
554-
item: str
555-
price: PriceDetail
556-
557-
llm = ChatGoogleGenerativeAI(
558-
model=_VISION_MODEL,
559-
response_mime_type="application/json",
560-
response_schema=list[MyModel],
561-
)
562547

563-
messages = [
564-
("system", "You are a helpful assistant"),
565-
("human", "List the price details of a common grocery item"),
566-
]
548+
# Ensure the response content is a JSON string
549+
assert isinstance(response.content, str), "Response content should be a string."
567550

568-
response = llm.invoke(messages)
569-
assert isinstance(response.content, list)
570-
assert len(response.content) > 0
571-
assert isinstance(response.content[0], MyModel)
572-
assert isinstance(response.content[0]["price"], PriceDetail)
551+
# Attempt to parse the JSON
552+
try:
553+
response_data = json.loads(response.content)
554+
except json.JSONDecodeError as e:
555+
pytest.fail(f"Response is not valid JSON: {e}")
573556

557+
list_key = prompt_key_names["list_key"]
558+
assert list_key in response_data, f"Expected key '{list_key}' is missing in the response."
559+
grocery_items = response_data[list_key]
560+
assert isinstance(grocery_items, list), f"'{list_key}' should be a list."
574561

575-
def test_enum_formatted_output() -> None:
576-
"""Test that response_mime_type works as expected with text/x.enum."""
562+
item_key = prompt_key_names["item_key"]
563+
price_key = prompt_key_names["price_key"]
577564

578-
def test_enum_formatted_output() -> None:
579-
"""Test that response_mime_type works as expected with text/x.enum."""
580-
import enum
565+
for item in grocery_items:
566+
assert isinstance(item, dict), "Each grocery item should be a dictionary."
567+
assert item_key in item, f"Each item should have the key '{item_key}'."
568+
assert price_key in item, f"Each item should have the key '{price_key}'."
581569

582-
class Types(enum.Enum):
583-
PERCUSSION = "Percussion"
584-
STRING = "String"
585-
WOODWIND = "Woodwind"
586-
BRASS = "Brass"
587-
KEYBOARD = "Keyboard"
570+
print("Response matches the key names specified in the prompt.")
588571

589-
llm = ChatGoogleGenerativeAI(
590-
model=_VISION_MODEL, response_mime_type="text/x.enum", response_schema=Types
591-
)
572+
def test_validate_response_mime_type_and_schema() -> None:
573+
"""
574+
Test that `response_mime_type` and `response_schema` are validated correctly.
575+
Ensure valid combinations of `response_mime_type` and `response_schema` pass,
576+
and invalid ones raise appropriate errors.
577+
"""
592578

593-
messages = [
594-
("system", "You are a helpful assistant"),
595-
("human", "What kind of instrument is an organ?"),
596-
]
579+
valid_model = ChatGoogleGenerativeAI(
580+
model="gemini-1.5-pro",
581+
response_mime_type="application/json",
582+
response_schema={"type": "list", "items": {"type": "object"}}, # Example schema
583+
)
597584

598-
response = llm.invoke(messages)
599-
assert isinstance(response.content, str)
600-
assert response.content in Types._value2member_map_
585+
try:
586+
valid_model.validate_environment()
587+
except ValueError as e:
588+
pytest.fail(f"Validation failed unexpectedly with valid parameters: {e}")
589+
590+
with pytest.raises(ValueError, match="response_mime_type must be either .*"):
591+
ChatGoogleGenerativeAI(
592+
model="gemini-1.5-pro",
593+
response_mime_type="invalid/type",
594+
response_schema={"type": "list", "items": {"type": "object"}},
595+
).validate_environment()
596+
597+
try:
598+
ChatGoogleGenerativeAI(
599+
model="gemini-1.5-pro",
600+
response_mime_type="application/json",
601+
).validate_environment()
602+
except ValueError as e:
603+
pytest.fail(f"Validation failed unexpectedly with a valid MIME type and no schema: {e}")

0 commit comments

Comments
 (0)