Skip to content

Commit 35d1282

Browse files
committed
[major] release HAT
1 parent 9e3d8af commit 35d1282

File tree

214 files changed

+89495
-10
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

214 files changed

+89495
-10
lines changed

HAT_ACL.pdf

-2.29 MB
Binary file not shown.

LICENSE

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
MIT License
2+
------------ LICENSE For Hardware-Aware Transformer software ---------------
3+
Copyright (c) 2020, Hanrui Wang, Zhanghao Wu, Zhijian Liu, Han Cai,
4+
Ligeng Zhu, Chuang Gan and Song Han
5+
All rights reserved.
6+
7+
Redistribution and use in source and binary forms, with or without
8+
modification, are permitted provided that the following conditions are met:
9+
10+
* Redistributions of source code must retain the above copyright notice, this
11+
list of conditions and the following disclaimer.
12+
13+
* Redistributions in binary form must reproduce the above copyright notice,
14+
this list of conditions and the following disclaimer in the documentation
15+
and/or other materials provided with the distribution.
16+
17+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
21+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
29+
------------------------- LICENSE FOR Fairseq ------------------------------
30+
MIT License
31+
32+
Copyright (c) Facebook, Inc. and its affiliates.
33+
34+
Permission is hereby granted, free of charge, to any person obtaining a copy
35+
of this software and associated documentation files (the "Software"), to deal
36+
in the Software without restriction, including without limitation the rights
37+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
38+
copies of the Software, and to permit persons to whom the Software is
39+
furnished to do so, subject to the following conditions:
40+
41+
The above copyright notice and this permission notice shall be included in all
42+
copies or substantial portions of the Software.
43+
44+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
45+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
46+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
47+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
48+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
49+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
50+
SOFTWARE.

README.md

+208-10
Large diffs are not rendered by default.

