21
21
# from tqdm.auto import tqdm
22
22
23
23
24
+ def check_coordinates (ds : Union [xr .Dataset , xr .DataArray ]) -> Union [xr .Dataset , xr .DataArray ]:
25
+ """Makes sure that the geographical coordinates are named 'lat' and 'lon' and rounds the values to 9 decimals to avoid conflicts.
26
+ """
27
+
28
+ # check names of the coordinates
29
+ ds = ds .rename (
30
+ {k : v for k , v in {'x' : 'lon' , 'y' : 'lat' }.items () if k in ds .sizes })
31
+
32
+ # round it to avoid issues when comparing with other datasets
33
+ ds ['lon' ] = ds ['lon' ].round (9 )
34
+ ds ['lat' ] = ds ['lat' ].round (9 )
35
+
36
+ return ds
37
+
38
+
24
39
def read_inputmaps (inputmaps : Union [str , Path ]) -> xr .Dataset :
25
40
"""It reads the input maps in NetCDF format from the input directory
26
41
@@ -38,27 +53,33 @@ def read_inputmaps(inputmaps: Union[str, Path]) -> xr.Dataset:
38
53
if not inputmaps .is_dir ():
39
54
print (f'ERROR: { inputmaps } is missing or not a directory!' )
40
55
sys .exit (1 )
41
-
56
+
42
57
filepaths = list (inputmaps .glob ('*.nc' ))
43
58
if not filepaths :
44
59
print (f'ERROR: No NetCDF files found in "{ inputmaps } "' )
45
60
sys .exit (2 )
46
61
47
62
print (f'{ len (filepaths )} input NetCDF files found in "{ inputmaps } "' )
48
-
63
+
49
64
try :
50
65
# for dynamic maps
51
- ds = xr .open_mfdataset (filepaths , chunks = 'auto' , parallel = True , engine = 'netcdf4' )
66
+ ds = xr .open_mfdataset (filepaths , chunks = 'auto' ,
67
+ parallel = True , engine = 'netcdf4' )
52
68
# chunks is set to auto for general purpose processing
53
69
# it could be optimized depending on input NetCDF
54
70
except :
55
71
# for static maps
56
- ds = xr .Dataset ({file .stem .split ('_' )[0 ]: xr .open_dataset (file , engine = 'netcdf4' )['Band1' ] for file in filepaths })
72
+ ds = xr .Dataset ({file .stem .split ('_' )[0 ]: xr .open_dataset (
73
+ file , engine = 'netcdf4' )['Band1' ] for file in filepaths })
57
74
if 'wgs_1984' in ds :
58
75
ds = ds .drop_vars ('wgs_1984' )
59
76
77
+ # check coordinates
78
+ ds = check_coordinates (ds )
79
+
60
80
return ds
61
81
82
+
62
83
def read_masks (mask : Union [str , Path ]) -> Dict [int , xr .DataArray ]:
63
84
"""It loads the catchment masks in NetCDF formal from the input directory
64
85
@@ -83,29 +104,31 @@ def read_masks(mask: Union[str, Path]) -> Dict[int, xr.DataArray]:
83
104
if not maskpaths :
84
105
print (f'ERROR: No NetCDF files found in "{ mask } "' )
85
106
sys .exit (2 )
86
-
107
+
87
108
print (f'{ len (maskpaths )} mask NetCDF files found in "{ mask } "' )
88
109
89
110
# load masks
90
111
masks = {}
91
- for maskpath in maskpaths :
112
+ for maskpath in maskpaths :
92
113
ID = int (maskpath .stem )
93
114
try :
94
115
try :
95
116
aoi = xr .open_dataset (maskpath , engine = 'netcdf4' )['Band1' ]
96
117
except :
97
118
aoi = xr .open_dataarray (maskpath , engine = 'netcdf4' )
98
119
aoi = xr .where (aoi .notnull (), 1 , aoi )
120
+ aoi = check_coordinates (aoi )
99
121
masks [ID ] = aoi
100
122
except Exception as e :
101
123
print (f'ERROR: The mask { maskpath } could not be read: { e } ' )
102
124
continue
103
125
104
126
return masks
105
127
128
+
106
129
def read_pixarea (pixarea : Union [str , Path ]) -> xr .DataArray :
107
130
"""It reads the LISFLOOD pixel area static map
108
-
131
+
109
132
Parameters:
110
133
-----------
111
134
pixarea: string or Path
@@ -120,25 +143,29 @@ def read_pixarea(pixarea: Union[str, Path]) -> xr.DataArray:
120
143
if not pixarea .is_file ():
121
144
print (f'ERROR: { pixarea } is not a file!' )
122
145
sys .exit (1 )
123
-
146
+
124
147
try :
125
148
weight = xr .open_dataset (pixarea , engine = 'netcdf4' )['Band1' ]
126
149
except Exception as e :
127
150
print (f'ERROR: The weighing map "{ pixarea } " could not be loaded: { e } ' )
128
151
sys .exit (2 )
129
152
153
+ # check coordinates
154
+ weight = check_coordinates (weight )
155
+
130
156
return weight
131
157
158
+
132
159
def catchment_statistics (maps : Union [xr .DataArray , xr .Dataset ],
133
160
masks : Dict [int , xr .DataArray ],
134
- statistic : Union [str , List [str ]],
161
+ statistic : Union [str , List [str ]],
135
162
weight : Optional [xr .DataArray ] = None ,
136
163
output : Optional [Union [str , Path ]] = None ,
137
164
overwrite : bool = False
138
165
) -> Optional [xr .Dataset ]:
139
166
"""
140
167
Given a set of input maps and catchment masks, it computes catchment statistics.
141
-
168
+
142
169
Parameters:
143
170
-----------
144
171
maps: xarray.DataArray or xarray.Dataset
@@ -153,7 +180,7 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
153
180
directory where the resulting NetCDF files will be saved. If not provided, the results are put out as a xr.Dataset
154
181
overwrite: boolean
155
182
whether to overwrite or skip catchments whose output NetCDF file already exists. By default is False, so the catchment will be skipped
156
-
183
+
157
184
Returns:
158
185
--------
159
186
A xr.Dataset of all catchment statistics or a NetCDF file for each catchment in the "masks" dictionary
@@ -167,19 +194,21 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
167
194
# check statistic
168
195
if isinstance (statistic , str ):
169
196
statistic = [statistic ]
170
- possible_stats = ['mean' , 'sum' , 'std' , 'var' , 'min' , 'max' , 'median' , 'count' ]
171
- assert all (stat in possible_stats for stat in statistic ), "All values in 'statistic' should be one of these: {0}" .format (', ' .join (possible_stats ))
197
+ possible_stats = ['mean' , 'sum' , 'std' ,
198
+ 'var' , 'min' , 'max' , 'median' , 'count' ]
199
+ assert all (stat in possible_stats for stat in statistic ), "All values in 'statistic' should be one of these: {0}" .format (
200
+ ', ' .join (possible_stats ))
172
201
stats_dict = {var : statistic for var in maps }
173
-
202
+
174
203
# output directory
175
204
if output is None :
176
205
results = []
177
206
else :
178
207
output = Path (output )
179
208
output .mkdir (parents = True , exist_ok = True )
180
-
209
+
181
210
# define coordinates and variables of the resulting Dataset
182
- dims = dict (maps .dims )
211
+ dims = dict (maps .sizes )
183
212
dimnames = [dim .lower () for dim in dims ]
184
213
if 'lat' in dimnames and 'lon' in dimnames :
185
214
x_dim , y_dim = 'lon' , 'lat'
@@ -188,34 +217,39 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
188
217
del dims [x_dim ]
189
218
del dims [y_dim ]
190
219
coords = {dim : maps [dim ] for dim in dims }
191
- variables = [f'{ var } _{ stat } ' for var , stats in stats_dict .items () for stat in stats ]
192
-
220
+ variables = [f'{ var } _{ stat } ' for var , stats in stats_dict .items ()
221
+ for stat in stats ]
222
+
193
223
# compute statistics for each catchemnt
194
224
# for ID in tqdm(masks.keys(), desc='processing catchments'):
195
- for ID in masks .keys ():
225
+ for ID in masks .keys ():
196
226
197
227
if output is not None :
198
228
fileout = output / f'{ ID :04} .nc'
199
229
if fileout .exists () and not overwrite :
200
- print (f'Output file { fileout } already exists. Moving forward to the next catchment' )
230
+ print (
231
+ f'Output file { fileout } already exists. Moving forward to the next catchment' )
201
232
continue
202
-
233
+
203
234
# create empty Dataset
204
235
coords .update ({'id' : [ID ]})
205
- maps_aoi = xr .Dataset ({var : xr .DataArray (coords = coords , dims = coords .keys ()) for var in variables })
206
-
236
+ maps_aoi = xr .Dataset (
237
+ {var : xr .DataArray (coords = coords , dims = coords .keys ()) for var in variables })
238
+
207
239
# apply mask to the dataset
208
240
aoi = masks [ID ]
209
- masked_maps = maps .sel ({x_dim : aoi [x_dim ], y_dim : aoi [y_dim ]}).where (aoi == 1 )
241
+ masked_maps = maps .sel (
242
+ {x_dim : aoi [x_dim ], y_dim : aoi [y_dim ]}).where (aoi == 1 )
210
243
masked_maps = masked_maps .compute ()
211
244
212
245
# apply weighting
213
246
if weight is not None :
214
- masked_weight = weight .sel ({x_dim : aoi [x_dim ], y_dim : aoi [y_dim ]}).where (aoi == 1 )
215
- weighted_maps = masked_maps .weighted (masked_weight .fillna (0 ))
247
+ masked_weight = weight .sel (
248
+ {x_dim : aoi [x_dim ], y_dim : aoi [y_dim ]}).where (aoi == 1 )
249
+ weighted_maps = masked_maps .weighted (masked_weight .fillna (0 ))
216
250
217
251
# compute statistics
218
- for var , stats in stats_dict .items ():
252
+ for var , stats in stats_dict .items ():
219
253
for stat in stats :
220
254
if (stat in ['mean' , 'sum' , 'std' , 'var' ]) and (weight is not None ):
221
255
x = getattr (weighted_maps , stat )(dim = [x_dim , y_dim ])[var ]
@@ -236,7 +270,8 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
236
270
if output is None :
237
271
results = xr .concat (results , dim = 'id' )
238
272
return results
239
-
273
+
274
+
240
275
def main (argv = sys .argv ):
241
276
prog = os .path .basename (argv [0 ])
242
277
parser = argparse .ArgumentParser (
@@ -247,26 +282,35 @@ def main(argv=sys.argv):
247
282
""" ,
248
283
prog = prog ,
249
284
)
250
- parser .add_argument ("-i" , "--input" , required = True , help = "Directory containing the input NetCDF files" )
251
- parser .add_argument ("-m" , "--mask" , required = True , help = "Directory containing the mask NetCDF files" )
252
- parser .add_argument ("-s" , "--statistic" , nargs = '+' , required = True , help = 'List of statistics to be computed. Possible values: mean, sum, std, var, min, max, median, count' )
253
- parser .add_argument ("-o" , "--output" , required = True , help = "Directory where the output NetCDF files will be saved" )
254
- parser .add_argument ("-a" , "--area" , required = False , default = None , help = "NetCDF file of pixel area used to weigh the statistics" )
255
- parser .add_argument ("-w" , "--overwrite" , action = "store_true" , default = False , help = "Overwrite existing output files" )
256
-
285
+ parser .add_argument ("-i" , "--input" , required = True ,
286
+ help = "Directory containing the input NetCDF files" )
287
+ parser .add_argument ("-m" , "--mask" , required = True ,
288
+ help = "Directory containing the mask NetCDF files" )
289
+ parser .add_argument ("-s" , "--statistic" , nargs = '+' , required = True ,
290
+ help = 'List of statistics to be computed. Possible values: mean, sum, std, var, min, max, median, count' )
291
+ parser .add_argument ("-o" , "--output" , required = True ,
292
+ help = "Directory where the output NetCDF files will be saved" )
293
+ parser .add_argument ("-a" , "--area" , required = False , default = None ,
294
+ help = "NetCDF file of pixel area used to weigh the statistics" )
295
+ parser .add_argument ("-w" , "--overwrite" , action = "store_true" ,
296
+ default = False , help = "Overwrite existing output files" )
297
+
257
298
args = parser .parse_args ()
258
299
259
300
try :
260
301
maps = read_inputmaps (args .input )
261
302
masks = read_masks (args .mask )
262
303
weight = read_pixarea (args .area ) if args .area is not None else None
263
- catchment_statistics (maps , masks , args .statistic , weight = weight , output = args .output , overwrite = args .overwrite )
304
+ catchment_statistics (maps , masks , args .statistic , weight = weight ,
305
+ output = args .output , overwrite = args .overwrite )
264
306
except Exception as e :
265
307
print (f'ERROR: { e } ' )
266
308
sys .exit (1 )
267
-
309
+
310
+
268
311
def main_script ():
269
312
sys .exit (main ())
270
313
314
+
271
315
if __name__ == "__main__" :
272
316
main_script ()
0 commit comments