Skip to content

Commit 5b8fe06

Browse files
committed
fix bug related to coordinate names and values in catchstats
1 parent 23d942b commit 5b8fe06

File tree

1 file changed

+81
-37
lines changed

1 file changed

+81
-37
lines changed

src/lisfloodutilities/catchstats/catchstats.py

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@
2121
# from tqdm.auto import tqdm
2222

2323

24+
def check_coordinates(ds: Union[xr.Dataset, xr.DataArray]) -> Union[xr.Dataset, xr.DataArray]:
25+
"""Makes sure that the geographical coordinates are named 'lat' and 'lon' and rounds the values to 9 decimals to avoid conflicts.
26+
"""
27+
28+
# check names of the coordinates
29+
ds = ds.rename(
30+
{k: v for k, v in {'x': 'lon', 'y': 'lat'}.items() if k in ds.sizes})
31+
32+
# round it to avoid issues when comparing with other datasets
33+
ds['lon'] = ds['lon'].round(9)
34+
ds['lat'] = ds['lat'].round(9)
35+
36+
return ds
37+
38+
2439
def read_inputmaps(inputmaps: Union[str, Path]) -> xr.Dataset:
2540
"""It reads the input maps in NetCDF format from the input directory
2641
@@ -38,27 +53,33 @@ def read_inputmaps(inputmaps: Union[str, Path]) -> xr.Dataset:
3853
if not inputmaps.is_dir():
3954
print(f'ERROR: {inputmaps} is missing or not a directory!')
4055
sys.exit(1)
41-
56+
4257
filepaths = list(inputmaps.glob('*.nc'))
4358
if not filepaths:
4459
print(f'ERROR: No NetCDF files found in "{inputmaps}"')
4560
sys.exit(2)
4661

4762
print(f'{len(filepaths)} input NetCDF files found in "{inputmaps}"')
48-
63+
4964
try:
5065
# for dynamic maps
51-
ds = xr.open_mfdataset(filepaths, chunks='auto', parallel=True, engine='netcdf4')
66+
ds = xr.open_mfdataset(filepaths, chunks='auto',
67+
parallel=True, engine='netcdf4')
5268
# chunks is set to auto for general purpose processing
5369
# it could be optimized depending on input NetCDF
5470
except:
5571
# for static maps
56-
ds = xr.Dataset({file.stem.split('_')[0]: xr.open_dataset(file, engine='netcdf4')['Band1'] for file in filepaths})
72+
ds = xr.Dataset({file.stem.split('_')[0]: xr.open_dataset(
73+
file, engine='netcdf4')['Band1'] for file in filepaths})
5774
if 'wgs_1984' in ds:
5875
ds = ds.drop_vars('wgs_1984')
5976

77+
# check coordinates
78+
ds = check_coordinates(ds)
79+
6080
return ds
6181

82+
6283
def read_masks(mask: Union[str, Path]) -> Dict[int, xr.DataArray]:
6384
"""It loads the catchment masks in NetCDF formal from the input directory
6485
@@ -83,29 +104,31 @@ def read_masks(mask: Union[str, Path]) -> Dict[int, xr.DataArray]:
83104
if not maskpaths:
84105
print(f'ERROR: No NetCDF files found in "{mask}"')
85106
sys.exit(2)
86-
107+
87108
print(f'{len(maskpaths)} mask NetCDF files found in "{mask}"')
88109

89110
# load masks
90111
masks = {}
91-
for maskpath in maskpaths:
112+
for maskpath in maskpaths:
92113
ID = int(maskpath.stem)
93114
try:
94115
try:
95116
aoi = xr.open_dataset(maskpath, engine='netcdf4')['Band1']
96117
except:
97118
aoi = xr.open_dataarray(maskpath, engine='netcdf4')
98119
aoi = xr.where(aoi.notnull(), 1, aoi)
120+
aoi = check_coordinates(aoi)
99121
masks[ID] = aoi
100122
except Exception as e:
101123
print(f'ERROR: The mask {maskpath} could not be read: {e}')
102124
continue
103125

104126
return masks
105127

