Skip to content

Commit 52db5af

Browse files
committed
Add support for dataclasses and pydantic in UDF/TVF parameters/returns
1 parent 31e8124 commit 52db5af

File tree

5 files changed

+300
-120
lines changed

5 files changed

+300
-120
lines changed

singlestoredb/functions/decorator.py

+118-74
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1+
import dataclasses
12
import datetime
23
import functools
3-
import string
4+
import inspect
45
from typing import Any
56
from typing import Callable
67
from typing import Dict
78
from typing import List
89
from typing import Optional
10+
from typing import Tuple
911
from typing import Union
1012

1113
from . import dtypes
1214
from .dtypes import DataType
15+
from .signature import simplify_dtype
16+
17+
try:
18+
import pydantic
19+
has_pydantic = True
20+
except ImportError:
21+
has_pydantic = False
1322

1423
python_type_map: Dict[Any, Callable[..., str]] = {
1524
str: dtypes.TEXT,
@@ -33,88 +42,123 @@ def listify(x: Any) -> List[Any]:
3342
return [x]
3443

3544

45+
def process_annotation(annotation: Any) -> Tuple[Any, bool]:
46+
types = simplify_dtype(annotation)
47+
if isinstance(types, list):
48+
nullable = False
49+
if type(None) in types:
50+
nullable = True
51+
types = [x for x in types if x is not type(None)]
52+
if len(types) > 1:
53+
raise ValueError(f'multiple types not supported: {annotation}')
54+
return types[0], nullable
55+
return types, True
56+
57+
58+
def process_types(params: Any) -> Any:
59+
if params is None:
60+
return params, []
61+
62+
elif isinstance(params, (list, tuple)):
63+
params = list(params)
64+
for i, item in enumerate(params):
65+
if params[i] in python_type_map:
66+
params[i] = python_type_map[params[i]]()
67+
elif callable(item):
68+
params[i] = item()
69+
for item in params:
70+
if not isinstance(item, str):
71+
raise TypeError(f'unrecognized type for parameter: {item}')
72+
return params, []
73+
74+
elif isinstance(params, dict):
75+
names = []
76+
params = dict(params)
77+
for k, v in list(params.items()):
78+
names.append(k)
79+
if params[k] in python_type_map:
80+
params[k] = python_type_map[params[k]]()
81+
elif callable(v):
82+
params[k] = v()
83+
for item in params.values():
84+
if not isinstance(item, str):
85+
raise TypeError(f'unrecognized type for parameter: {item}')
86+
return params, names
87+
88+
elif dataclasses.is_dataclass(params):
89+
names = []
90+
out = []
91+
for item in dataclasses.fields(params):
92+
typ, nullable = process_annotation(item.type)
93+
sql_type = process_types(typ)[0]
94+
if not nullable:
95+
sql_type = sql_type.replace('NULL', 'NOT NULL')
96+
out.append(sql_type)
97+
names.append(item.name)
98+
return out, names
99+
100+
elif has_pydantic and inspect.isclass(params) \
101+
and issubclass(params, pydantic.BaseModel):
102+
names = []
103+
out = []
104+
for name, item in params.model_fields.items():
105+
typ, nullable = process_annotation(item.annotation)
106+
sql_type = process_types(typ)[0]
107+
if not nullable:
108+
sql_type = sql_type.replace('NULL', 'NOT NULL')
109+
out.append(sql_type)
110+
names.append(name)
111+
return out, names
112+
113+
elif params in python_type_map:
114+
return python_type_map[params](), []
115+
116+
elif callable(params):
117+
return params(), []
118+
119+
elif isinstance(params, str):
120+
return params, []
121+
122+
raise TypeError(f'unrecognized data type for args: {params}')
123+
124+
36125
def _func(
37126
func: Optional[Callable[..., Any]] = None,
38127
*,
39128
name: Optional[str] = None,
40-
args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
41-
returns: Optional[Union[str, List[DataType], List[type]]] = None,
129+
args: Optional[
130+
Union[
131+
DataType,
132+
List[DataType],
133+
Dict[str, DataType],
134+
'pydantic.BaseModel',
135+
type,
136+
]
137+
] = None,
138+
returns: Optional[
139+
Union[
140+
str,
141+
List[DataType],
142+
List[type],
143+
'pydantic.BaseModel',
144+
type,
145+
]
146+
] = None,
42147
data_format: Optional[str] = None,
43148
include_masks: bool = False,
44149
function_type: str = 'udf',
45150
output_fields: Optional[List[str]] = None,
46151
) -> Callable[..., Any]:
47152
"""Generic wrapper for UDF and TVF decorators."""
48-
if args is None:
49-
pass
50-
elif isinstance(args, (list, tuple)):
51-
args = list(args)
52-
for i, item in enumerate(args):
53-
if args[i] in python_type_map:
54-
args[i] = python_type_map[args[i]]()
55-
elif callable(item):
56-
args[i] = item()
57-
for item in args:
58-
if not isinstance(item, str):
59-
raise TypeError(f'unrecognized type for parameter: {item}')
60-
elif isinstance(args, dict):
61-
args = dict(args)
62-
for k, v in list(args.items()):
63-
if args[k] in python_type_map:
64-
args[k] = python_type_map[args[k]]()
65-
elif callable(v):
66-
args[k] = v()
67-
for item in args.values():
68-
if not isinstance(item, str):
69-
raise TypeError(f'unrecognized type for parameter: {item}')
70-
elif args in python_type_map:
71-
args = python_type_map[args]()
72-
elif callable(args):
73-
args = args()
74-
elif isinstance(args, str):
75-
args = args
76-
else:
77-
raise TypeError(f'unrecognized data type for args: {args}')
78-
79-
if returns is None:
80-
pass
81-
elif isinstance(returns, (list, tuple)):
82-
returns = list(returns)
83-
for i, item in enumerate(returns):
84-
if item in python_type_map:
85-
returns[i] = python_type_map[item]()
86-
elif callable(item):
87-
returns[i] = item()
88-
for item in returns:
89-
if not isinstance(item, str):
90-
raise TypeError(f'unrecognized return type: {item}')
91-
elif returns in python_type_map:
92-
returns = python_type_map[returns]()
93-
elif callable(returns):
94-
returns = returns()
95-
elif isinstance(returns, str):
96-
returns = returns
97-
else:
98-
raise TypeError(f'unrecognized return type: {returns}')
99-
100-
if returns is None:
101-
pass
102-
elif isinstance(returns, list):
103-
for item in returns:
104-
if not isinstance(item, str):
105-
raise TypeError(f'unrecognized return type: {item}')
106-
elif not isinstance(returns, str):
107-
raise TypeError(f'unrecognized return type: {returns}')
108-
109-
if not output_fields:
110-
if isinstance(returns, list):
111-
output_fields = []
112-
for i, _ in enumerate(returns):
113-
output_fields.append(string.ascii_letters[i])
114-
else:
115-
output_fields = [string.ascii_letters[0]]
116-
117-
if isinstance(returns, list) and len(output_fields) != len(returns):
153+
args, _ = process_types(args)
154+
returns, fields = process_types(returns)
155+
156+
if not output_fields and fields:
157+
output_fields = fields
158+
159+
if isinstance(returns, list) \
160+
and isinstance(output_fields, list) \
161+
and len(output_fields) != len(returns):
118162
raise ValueError(
119163
'The number of output fields must match the number of return types',
120164
)
@@ -133,7 +177,7 @@ def _func(
133177
data_format=data_format,
134178
include_masks=include_masks,
135179
function_type=function_type,
136-
output_fields=output_fields,
180+
output_fields=output_fields or None,
137181
).items() if v is not None
138182
}
139183