average_checkpoints.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import collections
9+
import torch
10+
import os
11+
import re
12+
13+
14+
def average_checkpoints(inputs):
15+
"""Loads checkpoints from inputs and returns a model with averaged weights.
16+
17+
Args:
18+
inputs: An iterable of string paths of checkpoints to load from.
19+
20+
Returns:
21+
A dict of string keys mapping to various values. The 'model' key
22+
from the returned dict should correspond to an OrderedDict mapping
23+
string parameter names to torch Tensors.
24+
"""
25+
params_dict = collections.OrderedDict()
26+
params_keys = None
27+
new_state = None
28+
num_models = len(inputs)
29+
30+
for f in inputs:
31+
state = torch.load(
32+
f,
33+
map_location=(
34+
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
35+
),
36+
)
37+
# Copies over the settings from the first checkpoint
38+
if new_state is None:
39+
new_state = state
40+
41+
model_params = state['model']
42+
43+
model_params_keys = list(model_params.keys())
44+
if params_keys is None:
45+
params_keys = model_params_keys
46+
elif params_keys != model_params_keys:
47+
raise KeyError(
48+
'For checkpoint {}, expected list of params: {}, '
49+
'but found: {}'.format(f, params_keys, model_params_keys)
50+
)
51+
52+
for k in params_keys:
53+
p = model_params[k]
54+
if isinstance(p, torch.HalfTensor):
55+
p = p.float()
56+
if k not in params_dict:
57+
params_dict[k] = p.clone()
58+
# NOTE: clone() is needed in case of p is a shared parameter
59+
else:
60+
params_dict[k] += p
61+
62+
averaged_params = collections.OrderedDict()
63+
for k, v in params_dict.items():
64+
averaged_params[k] = v
65+
averaged_params[k].div_(num_models)
66+
new_state['model'] = averaged_params
67+
return new_state
68+
69+
70+
def last_n_checkpoints(paths, n, update_based, upper_bound=None):
71+
assert len(paths) == 1
72+
path = paths[0]
73+
if update_based:
74+
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
75+
else:
76+
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
77+
files = os.listdir(path)
78+
79+
entries = []
80+
for f in files:
81+
m = pt_regexp.fullmatch(f)
82+
if m is not None:
83+
sort_key = int(m.group(1))
84+
if upper_bound is None or sort_key <= upper_bound:
85+
entries.append((sort_key, m.group(0)))
86+
if len(entries) < n:
87+
raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
88+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
89+
90+
91+
def main():
92+
parser = argparse.ArgumentParser(
93+
description='Tool to average the params of input checkpoints to '
94+
'produce a new checkpoint',
95+
)
96+
# fmt: off
97+
parser.add_argument('--inputs', required=True, nargs='+',
98+
help='Input checkpoint file paths.')
99+
parser.add_argument('--output', required=True, metavar='FILE',
100+
help='Write the new checkpoint containing the averaged weights to this path.')
101+
num_group = parser.add_mutually_exclusive_group()
102+
num_group.add_argument('--num-epoch-checkpoints', type=int,
103+
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
104+
'and average last this many of them.')
105+
num_group.add_argument('--num-update-checkpoints', type=int,
106+
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
107+
'and average last this many of them.')
108+
parser.add_argument('--checkpoint-upper-bound', type=int,
109+
help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
110+
'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
111+
# fmt: on
112+
args = parser.parse_args()
113+
print(args)
114+
115+
num = None
116+
is_update_based = False
117+
if args.num_update_checkpoints is not None:
118+
num = args.num_update_checkpoints
119+
is_update_based = True
120+
elif args.num_epoch_checkpoints is not None:
121+
num = args.num_epoch_checkpoints
122+
123+
assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
124+
'--checkpoint-upper-bound requires --num-epoch-checkpoints'
125+
assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
126+
'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
127+
128+
if num is not None:
129+
args.inputs = last_n_checkpoints(
130+
args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
131+
)
132+
print('averaging checkpoints: ', args.inputs)
133+
134+
new_state = average_checkpoints(args.inputs)
135+
torch.save(new_state, args.output)
136+
print('Finished writing averaged checkpoint to {}.'.format(args.output))
137+
138+
139+
if __name__ == '__main__':
140+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
checkpoints_path=$1
2+
avg_checkpoints=${2:-10}
3+
4+
model=average_model_$avg_checkpoints.pt
5+
output_path=$checkpoints_path
6+
7+
python average_checkpoints.py \
8+
--inputs $output_path \
9+
--num-epoch-checkpoints $avg_checkpoints \
10+
--output $output_path/$model
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
evo-iter: 30
2+
population-size: 125
3+
parent-size: 25
4+
mutation-size: 50
5+
crossover-size: 50
6+
mutation-prob: 0.3
7+
8+
9+
# path to load latency predictor
10+
ckpt-path: ./latency_dataset/predictors/iwslt14deen_gpu_titanxp.pt
11+
# feature-norm should match with that when train the latency predictor
12+
feature-norm: [640, 6, 2048, 6, 640, 6, 2048, 6, 6, 2]
13+
# lat-norm should match with that when train the latency predictor
14+
lat-norm: 200
15+
# path to load supertransformer weights
16+
restore-file: ./downloaded_models/HAT_iwslt14deen_super_space1.pt
17+
18+
19+
# path to write subtransformer configs
20+
write-config-path: configs/iwslt14.de-en/subtransformer/[email protected]
21+
# latency constraint
22+
latency-constraint: 200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
3+
mkdir -p data/binary/iwslt14_de_en
4+
5+
wget -O data/binary/iwslt14_de_en/iwslt14_de_en.preprocessed.tgz 'https://www.dropbox.com/s/t5dqiamjdzahhfc/iwslt14_de_en.preproessed.tgz?dl=0'
6+
7+
cd data/binary/iwslt14_de_en
8+
9+
tar -xzvf iwslt14_de_en.preprocessed.tgz
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
lat-dataset-path: ./latency_dataset/iwslt14deen_gpu_titanxp.csv
2+
lat-dataset-size: 2000
3+
latgpu: True
4+
latiter: 20
5+
latsilent: True
6+
7+
8+
# below is the configs for the data point sampling space for the latency predictor
9+
10+
# model
11+
arch: transformersuper_iwslt_de_en
12+
max-tokens: 4096
13+
data: data/binary/iwslt14_de_en
14+
source-lang: de
15+
target-lang: en
16+
17+
# SuperTransformer configs
18+
encoder-embed-dim: 640
19+
decoder-embed-dim: 640
20+
21+
encoder-ffn-embed-dim: 3072
22+
decoder-ffn-embed-dim: 3072
23+
24+
encoder-layers: 6
25+
decoder-layers: 6
26+
27+
encoder-attention-heads: 8
28+
decoder-attention-heads: 8
29+
30+
31+
qkv-dim: 512
32+
33+
# SubTransformers search space
34+
encoder-embed-choice: [640, 512]
35+
decoder-embed-choice: [640, 512]
36+
37+
encoder-ffn-embed-dim-choice: [3072, 2048, 1024, 512]
38+
decoder-ffn-embed-dim-choice: [3072, 2048, 1024, 512]
39+
40+
encoder-layer-num-choice: [6]
41+
decoder-layer-num-choice: [6, 5, 4, 3, 2, 1]
42+
43+
encoder-self-attention-heads-choice: [8, 4, 2]
44+
decoder-self-attention-heads-choice: [8, 4, 2]
45+
decoder-ende-attention-heads-choice: [8, 4, 2]
46+
47+
# for arbitrary encoder decoder attention. -1 means attending to last one encoder layer
48+
# 1 means last two encoder layers, 2 means last three encoder layers
49+
decoder-arbitrary-ende-attn-choice: [-1, 1, 2]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
lat-dataset-path: ./latency_dataset/iwslt14deen_gpu_titanxp_all.csv
2+
feature-norm: [640, 6, 2048, 6, 640, 6, 2048, 6, 6, 2]
3+
lat-norm: 200
4+
feature-dim: 10
5+
hidden-dim: 400
6+
hidden-layer-num: 3
7+
ckpt-path: ./latency_dataset/predictors/iwslt14deen_gpu_titanxp.pt
8+
train-steps: 5000
9+
bsz: 128
10+
lr: 1e-5

0 commit comments

Comments
 (0)