1
+ import dataclasses
1
2
import datetime
2
3
import functools
3
- import string
4
+ import inspect
4
5
from typing import Any
5
6
from typing import Callable
6
7
from typing import Dict
7
8
from typing import List
8
9
from typing import Optional
10
+ from typing import Tuple
9
11
from typing import Union
10
12
11
13
from . import dtypes
12
14
from .dtypes import DataType
15
+ from .signature import simplify_dtype
16
+
17
+ try :
18
+ import pydantic
19
+ has_pydantic = True
20
+ except ImportError :
21
+ has_pydantic = False
13
22
14
23
python_type_map : Dict [Any , Callable [..., str ]] = {
15
24
str : dtypes .TEXT ,
@@ -33,88 +42,123 @@ def listify(x: Any) -> List[Any]:
33
42
return [x ]
34
43
35
44
45
+ def process_annotation (annotation : Any ) -> Tuple [Any , bool ]:
46
+ types = simplify_dtype (annotation )
47
+ if isinstance (types , list ):
48
+ nullable = False
49
+ if type (None ) in types :
50
+ nullable = True
51
+ types = [x for x in types if x is not type (None )]
52
+ if len (types ) > 1 :
53
+ raise ValueError (f'multiple types not supported: { annotation } ' )
54
+ return types [0 ], nullable
55
+ return types , True
56
+
57
+
58
+ def process_types (params : Any ) -> Any :
59
+ if params is None :
60
+ return params , []
61
+
62
+ elif isinstance (params , (list , tuple )):
63
+ params = list (params )
64
+ for i , item in enumerate (params ):
65
+ if params [i ] in python_type_map :
66
+ params [i ] = python_type_map [params [i ]]()
67
+ elif callable (item ):
68
+ params [i ] = item ()
69
+ for item in params :
70
+ if not isinstance (item , str ):
71
+ raise TypeError (f'unrecognized type for parameter: { item } ' )
72
+ return params , []
73
+
74
+ elif isinstance (params , dict ):
75
+ names = []
76
+ params = dict (params )
77
+ for k , v in list (params .items ()):
78
+ names .append (k )
79
+ if params [k ] in python_type_map :
80
+ params [k ] = python_type_map [params [k ]]()
81
+ elif callable (v ):
82
+ params [k ] = v ()
83
+ for item in params .values ():
84
+ if not isinstance (item , str ):
85
+ raise TypeError (f'unrecognized type for parameter: { item } ' )
86
+ return params , names
87
+
88
+ elif dataclasses .is_dataclass (params ):
89
+ names = []
90
+ out = []
91
+ for item in dataclasses .fields (params ):
92
+ typ , nullable = process_annotation (item .type )
93
+ sql_type = process_types (typ )[0 ]
94
+ if not nullable :
95
+ sql_type = sql_type .replace ('NULL' , 'NOT NULL' )
96
+ out .append (sql_type )
97
+ names .append (item .name )
98
+ return out , names
99
+
100
+ elif has_pydantic and inspect .isclass (params ) \
101
+ and issubclass (params , pydantic .BaseModel ):
102
+ names = []
103
+ out = []
104
+ for name , item in params .model_fields .items ():
105
+ typ , nullable = process_annotation (item .annotation )
106
+ sql_type = process_types (typ )[0 ]
107
+ if not nullable :
108
+ sql_type = sql_type .replace ('NULL' , 'NOT NULL' )
109
+ out .append (sql_type )
110
+ names .append (name )
111
+ return out , names
112
+
113
+ elif params in python_type_map :
114
+ return python_type_map [params ](), []
115
+
116
+ elif callable (params ):
117
+ return params (), []
118
+
119
+ elif isinstance (params , str ):
120
+ return params , []
121
+
122
+ raise TypeError (f'unrecognized data type for args: { params } ' )
123
+
124
+
36
125
def _func (
37
126
func : Optional [Callable [..., Any ]] = None ,
38
127
* ,
39
128
name : Optional [str ] = None ,
40
- args : Optional [Union [DataType , List [DataType ], Dict [str , DataType ]]] = None ,
41
- returns : Optional [Union [str , List [DataType ], List [type ]]] = None ,
129
+ args : Optional [
130
+ Union [
131
+ DataType ,
132
+ List [DataType ],
133
+ Dict [str , DataType ],
134
+ 'pydantic.BaseModel' ,
135
+ type ,
136
+ ]
137
+ ] = None ,
138
+ returns : Optional [
139
+ Union [
140
+ str ,
141
+ List [DataType ],
142
+ List [type ],
143
+ 'pydantic.BaseModel' ,
144
+ type ,
145
+ ]
146
+ ] = None ,
42
147
data_format : Optional [str ] = None ,
43
148
include_masks : bool = False ,
44
149
function_type : str = 'udf' ,
45
150
output_fields : Optional [List [str ]] = None ,
46
151
) -> Callable [..., Any ]:
47
152
"""Generic wrapper for UDF and TVF decorators."""
48
- if args is None :
49
- pass
50
- elif isinstance (args , (list , tuple )):
51
- args = list (args )
52
- for i , item in enumerate (args ):
53
- if args [i ] in python_type_map :
54
- args [i ] = python_type_map [args [i ]]()
55
- elif callable (item ):
56
- args [i ] = item ()
57
- for item in args :
58
- if not isinstance (item , str ):
59
- raise TypeError (f'unrecognized type for parameter: { item } ' )
60
- elif isinstance (args , dict ):
61
- args = dict (args )
62
- for k , v in list (args .items ()):
63
- if args [k ] in python_type_map :
64
- args [k ] = python_type_map [args [k ]]()
65
- elif callable (v ):
66
- args [k ] = v ()
67
- for item in args .values ():
68
- if not isinstance (item , str ):
69
- raise TypeError (f'unrecognized type for parameter: { item } ' )
70
- elif args in python_type_map :
71
- args = python_type_map [args ]()
72
- elif callable (args ):
73
- args = args ()
74
- elif isinstance (args , str ):
75
- args = args
76
- else :
77
- raise TypeError (f'unrecognized data type for args: { args } ' )
78
-
79
- if returns is None :
80
- pass
81
- elif isinstance (returns , (list , tuple )):
82
- returns = list (returns )
83
- for i , item in enumerate (returns ):
84
- if item in python_type_map :
85
- returns [i ] = python_type_map [item ]()
86
- elif callable (item ):
87
- returns [i ] = item ()
88
- for item in returns :
89
- if not isinstance (item , str ):
90
- raise TypeError (f'unrecognized return type: { item } ' )
91
- elif returns in python_type_map :
92
- returns = python_type_map [returns ]()
93
- elif callable (returns ):
94
- returns = returns ()
95
- elif isinstance (returns , str ):
96
- returns = returns
97
- else :
98
- raise TypeError (f'unrecognized return type: { returns } ' )
99
-
100
- if returns is None :
101
- pass
102
- elif isinstance (returns , list ):
103
- for item in returns :
104
- if not isinstance (item , str ):
105
- raise TypeError (f'unrecognized return type: { item } ' )
106
- elif not isinstance (returns , str ):
107
- raise TypeError (f'unrecognized return type: { returns } ' )
108
-
109
- if not output_fields :
110
- if isinstance (returns , list ):
111
- output_fields = []
112
- for i , _ in enumerate (returns ):
113
- output_fields .append (string .ascii_letters [i ])
114
- else :
115
- output_fields = [string .ascii_letters [0 ]]
116
-
117
- if isinstance (returns , list ) and len (output_fields ) != len (returns ):
153
+ args , _ = process_types (args )
154
+ returns , fields = process_types (returns )
155
+
156
+ if not output_fields and fields :
157
+ output_fields = fields
158
+
159
+ if isinstance (returns , list ) \
160
+ and isinstance (output_fields , list ) \
161
+ and len (output_fields ) != len (returns ):
118
162
raise ValueError (
119
163
'The number of output fields must match the number of return types' ,
120
164
)
@@ -133,7 +177,7 @@ def _func(
133
177
data_format = data_format ,
134
178
include_masks = include_masks ,
135
179
function_type = function_type ,
136
- output_fields = output_fields ,
180
+ output_fields = output_fields or None ,
137
181
).items () if v is not None
138
182
}
139
183
0 commit comments