diff --git a/langchain_postgres/_utils.py b/langchain_postgres/_utils.py index 9d8055af..14eae3b7 100644 --- a/langchain_postgres/_utils.py +++ b/langchain_postgres/_utils.py @@ -30,10 +30,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: X = np.array(X, dtype=np.float32) Y = np.array(Y, dtype=np.float32) - Z = 1 - simd.cdist(X, Y, metric="cosine") - if isinstance(Z, float): - return np.array([Z]) - return np.array(Z) + Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) + return Z except ImportError: logger.debug( "Unable to import simsimd, defaulting to NumPy implementation. If you want "