Skip to content

Commit dd3f041

Browse files
committed
Add vector utility functions for UDFs
1 parent ba86736 commit dd3f041

File tree

2 files changed

+175
-92
lines changed

2 files changed

+175
-92
lines changed

singlestoredb/functions/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from .decorator import udf # noqa: F401
22
from .typing import Masked # noqa: F401
33
from .typing import Table # noqa: F401
4+
from .utils import pack_vector # noqa: F401
5+
from .utils import pack_vectors # noqa: F401
6+
from .utils import unpack_vector # noqa: F401
7+
from .utils import unpack_vectors # noqa: F401
48
from .utils import VectorTypes
59

610

singlestoredb/functions/utils.py

+171-92
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import Any
99
from typing import Dict
1010
from typing import Iterable
11+
from typing import Tuple
12+
from typing import Union
1113

1214
from .typing import Masked
1315

@@ -192,151 +194,228 @@ class VectorTypes(str, Enum):
192194
I64 = 'i64'
193195

194196

197+
def _vector_type_to_numpy_type(
198+
vector_type: VectorTypes,
199+
) -> str:
200+
"""Convert a vector type to a numpy type."""
201+
if vector_type == VectorTypes.F32:
202+
return 'f4'
203+
elif vector_type == VectorTypes.F64:
204+
return 'f8'
205+
elif vector_type == VectorTypes.I8:
206+
return 'i1'
207+
elif vector_type == VectorTypes.I16:
208+
return 'i2'
209+
elif vector_type == VectorTypes.I32:
210+
return 'i4'
211+
elif vector_type == VectorTypes.I64:
212+
return 'i8'
213+
raise ValueError(f'unsupported element type: {vector_type}')
214+
215+
216+
def _vector_type_to_struct_format(
217+
vec: Any,
218+
vector_type: VectorTypes,
219+
) -> str:
220+
"""Convert a vector type to a struct format string."""
221+
n = len(vec)
222+
if vector_type == VectorTypes.F32:
223+
if isinstance(vec, (bytes, bytearray)):
224+
n = n // 4
225+
return f'<{n}f'
226+
elif vector_type == VectorTypes.F64:
227+
if isinstance(vec, (bytes, bytearray)):
228+
n = n // 8
229+
return f'<{n}d'
230+
elif vector_type == VectorTypes.I8:
231+
return f'<{n}b'
232+
elif vector_type == VectorTypes.I16:
233+
if isinstance(vec, (bytes, bytearray)):
234+
n = n // 2
235+
return f'<{n}h'
236+
elif vector_type == VectorTypes.I32:
237+
if isinstance(vec, (bytes, bytearray)):
238+
n = n // 4
239+
return f'<{n}i'
240+
elif vector_type == VectorTypes.I64:
241+
if isinstance(vec, (bytes, bytearray)):
242+
n = n // 8
243+
return f'<{n}q'
244+
raise ValueError(f'unsupported element type: {vector_type}')
245+
246+
195247
def unpack_vector(
196-
obj: Any,
197-
element_type: VectorTypes = VectorTypes.F32,
198-
) -> Iterable[Any]:
248+
obj: Union[bytes, bytearray],
249+
vec_type: VectorTypes = VectorTypes.F32,
250+
) -> Tuple[Any]:
199251
"""
200252
Unpack a vector from bytes.
201253
202254
Parameters
203255
----------
204-
obj : Any
256+
obj : bytes or bytearray
205257
The object to unpack.
206-
element_type : VectorTypes
258+
vec_type : VectorTypes
207259
The type of the elements in the vector.
208260
Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.
209261
Default is 'f32'.
210262
211263
Returns
212264
-------
213-
Iterable[Any]
265+
Tuple[Any]
214266
The unpacked vector.
215267
216268
"""
217-
if isinstance(obj, (bytes, bytearray, list, tuple)):
218-
if element_type == 'f32':
219-
n = len(obj) // 4
220-
fmt = 'f'
221-
elif element_type == 'f64':
222-
n = len(obj) // 8
223-
fmt = 'd'
224-
elif element_type == 'i8':
225-
n = len(obj)
226-
fmt = 'b'
227-
elif element_type == 'i16':
228-
n = len(obj) // 2
229-
fmt = 'h'
230-
elif element_type == 'i32':
231-
n = len(obj) // 4
232-
fmt = 'i'
233-
elif element_type == 'i64':
234-
n = len(obj) // 8
235-
fmt = 'q'
236-
else:
237-
raise ValueError(f'unsupported element type: {element_type}')
238-
239-
if isinstance(obj, (bytes, bytearray)):
240-
return struct.unpack(f'<{n}{fmt}', obj)
241-
return tuple([struct.unpack(f'<{n}{fmt}', x) for x in obj])
242-
243-
if element_type == 'f32':
244-
np_type = 'f4'
245-
elif element_type == 'f64':
246-
np_type = 'f8'
247-
elif element_type == 'i8':
248-
np_type = 'i1'
249-
elif element_type == 'i16':
250-
np_type = 'i2'
251-
elif element_type == 'i32':
252-
np_type = 'i4'
253-
elif element_type == 'i64':
254-
np_type = 'i8'
255-
else:
256-
raise ValueError(f'unsupported element type: {element_type}')
269+
return struct.unpack(_vector_type_to_struct_format(obj, vec_type), obj)
270+
271+
272+
def pack_vector(
273+
obj: Any,
274+
vec_type: VectorTypes = VectorTypes.F32,
275+
) -> bytes:
276+
"""
277+
Pack a vector into bytes.
278+
279+
Parameters
280+
----------
281+
obj : Any
282+
The object to pack.
283+
vec_type : VectorTypes
284+
The type of the elements in the vector.
285+
Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.
286+
Default is 'f32'.
287+
288+
Returns
289+
-------
290+
bytes
291+
The packed vector.
292+
293+
"""
294+
if isinstance(obj, (list, tuple)):
295+
return struct.pack(_vector_type_to_struct_format(obj, vec_type), *obj)
257296