128+
106129
def read_pixarea(pixarea: Union[str, Path]) -> xr.DataArray:
107130
"""It reads the LISFLOOD pixel area static map
108-
131+
109132
Parameters:
110133
-----------
111134
pixarea: string or Path
@@ -120,25 +143,29 @@ def read_pixarea(pixarea: Union[str, Path]) -> xr.DataArray:
120143
if not pixarea.is_file():
121144
print(f'ERROR: {pixarea} is not a file!')
122145
sys.exit(1)
123-
146+
124147
try:
125148
weight = xr.open_dataset(pixarea, engine='netcdf4')['Band1']
126149
except Exception as e:
127150
print(f'ERROR: The weighing map "{pixarea}" could not be loaded: {e}')
128151
sys.exit(2)
129152

153+
# check coordinates
154+
weight = check_coordinates(weight)
155+
130156
return weight
131157

158+
132159
def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
133160
masks: Dict[int, xr.DataArray],
134-
statistic: Union[str, List[str]],
161+
statistic: Union[str, List[str]],
135162
weight: Optional[xr.DataArray] = None,
136163
output: Optional[Union[str, Path]] = None,
137164
overwrite: bool = False
138165
) -> Optional[xr.Dataset]:
139166
"""
140167
Given a set of input maps and catchment masks, it computes catchment statistics.
141-
168+
142169
Parameters:
143170
-----------
144171
maps: xarray.DataArray or xarray.Dataset
@@ -153,7 +180,7 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
153180
directory where the resulting NetCDF files will be saved. If not provided, the results are put out as a xr.Dataset
154181
overwrite: boolean
155182
whether to overwrite or skip catchments whose output NetCDF file already exists. By default is False, so the catchment will be skipped
156-
183+
157184
Returns:
158185
--------
159186
A xr.Dataset of all catchment statistics or a NetCDF file for each catchment in the "masks" dictionary
@@ -167,19 +194,21 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
167194
# check statistic
168195
if isinstance(statistic, str):
169196
statistic = [statistic]
170-
possible_stats = ['mean', 'sum', 'std', 'var', 'min', 'max', 'median', 'count']
171-
assert all(stat in possible_stats for stat in statistic), "All values in 'statistic' should be one of these: {0}".format(', '.join(possible_stats))
197+
possible_stats = ['mean', 'sum', 'std',
198+
'var', 'min', 'max', 'median', 'count']
199+
assert all(stat in possible_stats for stat in statistic), "All values in 'statistic' should be one of these: {0}".format(
200+
', '.join(possible_stats))
172201
stats_dict = {var: statistic for var in maps}
173-
202+
174203
# output directory
175204
if output is None:
176205
results = []
177206
else:
178207
output = Path(output)
179208
output.mkdir(parents=True, exist_ok=True)
180-
209+
181210
# define coordinates and variables of the resulting Dataset
182-
dims = dict(maps.dims)
211+
dims = dict(maps.sizes)
183212
dimnames = [dim.lower() for dim in dims]
184213
if 'lat' in dimnames and 'lon' in dimnames:
185214
x_dim, y_dim = 'lon', 'lat'
@@ -188,34 +217,39 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
188217
del dims[x_dim]
189218
del dims[y_dim]
190219
coords = {dim: maps[dim] for dim in dims}
191-
variables = [f'{var}_{stat}' for var, stats in stats_dict.items() for stat in stats]
192-
220+
variables = [f'{var}_{stat}' for var, stats in stats_dict.items()
221+
for stat in stats]
222+
193223
# compute statistics for each catchemnt
194224
# for ID in tqdm(masks.keys(), desc='processing catchments'):
195-
for ID in masks.keys():
225+
for ID in masks.keys():
196226

