1
+ '''
2
+ code by TaeHwan Jung(@graykode)
3
+ Original Paper and repository here : https://github.com/openai/gpt-2
4
+ GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
5
+ '''
6
+ import copy
7
+ import torch
8
+ import math
9
+ import torch .nn as nn
10
+ from torch .nn .parameter import Parameter
11
+
12
+ def gelu (x ):
13
+ return 0.5 * x * (1 + torch .tanh (math .sqrt (2 / math .pi ) * (x + 0.044715 * torch .pow (x , 3 ))))
14
+
15
+ class LayerNorm (nn .Module ):
16
+ def __init__ (self , hidden_size , eps = 1e-12 ):
17
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
18
+ """
19
+ super (LayerNorm , self ).__init__ ()
20
+ self .weight = nn .Parameter (torch .ones (hidden_size ))
21
+ self .bias = nn .Parameter (torch .zeros (hidden_size ))
22
+ self .variance_epsilon = eps
23
+
24
+ def forward (self , x ):
25
+ u = x .mean (- 1 , keepdim = True )
26
+ s = (x - u ).pow (2 ).mean (- 1 , keepdim = True )
27
+ x = (x - u ) / torch .sqrt (s + self .variance_epsilon )
28
+ return self .weight * x + self .bias
29
+
30
+ class Conv1D (nn .Module ):
31
+ def __init__ (self , nf , nx ):
32
+ super (Conv1D , self ).__init__ ()
33
+ self .nf = nf
34
+ w = torch .empty (nx , nf )
35
+ nn .init .normal_ (w , std = 0.02 )
36
+ self .weight = Parameter (w )
37
+ self .bias = Parameter (torch .zeros (nf ))
38
+
39
+ def forward (self , x ):
40
+ size_out = x .size ()[:- 1 ] + (self .nf ,)
41
+ x = torch .addmm (self .bias , x .view (- 1 , x .size (- 1 )), self .weight )
42
+ x = x .view (* size_out )
43
+ return x
44
+
45
+ class Attention (nn .Module ):
46
+ def __init__ (self , nx , n_ctx , config , scale = False ):
47
+ super (Attention , self ).__init__ ()
48
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
49
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
50
+ assert n_state % config .n_head == 0
51
+ self .register_buffer ("bias" , torch .tril (torch .ones (n_ctx , n_ctx )).view (1 , 1 , n_ctx , n_ctx ))
52
+ self .n_head = config .n_head
53
+ self .split_size = n_state
54
+ self .scale = scale
55
+ self .c_attn = Conv1D (n_state * 3 , nx )
56
+ self .c_proj = Conv1D (n_state , nx )
57
+
58
+ def _attn (self , q , k , v ):
59
+ w = torch .matmul (q , k )
60
+ if self .scale :
61
+ w = w / math .sqrt (v .size (- 1 ))
62
+ nd , ns = w .size (- 2 ), w .size (- 1 )
63
+ b = self .bias [:, :, ns - nd :ns , :ns ]
64
+ w = w * b - 1e10 * (1 - b )
65
+ w = nn .Softmax (dim = - 1 )(w )
66
+ return torch .matmul (w , v )
67
+
68
+ def merge_heads (self , x ):
69
+ x = x .permute (0 , 2 , 1 , 3 ).contiguous ()
70
+ new_x_shape = x .size ()[:- 2 ] + (x .size (- 2 ) * x .size (- 1 ),)
71
+ return x .view (* new_x_shape ) # in Tensorflow implem: fct merge_states
72
+
73
+ def split_heads (self , x , k = False ):
74
+ new_x_shape = x .size ()[:- 1 ] + (self .n_head , x .size (- 1 ) // self .n_head )
75
+ x = x .view (* new_x_shape ) # in Tensorflow implem: fct split_states
76
+ if k :
77
+ return x .permute (0 , 2 , 3 , 1 ) # (batch, head, head_features, seq_length)
78
+ else :
79
+ return x .permute (0 , 2 , 1 , 3 ) # (batch, head, seq_length, head_features)
80
+
81
+ def forward (self , x , layer_past = None ):
82
+ x = self .c_attn (x )
83
+ query , key , value = x .split (self .split_size , dim = 2 )
84
+ query = self .split_heads (query )
85
+ key = self .split_heads (key , k = True )
86
+ value = self .split_heads (value )
87
+ if layer_past is not None :
88
+ past_key , past_value = layer_past [0 ].transpose (- 2 , - 1 ), layer_past [1 ] # transpose back cf below
89
+ key = torch .cat ((past_key , key ), dim = - 1 )
90
+ value = torch .cat ((past_value , value ), dim = - 2 )
91
+ present = torch .stack ((key .transpose (- 2 , - 1 ), value )) # transpose to have same shapes for stacking
92
+ a = self ._attn (query , key , value )
93
+ a = self .merge_heads (a )
94
+ a = self .c_proj (a )
95
+ return a , present
96
+
97
+ class MLP (nn .Module ):
98
+ def __init__ (self , n_state , config ): # in MLP: n_state=3072 (4 * n_embd)
99
+ super (MLP , self ).__init__ ()
100
+ nx = config .n_embd
101
+ self .c_fc = Conv1D (n_state , nx )
102
+ self .c_proj = Conv1D (nx , n_state )
103
+ self .act = gelu
104
+
105
+ def forward (self , x ):
106
+ h = self .act (self .c_fc (x ))
107
+ h2 = self .c_proj (h )
108
+ return h2
109
+
110
+ class Block (nn .Module ):
111
+ def __init__ (self , n_ctx , config , scale = False ):
112
+ super (Block , self ).__init__ ()
113
+ nx = config .n_embd
114
+ self .ln_1 = LayerNorm (nx , eps = config .layer_norm_epsilon )
115
+ self .attn = Attention (nx , n_ctx , config , scale )
116
+ self .ln_2 = LayerNorm (nx , eps = config .layer_norm_epsilon )
117
+ self .mlp = MLP (4 * nx , config )
118
+
119
+ def forward (self , x , layer_past = None ):
120
+ a , present = self .attn (self .ln_1 (x ), layer_past = layer_past )
121
+ x = x + a
122
+ m = self .mlp (self .ln_2 (x ))
123
+ x = x + m
124
+ return x , present
125
+
126
+ class GPT2Model (nn .Module ):
127
+ def __init__ (self , config ):
128
+ super (GPT2Model , self ).__init__ ()
129
+ self .n_layer = config .n_layer
130
+ self .n_embd = config .n_embd
131
+ self .n_vocab = config .vocab_size
132
+
133
+ self .wte = nn .Embedding (config .vocab_size , config .n_embd )
134
+ self .wpe = nn .Embedding (config .n_positions , config .n_embd )
135
+ block = Block (config .n_ctx , config , scale = True )
136
+ self .h = nn .ModuleList ([copy .deepcopy (block ) for _ in range (config .n_layer )])
137
+ self .ln_f = LayerNorm (config .n_embd , eps = config .layer_norm_epsilon )
138
+
139
+ def set_embeddings_weights (self , model_embeddings_weights ):
140
+ embed_shape = model_embeddings_weights .shape
141
+ self .decoder = nn .Linear (embed_shape [1 ], embed_shape [0 ], bias = False )
142
+ self .decoder .weight = model_embeddings_weights # Tied weights
143
+
144
+ def forward (self , input_ids , position_ids = None , token_type_ids = None , past = None ):
145
+ if past is None :
146
+ past_length = 0
147
+ past = [None ] * len (self .h )
148
+ else :
149
+ past_length = past [0 ][0 ].size (- 2 )
150
+ if position_ids is None :
151
+ position_ids = torch .arange (past_length , input_ids .size (- 1 ) + past_length , dtype = torch .long ,
152
+ device = input_ids .device )
153
+ position_ids = position_ids .unsqueeze (0 ).expand_as (input_ids )
154
+
155
+ input_shape = input_ids .size ()
156
+ input_ids = input_ids .view (- 1 , input_ids .size (- 1 ))
157
+ position_ids = position_ids .view (- 1 , position_ids .size (- 1 ))
158
+
159
+ inputs_embeds = self .wte (input_ids )
160
+ position_embeds = self .wpe (position_ids )
161
+ if token_type_ids is not None :
162
+ token_type_ids = token_type_ids .view (- 1 , token_type_ids .size (- 1 ))
163
+ token_type_embeds = self .wte (token_type_ids )
164
+ else :
165
+ token_type_embeds = 0
166
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
167
+ presents = []
168
+ for block , layer_past in zip (self .h , past ):
169
+ hidden_states , present = block (hidden_states , layer_past )
170
+ presents .append (present )
171
+ hidden_states = self .ln_f (hidden_states )
172
+ output_shape = input_shape + (hidden_states .size (- 1 ),)
173
+ return hidden_states .view (* output_shape ), presents
174
+
175
+ class GPT2LMHead (nn .Module ):
176
+ def __init__ (self , model_embeddings_weights , config ):
177
+ super (GPT2LMHead , self ).__init__ ()
178
+ self .n_embd = config .n_embd
179
+ self .set_embeddings_weights (model_embeddings_weights )
180
+
181
+ def set_embeddings_weights (self , model_embeddings_weights ):
182
+ embed_shape = model_embeddings_weights .shape
183
+ self .decoder = nn .Linear (embed_shape [1 ], embed_shape [0 ], bias = False )
184
+ self .decoder .weight = model_embeddings_weights # Tied weights
185
+
186
+ def forward (self , hidden_state ):
187
+ # Truncated Language modeling logits (we remove the last token)
188
+ # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
189
+ lm_logits = self .decoder (hidden_state )
190
+ return lm_logits
191
+
192
+ class GPT2LMHeadModel (nn .Module ):
193
+ def __init__ (self , config ):
194
+ super (GPT2LMHeadModel , self ).__init__ ()
195
+ self .transformer = GPT2Model (config )
196
+ self .lm_head = GPT2LMHead (self .transformer .wte .weight , config )
197
+
198
+ def set_tied (self ):
199
+ """ Make sure we are sharing the embeddings
200
+ """
201
+ self .lm_head .set_embeddings_weights (self .transformer .wte .weight )
202
+
203
+ def forward (self , input_ids , position_ids = None , token_type_ids = None , lm_labels = None , past = None ):
204
+ hidden_states , presents = self .transformer (input_ids , position_ids , token_type_ids , past )
205
+ lm_logits = self .lm_head (hidden_states )
206
+ if lm_labels is not None :
207
+ loss_fct = nn .CrossEntropyLoss (ignore_index = - 1 )
208
+ loss = loss_fct (lm_logits .view (- 1 , lm_logits .size (- 1 )), lm_labels .view (- 1 ))
209
+ return loss
210
+ return lm_logits , presents
0 commit comments