|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import json |
| 4 | +import logging |
| 5 | +from typing import ( |
| 6 | + Any, |
| 7 | + Dict, |
| 8 | + List, |
| 9 | + Optional, |
| 10 | + Union, |
| 11 | +) |
| 12 | + |
| 13 | +import google.ai.generativelanguage as glm |
| 14 | +import google.cloud.aiplatform_v1beta1.types as gapic |
| 15 | +from langchain_core.tools import BaseTool |
| 16 | +from langchain_core.utils.json_schema import dereference_refs |
| 17 | + |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | +TYPE_ENUM = { |
| 21 | + "string": glm.Type.STRING, |
| 22 | + "number": glm.Type.NUMBER, |
| 23 | + "integer": glm.Type.INTEGER, |
| 24 | + "boolean": glm.Type.BOOLEAN, |
| 25 | + "array": glm.Type.ARRAY, |
| 26 | + "object": glm.Type.OBJECT, |
| 27 | + "null": None, |
| 28 | +} |
| 29 | + |
| 30 | + |
| 31 | +_ALLOWED_SCHEMA_FIELDS = [] |
| 32 | +_ALLOWED_SCHEMA_FIELDS.extend([f.name for f in gapic.Schema()._pb.DESCRIPTOR.fields]) |
| 33 | +_ALLOWED_SCHEMA_FIELDS.extend( |
| 34 | + [ |
| 35 | + f |
| 36 | + for f in gapic.Schema.to_dict( |
| 37 | + gapic.Schema(), preserving_proto_field_name=False |
| 38 | + ).keys() |
| 39 | + ] |
| 40 | +) |
| 41 | +_ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS) |
| 42 | + |
| 43 | + |
| 44 | +def dict_to_gapic_json_schema( |
| 45 | + schema: Dict[str, Any], pydantic_version: str = "v1" |
| 46 | +) -> str: |
| 47 | + # Resolve refs in schema because $refs and $defs are not supported |
| 48 | + # by the Gemini API. |
| 49 | + dereferenced_schema = dereference_refs(schema) |
| 50 | + |
| 51 | + if pydantic_version == "v1": |
| 52 | + formatted_schema = _format_json_schema_to_gapic_v1(dereferenced_schema) |
| 53 | + else: |
| 54 | + formatted_schema = _format_json_schema_to_gapic(dereferenced_schema) |
| 55 | + |
| 56 | + return json.dumps(formatted_schema) |
| 57 | + |
| 58 | + |
| 59 | +def _format_json_schema_to_gapic_v1(schema: Dict[str, Any]) -> Dict[str, Any]: |
| 60 | + """Format a JSON schema from a Pydantic V1 BaseModel to gapic.""" |
| 61 | + converted_schema: Dict[str, Any] = {} |
| 62 | + for key, value in schema.items(): |
| 63 | + if key == "definitions": |
| 64 | + continue |
| 65 | + elif key == "items": |
| 66 | + converted_schema["items"] = _format_json_schema_to_gapic_v1(value) |
| 67 | + elif key == "properties": |
| 68 | + converted_schema["properties"] = _get_properties_from_schema(value) |
| 69 | + continue |
| 70 | + elif key in ["type", "_type"]: |
| 71 | + converted_schema["type"] = str(value).upper() |
| 72 | + elif key == "allOf": |
| 73 | + if len(value) > 1: |
| 74 | + logger.warning( |
| 75 | + "Only first value for 'allOf' key is supported. " |
| 76 | + f"Got {len(value)}, ignoring other than first value!" |
| 77 | + ) |
| 78 | + return _format_json_schema_to_gapic_v1(value[0]) |
| 79 | + elif key not in _ALLOWED_SCHEMA_FIELDS_SET: |
| 80 | + logger.warning(f"Key '{key}' is not supported in schema, ignoring") |
| 81 | + else: |
| 82 | + converted_schema[key] = value |
| 83 | + return converted_schema |
| 84 | + |
| 85 | + |
| 86 | +def _format_json_schema_to_gapic( |
| 87 | + schema: Dict[str, Any], |
| 88 | + parent_key: Optional[str] = None, |
| 89 | + required_fields: Optional[list] = None, |
| 90 | +) -> Dict[str, Any]: |
| 91 | + """Format a JSON schema from a Pydantic V2 BaseModel to gapic.""" |
| 92 | + converted_schema: Dict[str, Any] = {} |
| 93 | + for key, value in schema.items(): |
| 94 | + if key == "$defs": |
| 95 | + continue |
| 96 | + elif key == "items": |
| 97 | + converted_schema["items"] = _format_json_schema_to_gapic( |
| 98 | + value, parent_key, required_fields |
| 99 | + ) |
| 100 | + elif key == "properties": |
| 101 | + if "properties" not in converted_schema: |
| 102 | + converted_schema["properties"] = {} |
| 103 | + for pkey, pvalue in value.items(): |
| 104 | + converted_schema["properties"][pkey] = _format_json_schema_to_gapic( |
| 105 | + pvalue, pkey, schema.get("required", []) |
| 106 | + ) |
| 107 | + continue |
| 108 | + elif key in ["type", "_type"]: |
| 109 | + converted_schema["type"] = str(value).upper() |
| 110 | + elif key == "allOf": |
| 111 | + if len(value) > 1: |
| 112 | + logger.warning( |
| 113 | + "Only first value for 'allOf' key is supported. " |
| 114 | + f"Got {len(value)}, ignoring other than first value!" |
| 115 | + ) |
| 116 | + return _format_json_schema_to_gapic(value[0], parent_key, required_fields) |
| 117 | + elif key == "anyOf": |
| 118 | + if len(value) == 2 and any(v.get("type") == "null" for v in value): |
| 119 | + non_null_type = next(v for v in value if v.get("type") != "null") |
| 120 | + converted_schema.update( |
| 121 | + _format_json_schema_to_gapic( |
| 122 | + non_null_type, parent_key, required_fields |
| 123 | + ) |
| 124 | + ) |
| 125 | + # Remove the field from required if it exists |
| 126 | + if required_fields and parent_key in required_fields: |
| 127 | + required_fields.remove(parent_key) |
| 128 | + continue |
| 129 | + elif key not in _ALLOWED_SCHEMA_FIELDS_SET: |
| 130 | + logger.warning(f"Key '{key}' is not supported in schema, ignoring") |
| 131 | + else: |
| 132 | + converted_schema[key] = value |
| 133 | + return converted_schema |
| 134 | + |
| 135 | + |
| 136 | +# Get Properties from Schema |
| 137 | +def _get_properties_from_schema_any(schema: Any) -> Dict[str, Any]: |
| 138 | + if isinstance(schema, Dict): |
| 139 | + return _get_properties_from_schema(schema) |
| 140 | + return {} |
| 141 | + |
| 142 | + |
| 143 | +def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]: |
| 144 | + properties: Dict[str, Any] = {} |
| 145 | + for k, v in schema.items(): |
| 146 | + if not isinstance(k, str): |
| 147 | + logger.warning(f"Key '{k}' is not supported in schema, type={type(k)}") |
| 148 | + continue |
| 149 | + if not isinstance(v, Dict): |
| 150 | + logger.warning(f"Value '{v}' is not supported in schema, ignoring v={v}") |
| 151 | + continue |
| 152 | + properties_item: Dict[str, Union[str, int, Dict, List]] = {} |
| 153 | + if v.get("type") or v.get("anyOf") or v.get("type_"): |
| 154 | + item_type_ = _get_type_from_schema(v) |
| 155 | + properties_item["type_"] = item_type_ |
| 156 | + if _is_nullable_schema(v): |
| 157 | + properties_item["nullable"] = True |
| 158 | + |
| 159 | + # Replace `v` with chosen definition for array / object json types |
| 160 | + any_of_types = v.get("anyOf") |
| 161 | + if any_of_types and item_type_ in [glm.Type.ARRAY, glm.Type.OBJECT]: |
| 162 | + json_type_ = "array" if item_type_ == glm.Type.ARRAY else "object" |
| 163 | + # Use Index -1 for consistency with `_get_nullable_type_from_schema` |
| 164 | + v = [val for val in any_of_types if val.get("type") == json_type_][-1] |
| 165 | + |
| 166 | + if v.get("enum"): |
| 167 | + properties_item["enum"] = v["enum"] |
| 168 | + |
| 169 | + v_title = v.get("title") |
| 170 | + if v_title and isinstance(v_title, str): |
| 171 | + properties_item["title"] = v_title |
| 172 | + |
| 173 | + description = v.get("description") |
| 174 | + if description and isinstance(description, str): |
| 175 | + properties_item["description"] = description |
| 176 | + |
| 177 | + if properties_item.get("type_") == glm.Type.ARRAY and v.get("items"): |
| 178 | + properties_item["items"] = _get_items_from_schema_any(v.get("items")) |
| 179 | + |
| 180 | + if properties_item.get("type_") == glm.Type.OBJECT: |
| 181 | + if ( |
| 182 | + v.get("anyOf") |
| 183 | + and isinstance(v["anyOf"], list) |
| 184 | + and isinstance(v["anyOf"][0], dict) |
| 185 | + ): |
| 186 | + v = v["anyOf"][0] |
| 187 | + v_properties = v.get("properties") |
| 188 | + if v_properties: |
| 189 | + properties_item["properties"] = _get_properties_from_schema_any( |
| 190 | + v_properties |
| 191 | + ) |
| 192 | + if isinstance(v_properties, dict): |
| 193 | + properties_item["required"] = [ |
| 194 | + k for k, v in v_properties.items() if "default" not in v |
| 195 | + ] |
| 196 | + else: |
| 197 | + # Providing dummy type for object without properties |
| 198 | + properties_item["type_"] = glm.Type.STRING |
| 199 | + |
| 200 | + if k == "title" and "description" not in properties_item: |
| 201 | + properties_item["description"] = k + " is " + str(v) |
| 202 | + properties[k] = properties_item |
| 203 | + |
| 204 | + return properties |
| 205 | + |
| 206 | + |
| 207 | +def _get_items_from_schema_any(schema: Any) -> Dict[str, Any]: |
| 208 | + if isinstance(schema, (dict, list, str)): |
| 209 | + return _get_items_from_schema(schema) |
| 210 | + return {} |
| 211 | + |
| 212 | + |
| 213 | +def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]: |
| 214 | + items: Dict = {} |
| 215 | + if isinstance(schema, List): |
| 216 | + for i, v in enumerate(schema): |
| 217 | + items[f"item{i}"] = _get_properties_from_schema_any(v) |
| 218 | + elif isinstance(schema, Dict): |
| 219 | + items["type_"] = _get_type_from_schema(schema) |
| 220 | + if items["type_"] == glm.Type.OBJECT and "properties" in schema: |
| 221 | + items["properties"] = _get_properties_from_schema_any(schema["properties"]) |
| 222 | + if items["type_"] == glm.Type.ARRAY and "items" in schema: |
| 223 | + items["items"] = _format_json_schema_to_gapic_v1(schema["items"]) |
| 224 | + if "title" in schema or "description" in schema: |
| 225 | + items["description"] = ( |
| 226 | + schema.get("description") or schema.get("title") or "" |
| 227 | + ) |
| 228 | + if _is_nullable_schema(schema): |
| 229 | + items["nullable"] = True |
| 230 | + if "required" in schema: |
| 231 | + items["required"] = schema["required"] |
| 232 | + else: |
| 233 | + # str |
| 234 | + items["type_"] = _get_type_from_schema({"type": schema}) |
| 235 | + if _is_nullable_schema({"type": schema}): |
| 236 | + items["nullable"] = True |
| 237 | + |
| 238 | + return items |
| 239 | + |
| 240 | + |
| 241 | +def _get_type_from_schema(schema: Dict[str, Any]) -> int: |
| 242 | + return _get_nullable_type_from_schema(schema) or glm.Type.STRING |
| 243 | + |
| 244 | + |
| 245 | +def _get_nullable_type_from_schema(schema: Dict[str, Any]) -> Optional[int]: |
| 246 | + if "anyOf" in schema: |
| 247 | + types = [ |
| 248 | + _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"] |
| 249 | + ] |
| 250 | + types = [t for t in types if t is not None] # Remove None values |
| 251 | + if types: |
| 252 | + return types[-1] # TODO: update FunctionDeclaration and pass all types? |
| 253 | + else: |
| 254 | + pass |
| 255 | + elif "type" in schema or "type_" in schema: |
| 256 | + type_ = schema["type"] if "type" in schema else schema["type_"] |
| 257 | + if isinstance(type_, int): |
| 258 | + return type_ |
| 259 | + stype = str(schema["type"]) if "type" in schema else str(schema["type_"]) |
| 260 | + return TYPE_ENUM.get(stype, glm.Type.STRING) |
| 261 | + else: |
| 262 | + pass |
| 263 | + return glm.Type.STRING # Default to string if no valid types found |
| 264 | + |
| 265 | + |
| 266 | +def _is_nullable_schema(schema: Dict[str, Any]) -> bool: |
| 267 | + if "anyOf" in schema: |
| 268 | + types = [ |
| 269 | + _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"] |
| 270 | + ] |
| 271 | + return any(t is None for t in types) |
| 272 | + elif "type" in schema or "type_" in schema: |
| 273 | + type_ = schema["type"] if "type" in schema else schema["type_"] |
| 274 | + if isinstance(type_, int): |
| 275 | + return False |
| 276 | + stype = str(schema["type"]) if "type" in schema else str(schema["type_"]) |
| 277 | + return TYPE_ENUM.get(stype, glm.Type.STRING) is None |
| 278 | + else: |
| 279 | + pass |
| 280 | + return False |
0 commit comments