Skip to content

Commit e728cae

Browse files
committed
add tests
1 parent f96267a commit e728cae

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

src/flowsom/models/batch/_som.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from __future__ import annotations
1+
"""Code adapted from student assignment Computational Biology 2024, Ghent University."""
22

33
from typing import Callable
44

src/flowsom/models/batch/som_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from . import SOM_Batch, map_data_to_codes
99

1010

11+
# TODO: try to use the same code for both SOMEstimator and BatchSOMEstimator
1112
class BatchSOMEstimator(BaseClusterEstimator):
1213
"""Estimate a Self-Organizing Map (SOM) clustering model."""
1314

tests/models/test_BatchFlowSOM.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from sklearn.metrics import v_measure_score
2+
3+
from flowsom.models import BatchFlowSOMEstimator
4+
5+
6+
def test_clustering(X):
7+
fsom = BatchFlowSOMEstimator(n_clusters=10)
8+
y_pred = fsom.fit_predict(X)
9+
assert y_pred.shape == (100,)
10+
11+
12+
def test_clustering_v_measure(X_and_y):
13+
som = BatchFlowSOMEstimator(n_clusters=10)
14+
X, y_true = X_and_y
15+
y_pred = som.fit_predict(X)
16+
score = v_measure_score(y_true, y_pred)
17+
assert score > 0.7
18+
19+
20+
def test_reproducibility_no_seed(X):
21+
fsom_1 = BatchFlowSOMEstimator(n_clusters=10)
22+
fsom_2 = BatchFlowSOMEstimator(n_clusters=10)
23+
y_pred_1 = fsom_1.fit_predict(X)
24+
y_pred_2 = fsom_2.fit_predict(X)
25+
26+
assert not all(y_pred_1 == y_pred_2)
27+
28+
29+
def test_reproducibility_seed(X):
30+
fsom_1 = BatchFlowSOMEstimator(n_clusters=10, seed=0)
31+
fsom_2 = BatchFlowSOMEstimator(n_clusters=10, seed=0)
32+
y_pred_1 = fsom_1.fit_predict(X)
33+
y_pred_2 = fsom_2.fit_predict(X)
34+
35+
assert all(y_pred_1 == y_pred_2)

0 commit comments

Comments
 (0)