258297
if is_numpy(obj):
259-
import numpy as np
260-
return np.array([np.frombuffer(x, dtype=np_type) for x in obj])
298+
return obj.tobytes()
261299

262300
if is_pandas_series(obj):
263-
import numpy as np
264301
import pandas as pd
265-
return pd.Series([np.frombuffer(x, dtype=np_type) for x in obj])
302+
return pd.Series(obj).to_numpy().tobytes()
266303

267304
if is_polars_series(obj):
268-
import numpy as np
269305
import polars as pl
270-
return pl.Series([np.frombuffer(x, dtype=np_type) for x in obj])
306+
return pl.Series(obj).to_numpy().tobytes()
271307

272308
if is_pyarrow_array(obj):
273-
import numpy as np
274309
import pyarrow as pa
275-
return pa.array([np.frombuffer(x, dtype=np_type) for x in obj])
310+
return pa.array(obj).to_numpy().tobytes()
276311

277312
raise ValueError(
278313
f'unsupported object type: {type(obj)}',
279314
)
280315

281316

282-
def pack_vector(
283-
obj: Any,
284-
element_type: VectorTypes = VectorTypes.F32,
285-
) -> bytes:
317+
def unpack_vectors(
318+
arr_of_vec: Any,
319+
vec_type: VectorTypes = VectorTypes.F32,
320+
) -> Iterable[Any]:
286321
"""
287-
Pack a vector into bytes.
322+
Unpack a vector from an array of bytes.
288323
289324
Parameters
290325
----------
291-
obj : Any
292-
The object to pack.
293-
element_type : VectorTypes
326+
arr_of_vec : Iterable[Any]
327+
The array of bytes to unpack.
328+
vec_type : VectorTypes
294329
The type of the elements in the vector.
295330
Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.
296331
Default is 'f32'.
297332
298333
Returns
299334
-------
300-
bytes
301-
The packed vector.
335+
Iterable[Any]
336+
The unpacked vector.
302337
303338
"""
304-
if element_type == 'f32':
305-
fmt = 'f'
306-
elif element_type == 'f64':
307-
fmt = 'd'
308-
elif element_type == 'i8':
309-
fmt = 'b'
310-
elif element_type == 'i16':
311-
fmt = 'h'
312-
elif element_type == 'i32':
313-
fmt = 'i'
314-
elif element_type == 'i64':
315-
fmt = 'q'
316-
else:
317-
raise ValueError(f'unsupported element type: {element_type}')
339+
if isinstance(arr_of_vec, (list, tuple)):
340+
return [unpack_vector(x, vec_type) for x in arr_of_vec]
318341

