Skip to content

Commit 829775e

Browse files
committed
Vector utility functions
1 parent 5039ec2 commit 829775e

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed

singlestoredb/functions/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
11
from .decorator import udf # noqa: F401
22
from .typing import Masked # noqa: F401
33
from .typing import Table # noqa: F401
4+
from .utils import VectorTypes
5+
6+
7+
F32 = VectorTypes.F32
8+
F64 = VectorTypes.F64
9+
I8 = VectorTypes.I8
10+
I16 = VectorTypes.I16
11+
I32 = VectorTypes.I32
12+
I64 = VectorTypes.I64

singlestoredb/functions/utils.py

+164
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import dataclasses
22
import inspect
3+
import struct
34
import sys
45
import types
56
import typing
7+
from enum import Enum
68
from typing import Any
79
from typing import Dict
10+
from typing import Iterable
811

912
from .typing import Masked
1013

@@ -176,3 +179,164 @@ def is_pydantic(obj: Any) -> bool:
176179
if get_module(x) == 'pydantic'
177180
and get_type_name(x) == 'BaseModel'
178181
])
182+
183+
184+
class VectorTypes(str, Enum):
185+
"""Enum for vector types."""
186+
F16 = 'f16'
187+
F32 = 'f32'
188+
F64 = 'f64'
189+
I8 = 'i8'
190+
I16 = 'i16'
191+
I32 = 'i32'
192+
I64 = 'i64'
193+
194+
195+
def unpack_vector(
196+
obj: Any,
197+
element_type: VectorTypes = VectorTypes.F32,
198+
) -> Iterable[Any]:
199+
"""
200+
Unpack a vector from bytes.
201+
202+
Parameters
203+
----------
204+
obj : Any
205+
The object to unpack.
206+
element_type : VectorTypes
207+
The type of the elements in the vector.
208+
Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.
209+
Default is 'f32'.
210+
211+
Returns
212+
-------
213+
Iterable[Any]
214+
The unpacked vector.
215+
216+
"""
217+
if isinstance(obj, (bytes, bytearray, list, tuple)):
218+
if element_type == 'f32':
219+
n = len(obj) // 4
220+
fmt = 'f'
221+
elif element_type == 'f64':
222+
n = len(obj) // 8
223+
fmt = 'd'
224+
elif element_type == 'i8':
225+
n = len(obj)
226+
fmt = 'b'
227+
elif element_type == 'i16':
228+
n = len(obj) // 2
229+
fmt = 'h'
230+
elif element_type == 'i32':
231+
n = len(obj) // 4
232+
fmt = 'i'
233+
elif element_type == 'i64':
234+
n = len(obj) // 8
235+
fmt = 'q'
236+
else:
237+
raise ValueError(f'unsupported element type: {element_type}')
238+
239+
if isinstance(obj, (bytes, bytearray)):
240+
return struct.unpack(f'<{n}{fmt}', obj)
241+
return tuple([struct.unpack(f'<{n}{fmt}', x) for x in obj])
242+
243+
if element_type == 'f32':
244+
np_type = 'f4'
245+
elif element_type == 'f64':
246+
np_type = 'f8'
247+
elif element_type == 'i8':
248+
np_type = 'i1'
249+
elif element_type == 'i16':
250+
np_type = 'i2'
251+
elif element_type == 'i32':
252+
np_type = 'i4'
253+
elif element_type == 'i64':
254+
np_type = 'i8'
255+
else:
256+
raise ValueError(f'unsupported element type: {element_type}')
257+
258+
if is_numpy(obj):
259+
import numpy as np
260+
return np.array([np.frombuffer(x, dtype=np_type) for x in obj])
261+
262+
if is_pandas_series(obj):
263+
import numpy as np
264+
import pandas as pd
265+
return pd.Series([np.frombuffer(x, dtype=np_type) for x in obj])
266+
267+
if is_polars_series(obj):
268+
import numpy as np
269+
import polars as pl
270+
return pl.Series([np.frombuffer(x, dtype=np_type) for x in obj])
271+
272+
if is_pyarrow_array(obj):
273+
import numpy as np
274+
import pyarrow as pa
275+
return pa.array([np.frombuffer(x, dtype=np_type) for x in obj])
276+
277+
raise ValueError(
278+
f'unsupported object type: {type(obj)}',
279+
)
280+
281+
282+
def pack_vector(
283+
obj: Any,
284+
element_type: VectorTypes = VectorTypes.F32,
285+
) -> bytes:
286+
"""
287+
Pack a vector into bytes.
288+
289+
Parameters
290+
----------
291+
obj : Any
292+
The object to pack.
293+
element_type : VectorTypes
294+
The type of the elements in the vector.
295+
Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.
296+
Default is 'f32'.
297+
298+
Returns
299+
-------
300+
bytes
301+
The packed vector.
302+
303+
"""
304+
if element_type == 'f32':
305+
fmt = 'f'
306+
elif element_type == 'f64':
307+
fmt = 'd'
308+
elif element_type == 'i8':
309+
fmt = 'b'
310+
elif element_type == 'i16':
311+
fmt = 'h'
312+
elif element_type == 'i32':
313+
fmt = 'i'
314+
elif element_type == 'i64':
315+
fmt = 'q'
316+
else:
317+
raise ValueError(f'unsupported element type: {element_type}')
318+
319+
if isinstance(obj, (list, tuple)):
320+
return struct.pack(f'<{len(obj)}{fmt}', *obj)
321+
322+
elif is_numpy(obj):
323+
return obj.tobytes()
324+
325+
elif is_pandas_series(obj):
326+
# TODO: Nested vectors
327+
import pandas as pd
328+
return pd.Series(obj).to_numpy().tobytes()
329+
330+
elif is_polars_series(obj):
331+
# TODO: Nested vectors
332+
import polars as pl
333+
return pl.Series(obj).to_numpy().tobytes()
334+
335+
elif is_pyarrow_array(obj):
336+
# TODO: Nested vectors
337+
import pyarrow as pa
338+
return pa.array(obj).to_numpy().tobytes()
339+
340+
raise ValueError(
341+
f'unsupported object type: {type(obj)}',
342+
)

0 commit comments

Comments
 (0)