Skip to content

Commit dc4301d

Browse files
committed
2 parents dfe36c8 + 29b03d1 commit dc4301d

File tree

10 files changed

+45
-45
lines changed

10 files changed

+45
-45
lines changed

.github/workflows/install-test-conda-forge.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
- name: Install package
3131
run: |
3232
conda info
33-
conda install homonim>=0.4.0
33+
conda install homonim>=0.4.1
3434
conda list
3535
3636
- name: Run CLI fusion test

homonim/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,13 @@
1919
import os
2020
import pathlib
2121
import logging
22-
import warnings
2322

24-
from rasterio.errors import NotGeoreferencedWarning
2523
from homonim.compare import RasterCompare
2624
from homonim.enums import Model, ProcCrs
2725
from homonim.fuse import RasterFuse
2826
from homonim.kernel_model import KernelModel
2927
from homonim.stats import ParamStats
3028

31-
# suppress NotGeoreferencedWarning which rasterio can raise incorrectly
32-
warnings.simplefilter('ignore', category=NotGeoreferencedWarning)
33-
3429
# Add a NullHandler to the package logger to hide logs by default. Applications can then add
3530
# their own handler(s).
3631
log = logging.getLogger(__name__)

homonim/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import yaml
3333
from click.core import ParameterSource
3434
from rasterio.warp import SUPPORTED_RESAMPLING
35+
from rasterio.errors import NotGeoreferencedWarning
3536

3637
from homonim import utils, version, RasterFuse, RasterCompare, ParamStats, ProcCrs, Model
3738
from homonim.errors import ImageFormatError
@@ -145,6 +146,9 @@ def showwarning(message, category, filename, lineno, file=None, line=None):
145146
logger = logging.getLogger(module_name)
146147
logger.warning(str(message))
147148

149+
# suppress NotGeoreferencedWarning which rasterio can raise incorrectly
150+
warnings.simplefilter('ignore', category=NotGeoreferencedWarning)
151+
148152
# redirect orthority warnings to module logger
149153
orig_show_warning = warnings.showwarning
150154
warnings.showwarning = showwarning

