Skip to content

Commit 7bc5280

Browse files
bugfix
1 parent 29c4124 commit 7bc5280

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

chkpnt_2_h5_2_tf_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from model import model
1313

1414
# specify directory as data io info
15-
BASEDIR = Path('/Users/biplovbhandari/Works/SIG/hydrafloods/output')
16-
MODEL_DIR = BASEDIR / 'trial_469ae9d7b6c82488deb9be9c0a0a25e7'
17-
CHECKPOINT_DIR = MODEL_DIR / '469ae9d7b6c82488deb9be9c0a0a25e7' / 'checkpoint'
15+
BASEDIR = Path('/mnt/hydrafloods/output/jrc_adjusted_LR_2020_07_13_V1/model/sentinel1-surface-water')
16+
MODEL_DIR = BASEDIR / 'trial_6ba0bc0ef8458bf43280b5814775bd2b'
17+
CHECKPOINT_DIR = MODEL_DIR / 'checkpoints' / 'epoch_0' / 'checkpoint'
1818
H5_MODEL = MODEL_DIR / 'tf-model-h5'
1919
TF_MODEL_DIR = MODEL_DIR / 'tf-model'
2020

evaluate_models.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
import os
1010
import tensorflow as tf
1111

12+
from tensorflow import keras
1213
from pathlib import Path
1314
from model import dataio, model
1415

1516

1617
# specify directory as data io info
17-
BASEDIR = Path('/Users/biplovbhandari/Works/SIG/hydrafloods/output')
18-
MODEL_DIR = BASEDIR / 'trial_469ae9d7b6c82488deb9be9c0a0a25e7'
19-
H5_MODEL = MODEL_DIR / 'tf-model-h5'
18+
DATADIR = Path('/home/ubuntu/hydrafloods')
19+
BASEDIR = Path('/mnt/hydrafloods/output/jrc_adjusted_LR_2020_07_13_V1/model/sentinel1-surface-water')
20+
MODEL_DIR = BASEDIR / 'trial_6ba0bc0ef8458bf43280b5814775bd2b'
21+
CHECKPOINT = MODEL_DIR / 'checkpoints' / 'epoch_0'/ 'checkpoint'
2022

2123
# specify some data structure
2224
FEATURES = ast.literal_eval(os.getenv('FEATURES'))
@@ -25,14 +27,28 @@
2527
in_shape = (None, None) + (len(FEATURES),)
2628
out_classes = int(os.getenv('OUT_CLASSES_NUM'))
2729

28-
VALIDATION_DIR = BASEDIR / 'validation_patches'
30+
VALIDATION_DIR = DATADIR / 'validation_patches_jrc'
2931
validation_files = glob.glob(str(VALIDATION_DIR) + '/*')
32+
# eval is batched by 1
3033
validation = dataio.get_dataset(validation_files, FEATURES, LABELS, PATCH_SHAPE, 1)
3134

3235
this_model = model.get_model(in_shape, out_classes)
3336

3437
# open and save model
35-
this_model.load_weights(f'{str(H5_MODEL)}')
38+
this_model.load_weights(f'{str(CHECKPOINT)}')
39+
40+
# compile the model
41+
this_model.compile(
42+
optimizer='adam',
43+
loss=model.bce_loss,
44+
metrics=[
45+
keras.metrics.categorical_accuracy,
46+
keras.metrics.Precision(),
47+
keras.metrics.Recall(),
48+
model.dice_coef,
49+
model.f1_m
50+
]
51+
)
3652

3753
# check how the model trained
3854
score = this_model.evaluate(validation)

0 commit comments

Comments
 (0)