singlestoredb/functions/ext/asgi.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"""
2525
import argparse
2626
import asyncio
27+
import dataclasses
2728
import importlib.util
2829
import io
2930
import itertools
@@ -136,6 +137,14 @@ def get_func_names(funcs: str) -> List[Tuple[str, str]]:
136137
return out
137138

138139

140+
def as_tuple(x: Any) -> Any:
141+
if hasattr(x, 'model_fields'):
142+
return tuple(x.model_fields.values())
143+
if dataclasses.is_dataclass(x):
144+
return dataclasses.astuple(x)
145+
return x
146+
147+
139148
def make_func(
140149
name: str,
141150
func: Callable[..., Any],
@@ -174,7 +183,7 @@ async def do_func(
174183
out_ids: List[int] = []
175184
out = []
176185
for i, res in zip(row_ids, func_map(func, rows)):
177-
out.extend(res)
186+
out.extend(as_tuple(res))
178187
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
179188
return out_ids, out
180189

@@ -234,7 +243,7 @@ async def do_func(
234243
List[Tuple[Any]],
235244
]:
236245
'''Call function on given rows of data.'''
237-
return row_ids, list(zip(func_map(func, rows)))
246+
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]
238247

239248
else:
240249
# Vector formats use the same function wrapper

0 commit comments

Comments
 (0)