319-
if isinstance(obj, (list, tuple)):
320-
return struct.pack(f'<{len(obj)}{fmt}', *obj)
342+
import numpy as np
321343

322-
elif is_numpy(obj):
323-
return obj.tobytes()
344+
dtype = _vector_type_to_numpy_type(vec_type)
324345

325-
elif is_pandas_series(obj):
326-
# TODO: Nested vectors
346+
np_arr = np.array(
347+
[np.frombuffer(x, dtype=dtype) for x in arr_of_vec],
348+
dtype=dtype,
349+
)
350+
351+
if is_numpy(arr_of_vec):
352+
return np_arr
353+
354+
if is_pandas_series(arr_of_vec):
327355
import pandas as pd
328-
return pd.Series(obj).to_numpy().tobytes()
356+
return pd.Series(np_arr)
329357

330-
elif is_polars_series(obj):
331-
# TODO: Nested vectors
358+
if is_polars_series(arr_of_vec):
332359
import polars as pl
333-
return pl.Series(obj).to_numpy().tobytes()
360+
return pl.Series(np_arr)
334361

335-
elif is_pyarrow_array(obj):
336-
# TODO: Nested vectors
362+
if is_pyarrow_array(arr_of_vec):
337363
import pyarrow as pa
338-
return pa.array(obj).to_numpy().tobytes()
364+
return pa.array(np_arr)
339365

340366
raise ValueError(
341-
f'unsupported object type: {type(obj)}',
367+
f'unsupported object type: {type(arr_of_vec)}',
368+
)
369+
370+
371+
def pack_vectors(
372+
arr_of_arr: Iterable[Any],
373+
vec_type: VectorTypes = VectorTypes.F32,
374+
) -> Iterable[Any]:
375+
"""
376+
Pack a vector into an array of bytes.
377+
378+
Parameters
379+
----------
380+
arr_of_arr : Iterable[Any]
381+
The array of bytes to pack.
382+
vec_type : VectorTypes
383+
The type of the elements in the vector.
384+
Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.
385+
Default is 'f32'.
386+
387+
Returns
388+
-------
389+
Iterable[Any]
390+
The array of packed vectors.
391+
392+
"""
393+
if isinstance(arr_of_arr, (list, tuple)):
394+
if not arr_of_arr:
395+
return []
396+
fmt = _vector_type_to_struct_format(arr_of_arr[0], vec_type)
397+
return [struct.pack(fmt, x) for x in arr_of_arr]
398+
399+
import numpy as np
400+
401+
# Use object type because numpy truncates nulls at the end of fixed binary
402+
np_arr = np.array([x.tobytes() for x in arr_of_arr], dtype=np.object_)
403+
404+
if is_numpy(arr_of_arr):
405+
return np_arr
406+
407+
if is_pandas_series(arr_of_arr):
408+
import pandas as pd
409+
return pd.Series(np_arr)
410+
411+
if is_polars_series(arr_of_arr):
412+
import polars as pl
413+
return pl.Series(np_arr)
414+
415+
if is_pyarrow_array(arr_of_arr):
416+
import pyarrow as pa
417+
return pa.array(np_arr)
418+
419+
raise ValueError(
420+
f'unsupported object type: {type(arr_of_arr)}',
342421
)

0 commit comments

Comments
 (0)