Skip to content

Commit 22088a1

Browse files
feat: Added Hybrid Search Config and Tests [1/N]
1 parent d3dbe18 commit 22088a1

File tree

2 files changed

+363
-0
lines changed

2 files changed

+363
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from abc import ABC
2+
from dataclasses import dataclass, field
3+
from typing import Any, Callable, Optional, Sequence
4+
5+
from sqlalchemy import RowMapping
6+
7+
8+
def weighted_sum_ranking(
9+
primary_search_results: Sequence[RowMapping],
10+
secondary_search_results: Sequence[RowMapping],
11+
primary_results_weight: float = 0.5,
12+
secondary_results_weight: float = 0.5,
13+
fetch_top_k: int = 4,
14+
) -> Sequence[dict[str, Any]]:
15+
"""
16+
Ranks documents using a weighted sum of scores from two sources.
17+
18+
Args:
19+
primary_search_results: A list of (document, distance) tuples from
20+
the primary search.
21+
secondary_search_results: A list of (document, distance) tuples from
22+
the secondary search.
23+
primary_results_weight: The weight for the primary source's scores.
24+
Defaults to 0.5.
25+
secondary_results_weight: The weight for the secondary source's scores.
26+
Defaults to 0.5.
27+
fetch_top_k: The number of documents to fetch after merging the results.
28+
Defaults to 4.
29+
30+
Returns:
31+
A list of (document, distance) tuples, sorted by weighted_score in
32+
descending order.
33+
"""
34+
35+
# stores computed metric with provided distance metric and weights
36+
weighted_scores: dict[str, dict[str, Any]] = {}
37+
38+
# Process results from primary source
39+
for row in primary_search_results:
40+
values = list(row.values())
41+
doc_id = str(values[0]) # first value is doc_id
42+
distance = float(values[-1]) # type: ignore # last value is distance
43+
row_values = dict(row)
44+
row_values["distance"] = primary_results_weight * distance
45+
weighted_scores[doc_id] = row_values
46+
47+
# Process results from secondary source,
48+
# adding to existing scores or creating new ones
49+
for row in secondary_search_results:
50+
values = list(row.values())
51+
doc_id = str(values[0]) # first value is doc_id
52+
distance = float(values[-1]) # type: ignore # last value is distance
53+
primary_score = (
54+
weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0
55+
)
56+
row_values = dict(row)
57+
row_values["distance"] = distance * secondary_results_weight + primary_score
58+
weighted_scores[doc_id] = row_values
59+
60+
# Sort the results by weighted score in descending order
61+
ranked_results = sorted(
62+
weighted_scores.values(), key=lambda item: item["distance"], reverse=True
63+
)
64+
return ranked_results[:fetch_top_k]
65+
66+
67+
def reciprocal_rank_fusion(
68+
primary_search_results: Sequence[RowMapping],
69+
secondary_search_results: Sequence[RowMapping],
70+
rrf_k: float = 60,
71+
fetch_top_k: int = 4,
72+
) -> Sequence[dict[str, Any]]:
73+
"""
74+
Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources.
75+
76+
Args:
77+
primary_search_results: A list of (document, distance) tuples from
78+
the primary search.
79+
secondary_search_results: A list of (document, distance) tuples from
80+
the secondary search.
81+
rrf_k: The RRF parameter k.
82+
Defaults to 60.
83+
fetch_top_k: The number of documents to fetch after merging the results.
84+
Defaults to 4.
85+
86+
Returns:
87+
A list of (document_id, rrf_score) tuples, sorted by rrf_score
88+
in descending order.
89+
"""
90+
rrf_scores: dict[str, dict[str, Any]] = {}
91+
92+
# Process results from primary source
93+
for rank, row in enumerate(
94+
sorted(primary_search_results, key=lambda item: item["distance"], reverse=True)
95+
):
96+
values = list(row.values())
97+
doc_id = str(values[0])
98+
row_values = dict(row)
99+
primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0
100+
primary_score += 1.0 / (rank + rrf_k)
101+
row_values["distance"] = primary_score
102+
rrf_scores[doc_id] = row_values
103+
104+
# Process results from secondary source
105+
for rank, row in enumerate(
106+
sorted(
107+
secondary_search_results, key=lambda item: item["distance"], reverse=True
108+
)
109+
):
110+
values = list(row.values())
111+
doc_id = str(values[0])
112+
row_values = dict(row)
113+
secondary_score = (
114+
rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0
115+
)
116+
secondary_score += 1.0 / (rank + rrf_k)
117+
row_values["distance"] = secondary_score
118+
rrf_scores[doc_id] = row_values
119+
120+
# Sort the results by rrf score in descending order
121+
# Sort the results by weighted score in descending order
122+
ranked_results = sorted(
123+
rrf_scores.values(), key=lambda item: item["distance"], reverse=True
124+
)
125+
# Extract only the RowMapping for the top results
126+
return ranked_results[:fetch_top_k]
127+
128+
129+
@dataclass
130+
class HybridSearchConfig(ABC):
131+
"""Google AlloyDB Vector Store Hybrid Search Config."""
132+
133+
tsv_column: Optional[str] = ""
134+
tsv_lang: Optional[str] = "pg_catalog.english"
135+
fts_query: Optional[str] = ""
136+
fusion_function: Callable[
137+
[Sequence[RowMapping], Sequence[RowMapping], Any], Sequence[Any]
138+
] = weighted_sum_ranking # Updated default
139+
fusion_function_parameters: dict[str, Any] = field(default_factory=dict)
140+
primary_top_k: int = 4
141+
secondary_top_k: int = 4
142+
index_name: str = "langchain_tsv_index"
143+
index_type: str = "GIN"
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import pytest
2+
3+
from langchain_postgres.v2.hybrid_search_config import (reciprocal_rank_fusion,
4+
weighted_sum_ranking)
5+
6+
7+
# Helper to create mock input items that mimic RowMapping for the fusion functions
8+
def get_row(doc_id: str, score: float, content: str = "content") -> dict:
9+
"""
10+
Simulates a RowMapping-like dictionary.
11+
The fusion functions expect to extract doc_id as the first value and
12+
the initial score/distance as the last value when casting values from RowMapping.
13+
They then operate on dictionaries, using the 'distance' key for the fused score.
14+
"""
15+
# Python dicts maintain insertion order (Python 3.7+).
16+
# This structure ensures list(row.values())[0] is doc_id and
17+
# list(row.values())[-1] is score.
18+
return {"id_val": doc_id, "content_field": content, "distance": score}
19+
20+
21+
class TestWeightedSumRanking:
22+
def test_empty_inputs(self):
23+
results = weighted_sum_ranking([], [])
24+
assert results == []
25+
26+
def test_primary_only(self):
27+
primary = [get_row("p1", 0.8), get_row("p2", 0.6)]
28+
# Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3
29+
results = weighted_sum_ranking(
30+
primary, [], primary_results_weight=0.5, secondary_results_weight=0.5
31+
)
32+
assert len(results) == 2
33+
assert results[0]["id_val"] == "p1"
34+
assert results[0]["distance"] == pytest.approx(0.4)
35+
assert results[1]["id_val"] == "p2"
36+
assert results[1]["distance"] == pytest.approx(0.3)
37+
38+
def test_secondary_only(self):
39+
secondary = [get_row("s1", 0.9), get_row("s2", 0.7)]
40+
# Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35
41+
results = weighted_sum_ranking(
42+
[], secondary, primary_results_weight=0.5, secondary_results_weight=0.5
43+
)
44+
assert len(results) == 2
45+
assert results[0]["id_val"] == "s1"
46+
assert results[0]["distance"] == pytest.approx(0.45)
47+
assert results[1]["id_val"] == "s2"
48+
assert results[1]["distance"] == pytest.approx(0.35)
49+
50+
def test_mixed_results_default_weights(self):
51+
primary = [get_row("common", 0.8), get_row("p_only", 0.7)]
52+
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)]
53+
# Weights are 0.5, 0.5
54+
# common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85
55+
# p_only_score = (0.7 * 0.5) = 0.35
56+
# s_only_score = (0.6 * 0.5) = 0.30
57+
# Order: common (0.85), p_only (0.35), s_only (0.30)
58+
59+
results = weighted_sum_ranking(primary, secondary)
60+
assert len(results) == 3
61+
assert results[0]["id_val"] == "common"
62+
assert results[0]["distance"] == pytest.approx(0.85)
63+
assert results[1]["id_val"] == "p_only"
64+
assert results[1]["distance"] == pytest.approx(0.35)
65+
assert results[2]["id_val"] == "s_only"
66+
assert results[2]["distance"] == pytest.approx(0.30)
67+
68+
def test_mixed_results_custom_weights(self):
69+
primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2
70+
secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4
71+
# Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6
72+
73+
results = weighted_sum_ranking(
74+
primary, secondary, primary_results_weight=0.2, secondary_results_weight=0.8
75+
)
76+
assert len(results) == 1
77+
assert results[0]["id_val"] == "d1"
78+
assert results[0]["distance"] == pytest.approx(0.6)
79+
80+
def test_fetch_top_k(self):
81+
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)]
82+
# Scores: 1.0, 0.9, 0.8, 0.7, 0.6
83+
# Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3
84+
secondary = []
85+
results = weighted_sum_ranking(primary, secondary, fetch_top_k=2)
86+
assert len(results) == 2
87+
assert results[0]["id_val"] == "p0"
88+
assert results[0]["distance"] == pytest.approx(0.5)
89+
assert results[1]["id_val"] == "p1"
90+
assert results[1]["distance"] == pytest.approx(0.45)
91+
92+
93+
class TestReciprocalRankFusion:
94+
def test_empty_inputs(self):
95+
results = reciprocal_rank_fusion([], [])
96+
assert results == []
97+
98+
def test_primary_only(self):
99+
primary = [
100+
get_row("p1", 0.8),
101+
get_row("p2", 0.6),
102+
] # p1 rank 0, p2 rank 1
103+
rrf_k = 60
104+
# p1_score = 1 / (0 + 60)
105+
# p2_score = 1 / (1 + 60)
106+
results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k)
107+
assert len(results) == 2
108+
assert results[0]["id_val"] == "p1"
109+
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
110+
assert results[1]["id_val"] == "p2"
111+
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
112+
113+
def test_secondary_only(self):
114+
secondary = [
115+
get_row("s1", 0.9),
116+
get_row("s2", 0.7),
117+
] # s1 rank 0, s2 rank 1
118+
rrf_k = 60
119+
results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k)
120+
assert len(results) == 2
121+
assert results[0]["id_val"] == "s1"
122+
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
123+
assert results[1]["id_val"] == "s2"
124+
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
125+
126+
def test_mixed_results_default_k(self):
127+
primary = [get_row("common", 0.8), get_row("p_only", 0.7)]
128+
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)]
129+
rrf_k = 60
130+
# common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k
131+
# p_only_score = (1/(1+k))_prim = 1/(k+1)
132+
# s_only_score = (1/(1+k))_sec = 1/(k+1)
133+
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k)
134+
assert len(results) == 3
135+
assert results[0]["id_val"] == "common"
136+
assert results[0]["distance"] == pytest.approx(2.0 / rrf_k)
137+
# Check the next two elements, their order might vary due to tie in score
138+
next_ids = {results[1]["id_val"], results[2]["id_val"]}
139+
next_scores = {results[1]["distance"], results[2]["distance"]}
140+
assert next_ids == {"p_only", "s_only"}
141+
for score in next_scores:
142+
assert score == pytest.approx(1.0 / (1 + rrf_k))
143+
144+
def test_fetch_top_k_rrf(self):
145+
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)]
146+
secondary = []
147+
rrf_k = 1
148+
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k, fetch_top_k=2)
149+
assert len(results) == 2
150+
assert results[0]["id_val"] == "p0"
151+
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
152+
assert results[1]["id_val"] == "p1"
153+
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
154+
155+
def test_rrf_content_preservation(self):
156+
primary = [get_row("doc1", 0.9, content="Primary Content")]
157+
secondary = [get_row("doc1", 0.8, content="Secondary Content")]
158+
# RRF processes primary then secondary. If a doc is in both,
159+
# the content from the secondary list will overwrite primary's.
160+
results = reciprocal_rank_fusion(primary, secondary, rrf_k=60)
161+
assert len(results) == 1
162+
assert results[0]["id_val"] == "doc1"
163+
assert results[0]["content_field"] == "Secondary Content"
164+
165+
# If only in primary
166+
results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60)
167+
assert results_prim_only[0]["content_field"] == "Primary Content"
168+
169+
def test_reordering_from_inputs_rrf(self):
170+
"""
171+
Tests that RRF fused ranking can be different from both primary and secondary
172+
input rankings.
173+
Primary Order: A, B, C
174+
Secondary Order: C, B, A
175+
Fused Order: (A, C) tied, then B
176+
"""
177+
primary = [
178+
get_row("docA", 0.9),
179+
get_row("docB", 0.8),
180+
get_row("docC", 0.1),
181+
]
182+
secondary = [
183+
get_row("docC", 0.9),
184+
get_row("docB", 0.5),
185+
get_row("docA", 0.2),
186+
]
187+
rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation
188+
# docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3
189+
# docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1
190+
# docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3
191+
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k)
192+
assert len(results) == 3
193+
assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"}
194+
assert results[0]["distance"] == pytest.approx(4.0 / 3.0)
195+
assert results[1]["distance"] == pytest.approx(4.0 / 3.0)
196+
assert results[2]["id_val"] == "docB"
197+
assert results[2]["distance"] == pytest.approx(1.0)
198+
199+
def test_reordering_from_inputs_weighted_sum(self):
200+
"""
201+
Tests that the fused ranking can be different from both primary and secondary
202+
input rankings.
203+
Primary Order: A (0.9), B (0.7)
204+
Secondary Order: B (0.8), A (0.2)
205+
Fusion (0.5/0.5 weights):
206+
docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55
207+
docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75
208+
Expected Fused Order: docB (0.75), docA (0.55)
209+
This is different from Primary (A,B) and Secondary (B,A) in terms of
210+
original score, but the fusion logic changes the effective contribution).
211+
"""
212+
primary = [get_row("docA", 0.9), get_row("docB", 0.7)]
213+
secondary = [get_row("docB", 0.8), get_row("docA", 0.2)]
214+
215+
results = weighted_sum_ranking(primary, secondary)
216+
assert len(results) == 2
217+
assert results[0]["id_val"] == "docB"
218+
assert results[0]["distance"] == pytest.approx(0.75)
219+
assert results[1]["id_val"] == "docA"
220+
assert results[1]["distance"] == pytest.approx(0.55)

0 commit comments

Comments
 (0)