Skip to content

Commit 05c602f

Browse files
authored
Merge pull request #13 from saeyslab/batchSOM
Initial implementation with batch SOM
2 parents 211ad7c + e728cae commit 05c602f

File tree

13 files changed

+1017
-2
lines changed

13 files changed

+1017
-2
lines changed

docs/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ For more background information, see the paper for this software package {cite:p
6262
:toctree: generated
6363
6464
models.FlowSOMEstimator
65+
models.BatchFlowSOMEstimator
6566
models.SOMEstimator
67+
models.BatchSOMEstimator
6668
models.ConsensusCluster
6769
models.BaseClusterEstimator
6870
models.BaseFlowSOMEstimator

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
:maxdepth: 1
88
99
notebooks/example
10+
notebooks/parallel
1011
api.md
1112
changelog.md
1213
contributing.md

docs/notebooks/parallel.ipynb

Lines changed: 369 additions & 0 deletions
Large diffs are not rendered by default.

src/flowsom/models/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .base_cluster_estimator import BaseClusterEstimator # isort:skip
33
from .som_estimator import SOMEstimator # isort:skip
44
from .base_flowsom_estimator import BaseFlowSOMEstimator # isort:skip
5-
from .consensus_cluster import ConsensusCluster
6-
from .flowsom_estimator import FlowSOMEstimator
5+
from .consensus_cluster import ConsensusCluster # isort:skip
6+
from .flowsom_estimator import FlowSOMEstimator # isort:skip
7+
from .batch_flowsom_estimator import BatchFlowSOMEstimator # isort:skip

src/flowsom/models/batch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from ._som import SOM_Batch, map_data_to_codes # isort:skip
2+
from .som_estimator import BatchSOMEstimator # isort:skip

src/flowsom/models/batch/_som.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""Code adapted from student assignment Computational Biology 2024, Ghent University."""
2+
3+
from typing import Callable
4+
5+
import numpy as np
6+
from numba import jit, prange
7+
from sklearn.neighbors import BallTree
8+
9+
from flowsom.models.numpy_numba import nb_median_axis_0
10+
11+
12+
@jit(nopython=True, fastmath=True)
13+
def eucl_without_sqrt(p1: np.ndarray, p2: np.ndarray):
14+
"""Function that computes the Euclidean distance between two points without taking the square root.
15+
16+
For performance reasons, the square root is not taken. This is useful when comparing distances, because the square
17+
root is a monotonic function, meaning that the order of the distances is preserved.
18+
19+
Args:
20+
p1 (np.ndarray): The first point.
21+
p2 (np.ndarray): The second point.
22+
23+
Returns
24+
-------
25+
float: The Euclidean distance between the two points.
26+
27+
>>> eucl_without_sqrt(np.array([1, 2, 3]), np.array([4, 5, 6]))
28+
27.0
29+
"""
30+
distance = 0.0
31+
for j in range(p1.shape[0]):
32+
diff = p1[j] - p2[j]
33+
distance += diff * diff
34+
return distance
35+
36+
37+
@jit(nopython=True, parallel=True, fastmath=True)
38+
def SOM_Batch(
39+
data: np.ndarray,
40+
codes: np.ndarray,
41+
nhbrdist: np.ndarray,
42+
alphas: tuple,
43+
radii: tuple,
44+
ncodes: int,
45+
rlen: int,
46+
num_batches: int = 10,
47+
distf: Callable[[np.ndarray, np.ndarray], float] = eucl_without_sqrt,
48+
seed=None,
49+
):
50+
"""Function that computes the Self-Organizing Map.
51+
52+
Args:
53+
data (np.ndarray): The data to be clustered.
54+
codes (np.ndarray): The initial codes.
55+
nhbrdist (np.ndarray): The neighbourhood distances.
56+
alphas (tuple): The alphas.
57+
radii (tuple): The radii.
58+
ncodes (int): The number of codes.
59+
rlen (int): The number of iterations.
60+
num_batches (int): The number of batches.
61+
distf (function): The distance function.
62+
seed (int): The seed for the random number generator.
63+
64+
Returns
65+
-------
66+
np.ndarray: The computed codes.
67+
"""
68+
if seed is not None:
69+
np.random.seed(seed)
70+
71+
# Number of data points
72+
n = data[-1].shape[0]
73+
74+
# Dimension of the data
75+
px = data[0].shape[1]
76+
77+
# Number of iterations
78+
niter = n
79+
80+
# The threshold is the radius of the neighbourhood, meaning in which range codes are updated.
81+
# The threshold step decides how much the threshold is decreased each iteration.
82+
treshold_step = (radii[0] - radii[1]) / niter
83+
84+
# Keep the temporary codes, using the given codes as the initial codes, for every batch
85+
tmp_codes_all = np.empty((num_batches, ncodes, px), dtype=np.float64)
86+
87+
# Copy the codes as a float64, because the codes are updated in the algorithm
88+
copy_codes = codes.copy().astype(np.float64)
89+
90+
# Execute some initial serial iterations to get a good init clustering
91+
xdist = np.empty(ncodes, dtype=np.float64)
92+
init_threshold = radii[0]
93+
init_alpha = alphas[0]
94+
95+
for i in range(niter):
96+
# Choose a random data point
97+
i = np.random.choice(n)
98+
99+
# Compute the nearest code
100+
nearest = 0
101+
for cd in range(ncodes):
102+
xdist[cd] = distf(data[0][i, :], copy_codes[cd, :])
103+
if xdist[cd] < xdist[nearest]:
104+
nearest = cd
105+
106+
init_alpha = alphas[0] - (alphas[0] - alphas[1]) * i / (niter * rlen)
107+
108+
for cd in range(ncodes):
109+
# The neighbourhood distance decides whether the code is updated. This states that the code is only updated
110+
# if they are close enough to each other. Otherwise, the value stays the same.
111+
if nhbrdist[cd, nearest] <= init_threshold:
112+
# Update the code based on the difference between the used data point and the code.
113+
for j in range(px):
114+
tmp = data[0][i, j] - copy_codes[cd, j]
115+
copy_codes[cd, j] += tmp * init_alpha
116+
117+
init_threshold -= treshold_step
118+
119+
# Choose random data points, for the different batches, and the rlen iterations
120+
data_points_random = np.random.choice(n, num_batches * rlen * n, replace=True)
121+
122+
# Decrease the number of iterations, because the first iterations are already done
123+
rlen = int(rlen / 2)
124+
125+
for iteration in range(rlen):
126+
# Execute the batches in parallel
127+
for batch_nr in prange(num_batches):
128+
# Keep the temporary codes, using the given codes as the initial codes
129+
tmp_codes = copy_codes.copy()
130+
131+
# Array for the distances
132+
xdists = np.empty(ncodes, dtype=np.float64)
133+
134+
# IMPORTANT: When setting the threshold to radii[0], this causes big changes every iteration. This is not
135+
# wanted, because the algorithm should converge. Therefore, the threshold is decreased every iteration.
136+
# Update: factor 2 is added, to make the threshold decrease faster.
137+
threshold = init_threshold - radii[0] * 2 * iteration / rlen
138+
139+
for k in range(iteration * niter, (iteration + 1) * niter):
140+
# Get the data point
141+
i = data_points_random[n * rlen * batch_nr + k]
142+
143+
# Compute the nearest code
144+
nearest = 0
145+
for cd in range(ncodes):
146+
xdists[cd] = distf(data[batch_nr][i, :], tmp_codes[cd, :])
147+
if xdists[cd] < xdists[nearest]:
148+
nearest = cd
149+
150+
if threshold < 1.0:
151+
threshold = 0.5
152+
alpha = init_alpha - (alphas[0] - alphas[1]) * k / (niter * rlen)
153+
154+
for cd in range(ncodes):
155+
# The neighbourhood distance decided whether the code is updated. This states that the code is only updated
156+
# if they are close enough to each other. Otherwise, the value stays the same.
157+
if nhbrdist[cd, nearest] <= threshold:
158+
# Update the code based on the difference between the used data point and the code.
159+
for j in range(px):
160+
tmp = data[batch_nr][i, j] - tmp_codes[cd, j]
161+
tmp_codes[cd, j] += tmp * alpha
162+
163+
threshold -= treshold_step
164+
165+
tmp_codes_all[batch_nr] = tmp_codes
166+
167+
# Merge the different SOM's together
168+
copy_codes = nb_median_axis_0(tmp_codes_all).astype(np.float64)
169+
170+
return copy_codes
171+
172+
173+
# ChatGPT generated alternative to map_data_to_codes
174+
def map_data_to_codes(data, codes):
175+
"""Returns a tuple with the indices and distances of the nearest code for each data point.
176+
177+
Args:
178+
data (np.ndarray): The data points.
179+
codes (np.ndarray): The codes that the data points are mapped to.
180+
181+
Returns
182+
-------
183+
np.ndarray: The indices of the nearest code for each data point.
184+
np.ndarray: The distances of the nearest code for each data point.
185+
186+
>>> data_ = np.array([[1, 2, 3], [4, 5, 6]])
187+
>>> codes_ = np.array([[1, 2, 3], [4, 5, 6]])
188+
>>> map_data_to_codes(data_, codes_)
189+
(array([0, 1]), array([0., 0.]))
190+
"""
191+
# Create a BallTree for the codes (this is an efficient data structure for nearest neighbor search)
192+
tree = BallTree(codes, metric="euclidean")
193+
194+
# Query the BallTree to find the nearest code for each data point (k=1 means we only want the nearest neighbor)
195+
dists, indices = tree.query(data, k=1)
196+
197+
# Flatten the results and return them
198+
return indices.flatten(), dists.flatten()

0 commit comments

Comments
 (0)