Skip to content

Commit b6c4dfa

Browse files
committed
Multiplatform training
Prep for training on aws boxes by ensuring num_cores is autodetected and that 
treelite can compile using correct compiler.
1 parent 5007d15 commit b6c4dfa

File tree

6 files changed

+36
-10
lines changed

6 files changed

+36
-10
lines changed

assess.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,5 @@ def run_faceoff(
102102
run_generation_ladder(
103103
environment_name=sys.argv[1],
104104
species_list=species_list,
105-
num_workers=settings.NUM_CORES,
105+
num_workers=settings.ASSESSMENT_THREADS,
106106
)

gbdt_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def train_from_training_data(
234234
'num_leaves': num_leaves,
235235
'max_bin': 128,
236236
'min_data_in_leaf': min_data_in_leaf,
237-
'num_threads': 16, # 0 is as many as CPUs for server
237+
'num_threads': settings.GBDT_TRAINING_THREADS, # 0 is as many as CPUs for server
238238
'verbose': 1,
239239
# 'max_depth': 3,
240240
# 'min_gain_to_split': 0.01,

paths.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from uuid import uuid4
33
import pathlib
44

5-
from settings import ROOT_DATA_DIRECTORY, TMP_DIRECTORY
5+
from settings import ROOT_DATA_DIRECTORY, TMP_DIRECTORY, TOOL_CHAIN
66

77

88
def full_path_mkdir_p(full_path):
@@ -62,7 +62,12 @@ def build_model_paths(environment, species, generation):
6262
# XXX VOMIT! Abstract this out.
6363
suffix = "model"
6464
if "gbdt" in species:
65-
suffix = "dylib"
65+
if TOOL_CHAIN == "clang":
66+
suffix = "dylib"
67+
elif TOOL_CHAIN == "gcc":
68+
suffix = "so"
69+
else:
70+
raise KeyError(f"Unsure about suffix for tool chain: {TOOL_CHAIN}")
6671
base = build_model_directory(environment, species, generation)
6772
value_basename = f"value_model_{generation:06d}.{suffix}"
6873
policy_basename = f"policy_model_{generation:06d}.{suffix}"

settings.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import os
2+
import platform
3+
import psutil
24
from pathlib import Path
35

46
HOME = str(Path.home())
7+
OS_PLATFORM = platform.system()
58

69
'''
710
0
@@ -26,7 +29,23 @@
2629
MONITORING_DB_PATH = f"{HOME}/system_monitoring/monitoring.db"
2730
TMP_DIRECTORY = "/tmp"
2831

29-
NUM_THREADS = 14
30-
NUM_CORES = 14
32+
NUM_CORES = psutil.cpu_count()
3133

32-
TOOL_CHAIN = "clang" # Will need to change this for linux/windows
34+
SELF_PLAY_THREADS = NUM_CORES
35+
GBDT_TRAINING_THREADS = NUM_CORES
36+
TREELITE_THREADS = NUM_CORES
37+
ASSESSMENT_THREADS = NUM_CORES
38+
39+
40+
'''
41+
TOOL_CHAIN used by treelite to compile tree
42+
mac/os: clang
43+
linux: gcc
44+
windows: ?
45+
'''
46+
if OS_PLATFORM == "Darwin":
47+
TOOL_CHAIN = "clang"
48+
elif OS_PLATFORM == "Linux":
49+
TOOL_CHAIN = "gcc"
50+
else:
51+
raise KeyError(f"Unhandled platform: {OS_PLATFORM}")

train_bot.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
22
import time
33

4+
5+
import settings
46
from self_play import run as run_self_play
57
from species import get_species
68
from train import run as run_model_training
@@ -12,8 +14,8 @@ def run(
1214
environment,
1315
species_name,
1416
num_batches,
15-
num_workers=12,
16-
adjusted_win_rate_threshold=0.51,
17+
num_workers=settings.SELF_PLAY_THREADS,
18+
adjusted_win_rate_threshold=0.50,
1719
num_assessment_games=200,
1820
):
1921
num_faceoff_rounds = math.ceil(num_assessment_games / num_workers) # Will play at least num_workers per round

treelite_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def build_treelite_model(
5858
# Compile model to C/C++
5959
print("Compiling Tree")
6060
params = dict(
61-
parallel_comp=settings.NUM_THREADS,
61+
parallel_comp=settings.TREELITE_THREADS,
6262
# quantize=1, # Supposed to speed up predictions. Didn't when I tried it.
6363
)
6464
if annotation_results_path is not None:

0 commit comments

Comments
 (0)