2
2
3
3
import numpy as np
4
4
import tensorflow as tf
5
+
5
6
from alibi_detect .utils .tensorflow .prediction import (
6
- predict_batch , predict_batch_transformer )
7
- from tensorflow .keras .layers import Dense , Flatten , Input , InputLayer
7
+ predict_batch , predict_batch_transformer , get_call_arg_mapping
8
+ )
9
+ from tensorflow .keras .layers import Dense , Flatten , Input , Lambda
8
10
from tensorflow .keras .models import Model
9
11
10
12
@@ -34,7 +36,11 @@ def __init__(
34
36
'tf.keras.Sequential or tf.keras.Model `mlp`' )
35
37
36
38
def call (self , x : Union [np .ndarray , tf .Tensor , Dict [str , tf .Tensor ]]) -> tf .Tensor :
37
- x = self .input_layer (x )
39
+ if not isinstance (x , (np .ndarray , tf .Tensor )):
40
+ x = get_call_arg_mapping (self .input_layer , x )
41
+ x = self .input_layer (** x )
42
+ else :
43
+ x = self .input_layer (x )
38
44
return self .mlp (x )
39
45
40
46
@@ -52,7 +58,7 @@ def __init__(
52
58
if is_enc :
53
59
self .encoder = encoder_net
54
60
elif not is_enc and is_enc_dim : # set default encoder
55
- input_layer = InputLayer ( input_shape = shape ) if input_layer is None else input_layer
61
+ input_layer = Lambda ( lambda x : x ) if input_layer is None else input_layer
56
62
input_dim = np .prod (shape )
57
63
step_dim = int ((input_dim - enc_dim ) / 3 )
58
64
self .encoder = _Encoder (input_layer , enc_dim = enc_dim , step_dim = step_dim )
@@ -61,7 +67,11 @@ def __init__(
61
67
' or tf.keras.Model `encoder_net`.' )
62
68
63
69
def call (self , x : Union [np .ndarray , tf .Tensor , Dict [str , tf .Tensor ]]) -> tf .Tensor :
64
- return self .encoder (x )
70
+ if not isinstance (x , (np .ndarray , tf .Tensor )):
71
+ x = get_call_arg_mapping (self .encoder , x )
72
+ return self .encoder (** x )
73
+ else :
74
+ return self .encoder (x )
65
75
66
76
67
77
class HiddenOutput (tf .keras .Model ):
@@ -73,7 +83,7 @@ def __init__(
73
83
flatten : bool = False
74
84
) -> None :
75
85
super ().__init__ ()
76
- if input_shape and not model .inputs :
86
+ if input_shape and not ( hasattr ( model , 'inputs' ) and model .inputs ) :
77
87
inputs = Input (shape = input_shape )
78
88
model .call (inputs )
79
89
else :
0 commit comments