Skip to content

Commit 9e8078b

Browse files
committed
Add ability to unflip individual signs in get_corrected_centroid_guesses() helper.
1 parent 2b628c3 commit 9e8078b

File tree

5 files changed

+57
-13
lines changed

5 files changed

+57
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Add `print_selected_statements()` presenter for inspecting `PolisClusteringResult`.
1414
- Add `print_consensus_statements()` presenter for inspecting `PolisClusteringResult`.
1515
- Allow `pick_max` and `confidence` interval args to be set in `polis.run_clustering()`.
16+
- Allow `get_corrected_centroid_guesses()` to unflip each axis if correction not needed.
1617

1718
### Chores
1819

docs/notebooks/polis-implementation-demo.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@
132132
"# We can run this from scratch, but kmeans is non-deterministic and might find slightly different clusters\n",
133133
"# or even different k-values (number of groups) if the silhouette scores it finds are better.\n",
134134
"# To show how to reproduce Polis results, we'll set init guess coordinates that we know polis platform got:\n",
135-
"init_cluster_center_guesses = get_corrected_centroid_guesses(math_data, skip_correction=False)\n",
135+
"init_cluster_center_guesses = get_corrected_centroid_guesses(math_data)\n",
136136
"print(f\"{init_cluster_center_guesses=}\")"
137137
],
138138
"metadata": {

docs/notebooks/polis-implementation-results-docs.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@
172172
"# Prepare some optional data to kickstart\n",
173173
"\n",
174174
"# KMeans is only reproducible when it starts with previous cluster center guesses.\n",
175-
"INIT_CLUSTER_CENTER_GUESSES = get_corrected_centroid_guesses(loader.math_data, skip_correction=False)\n",
175+
"INIT_CLUSTER_CENTER_GUESSES = get_corrected_centroid_guesses(loader.math_data)\n",
176176
"\n",
177177
"# Polis has some edge-cases logic that keeps arbitrary [early] participants in\n",
178178
"# the clustering algorithm for reasons that are hard to reproduce, so we\n",

reddwarf/utils/polismath.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,16 @@ def create_bidirectional_id_maps(base_clusters):
156156
def get_corrected_centroid_guesses(
157157
polis_math_data: dict,
158158
source: Literal["group-clusters", "base-clusters"] = "group-clusters",
159-
skip_correction: bool = False,
159+
flip_x: bool = True,
160+
flip_y: bool = True,
160161
):
161162
"""
162163
A helper to extract and correct centroid guesses from polismath data.
163164
164165
This is helpful to seed new rounds of KMeans locally.
165166
166167
NOTE: It seems that there's an inversion somewhere in the polis codebase, and so
167-
we need to invert their coordinates in relation to ours.
168+
we usually (but not always) need to invert their coordinates in relation to ours.
168169
169170
Given the consistency of this intervention, it's likely not a PCA artifact,
170171
and instead relates to using inverting the sign of agree/disagree in
@@ -177,7 +178,8 @@ def get_corrected_centroid_guesses(
177178
Arguments:
178179
polis_math_data (dict): The polismath data from the Polis API
179180
source (str): Where to extract centroid data from. One of "group-clusters" or "base-clusters".
180-
skip_correction (bool): Whether to skip correction (helpful for debugging)
181+
flip_x (bool): Whether to correctively flip X sign (usually helpful)
182+
flip_y (bool): Whether to correctively flip Y sign (usually helpful)
181183
182184
Returns:
183185
centroids: A list of centroid [x,y] coord guesses
@@ -197,8 +199,11 @@ def get_corrected_centroid_guesses(
197199
# Defensive fallback in case `source` comes from dynamic or untyped input
198200
raise ValueError(f"Unknown source '{source}'. Must be 'group-clusters' or 'base-clusters'.")
199201

200-
if skip_correction:
201-
return extracted_centroids
202-
else:
203-
corrected_centroids = [[-xy[0], -xy[1]] for xy in extracted_centroids]
204-
return corrected_centroids
202+
corrected_centroids = [
203+
[
204+
-xy[0] if flip_x else xy[0],
205+
-xy[1] if flip_y else xy[1],
206+
] for xy in extracted_centroids
207+
]
208+
209+
return corrected_centroids

tests/utils/test_polismath.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from reddwarf.utils import polismath
33
from tests.fixtures import polis_convo_data
44
from reddwarf.utils import matrix as MatrixUtils
5+
from reddwarf.utils.polismath import get_corrected_centroid_guesses
56
from reddwarf.data_loader import Loader
67

78
# TODO: Figure out if in-conv is always equivalent. If so, remove this test.
@@ -45,6 +46,43 @@ def test_user_vote_count_sanity_check(polis_convo_data):
4546

4647
assert expected_user_vote_counts == actual_user_vote_counts.to_dict()
4748

48-
@pytest.mark.skip
49-
def test_get_corrected_centroid_guesses():
50-
raise NotImplementedError
49+
@pytest.fixture
50+
def mock_math_data():
51+
return {
52+
"group-clusters": [
53+
{"center": [1.0, 2.0]},
54+
{"center": [-3.0, 4.0]}
55+
],
56+
"base-clusters": {
57+
"x": [5.0, -6.0],
58+
"y": [7.0, -8.0]
59+
}
60+
}
61+
62+
def test_get_corrected_centroid_guesses_default_flip(mock_math_data):
63+
result = get_corrected_centroid_guesses(mock_math_data)
64+
assert result == [[-1.0, -2.0], [3.0, -4.0]]
65+
66+
def test_get_corrected_centroid_guesses_unflip_x(mock_math_data):
67+
result = get_corrected_centroid_guesses(mock_math_data, flip_x=False)
68+
assert result == [[1.0, -2.0], [-3.0, -4.0]]
69+
70+
def test_get_corrected_centroid_guesses_unflip_y(mock_math_data):
71+
result = get_corrected_centroid_guesses(mock_math_data, flip_y=False)
72+
assert result == [[-1.0, 2.0], [3.0, 4.0]]
73+
74+
def test_get_corrected_centroid_guesses_group_unflip_both(mock_math_data):
75+
result = get_corrected_centroid_guesses(mock_math_data, source="group-clusters", flip_x=False, flip_y=False)
76+
assert result == [[1.0, 2.0], [-3.0, 4.0]]
77+
78+
def test_get_corrected_centroid_guesses_base_clusters_unflip_x(mock_math_data):
79+
result = get_corrected_centroid_guesses(mock_math_data, source="base-clusters", flip_x=False)
80+
assert result == [[5.0, -7.0], [-6.0, 8.0]]
81+
82+
def test_get_corrected_centroid_guesses_base_clusters_unflip_y(mock_math_data):
83+
result = get_corrected_centroid_guesses(mock_math_data, source="base-clusters", flip_y=False)
84+
assert result == [[-5.0, 7.0], [6.0, -8.0]]
85+
86+
def test_get_corrected_centroid_guesses_unknown_source(mock_math_data):
87+
with pytest.raises(ValueError, match="Unknown source 'invalid-source'"):
88+
get_corrected_centroid_guesses(mock_math_data, source="invalid-source")

0 commit comments

Comments
 (0)