|
| 1 | +"""Code adapted from student assignment Computational Biology 2024, Ghent University.""" |
| 2 | + |
| 3 | +from typing import Callable |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from numba import jit, prange |
| 7 | +from sklearn.neighbors import BallTree |
| 8 | + |
| 9 | +from flowsom.models.numpy_numba import nb_median_axis_0 |
| 10 | + |
| 11 | + |
| 12 | +@jit(nopython=True, fastmath=True) |
| 13 | +def eucl_without_sqrt(p1: np.ndarray, p2: np.ndarray): |
| 14 | + """Function that computes the Euclidean distance between two points without taking the square root. |
| 15 | +
|
| 16 | + For performance reasons, the square root is not taken. This is useful when comparing distances, because the square |
| 17 | + root is a monotonic function, meaning that the order of the distances is preserved. |
| 18 | +
|
| 19 | + Args: |
| 20 | + p1 (np.ndarray): The first point. |
| 21 | + p2 (np.ndarray): The second point. |
| 22 | +
|
| 23 | + Returns |
| 24 | + ------- |
| 25 | + float: The Euclidean distance between the two points. |
| 26 | +
|
| 27 | + >>> eucl_without_sqrt(np.array([1, 2, 3]), np.array([4, 5, 6])) |
| 28 | + 27.0 |
| 29 | + """ |
| 30 | + distance = 0.0 |
| 31 | + for j in range(p1.shape[0]): |
| 32 | + diff = p1[j] - p2[j] |
| 33 | + distance += diff * diff |
| 34 | + return distance |
| 35 | + |
| 36 | + |
| 37 | +@jit(nopython=True, parallel=True, fastmath=True) |
| 38 | +def SOM_Batch( |
| 39 | + data: np.ndarray, |
| 40 | + codes: np.ndarray, |
| 41 | + nhbrdist: np.ndarray, |
| 42 | + alphas: tuple, |
| 43 | + radii: tuple, |
| 44 | + ncodes: int, |
| 45 | + rlen: int, |
| 46 | + num_batches: int = 10, |
| 47 | + distf: Callable[[np.ndarray, np.ndarray], float] = eucl_without_sqrt, |
| 48 | + seed=None, |
| 49 | +): |
| 50 | + """Function that computes the Self-Organizing Map. |
| 51 | +
|
| 52 | + Args: |
| 53 | + data (np.ndarray): The data to be clustered. |
| 54 | + codes (np.ndarray): The initial codes. |
| 55 | + nhbrdist (np.ndarray): The neighbourhood distances. |
| 56 | + alphas (tuple): The alphas. |
| 57 | + radii (tuple): The radii. |
| 58 | + ncodes (int): The number of codes. |
| 59 | + rlen (int): The number of iterations. |
| 60 | + num_batches (int): The number of batches. |
| 61 | + distf (function): The distance function. |
| 62 | + seed (int): The seed for the random number generator. |
| 63 | +
|
| 64 | + Returns |
| 65 | + ------- |
| 66 | + np.ndarray: The computed codes. |
| 67 | + """ |
| 68 | + if seed is not None: |
| 69 | + np.random.seed(seed) |
| 70 | + |
| 71 | + # Number of data points |
| 72 | + n = data[-1].shape[0] |
| 73 | + |
| 74 | + # Dimension of the data |
| 75 | + px = data[0].shape[1] |
| 76 | + |
| 77 | + # Number of iterations |
| 78 | + niter = n |
| 79 | + |
| 80 | + # The threshold is the radius of the neighbourhood, meaning in which range codes are updated. |
| 81 | + # The threshold step decides how much the threshold is decreased each iteration. |
| 82 | + treshold_step = (radii[0] - radii[1]) / niter |
| 83 | + |
| 84 | + # Keep the temporary codes, using the given codes as the initial codes, for every batch |
| 85 | + tmp_codes_all = np.empty((num_batches, ncodes, px), dtype=np.float64) |
| 86 | + |
| 87 | + # Copy the codes as a float64, because the codes are updated in the algorithm |
| 88 | + copy_codes = codes.copy().astype(np.float64) |
| 89 | + |
| 90 | + # Execute some initial serial iterations to get a good init clustering |
| 91 | + xdist = np.empty(ncodes, dtype=np.float64) |
| 92 | + init_threshold = radii[0] |
| 93 | + init_alpha = alphas[0] |
| 94 | + |
| 95 | + for i in range(niter): |
| 96 | + # Choose a random data point |
| 97 | + i = np.random.choice(n) |
| 98 | + |
| 99 | + # Compute the nearest code |
| 100 | + nearest = 0 |
| 101 | + for cd in range(ncodes): |
| 102 | + xdist[cd] = distf(data[0][i, :], copy_codes[cd, :]) |
| 103 | + if xdist[cd] < xdist[nearest]: |
| 104 | + nearest = cd |
| 105 | + |
| 106 | + init_alpha = alphas[0] - (alphas[0] - alphas[1]) * i / (niter * rlen) |
| 107 | + |
| 108 | + for cd in range(ncodes): |
| 109 | + # The neighbourhood distance decides whether the code is updated. This states that the code is only updated |
| 110 | + # if they are close enough to each other. Otherwise, the value stays the same. |
| 111 | + if nhbrdist[cd, nearest] <= init_threshold: |
| 112 | + # Update the code based on the difference between the used data point and the code. |
| 113 | + for j in range(px): |
| 114 | + tmp = data[0][i, j] - copy_codes[cd, j] |
| 115 | + copy_codes[cd, j] += tmp * init_alpha |
| 116 | + |
| 117 | + init_threshold -= treshold_step |
| 118 | + |
| 119 | + # Choose random data points, for the different batches, and the rlen iterations |
| 120 | + data_points_random = np.random.choice(n, num_batches * rlen * n, replace=True) |
| 121 | + |
| 122 | + # Decrease the number of iterations, because the first iterations are already done |
| 123 | + rlen = int(rlen / 2) |
| 124 | + |
| 125 | + for iteration in range(rlen): |
| 126 | + # Execute the batches in parallel |
| 127 | + for batch_nr in prange(num_batches): |
| 128 | + # Keep the temporary codes, using the given codes as the initial codes |
| 129 | + tmp_codes = copy_codes.copy() |
| 130 | + |
| 131 | + # Array for the distances |
| 132 | + xdists = np.empty(ncodes, dtype=np.float64) |
| 133 | + |
| 134 | + # IMPORTANT: When setting the threshold to radii[0], this causes big changes every iteration. This is not |
| 135 | + # wanted, because the algorithm should converge. Therefore, the threshold is decreased every iteration. |
| 136 | + # Update: factor 2 is added, to make the threshold decrease faster. |
| 137 | + threshold = init_threshold - radii[0] * 2 * iteration / rlen |
| 138 | + |
| 139 | + for k in range(iteration * niter, (iteration + 1) * niter): |
| 140 | + # Get the data point |
| 141 | + i = data_points_random[n * rlen * batch_nr + k] |
| 142 | + |
| 143 | + # Compute the nearest code |
| 144 | + nearest = 0 |
| 145 | + for cd in range(ncodes): |
| 146 | + xdists[cd] = distf(data[batch_nr][i, :], tmp_codes[cd, :]) |
| 147 | + if xdists[cd] < xdists[nearest]: |
| 148 | + nearest = cd |
| 149 | + |
| 150 | + if threshold < 1.0: |
| 151 | + threshold = 0.5 |
| 152 | + alpha = init_alpha - (alphas[0] - alphas[1]) * k / (niter * rlen) |
| 153 | + |
| 154 | + for cd in range(ncodes): |
| 155 | + # The neighbourhood distance decided whether the code is updated. This states that the code is only updated |
| 156 | + # if they are close enough to each other. Otherwise, the value stays the same. |
| 157 | + if nhbrdist[cd, nearest] <= threshold: |
| 158 | + # Update the code based on the difference between the used data point and the code. |
| 159 | + for j in range(px): |
| 160 | + tmp = data[batch_nr][i, j] - tmp_codes[cd, j] |
| 161 | + tmp_codes[cd, j] += tmp * alpha |
| 162 | + |
| 163 | + threshold -= treshold_step |
| 164 | + |
| 165 | + tmp_codes_all[batch_nr] = tmp_codes |
| 166 | + |
| 167 | + # Merge the different SOM's together |
| 168 | + copy_codes = nb_median_axis_0(tmp_codes_all).astype(np.float64) |
| 169 | + |
| 170 | + return copy_codes |
| 171 | + |
| 172 | + |
| 173 | +# ChatGPT generated alternative to map_data_to_codes |
| 174 | +def map_data_to_codes(data, codes): |
| 175 | + """Returns a tuple with the indices and distances of the nearest code for each data point. |
| 176 | +
|
| 177 | + Args: |
| 178 | + data (np.ndarray): The data points. |
| 179 | + codes (np.ndarray): The codes that the data points are mapped to. |
| 180 | +
|
| 181 | + Returns |
| 182 | + ------- |
| 183 | + np.ndarray: The indices of the nearest code for each data point. |
| 184 | + np.ndarray: The distances of the nearest code for each data point. |
| 185 | +
|
| 186 | + >>> data_ = np.array([[1, 2, 3], [4, 5, 6]]) |
| 187 | + >>> codes_ = np.array([[1, 2, 3], [4, 5, 6]]) |
| 188 | + >>> map_data_to_codes(data_, codes_) |
| 189 | + (array([0, 1]), array([0., 0.])) |
| 190 | + """ |
| 191 | + # Create a BallTree for the codes (this is an efficient data structure for nearest neighbor search) |
| 192 | + tree = BallTree(codes, metric="euclidean") |
| 193 | + |
| 194 | + # Query the BallTree to find the nearest code for each data point (k=1 means we only want the nearest neighbor) |
| 195 | + dists, indices = tree.query(data, k=1) |
| 196 | + |
| 197 | + # Flatten the results and return them |
| 198 | + return indices.flatten(), dists.flatten() |
0 commit comments