@@ -53,14 +53,15 @@ def get_token_embeddings(vocab_size, num_units, zero_pad=True):
53
53
embeddings [1 :, :]), 0 )
54
54
return embeddings
55
55
56
- def scaled_dot_product_attention (Q , K , V ,
56
+ def scaled_dot_product_attention (Q , K , V , key_masks ,
57
57
causality = False , dropout_rate = 0. ,
58
58
training = True ,
59
59
scope = "scaled_dot_product_attention" ):
60
60
'''See 3.2.1.
61
61
Q: Packed queries. 3d tensor. [N, T_q, d_k].
62
62
K: Packed keys. 3d tensor. [N, T_k, d_k].
63
63
V: Packed values. 3d tensor. [N, T_k, d_v].
64
+ key_masks: A 2d tensor with shape of [N, key_seqlen]
64
65
causality: If True, applies masking for future blinding
65
66
dropout_rate: A floating point number of [0, 1].
66
67
training: boolean for controlling droput
@@ -76,7 +77,7 @@ def scaled_dot_product_attention(Q, K, V,
76
77
outputs /= d_k ** 0.5
77
78
78
79
# key masking
79
- outputs = mask (outputs , Q , K , type = "key" )
80
+ outputs = mask (outputs , key_masks = key_masks , type = "key" )
80
81
81
82
# causality or future blinding masking
82
83
if causality :
@@ -87,8 +88,8 @@ def scaled_dot_product_attention(Q, K, V,
87
88
attention = tf .transpose (outputs , [0 , 2 , 1 ])
88
89
tf .summary .image ("attention" , tf .expand_dims (attention [:1 ], - 1 ))
89
90
90
- # query masking
91
- outputs = mask (outputs , Q , K , type = "query" )
91
+ # # query masking
92
+ # outputs = mask(outputs, Q, K, type="query")
92
93
93
94
# dropout
94
95
outputs = tf .layers .dropout (outputs , rate = dropout_rate , training = training )
@@ -98,65 +99,58 @@ def scaled_dot_product_attention(Q, K, V,
98
99
99
100
return outputs
100
101
101
- def mask (inputs , queries = None , keys = None , type = None ):
102
+
103
+ def mask (inputs , key_masks = None , type = None ):
102
104
"""Masks paddings on keys or queries to inputs
103
- inputs: 3d tensor. (N, T_q, T_k)
104
- queries : 3d tensor. (N, T_q, d )
105
- keys: 3d tensor. (N, T_k, d)
105
+ inputs: 3d tensor. (h* N, T_q, T_k)
106
+ key_masks : 3d tensor. (N, 1, T_k )
107
+ type: string. "key" | "future"
106
108
107
109
e.g.,
108
- >> queries = tf.constant([[[1.],
109
- [2.],
110
- [0.]]], tf.float32) # (1, 3, 1)
111
- >> keys = tf.constant([[[4.],
112
- [0.]]], tf.float32) # (1, 2, 1)
113
- >> inputs = tf.constant([[[4., 0.],
114
- [8., 0.],
115
- [0., 0.]]], tf.float32)
116
- >> mask(inputs, queries, keys, "key")
117
- array([[[ 4.0000000e+00, -4.2949673e+09],
118
- [ 8.0000000e+00, -4.2949673e+09],
119
- [ 0.0000000e+00, -4.2949673e+09]]], dtype=float32)
120
- >> inputs = tf.constant([[[1., 0.],
121
- [1., 0.],
122
- [1., 0.]]], tf.float32)
123
- >> mask(inputs, queries, keys, "query")
124
- array([[[1., 0.],
125
- [1., 0.],
126
- [0., 0.]]], dtype=float32)
110
+ >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32)
111
+ >> key_masks = tf.constant([[0., 0., 1.],
112
+ [0., 1., 1.]])
113
+ >> mask(inputs, key_masks=key_masks, type="key")
114
+ array([[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09],
115
+ [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]],
116
+
117
+ [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
118
+ [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]],
119
+
120
+ [[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09],
121
+ [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]],
122
+
123
+ [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
124
+ [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32)
127
125
"""
128
126
padding_num = - 2 ** 32 + 1
129
127
if type in ("k" , "key" , "keys" ):
130
- # Generate masks
131
- masks = tf .sign (tf .reduce_sum (tf .abs (keys ), axis = - 1 )) # (N, T_k)
132
- masks = tf .expand_dims (masks , 1 ) # (N, 1, T_k)
133
- masks = tf .tile (masks , [1 , tf .shape (queries )[1 ], 1 ]) # (N, T_q, T_k)
134
-
135
- # Apply masks to inputs
136
- paddings = tf .ones_like (inputs ) * padding_num
137
- outputs = tf .where (tf .equal (masks , 0 ), paddings , inputs ) # (N, T_q, T_k)
138
- elif type in ("q" , "query" , "queries" ):
139
- # Generate masks
140
- masks = tf .sign (tf .reduce_sum (tf .abs (queries ), axis = - 1 )) # (N, T_q)
141
- masks = tf .expand_dims (masks , - 1 ) # (N, T_q, 1)
142
- masks = tf .tile (masks , [1 , 1 , tf .shape (keys )[1 ]]) # (N, T_q, T_k)
143
-
144
- # Apply masks to inputs
145
- outputs = inputs * masks
128
+ key_masks = tf .to_float (key_masks )
129
+ key_masks = tf .tile (key_masks , [tf .shape (inputs )[0 ] // tf .shape (key_masks )[0 ], 1 ]) # (h*N, seqlen)
130
+ key_masks = tf .expand_dims (key_masks , 1 ) # (h*N, 1, seqlen)
131
+ outputs = inputs + key_masks * padding_num
132
+ # elif type in ("q", "query", "queries"):
133
+ # # Generate masks
134
+ # masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q)
135
+ # masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
136
+ # masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k)
137
+ #
138
+ # # Apply masks to inputs
139
+ # outputs = inputs*masks
146
140
elif type in ("f" , "future" , "right" ):
147
141
diag_vals = tf .ones_like (inputs [0 , :, :]) # (T_q, T_k)
148
142
tril = tf .linalg .LinearOperatorLowerTriangular (diag_vals ).to_dense () # (T_q, T_k)
149
- masks = tf .tile (tf .expand_dims (tril , 0 ), [tf .shape (inputs )[0 ], 1 , 1 ]) # (N, T_q, T_k)
143
+ future_masks = tf .tile (tf .expand_dims (tril , 0 ), [tf .shape (inputs )[0 ], 1 , 1 ]) # (N, T_q, T_k)
150
144
151
- paddings = tf .ones_like (masks ) * padding_num
152
- outputs = tf .where (tf .equal (masks , 0 ), paddings , inputs )
145
+ paddings = tf .ones_like (future_masks ) * padding_num
146
+ outputs = tf .where (tf .equal (future_masks , 0 ), paddings , inputs )
153
147
else :
154
148
print ("Check if you entered type correctly!" )
155
149
156
-
157
150
return outputs
158
151
159
- def multihead_attention (queries , keys , values ,
152
+
153
+ def multihead_attention (queries , keys , values , key_masks ,
160
154
num_heads = 8 ,
161
155
dropout_rate = 0 ,
162
156
training = True ,
@@ -166,6 +160,7 @@ def multihead_attention(queries, keys, values,
166
160
queries: A 3d tensor with shape of [N, T_q, d_model].
167
161
keys: A 3d tensor with shape of [N, T_k, d_model].
168
162
values: A 3d tensor with shape of [N, T_k, d_model].
163
+ key_masks: A 2d tensor with shape of [N, key_seqlen]
169
164
num_heads: An int. Number of heads.
170
165
dropout_rate: A floating point number.
171
166
training: Boolean. Controller of mechanism for dropout.
@@ -178,17 +173,17 @@ def multihead_attention(queries, keys, values,
178
173
d_model = queries .get_shape ().as_list ()[- 1 ]
179
174
with tf .variable_scope (scope , reuse = tf .AUTO_REUSE ):
180
175
# Linear projections
181
- Q = tf .layers .dense (queries , d_model , use_bias = False ) # (N, T_q, d_model)
182
- K = tf .layers .dense (keys , d_model , use_bias = False ) # (N, T_k, d_model)
183
- V = tf .layers .dense (values , d_model , use_bias = False ) # (N, T_k, d_model)
176
+ Q = tf .layers .dense (queries , d_model , use_bias = True ) # (N, T_q, d_model)
177
+ K = tf .layers .dense (keys , d_model , use_bias = True ) # (N, T_k, d_model)
178
+ V = tf .layers .dense (values , d_model , use_bias = True ) # (N, T_k, d_model)
184
179
185
180
# Split and concat
186
181
Q_ = tf .concat (tf .split (Q , num_heads , axis = 2 ), axis = 0 ) # (h*N, T_q, d_model/h)
187
182
K_ = tf .concat (tf .split (K , num_heads , axis = 2 ), axis = 0 ) # (h*N, T_k, d_model/h)
188
183
V_ = tf .concat (tf .split (V , num_heads , axis = 2 ), axis = 0 ) # (h*N, T_k, d_model/h)
189
184
190
185
# Attention
191
- outputs = scaled_dot_product_attention (Q_ , K_ , V_ , causality , dropout_rate , training )
186
+ outputs = scaled_dot_product_attention (Q_ , K_ , V_ , key_masks , causality , dropout_rate , training )
192
187
193
188
# Restore shape
194
189
outputs = tf .concat (tf .split (outputs , num_heads , axis = 0 ), axis = 2 ) # (N, T_q, d_model)
0 commit comments