Skip to content

Commit 0fbf949

Browse files
author
kyubyong park
committed
removed query masking
1 parent 3a6723b commit 0fbf949

File tree

1 file changed

+47
-52
lines changed

1 file changed

+47
-52
lines changed

modules.py

+47-52
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,15 @@ def get_token_embeddings(vocab_size, num_units, zero_pad=True):
5353
embeddings[1:, :]), 0)
5454
return embeddings
5555

56-
def scaled_dot_product_attention(Q, K, V,
56+
def scaled_dot_product_attention(Q, K, V, key_masks,
5757
causality=False, dropout_rate=0.,
5858
training=True,
5959
scope="scaled_dot_product_attention"):
6060
'''See 3.2.1.
6161
Q: Packed queries. 3d tensor. [N, T_q, d_k].
6262
K: Packed keys. 3d tensor. [N, T_k, d_k].
6363
V: Packed values. 3d tensor. [N, T_k, d_v].
64+
key_masks: A 2d tensor with shape of [N, key_seqlen]
6465
causality: If True, applies masking for future blinding
6566
dropout_rate: A floating point number of [0, 1].
6667
training: boolean for controlling droput
@@ -76,7 +77,7 @@ def scaled_dot_product_attention(Q, K, V,
7677
outputs /= d_k ** 0.5
7778

7879
# key masking
79-
outputs = mask(outputs, Q, K, type="key")
80+
outputs = mask(outputs, key_masks=key_masks, type="key")
8081

8182
# causality or future blinding masking
8283
if causality:
@@ -87,8 +88,8 @@ def scaled_dot_product_attention(Q, K, V,
8788
attention = tf.transpose(outputs, [0, 2, 1])
8889
tf.summary.image("attention", tf.expand_dims(attention[:1], -1))
8990

90-
# query masking
91-
outputs = mask(outputs, Q, K, type="query")
91+
# # query masking
92+
# outputs = mask(outputs, Q, K, type="query")
9293

9394
# dropout
9495
outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=training)
@@ -98,65 +99,58 @@ def scaled_dot_product_attention(Q, K, V,
9899

99100
return outputs
100101

101-
def mask(inputs, queries=None, keys=None, type=None):
102+
103+
def mask(inputs, key_masks=None, type=None):
102104
"""Masks paddings on keys or queries to inputs
103-
inputs: 3d tensor. (N, T_q, T_k)
104-
queries: 3d tensor. (N, T_q, d)
105-
keys: 3d tensor. (N, T_k, d)
105+
inputs: 3d tensor. (h*N, T_q, T_k)
106+
key_masks: 3d tensor. (N, 1, T_k)
107+
type: string. "key" | "future"
106108
107109
e.g.,
108-
>> queries = tf.constant([[[1.],
109-
[2.],
110-
[0.]]], tf.float32) # (1, 3, 1)
111-
>> keys = tf.constant([[[4.],
112-
[0.]]], tf.float32) # (1, 2, 1)
113-
>> inputs = tf.constant([[[4., 0.],
114-
[8., 0.],
115-
[0., 0.]]], tf.float32)
116-
>> mask(inputs, queries, keys, "key")
117-
array([[[ 4.0000000e+00, -4.2949673e+09],
118-
[ 8.0000000e+00, -4.2949673e+09],
119-
[ 0.0000000e+00, -4.2949673e+09]]], dtype=float32)
120-
>> inputs = tf.constant([[[1., 0.],
121-
[1., 0.],
122-
[1., 0.]]], tf.float32)
123-
>> mask(inputs, queries, keys, "query")
124-
array([[[1., 0.],
125-
[1., 0.],
126-
[0., 0.]]], dtype=float32)
110+
>> inputs = tf.zeros([2, 2, 3], dtype=tf.float32)
111+
>> key_masks = tf.constant([[0., 0., 1.],
112+
[0., 1., 1.]])
113+
>> mask(inputs, key_masks=key_masks, type="key")
114+
array([[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09],
115+
[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]],
116+
117+
[[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
118+
[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]],
119+
120+
[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09],
121+
[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]],
122+
123+
[[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
124+
[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32)
127125
"""
128126
padding_num = -2 ** 32 + 1
129127
if type in ("k", "key", "keys"):
130-
# Generate masks
131-
masks = tf.sign(tf.reduce_sum(tf.abs(keys), axis=-1)) # (N, T_k)
132-
masks = tf.expand_dims(masks, 1) # (N, 1, T_k)
133-
masks = tf.tile(masks, [1, tf.shape(queries)[1], 1]) # (N, T_q, T_k)
134-
135-
# Apply masks to inputs
136-
paddings = tf.ones_like(inputs) * padding_num
137-
outputs = tf.where(tf.equal(masks, 0), paddings, inputs) # (N, T_q, T_k)
138-
elif type in ("q", "query", "queries"):
139-
# Generate masks
140-
masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q)
141-
masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
142-
masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k)
143-
144-
# Apply masks to inputs
145-
outputs = inputs*masks
128+
key_masks = tf.to_float(key_masks)
129+
key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
130+
key_masks = tf.expand_dims(key_masks, 1) # (h*N, 1, seqlen)
131+
outputs = inputs + key_masks * padding_num
132+
# elif type in ("q", "query", "queries"):
133+
# # Generate masks
134+
# masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q)
135+
# masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
136+
# masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k)
137+
#
138+
# # Apply masks to inputs
139+
# outputs = inputs*masks
146140
elif type in ("f", "future", "right"):
147141
diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k)
148142
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
149-
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)
143+
future_masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)
150144

151-
paddings = tf.ones_like(masks) * padding_num
152-
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
145+
paddings = tf.ones_like(future_masks) * padding_num
146+
outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)
153147
else:
154148
print("Check if you entered type correctly!")
155149

156-
157150
return outputs
158151

159-
def multihead_attention(queries, keys, values,
152+
153+
def multihead_attention(queries, keys, values, key_masks,
160154
num_heads=8,
161155
dropout_rate=0,
162156
training=True,
@@ -166,6 +160,7 @@ def multihead_attention(queries, keys, values,
166160
queries: A 3d tensor with shape of [N, T_q, d_model].
167161
keys: A 3d tensor with shape of [N, T_k, d_model].
168162
values: A 3d tensor with shape of [N, T_k, d_model].
163+
key_masks: A 2d tensor with shape of [N, key_seqlen]
169164
num_heads: An int. Number of heads.
170165
dropout_rate: A floating point number.
171166
training: Boolean. Controller of mechanism for dropout.
@@ -178,17 +173,17 @@ def multihead_attention(queries, keys, values,
178173
d_model = queries.get_shape().as_list()[-1]
179174
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
180175
# Linear projections
181-
Q = tf.layers.dense(queries, d_model, use_bias=False) # (N, T_q, d_model)
182-
K = tf.layers.dense(keys, d_model, use_bias=False) # (N, T_k, d_model)
183-
V = tf.layers.dense(values, d_model, use_bias=False) # (N, T_k, d_model)
176+
Q = tf.layers.dense(queries, d_model, use_bias=True) # (N, T_q, d_model)
177+
K = tf.layers.dense(keys, d_model, use_bias=True) # (N, T_k, d_model)
178+
V = tf.layers.dense(values, d_model, use_bias=True) # (N, T_k, d_model)
184179

185180
# Split and concat
186181
Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h)
187182
K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)
188183
V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)
189184

190185
# Attention
191-
outputs = scaled_dot_product_attention(Q_, K_, V_, causality, dropout_rate, training)
186+
outputs = scaled_dot_product_attention(Q_, K_, V_, key_masks, causality, dropout_rate, training)
192187

193188
# Restore shape
194189
outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, d_model)

0 commit comments

Comments
 (0)