Skip to content

Commit 7eac67a

Browse files
committed
Merge branch 'main' into eugene/add_implementations
2 parents 0ded551 + ee42d24 commit 7eac67a

File tree

7 files changed

+1110
-648
lines changed

7 files changed

+1110
-648
lines changed

langchain_postgres/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from importlib import metadata
22

33
from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
4+
from langchain_postgres.translator import PGVectorTranslator
45
from langchain_postgres.vectorstores import PGVector
56

67
try:
@@ -13,4 +14,5 @@
1314
"__version__",
1415
"PostgresChatMessageHistory",
1516
"PGVector",
17+
"PGVectorTranslator",
1618
]

langchain_postgres/translator.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Dict, Tuple, Union
2+
3+
from langchain_core.structured_query import (
4+
Comparator,
5+
Comparison,
6+
Operation,
7+
Operator,
8+
StructuredQuery,
9+
Visitor,
10+
)
11+
12+
13+
class PGVectorTranslator(Visitor):
14+
"""Translate `PGVector` internal query language elements to valid filters."""
15+
16+
allowed_operators = [Operator.AND, Operator.OR]
17+
"""Subset of allowed logical operators."""
18+
allowed_comparators = [
19+
Comparator.EQ,
20+
Comparator.NE,
21+
Comparator.GT,
22+
Comparator.LT,
23+
Comparator.IN,
24+
Comparator.NIN,
25+
Comparator.CONTAIN,
26+
Comparator.LIKE,
27+
]
28+
"""Subset of allowed logical comparators."""
29+
30+
def _format_func(self, func: Union[Operator, Comparator]) -> str:
31+
self._validate_func(func)
32+
return f"${func.value}"
33+
34+
def visit_operation(self, operation: Operation) -> Dict:
35+
args = [arg.accept(self) for arg in operation.arguments]
36+
return {self._format_func(operation.operator): args}
37+
38+
def visit_comparison(self, comparison: Comparison) -> Dict:
39+
return {
40+
comparison.attribute: {
41+
self._format_func(comparison.comparator): comparison.value
42+
}
43+
}
44+
45+
def visit_structured_query(
46+
self, structured_query: StructuredQuery
47+
) -> Tuple[str, dict]:
48+
if structured_query.filter is None:
49+
kwargs = {}
50+
else:
51+
kwargs = {"filter": structured_query.filter.accept(self)}
52+
return structured_query.query, kwargs

poetry.lock

Lines changed: 966 additions & 646 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "langchain-postgres"
3-
version = "0.0.7"
3+
version = "0.0.9"
44
description = "An integration package connecting Postgres and LangChain"
55
authors = []
66
readme = "README.md"
@@ -11,7 +11,7 @@ license = "MIT"
1111
"Source Code" = "https://github.com/langchain-ai/langchain-postgres/tree/master/langchain_postgres"
1212

1313
[tool.poetry.dependencies]
14-
python = "^3.9"
14+
python = "^3.8.1"
1515
langchain-core = "^0.2.13"
1616
psycopg = "^3"
1717
psycopg-pool = "^3.2.1"

tests/unit_tests/query_constructors/__init__.py

Whitespace-only changes.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import Dict, Tuple
2+
3+
import pytest as pytest
4+
from langchain_core.structured_query import (
5+
Comparator,
6+
Comparison,
7+
Operation,
8+
Operator,
9+
StructuredQuery,
10+
)
11+
12+
from langchain_postgres import PGVectorTranslator
13+
14+
DEFAULT_TRANSLATOR = PGVectorTranslator()
15+
16+
17+
def test_visit_comparison() -> None:
18+
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
19+
expected = {"foo": {"$lt": 1}}
20+
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
21+
assert expected == actual
22+
23+
24+
@pytest.mark.skip("Not implemented")
25+
def test_visit_operation() -> None:
26+
op = Operation(
27+
operator=Operator.AND,
28+
arguments=[
29+
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
30+
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
31+
Comparison(comparator=Comparator.GT, attribute="abc", value=2.0),
32+
],
33+
)
34+
expected = {
35+
"foo": {"$lt": 2},
36+
"bar": {"$eq": "baz"},
37+
"abc": {"$gt": 2.0},
38+
}
39+
actual = DEFAULT_TRANSLATOR.visit_operation(op)
40+
assert expected == actual
41+
42+
43+
def test_visit_structured_query() -> None:
44+
query = "What is the capital of France?"
45+
structured_query = StructuredQuery(
46+
query=query,
47+
filter=None,
48+
)
49+
expected: Tuple[str, Dict] = (query, {})
50+
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
51+
assert expected == actual
52+
53+
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
54+
structured_query = StructuredQuery(
55+
query=query,
56+
filter=comp,
57+
)
58+
expected = (query, {"filter": {"foo": {"$lt": 1}}})
59+
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
60+
assert expected == actual
61+
62+
op = Operation(
63+
operator=Operator.AND,
64+
arguments=[
65+
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
66+
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
67+
Comparison(comparator=Comparator.GT, attribute="abc", value=2.0),
68+
],
69+
)
70+
structured_query = StructuredQuery(
71+
query=query,
72+
filter=op,
73+
)
74+
expected = (
75+
query,
76+
{
77+
"filter": {
78+
"$and": [
79+
{"foo": {"$lt": 2}},
80+
{"bar": {"$eq": "baz"}},
81+
{"abc": {"$gt": 2.0}},
82+
]
83+
}
84+
},
85+
)
86+
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
87+
assert expected == actual

tests/unit_tests/test_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
EXPECTED_ALL = [
44
"__version__",
55
"PGVector",
6+
"PGVectorTranslator",
67
"PostgresChatMessageHistory",
78
]
89

0 commit comments

Comments
 (0)