@@ -516,85 +516,88 @@ async def model_astream(context: str) -> List[BaseMessageChunk]:
516
516
assert isinstance (result [0 ], AIMessageChunk )
517
517
518
518
519
- def test_json_formatted_output () -> None :
520
- """Test that json_mode works as expected with a json_schema."""
521
-
522
- class MyModel ( typing_extensions . TypedDict ):
523
- item : str
524
- price : float
519
+ def test_output_matches_prompt_keys () -> None :
520
+ """
521
+ Validate that when response_mime_type="application/json" is specified,
522
+ the output is a valid JSON format and contains the expected structure
523
+ based on the prompt.
524
+ """
525
525
526
526
llm = ChatGoogleGenerativeAI (
527
527
model = _VISION_MODEL ,
528
528
response_mime_type = "application/json" ,
529
- response_schema = list [MyModel ],
530
529
)
531
530
531
+ prompt_key_names = {
532
+ "list_key" : "grocery_items" ,
533
+ "item_key" : "item" ,
534
+ "price_key" : "price" ,
535
+ }
536
+
532
537
messages = [
533
538
("system" , "You are a helpful assistant" ),
534
- ("human" , "List the prices of common grocery items" ),
539
+ (
540
+ "human" ,
541
+ "Provide a list of grocery items with key 'grocery_items', "
542
+ "and for each item, use 'item' for the name and 'price' for the cost." ,
543
+ ),
535
544
]
536
545
537
546
response = llm .invoke (messages )
538
- response_data = json .loads (response .content )
539
- assert isinstance (response_data , list )
540
- assert len (response .content ) > 1
541
- assert isinstance (response .content [0 ], MyModel )
542
- for item in response .content :
543
- assert isinstance (item , MyModel )
544
-
545
-
546
- def test_json_formatted_output_with_nested_schema () -> None :
547
- """Test that json_mode works as expected with a nested json_schema."""
548
-
549
- class PriceDetail (typing_extensions .TypedDict ):
550
- amount : float
551
- currency : str
552
-
553
- class MyModel (TypedDict ):
554
- item : str
555
- price : PriceDetail
556
-
557
- llm = ChatGoogleGenerativeAI (
558
- model = _VISION_MODEL ,
559
- response_mime_type = "application/json" ,
560
- response_schema = list [MyModel ],
561
- )
562
547
563
- messages = [
564
- ("system" , "You are a helpful assistant" ),
565
- ("human" , "List the price details of a common grocery item" ),
566
- ]
548
+ # Ensure the response content is a JSON string
549
+ assert isinstance (response .content , str ), "Response content should be a string."
567
550
568
- response = llm . invoke ( messages )
569
- assert isinstance ( response . content , list )
570
- assert len (response .content ) > 0
571
- assert isinstance ( response . content [ 0 ], MyModel )
572
- assert isinstance ( response . content [ 0 ][ "price" ], PriceDetail )
551
+ # Attempt to parse the JSON
552
+ try :
553
+ response_data = json . loads (response .content )
554
+ except json . JSONDecodeError as e :
555
+ pytest . fail ( f"Response is not valid JSON: { e } " )
573
556
557
+ list_key = prompt_key_names ["list_key" ]
558
+ assert list_key in response_data , f"Expected key '{ list_key } ' is missing in the response."
559
+ grocery_items = response_data [list_key ]
560
+ assert isinstance (grocery_items , list ), f"'{ list_key } ' should be a list."
574
561
575
- def test_enum_formatted_output () -> None :
576
- """Test that response_mime_type works as expected with text/x.enum."""
562
+ item_key = prompt_key_names [ "item_key" ]
563
+ price_key = prompt_key_names [ "price_key" ]
577
564
578
- def test_enum_formatted_output () -> None :
579
- """Test that response_mime_type works as expected with text/x.enum."""
580
- import enum
565
+ for item in grocery_items :
566
+ assert isinstance (item , dict ), "Each grocery item should be a dictionary."
567
+ assert item_key in item , f"Each item should have the key '{ item_key } '."
568
+ assert price_key in item , f"Each item should have the key '{ price_key } '."
581
569
582
- class Types (enum .Enum ):
583
- PERCUSSION = "Percussion"
584
- STRING = "String"
585
- WOODWIND = "Woodwind"
586
- BRASS = "Brass"
587
- KEYBOARD = "Keyboard"
570
+ print ("Response matches the key names specified in the prompt." )
588
571
589
- llm = ChatGoogleGenerativeAI (
590
- model = _VISION_MODEL , response_mime_type = "text/x.enum" , response_schema = Types
591
- )
572
+ def test_validate_response_mime_type_and_schema () -> None :
573
+ """
574
+ Test that `response_mime_type` and `response_schema` are validated correctly.
575
+ Ensure valid combinations of `response_mime_type` and `response_schema` pass,
576
+ and invalid ones raise appropriate errors.
577
+ """
592
578
593
- messages = [
594
- ("system" , "You are a helpful assistant" ),
595
- ("human" , "What kind of instrument is an organ?" ),
596
- ]
579
+ valid_model = ChatGoogleGenerativeAI (
580
+ model = "gemini-1.5-pro" ,
581
+ response_mime_type = "application/json" ,
582
+ response_schema = {"type" : "list" , "items" : {"type" : "object" }}, # Example schema
583
+ )
597
584
598
- response = llm .invoke (messages )
599
- assert isinstance (response .content , str )
600
- assert response .content in Types ._value2member_map_
585
+ try :
586
+ valid_model .validate_environment ()
587
+ except ValueError as e :
588
+ pytest .fail (f"Validation failed unexpectedly with valid parameters: { e } " )
589
+
590
+ with pytest .raises (ValueError , match = "response_mime_type must be either .*" ):
591
+ ChatGoogleGenerativeAI (
592
+ model = "gemini-1.5-pro" ,
593
+ response_mime_type = "invalid/type" ,
594
+ response_schema = {"type" : "list" , "items" : {"type" : "object" }},
595
+ ).validate_environment ()
596
+
597
+ try :
598
+ ChatGoogleGenerativeAI (
599
+ model = "gemini-1.5-pro" ,
600
+ response_mime_type = "application/json" ,
601
+ ).validate_environment ()
602
+ except ValueError as e :
603
+ pytest .fail (f"Validation failed unexpectedly with a valid MIME type and no schema: { e } " )
0 commit comments