|
8 | 8 | from typing import Any
|
9 | 9 | from typing import Dict
|
10 | 10 | from typing import Iterable
|
| 11 | +from typing import Tuple |
| 12 | +from typing import Union |
11 | 13 |
|
12 | 14 | from .typing import Masked
|
13 | 15 |
|
@@ -192,151 +194,228 @@ class VectorTypes(str, Enum):
|
192 | 194 | I64 = 'i64'
|
193 | 195 |
|
194 | 196 |
|
| 197 | +def _vector_type_to_numpy_type( |
| 198 | + vector_type: VectorTypes, |
| 199 | +) -> str: |
| 200 | + """Convert a vector type to a numpy type.""" |
| 201 | + if vector_type == VectorTypes.F32: |
| 202 | + return 'f4' |
| 203 | + elif vector_type == VectorTypes.F64: |
| 204 | + return 'f8' |
| 205 | + elif vector_type == VectorTypes.I8: |
| 206 | + return 'i1' |
| 207 | + elif vector_type == VectorTypes.I16: |
| 208 | + return 'i2' |
| 209 | + elif vector_type == VectorTypes.I32: |
| 210 | + return 'i4' |
| 211 | + elif vector_type == VectorTypes.I64: |
| 212 | + return 'i8' |
| 213 | + raise ValueError(f'unsupported element type: {vector_type}') |
| 214 | + |
| 215 | + |
| 216 | +def _vector_type_to_struct_format( |
| 217 | + vec: Any, |
| 218 | + vector_type: VectorTypes, |
| 219 | +) -> str: |
| 220 | + """Convert a vector type to a struct format string.""" |
| 221 | + n = len(vec) |
| 222 | + if vector_type == VectorTypes.F32: |
| 223 | + if isinstance(vec, (bytes, bytearray)): |
| 224 | + n = n // 4 |
| 225 | + return f'<{n}f' |
| 226 | + elif vector_type == VectorTypes.F64: |
| 227 | + if isinstance(vec, (bytes, bytearray)): |
| 228 | + n = n // 8 |
| 229 | + return f'<{n}d' |
| 230 | + elif vector_type == VectorTypes.I8: |
| 231 | + return f'<{n}b' |
| 232 | + elif vector_type == VectorTypes.I16: |
| 233 | + if isinstance(vec, (bytes, bytearray)): |
| 234 | + n = n // 2 |
| 235 | + return f'<{n}h' |
| 236 | + elif vector_type == VectorTypes.I32: |
| 237 | + if isinstance(vec, (bytes, bytearray)): |
| 238 | + n = n // 4 |
| 239 | + return f'<{n}i' |
| 240 | + elif vector_type == VectorTypes.I64: |
| 241 | + if isinstance(vec, (bytes, bytearray)): |
| 242 | + n = n // 8 |
| 243 | + return f'<{n}q' |
| 244 | + raise ValueError(f'unsupported element type: {vector_type}') |
| 245 | + |
| 246 | + |
195 | 247 | def unpack_vector(
|
196 |
| - obj: Any, |
197 |
| - element_type: VectorTypes = VectorTypes.F32, |
198 |
| -) -> Iterable[Any]: |
| 248 | + obj: Union[bytes, bytearray], |
| 249 | + vec_type: VectorTypes = VectorTypes.F32, |
| 250 | +) -> Tuple[Any]: |
199 | 251 | """
|
200 | 252 | Unpack a vector from bytes.
|
201 | 253 |
|
202 | 254 | Parameters
|
203 | 255 | ----------
|
204 |
| - obj : Any |
| 256 | + obj : bytes or bytearray |
205 | 257 | The object to unpack.
|
206 |
| - element_type : VectorTypes |
| 258 | + vec_type : VectorTypes |
207 | 259 | The type of the elements in the vector.
|
208 | 260 | Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.
|
209 | 261 | Default is 'f32'.
|
210 | 262 |
|
211 | 263 | Returns
|
212 | 264 | -------
|
213 |
| - Iterable[Any] |
| 265 | + Tuple[Any] |
214 | 266 | The unpacked vector.
|
215 | 267 |
|
216 | 268 | """
|
217 |
| - if isinstance(obj, (bytes, bytearray, list, tuple)): |
218 |
| - if element_type == 'f32': |
219 |
| - n = len(obj) // 4 |
220 |
| - fmt = 'f' |
221 |
| - elif element_type == 'f64': |
222 |
| - n = len(obj) // 8 |
223 |
| - fmt = 'd' |
224 |
| - elif element_type == 'i8': |
225 |
| - n = len(obj) |
226 |
| - fmt = 'b' |
227 |
| - elif element_type == 'i16': |
228 |
| - n = len(obj) // 2 |
229 |
| - fmt = 'h' |
230 |
| - elif element_type == 'i32': |
231 |
| - n = len(obj) // 4 |
232 |
| - fmt = 'i' |
233 |
| - elif element_type == 'i64': |
234 |
| - n = len(obj) // 8 |
235 |
| - fmt = 'q' |
236 |
| - else: |
237 |
| - raise ValueError(f'unsupported element type: {element_type}') |
238 |
| - |
239 |
| - if isinstance(obj, (bytes, bytearray)): |
240 |
| - return struct.unpack(f'<{n}{fmt}', obj) |
241 |
| - return tuple([struct.unpack(f'<{n}{fmt}', x) for x in obj]) |
242 |
| - |
243 |
| - if element_type == 'f32': |
244 |
| - np_type = 'f4' |
245 |
| - elif element_type == 'f64': |
246 |
| - np_type = 'f8' |
247 |
| - elif element_type == 'i8': |
248 |
| - np_type = 'i1' |
249 |
| - elif element_type == 'i16': |
250 |
| - np_type = 'i2' |
251 |
| - elif element_type == 'i32': |
252 |
| - np_type = 'i4' |
253 |
| - elif element_type == 'i64': |
254 |
| - np_type = 'i8' |
255 |
| - else: |
256 |
| - raise ValueError(f'unsupported element type: {element_type}') |
| 269 | + return struct.unpack(_vector_type_to_struct_format(obj, vec_type), obj) |
| 270 | + |
| 271 | + |
| 272 | +def pack_vector( |
| 273 | + obj: Any, |
| 274 | + vec_type: VectorTypes = VectorTypes.F32, |
| 275 | +) -> bytes: |
| 276 | + """ |
| 277 | + Pack a vector into bytes. |
| 278 | +
|
| 279 | + Parameters |
| 280 | + ---------- |
| 281 | + obj : Any |
| 282 | + The object to pack. |
| 283 | + vec_type : VectorTypes |
| 284 | + The type of the elements in the vector. |
| 285 | + Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'. |
| 286 | + Default is 'f32'. |
| 287 | +
|
| 288 | + Returns |
| 289 | + ------- |
| 290 | + bytes |
| 291 | + The packed vector. |
| 292 | +
|
| 293 | + """ |
| 294 | + if isinstance(obj, (list, tuple)): |
| 295 | + return struct.pack(_vector_type_to_struct_format(obj, vec_type), *obj) |
257 | 296 |
|
258 | 297 | if is_numpy(obj):
|
259 |
| - import numpy as np |
260 |
| - return np.array([np.frombuffer(x, dtype=np_type) for x in obj]) |
| 298 | + return obj.tobytes() |
261 | 299 |
|
262 | 300 | if is_pandas_series(obj):
|
263 |
| - import numpy as np |
264 | 301 | import pandas as pd
|
265 |
| - return pd.Series([np.frombuffer(x, dtype=np_type) for x in obj]) |
| 302 | + return pd.Series(obj).to_numpy().tobytes() |
266 | 303 |
|
267 | 304 | if is_polars_series(obj):
|
268 |
| - import numpy as np |
269 | 305 | import polars as pl
|
270 |
| - return pl.Series([np.frombuffer(x, dtype=np_type) for x in obj]) |
| 306 | + return pl.Series(obj).to_numpy().tobytes() |
271 | 307 |
|
272 | 308 | if is_pyarrow_array(obj):
|
273 |
| - import numpy as np |
274 | 309 | import pyarrow as pa
|
275 |
| - return pa.array([np.frombuffer(x, dtype=np_type) for x in obj]) |
| 310 | + return pa.array(obj).to_numpy().tobytes() |
276 | 311 |
|
277 | 312 | raise ValueError(
|
278 | 313 | f'unsupported object type: {type(obj)}',
|
279 | 314 | )
|
280 | 315 |
|
281 | 316 |
|
282 |
| -def pack_vector( |
283 |
| - obj: Any, |
284 |
| - element_type: VectorTypes = VectorTypes.F32, |
285 |
| -) -> bytes: |
| 317 | +def unpack_vectors( |
| 318 | + arr_of_vec: Any, |
| 319 | + vec_type: VectorTypes = VectorTypes.F32, |
| 320 | +) -> Iterable[Any]: |
286 | 321 | """
|
287 |
| - Pack a vector into bytes. |
| 322 | + Unpack a vector from an array of bytes. |
288 | 323 |
|
289 | 324 | Parameters
|
290 | 325 | ----------
|
291 |
| - obj : Any |
292 |
| - The object to pack. |
293 |
| - element_type : VectorTypes |
| 326 | + arr_of_vec : Iterable[Any] |
| 327 | + The array of bytes to unpack. |
| 328 | + vec_type : VectorTypes |
294 | 329 | The type of the elements in the vector.
|
295 | 330 | Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.
|
296 | 331 | Default is 'f32'.
|
297 | 332 |
|
298 | 333 | Returns
|
299 | 334 | -------
|
300 |
| - bytes |
301 |
| - The packed vector. |
| 335 | + Iterable[Any] |
| 336 | + The unpacked vector. |
302 | 337 |
|
303 | 338 | """
|
304 |
| - if element_type == 'f32': |
305 |
| - fmt = 'f' |
306 |
| - elif element_type == 'f64': |
307 |
| - fmt = 'd' |
308 |
| - elif element_type == 'i8': |
309 |
| - fmt = 'b' |
310 |
| - elif element_type == 'i16': |
311 |
| - fmt = 'h' |
312 |
| - elif element_type == 'i32': |
313 |
| - fmt = 'i' |
314 |
| - elif element_type == 'i64': |
315 |
| - fmt = 'q' |
316 |
| - else: |
317 |
| - raise ValueError(f'unsupported element type: {element_type}') |
| 339 | + if isinstance(arr_of_vec, (list, tuple)): |
| 340 | + return [unpack_vector(x, vec_type) for x in arr_of_vec] |
318 | 341 |
|
319 |
| - if isinstance(obj, (list, tuple)): |
320 |
| - return struct.pack(f'<{len(obj)}{fmt}', *obj) |
| 342 | + import numpy as np |
321 | 343 |
|
322 |
| - elif is_numpy(obj): |
323 |
| - return obj.tobytes() |
| 344 | + dtype = _vector_type_to_numpy_type(vec_type) |
324 | 345 |
|
325 |
| - elif is_pandas_series(obj): |
326 |
| - # TODO: Nested vectors |
| 346 | + np_arr = np.array( |
| 347 | + [np.frombuffer(x, dtype=dtype) for x in arr_of_vec], |
| 348 | + dtype=dtype, |
| 349 | + ) |
| 350 | + |
| 351 | + if is_numpy(arr_of_vec): |
| 352 | + return np_arr |
| 353 | + |
| 354 | + if is_pandas_series(arr_of_vec): |
327 | 355 | import pandas as pd
|
328 |
| - return pd.Series(obj).to_numpy().tobytes() |
| 356 | + return pd.Series(np_arr) |
329 | 357 |
|
330 |
| - elif is_polars_series(obj): |
331 |
| - # TODO: Nested vectors |
| 358 | + if is_polars_series(arr_of_vec): |
332 | 359 | import polars as pl
|
333 |
| - return pl.Series(obj).to_numpy().tobytes() |
| 360 | + return pl.Series(np_arr) |
334 | 361 |
|
335 |
| - elif is_pyarrow_array(obj): |
336 |
| - # TODO: Nested vectors |
| 362 | + if is_pyarrow_array(arr_of_vec): |
337 | 363 | import pyarrow as pa
|
338 |
| - return pa.array(obj).to_numpy().tobytes() |
| 364 | + return pa.array(np_arr) |
339 | 365 |
|
340 | 366 | raise ValueError(
|
341 |
| - f'unsupported object type: {type(obj)}', |
| 367 | + f'unsupported object type: {type(arr_of_vec)}', |
| 368 | + ) |
| 369 | + |
| 370 | + |
| 371 | +def pack_vectors( |
| 372 | + arr_of_arr: Iterable[Any], |
| 373 | + vec_type: VectorTypes = VectorTypes.F32, |
| 374 | +) -> Iterable[Any]: |
| 375 | + """ |
| 376 | + Pack a vector into an array of bytes. |
| 377 | +
|
| 378 | + Parameters |
| 379 | + ---------- |
| 380 | + arr_of_arr : Iterable[Any] |
| 381 | + The array of bytes to pack. |
| 382 | + vec_type : VectorTypes |
| 383 | + The type of the elements in the vector. |
| 384 | + Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'. |
| 385 | + Default is 'f32'. |
| 386 | +
|
| 387 | + Returns |
| 388 | + ------- |
| 389 | + Iterable[Any] |
| 390 | + The array of packed vectors. |
| 391 | +
|
| 392 | + """ |
| 393 | + if isinstance(arr_of_arr, (list, tuple)): |
| 394 | + if not arr_of_arr: |
| 395 | + return [] |
| 396 | + fmt = _vector_type_to_struct_format(arr_of_arr[0], vec_type) |
| 397 | + return [struct.pack(fmt, x) for x in arr_of_arr] |
| 398 | + |
| 399 | + import numpy as np |
| 400 | + |
| 401 | + # Use object type because numpy truncates nulls at the end of fixed binary |
| 402 | + np_arr = np.array([x.tobytes() for x in arr_of_arr], dtype=np.object_) |
| 403 | + |
| 404 | + if is_numpy(arr_of_arr): |
| 405 | + return np_arr |
| 406 | + |
| 407 | + if is_pandas_series(arr_of_arr): |
| 408 | + import pandas as pd |
| 409 | + return pd.Series(np_arr) |
| 410 | + |
| 411 | + if is_polars_series(arr_of_arr): |
| 412 | + import polars as pl |
| 413 | + return pl.Series(np_arr) |
| 414 | + |
| 415 | + if is_pyarrow_array(arr_of_arr): |
| 416 | + import pyarrow as pa |
| 417 | + return pa.array(np_arr) |
| 418 | + |
| 419 | + raise ValueError( |
| 420 | + f'unsupported object type: {type(arr_of_arr)}', |
342 | 421 | )
|
0 commit comments