Skip to content

Commit ba86736

Browse files
authored
UDF annotation work (#61)
* Refactor to use Table / Masked types * Downgrade mypy for Python 3.8 * Fix annotations for older versions of Python * Fix annotations for older versions of Python * Fix annotations for older versions of Python * Fix annotations for older versions of Python * Fix annotations for older versions of Python * Add null masks * Fix difference in numpy detection * Short circuit common valid types * Add 3.13 checks * Update autopep * Add 3.13 to smoke tests * Fix Table wrappers * Fix masks in table results * Vector utility functions
1 parent bdc06fb commit ba86736

18 files changed

+1339
-477
lines changed

.github/workflows/pre-commit.yml

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ jobs:
1414
- "3.10"
1515
- "3.11"
1616
- "3.12"
17+
- "3.13"
1718

1819
steps:
1920
- uses: actions/checkout@v3

.github/workflows/smoke-test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ jobs:
5454
- "3.10"
5555
- "3.11"
5656
- "3.12"
57+
- "3.13"
5758
driver:
5859
- mysql
5960
- https

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ repos:
2121
exclude: singlestoredb/clients/pymysqlsv/
2222
additional_dependencies: [flake8-typing-imports==1.12.0]
2323
- repo: https://github.com/hhatto/autopep8
24-
rev: v2.0.4
24+
rev: v2.3.1
2525
hooks:
2626
- id: autopep8
2727
args: [--diff]
@@ -40,7 +40,7 @@ repos:
4040
hooks:
4141
- id: setup-cfg-fmt
4242
- repo: https://github.com/pre-commit/mirrors-mypy
43-
rev: v1.6.1
43+
rev: v1.14.1
4444
hooks:
4545
- id: mypy
4646
additional_dependencies: [types-requests]

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ requests
55
setuptools
66
sqlparams
77
tomli>=1.1.0; python_version < '3.11'
8+
typing_extensions<=4.13.2
89
wheel

setup.cfg

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ install_requires =
2727
sqlparams
2828
wheel
2929
tomli>=1.1.0;python_version < '3.11'
30+
typing-extensions<=4.13.2;python_version < '3.11'
3031
python_requires = >=3.8
3132
include_package_data = True
3233
tests_require =

singlestoredb/functions/__init__.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from .decorator import tvf # noqa: F401
2-
from .decorator import tvf_with_null_masks # noqa: F401
31
from .decorator import udf # noqa: F401
4-
from .decorator import udf_with_null_masks # noqa: F401
52
from .typing import Masked # noqa: F401
6-
from .typing import MaskedNDArray # noqa: F401
3+
from .typing import Table # noqa: F401
4+
from .utils import VectorTypes
5+
6+
7+
F32 = VectorTypes.F32
8+
F64 = VectorTypes.F64
9+
I8 = VectorTypes.I8
10+
I16 = VectorTypes.I16
11+
I32 = VectorTypes.I32
12+
I64 = VectorTypes.I64

singlestoredb/functions/decorator.py

+5-194
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import inspect
3-
import typing
43
from typing import Any
54
from typing import Callable
65
from typing import List
@@ -60,40 +59,6 @@ def is_valid_callable(obj: Any) -> bool:
6059
)
6160

6261