197227
if output is not None:
198228
fileout = output / f'{ID:04}.nc'
199229
if fileout.exists() and not overwrite:
200-
print(f'Output file {fileout} already exists. Moving forward to the next catchment')
230+
print(
231+
f'Output file {fileout} already exists. Moving forward to the next catchment')
201232
continue
202-
233+
203234
# create empty Dataset
204235
coords.update({'id': [ID]})
205-
maps_aoi = xr.Dataset({var: xr.DataArray(coords=coords, dims=coords.keys()) for var in variables})
206-
236+
maps_aoi = xr.Dataset(
237+
{var: xr.DataArray(coords=coords, dims=coords.keys()) for var in variables})
238+
207239
# apply mask to the dataset
208240
aoi = masks[ID]
209-
masked_maps = maps.sel({x_dim: aoi[x_dim], y_dim: aoi[y_dim]}).where(aoi == 1)
241+
masked_maps = maps.sel(
242+
{x_dim: aoi[x_dim], y_dim: aoi[y_dim]}).where(aoi == 1)
210243
masked_maps = masked_maps.compute()
211244

212245
# apply weighting
213246
if weight is not None:
214-
masked_weight = weight.sel({x_dim: aoi[x_dim], y_dim: aoi[y_dim]}).where(aoi == 1)
215-
weighted_maps = masked_maps.weighted(masked_weight.fillna(0))
247+
masked_weight = weight.sel(
248+
{x_dim: aoi[x_dim], y_dim: aoi[y_dim]}).where(aoi == 1)
249+
weighted_maps = masked_maps.weighted(masked_weight.fillna(0))
216250

217251
# compute statistics
218-
for var, stats in stats_dict.items():
252+
for var, stats in stats_dict.items():
219253
for stat in stats:
220254
if (stat in ['mean', 'sum', 'std', 'var']) and (weight is not None):
221255
x = getattr(weighted_maps, stat)(dim=[x_dim, y_dim])[var]
@@ -236,7 +270,8 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
236270
if output is None:
237271
results = xr.concat(results, dim='id')
238272
return results
239-
273+
274+
240275
def main(argv=sys.argv):
241276
prog = os.path.basename(argv[0])
242277
parser = argparse.ArgumentParser(
@@ -247,26 +282,35 @@ def main(argv=sys.argv):
247282
""",
248283
prog=prog,
249284
)
250-
parser.add_argument("-i", "--input", required=True, help="Directory containing the input NetCDF files")
251-
parser.add_argument("-m", "--mask", required=True, help="Directory containing the mask NetCDF files")
252-
parser.add_argument("-s", "--statistic", nargs='+', required=True, help='List of statistics to be computed. Possible values: mean, sum, std, var, min, max, median, count')
253-
parser.add_argument("-o", "--output", required=True, help="Directory where the output NetCDF files will be saved")
254-
parser.add_argument("-a", "--area", required=False, default=None, help="NetCDF file of pixel area used to weigh the statistics")
255-
parser.add_argument("-w", "--overwrite", action="store_true", default=False, help="Overwrite existing output files")
256-
285+
parser.add_argument("-i", "--input", required=True,
286+
help="Directory containing the input NetCDF files")
287+
parser.add_argument("-m", "--mask", required=True,
288+
help="Directory containing the mask NetCDF files")
289+
parser.add_argument("-s", "--statistic", nargs='+', required=True,
290+
help='List of statistics to be computed. Possible values: mean, sum, std, var, min, max, median, count')
291+
parser.add_argument("-o", "--output", required=True,
292+
help="Directory where the output NetCDF files will be saved")
293+
parser.add_argument("-a", "--area", required=False, default=None,
294+
help="NetCDF file of pixel area used to weigh the statistics")
295+
parser.add_argument("-w", "--overwrite", action="store_true",
296+
default=False, help="Overwrite existing output files")
297+
257298
args = parser.parse_args()
258299

259300
try:
260301
maps = read_inputmaps(args.input)
261302
masks = read_masks(args.mask)
262303
weight = read_pixarea(args.area) if args.area is not None else None
263-
catchment_statistics(maps, masks, args.statistic, weight=weight, output=args.output, overwrite=args.overwrite)
304+
catchment_statistics(maps, masks, args.statistic, weight=weight,
305+
output=args.output, overwrite=args.overwrite)
264306
except Exception as e:
265307
print(f'ERROR: {e}')
266308
sys.exit(1)
267-
309+
310+
268311
def main_script():
269312
sys.exit(main())
270313

314+
271315
if __name__ == "__main__":
272316
main_script()

0 commit comments

Comments
 (0)