homonim/kernel_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _r2_array(
188188
# The above can be expanded and expressed in terms of cv.boxFilter kernel sums as:
189189
ss_res_array = (
190190
((param_array[0] ** 2) * src2_sum) +
191-
(2 * np.product(param_array[:2], axis=0) * src_sum) -
191+
(2 * np.prod(param_array[:2], axis=0) * src_sum) -
192192
(2 * param_array[0] * src_ref_sum) -
193193
(2 * param_array[1] * ref_sum) +
194194
ref2_sum + (mask_sum * (param_array[1] ** 2))

homonim/matched_pair.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def _match_pair_bands(
248248
# absolute & relative distance matrix between src and ref center wavelengths
249249
abs_dist = np.abs(src_wavelengths[:, np.newaxis] - ref_wavelengths[np.newaxis, :])
250250
rel_dist = abs_dist / src_wavelengths[:, np.newaxis]
251+
251252
def greedy_match(dist: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
252253
"""
253254
Greedy matching of src to ref bands based on the provided center wavelength distance matrix,
@@ -257,30 +258,29 @@ def greedy_match(dist: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
257258
"""
258259
# match_idx[i] is the index of the ref band that matches with the ith src band
259260
match_idx = np.array([np.nan] * dist.shape[0])
260-
match_dist = np.array([np.nan] * dist.shape[0]) # distances corresponding to the above matches
261-
262-
# suppress runtime warning on all-Nan slice, as it is expected in normal operation
263-
with warnings.catch_warnings():
264-
warnings.simplefilter('ignore', category=RuntimeWarning)
265-
# repeat until all src or ref bands have been matched
266-
while not all(np.isnan(np.nanmin(dist, axis=1))) or not all(np.isnan(np.nanmin(dist, axis=0))):
267-
# find the row with the smallest distance in it
268-
min_dist = np.nanmin(dist, axis=1)
269-
min_dist_row_idx = np.nanargmin(min_dist)
270-
min_dist_row = dist[min_dist_row_idx, :]
271-
# store match idx and distance for this row
272-
match_idx[min_dist_row_idx] = np.nanargmin(min_dist_row)
273-
match_dist[min_dist_row_idx] = min_dist[min_dist_row_idx]
274-
# set the matched row and col to nan, so that it is not used again
275-
dist[:, int(match_idx[min_dist_row_idx])] = np.nan
276-
dist[min_dist_row_idx, :] = np.nan
261+
match_dist = np.array([np.nan] * dist.shape[0]) # distances corresponding to the above matches
262+
263+
# use masked array rather than nan pass-through to avoid all-nan slice warnings
264+
dist = np.ma.array(dist, mask=np.isnan(dist))
265+
266+
# repeat until all src or ref bands have been matched
267+
while not dist.mask.all():
268+
# find the row with the smallest distance in it
269+
min_dist = dist.min(axis=1)
270+
min_dist_row_idx = np.ma.argmin(min_dist)
271+
min_dist_row = dist[min_dist_row_idx, :]
272+
# store match idx and distance for this row
273+
match_idx[min_dist_row_idx] = np.ma.argmin(min_dist_row)
274+
match_dist[min_dist_row_idx] = min_dist[min_dist_row_idx]
275+
# set the matched row and col to nan, so that it is not used again
276+
dist[:, int(match_idx[min_dist_row_idx])] = np.ma.masked
277+
dist[min_dist_row_idx, :] = np.ma.masked
277278

278279
return match_dist, match_idx
279280

280281
match_dist, match_idx = greedy_match(rel_dist)
281282

282-
# if any of the matched distances are greater than a threshold, raise an informative error,
283-
# or log a warning, depending on `self._force`
283+
# if any of the matched distances are greater than a threshold, raise an informative error
284284
if any(match_dist > MatchedPairReader._max_rel_wavelength_diff):
285285
err_idx = match_dist > MatchedPairReader._max_rel_wavelength_diff
286286
src_err_band_names = list(src_band_names[err_idx])

homonim/raster_array.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,6 @@ def to_rio_dataset(
478478
f'The length of indexes ({len(indexes)}) exceeds the number of bands in the '
479479
f'RasterArray ({self.count})'
480480
)
481-
if rio_dataset.nodata is not None and (
482-
self.nodata is None or not utils.nan_equals(self.nodata, rio_dataset.nodata)
483-
):
484-
warnings.warn(
485-
f"The dataset nodata: {rio_dataset.nodata} does not match the RasterArray nodata: {self.nodata}",
486-
category=ImageFormatWarning
487-
)
488481

489482
if window is None:
490483
# a window defining the region in the dataset corresponding to the RasterArray extents

homonim/raster_pair.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def _auto_block_shape(self, max_block_mem: float = np.inf) -> Tuple[int, int]:
230230
proc_win = self._ref_win if self.proc_crs == ProcCrs.ref else self._src_win
231231
# adjust max_block_mem to represent the size of a block in the highest resolution image, but scaled to the
232232
# equivalent in proc_crs.
233-
src_pix_area = np.product(np.abs(self._src_im.res))
234-
ref_pix_area = np.product(np.abs(self._ref_im.res))
233+
src_pix_area = np.prod(np.abs(self._src_im.res))
234+
ref_pix_area = np.prod(np.abs(self._ref_im.res))
235235
if self.proc_crs == ProcCrs.ref:
236236
mem_scale = src_pix_area / ref_pix_area if ref_pix_area > src_pix_area else 1.
237237
elif self.proc_crs == ProcCrs.src:
@@ -247,7 +247,7 @@ def _auto_block_shape(self, max_block_mem: float = np.inf) -> Tuple[int, int]:
247247
block_shape = np.array((proc_win.height, proc_win.width)).astype('float')
248248

249249
# keep halving the block_shape along the longest dimension until it satisfies max_block_mem
250-
while (np.product(block_shape) * dtype_size) > max_block_mem:
250+
while (np.prod(block_shape) * dtype_size) > max_block_mem:
251251
div_dim = np.argmax(block_shape)
252252
block_shape[div_dim] /= 2
253253

homonim/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def validate_kernel_shape(kernel_shape: Tuple[int, int], model: Model = Model.ga
121121
if not np.all(np.mod(kernel_shape, 2) == 1):
122122
raise ValueError('`kernel_shape` must be odd in both dimensions.')
123123
if model == Model.gain_offset:
124-
if np.product(kernel_shape) < 2:
124+
if np.prod(kernel_shape) < 2:
125125
raise ValueError('`kernel_shape` area should contain at least 2 elements for the gain-offset model.')
126-
elif np.product(kernel_shape) < 25:
126+
elif np.prod(kernel_shape) < 25:
127127
warnings.warn(
128128
'A `kernel_shape` of at least 25 elements is recommended for the gain-offset model.',
129129
category=ConfigWarning
@@ -155,6 +155,8 @@ def overlap_for_kernel(kernel_shape: Tuple[int, int]) -> Tuple[int, int]:
155155

156156
def validate_threads(threads: int) -> int:
157157
""" Parse number of threads parameter. """
158+
# TODO: Memory increases ~linearly with number of threads, but does processing speed? The bottleneck is often
159+
# file IO & I am not sure >2 threads as a default is justified.
158160
_cpu_count = cpu_count()
159161
threads = _cpu_count if threads == 0 else threads
160162
if threads > _cpu_count:

homonim/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.4.0'
1+
__version__ = '0.4.1'

tests/test_matched_pair.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
"""
1919

2020
from pathlib import Path
21-
from typing import Tuple, List, Dict
21+
from typing import Tuple, List
22+
import warnings
2223

2324
import numpy as np
2425
import pytest
2526
import rasterio as rio
26-
from rasterio import Affine
27-
from rasterio.windows import Window
2827

2928
from homonim import utils
3029
from homonim.matched_pair import MatchedPairReader
30+
from homonim.errors import HomonimWarning
3131

3232

3333
@pytest.mark.parametrize(['file', 'bands', 'exp_bands', 'exp_band_names', 'exp_wavelengths'], [
@@ -104,9 +104,15 @@ def test_match(
104104
""" Test matching of different source and reference files. """
105105
src_file: Path = request.getfixturevalue(src_file)
106106
ref_file: Path = request.getfixturevalue(ref_file)
107-
with MatchedPairReader(src_file, ref_file, src_bands=src_bands, ref_bands=ref_bands, force=force) as matched_pair:
108-
assert all(np.array(matched_pair.src_bands) == exp_src_bands)
109-
assert all(np.array(matched_pair.ref_bands) == exp_ref_bands)
107+
108+
with warnings.catch_warnings():
109+
# test there are no all-nan warnings by turning them RuntimeWarning into an error, while allowing
110+
# HomonimWarning which sub-classes RuntimeWarning
111+
warnings.simplefilter("error", category=RuntimeWarning)
112+
warnings.simplefilter("default", category=HomonimWarning)
113+
with MatchedPairReader(src_file, ref_file, src_bands=src_bands, ref_bands=ref_bands, force=force) as matched_pair:
114+
assert all(np.array(matched_pair.src_bands) == exp_src_bands)
115+
assert all(np.array(matched_pair.ref_bands) == exp_ref_bands)
110116

111117

112118
def test_match_fewer_ref_bands_error(s2_ref_file, landsat_ref_file):

0 commit comments

Comments
 (0)