63-
def verify_mask(obj: Any) -> bool:
64-
"""Verify that the object is a tuple of two vector types."""
65-
if typing.get_origin(obj) is not tuple or len(typing.get_args(obj)) != 2:
66-
raise TypeError(
67-
f'Expected a tuple of two vector types, but got {type(obj)}',
68-
)
69-
70-
args = typing.get_args(obj)
71-
72-
if not utils.is_vector(args[0]):
73-
raise TypeError(
74-
f'Expected a vector type for the first element, but got {args[0]}',
75-
)
76-
77-
if not utils.is_vector(args[1]):
78-
raise TypeError(
79-
f'Expected a vector type for the second element, but got {args[1]}',
80-
)
81-
82-
return True
83-
84-
85-
def verify_masks(obj: Callable[..., Any]) -> bool:
86-
"""Verify that the function parameters and return value are all masks."""
87-
ann = utils.get_annotations(obj)
88-
for name, value in ann.items():
89-
if not verify_mask(value):
90-
raise TypeError(
91-
f'Expected a vector type for the parameter {name} '
92-
f'in function {obj.__name__}, but got {value}',
93-
)
94-
return True
95-
96-
9762
def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
9863
"""Expand the types for the function arguments / return values."""
9964
if args is None:
@@ -135,8 +100,6 @@ def _func(
135100
name: Optional[str] = None,
136101
args: Optional[ParameterType] = None,
137102
returns: Optional[ReturnType] = None,
138-
with_null_masks: bool = False,
139-
function_type: str = 'udf',
140103
) -> Callable[..., Any]:
141104
"""Generic wrapper for UDF and TVF decorators."""
142105

@@ -145,8 +108,6 @@ def _func(
145108
name=name,
146109
args=expand_types(args),
147110
returns=expand_types(returns),
148-
with_null_masks=with_null_masks,
149-
function_type=function_type,
150111
).items() if v is not None
151112
}
152113

@@ -155,8 +116,6 @@ def _func(
155116
# in at that time.
156117
if func is None:
157118
def decorate(func: Callable[..., Any]) -> Callable[..., Any]:
158-
if with_null_masks:
159-
verify_masks(func)
160119

161120
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
162121
return func(*args, **kwargs) # type: ignore
@@ -167,9 +126,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
167126

168127
return decorate
169128

170-
if with_null_masks:
171-
verify_masks(func)
172-
173129
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
174130
return func(*args, **kwargs) # type: ignore
175131

@@ -194,151 +150,7 @@ def udf(
194150
The UDF to apply parameters to
195151
name : str, optional
196152
The name to use for the UDF in the database
197-
args : str | Callable | List[str | Callable], optional
198-
Specifies the data types of the function arguments. Typically,
199-
the function data types are derived from the function parameter
200-
annotations. These annotations can be overridden. If the function
201-
takes a single type for all parameters, `args` can be set to a
202-
SQL string describing all parameters. If the function takes more
203-
than one parameter and all of the parameters are being manually
204-
defined, a list of SQL strings may be used (one for each parameter).
205-
A dictionary of SQL strings may be used to specify a parameter type
206-
for a subset of parameters; the keys are the names of the
207-
function parameters. Callables may also be used for datatypes. This
208-
is primarily for using the functions in the ``dtypes`` module that
209-
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
210-
returns : str, optional
211-
Specifies the return data type of the function. If not specified,
212-
the type annotation from the function is used.
213-
214-
Returns
215-
-------
216-
Callable
217-
218-
"""
219-
return _func(
220-
func=func,
221-
name=name,
222-
args=args,
223-
returns=returns,
224-
with_null_masks=False,
225-
function_type='udf',
226-
)
227-
228-
229-
def udf_with_null_masks(
230-
func: Optional[Callable[..., Any]] = None,
231-
*,
232-
name: Optional[str] = None,
233-
args: Optional[ParameterType] = None,
234-
returns: Optional[ReturnType] = None,
235-
) -> Callable[..., Any]:
236-
"""
237-
Define a user-defined function (UDF) with null masks.
238-
239-
Parameters
240-
----------
241-
func : callable, optional
242-
The UDF to apply parameters to
243-
name : str, optional
244-
The name to use for the UDF in the database
245-
args : str | Callable | List[str | Callable], optional
246-
Specifies the data types of the function arguments. Typically,
247-
the function data types are derived from the function parameter
248-
annotations. These annotations can be overridden. If the function
249-
takes a single type for all parameters, `args` can be set to a
250-
SQL string describing all parameters. If the function takes more
251-
than one parameter and all of the parameters are being manually
252-
defined, a list of SQL strings may be used (one for each parameter).
253-
A dictionary of SQL strings may be used to specify a parameter type
254-
for a subset of parameters; the keys are the names of the
255-
function parameters. Callables may also be used for datatypes. This
256-
is primarily for using the functions in the ``dtypes`` module that
257-
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
258-
returns : str, optional
259-
Specifies the return data type of the function. If not specified,
260-
the type annotation from the function is used.
261-
262-
Returns
263-
-------
264-
Callable
265-
266-
"""
267-
return _func(
268-
func=func,
269-
name=name,
270-
args=args,
271-
returns=returns,
272-
with_null_masks=True,
273-
function_type='udf',
274-
)
275-
276-
277-
def tvf(
278-
func: Optional[Callable[..., Any]] = None,
279-
*,
280-
name: Optional[str] = None,
281-
args: Optional[ParameterType] = None,
282-
returns: Optional[ReturnType] = None,
283-
) -> Callable[..., Any]:
284-
"""
285-
Define a table-valued function (TVF).
286-
287-
Parameters
288-
----------
289-
func : callable, optional
290-
The TVF to apply parameters to
291-
name : str, optional
292-
The name to use for the TVF in the database
293-
args : str | Callable | List[str | Callable], optional
294-
Specifies the data types of the function arguments. Typically,
295-
the function data types are derived from the function parameter
296-
annotations. These annotations can be overridden. If the function
297-
takes a single type for all parameters, `args` can be set to a
298-
SQL string describing all parameters. If the function takes more
299-
than one parameter and all of the parameters are being manually
300-
defined, a list of SQL strings may be used (one for each parameter).
301-
A dictionary of SQL strings may be used to specify a parameter type
302-
for a subset of parameters; the keys are the names of the
303-
function parameters. Callables may also be used for datatypes. This
304-
is primarily for using the functions in the ``dtypes`` module that
305-
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
306-
returns : str, optional
307-
Specifies the return data type of the function. If not specified,
308-
the type annotation from the function is used.
309-
310-
Returns
311-
-------
312-
Callable
313-
314-
"""
315-
return _func(
316-
func=func,
317-
name=name,
318-
args=args,
319-
returns=returns,
320-
with_null_masks=False,
321-
function_type='tvf',
322-
)
323-
324-
325-
def tvf_with_null_masks(
326-
func: Optional[Callable[..., Any]] = None,
327-
*,
328-
name: Optional[str] = None,
329-
args: Optional[ParameterType] = None,
330-
returns: Optional[ReturnType] = None,
331-
) -> Callable[..., Any]:
332-
"""
333-
Define a table-valued function (TVF) using null masks.
334-
335-
Parameters
336-
----------
337-
func : callable, optional
338-
The TVF to apply parameters to
339-
name : str, optional
340-
The name to use for the TVF in the database
341-
args : str | Callable | List[str | Callable], optional
153+
args : str | Type | Callable | List[str | Callable], optional
342154
Specifies the data types of the function arguments. Typically,
343155
the function data types are derived from the function parameter
344156
annotations. These annotations can be overridden. If the function
@@ -351,9 +163,10 @@ def tvf_with_null_masks(
351163
function parameters. Callables may also be used for datatypes. This
352164
is primarily for using the functions in the ``dtypes`` module that
353165
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
354-
returns : str, optional
355-
Specifies the return data type of the function. If not specified,
356-
the type annotation from the function is used.
166+
returns : str | Type | Callable | List[str | Callable] | Table, optional
167+
Specifies the return data type of the function. This parameter
168+
works the same way as `args`. If the function is a table-valued
169+
function, the return type should be a `Table` object.
357170
358171
Returns
359172
-------
@@ -365,6 +178,4 @@ def tvf_with_null_masks(
365178
name=name,
366179
args=args,
367180
returns=returns,
368-
with_null_masks=True,
369-
function_type='tvf',
370181
)

0 commit comments

Comments
 (0)