# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""The module to tune schedules."""
import argparse
import copy
import importlib
import json
import os
import pathlib
import re
import sys
import time
from slapo.logger import get_logger
from slapo.framework_dialect import get_dialect_cls
logger = get_logger()
def must_fix(func):
"""Decorator to mark a function in Symbol to ensure its value has been fixed."""
def wrapper(self, *args, **kwargs):
return self.is_fixed() and func(self, *args, **kwargs)
return wrapper
[docs]class Symbol:
"""A tunable symbol."""
def __init__(self, name, vals):
self.name = name
self.vals = vals
self.fixed_idx = -1
@property
def value(self):
if not self.is_fixed():
raise ValueError(f"The value of symbol {self.name} has not been fixed")
return self.vals[self.fixed_idx]
@must_fix
def __gt__(self, other):
return self.value > other
@must_fix
def __ge__(self, other):
return self.value >= other
@must_fix
def __lt__(self, other):
return self.value < other
@must_fix
def __le__(self, other):
return self.value <= other
def __len__(self):
return len(self.vals)
[docs] def add(self, val):
"""Add a value to the symbol. If the value is already in the symbol, do nothing."""
if val not in self.vals:
self.vals.append(val)
[docs] def fix_at(self, idx):
"""Fix the value of this symbol at the given index."""
if idx < len(self.vals):
self.fixed_idx = idx
return
raise ValueError(
f"Cannot fix {self.name} with {len(self.vals)} values at idx {idx}"
)
[docs] def is_fixed(self):
"""Check if the value of this symbol has been fixed."""
return self.fixed_idx != -1
[docs]class Space:
"""The tuning space."""
def __init__(self):
self.space = {}
self.idx_to_name = []
self.fixed_idx = -1
[docs] def create_symbol(self, name, vals):
"""Create a symbol in the space. If the symbol already exists:
1) If the symbol is fixed, do nothing;
2) Otherwise re-create the symbol, because its candidate values may change
due to other fixed symbols.
"""
if name in self.space:
# Ignore if the value has been fixed; otherwise re-generate the symbol.
if self.space[name].is_fixed():
return self.space[name]
# Create a new symbol.
if name not in self.space:
self.idx_to_name.append(name)
self.space[name] = Symbol(name, vals)
return self.space[name]
[docs] def next(self):
"""Get the next symbol to fix."""
if self.fixed_idx + 1 < len(self.space):
self.fixed_idx += 1
return self.space[self.idx_to_name[self.fixed_idx]]
return None
[docs] def reset(self):
"""Reset the space to the initial state."""
self.fixed_idx = -1
for symbol in self.space.values():
symbol.fixed_idx = -1
[docs] def to_dict(self):
"""Convert the space to a dict. Note that all symbols must be fixed
before calling this function.
"""
cfg = {}
for symbol in self.space.values():
cfg[symbol.name] = symbol.value
return cfg
[docs] def clone(self):
"""Clone the space."""
return copy.deepcopy(self)
[docs] @staticmethod
def cfg_dict_to_str(cfg_dict):
"""Convert a config dict to a string for logging and debugging."""
ret = "("
for idx, (k, v) in enumerate(cfg_dict.items()):
is_last = idx == len(cfg_dict) - 1
last_ch = ")" if is_last else ", "
ret += f"{k}: {v}{last_ch}"
return ret
[docs] def log_space(self, training_script_args, update_space_fn):
"""Print the tuning space for logging."""
def _run(space, count=0):
symbol = space.next()
if symbol is not None:
for idx in range(len(symbol.vals)):
symbol.fix_at(idx)
space = update_space_fn(training_script_args, space)
count = _run(space.clone(), count)
return count
logger.info("\t%s", self.cfg_dict_to_str(space.to_dict()))
return count + 1
logger.info("Enumerating the search space:")
count = _run(self.clone())
logger.info("Space size: %d", count)
def __repr__(self):
ret = "Space(\n"
for idx, symbol in enumerate(self.space.values()):
is_last = idx == len(self.space) - 1
val = f"{symbol.value} (fixed)" if symbol.is_fixed() else str(symbol.vals)
last_ch = ")" if is_last else ",\n"
ret += f"\t{symbol.name}: {val}{last_ch}"
return ret
[docs]class Database:
"""A simple database to store the results of tuning in Dict."""
def __init__(self, db_file_name=None):
self.db_file_name = db_file_name
self.db = {}
if self.db_file_name:
logger.info("Tuning records will be saved to %s", self.db_file_name)
[docs] def load(self):
"""Load the database from the file."""
if self.db_file_name and os.path.exists(self.db_file_name):
with open(self.db_file_name, "r", encoding="utf-8") as filep:
self.db = json.load(filep)
logger.info(
"Loaded %d tuning records from %s", len(self.db), self.db_file_name
)
[docs] def commit(self, key, data):
"""Commit the data to the database and update the DB file."""
self.db[key] = data
if self.db_file_name:
with open(self.db_file_name, "w", encoding="utf-8") as filep:
json.dump(self.db, filep, indent=2)
def parse_args():
parser = argparse.ArgumentParser("Auto-Tuning for Model Schedule")
parser.add_argument(
"--config",
type=str,
required=True,
help="The config file including tuning space definition, etc",
)
parser.add_argument(
"--db",
type=str,
help="The file path to store tuning records in JSON format",
)
parser.add_argument(
"--error-stop",
type=str,
default="none",
choices=["none", "symbol", "all"],
help="When error occurs, either stop tuning the current symbol or all symbols",
)
parser.add_argument(
"training_script",
type=str,
help="The full path to the training script. The defined tunable parameters in "
"config file can be used in both the script and the arguments",
)
parser.add_argument(
"training_script_args",
nargs=argparse.REMAINDER,
help="The arguments to the training script",
)
return parser.parse_args()
def convert_nargs_to_dict(nargs):
"""Convert the arguments to a dict."""
if not nargs:
return {}
def infer_type(val):
try:
val = float(val)
if val // 1 == val:
return int(val)
return val
except ValueError:
return val
def remove_leading_minus(name):
idx = re.match("-*", name).end()
return name[idx:]
ret = {}
ptr = 0
while ptr < len(nargs):
arg_name = nargs[ptr]
ptr += 1
if (
ptr == len(nargs)
or not arg_name.startswith(("-", "--"))
or nargs[ptr].startswith(("-", "--"))
):
# True/False flag or positional argument.
ret[remove_leading_minus(arg_name)] = 1
else:
vals = []
while ptr < len(nargs) and not nargs[ptr].startswith(("-", "--")):
vals.append(infer_type(nargs[ptr]))
ptr += 1
ret[remove_leading_minus(arg_name)] = vals[0] if len(vals) == 1 else vals
return ret
def run_training_script(args, tuneable_cfg):
"""Run the training script with the given config."""
train_script_args = " ".join(args.training_script_args)
# Replace tunable values in training script arguments.
for key, val in tuneable_cfg.items():
train_script_args = re.sub(key, f'"{str(val)}"', train_script_args)
# Set all tunable parameters as environment variables.
env = " ".join(f"{k}={v}" for k, v in tuneable_cfg.items())
cmd = f"{env} python3 {args.training_script} {train_script_args}"
cmd += " > run_script.log 2>&1"
logger.info("\tRunning command: %s", cmd)
os.system(cmd)
return "run_script.log"
def tune(args, get_bs_range, eval_fn):
"""Tune the given space with an evaluation function."""
training_script_args = convert_nargs_to_dict(args.training_script_args)
min_bs, max_bs, step = get_bs_range(training_script_args)
bs_range = list(range(min_bs, max_bs + 1, step))
ckpt_ratio_range = [1.0, 0.92, 0.84, 0.67, 0.5, 0.34, 0.25]
early_stopping_patience = 0
def is_valid(config):
if "slapo-deepspeed" in training_script_args:
# DeepSpeed uses data parallelism requiring the global batch size
# can be divided by number of devices
return config["batch_size"] % int(training_script_args["gpus"]) == 0
return True
def binary_search(data, cfg_dict, key, curr_best, lt=0, rt=None):
nonlocal early_stopping_patience
logger.info("Binary searching %s without OOM", key)
if rt is None:
rt = len(data) - 1
while lt <= rt:
mid = (lt + rt) // 2
cfg_dict[key] = data[mid]
logger.info("- Evaluating %s", str(cfg_dict))
# early pruning
if is_valid(cfg_dict):
thrpt = eval_fn(cfg_dict)
else:
thrpt = 0.0
logger.info(
"Invalid configuration point %s, n_gpu=%s",
str(cfg_dict),
training_script_args["gpus"],
)
time.sleep(0.5)
logger.info("\tThroughput: %.2f", thrpt)
# TODO: threshold should be a larger value used for pruning
# maybe provide an interface for the users
if thrpt < 0.01:
rt = mid - 1
else:
lt = mid + 1
if thrpt > curr_best[1]:
curr_best = (cfg_dict.copy(), thrpt)
early_stopping_patience = 0
else:
early_stopping_patience += 1
# set step 5 as the patience
if early_stopping_patience >= 5:
return mid, None, curr_best
logger.info(
"\tCurrent best config: %s, thrpt: %.2f",
str(curr_best[0]),
curr_best[1],
)
if thrpt < 0.01:
mid = mid - 1
return mid, thrpt, curr_best
def _run(min_bs, max_bs, step):
if "megatron" in training_script_args:
ckpt_ratio = "full"
else:
ckpt_ratio = 1.0
cfg_dict = {"batch_size": max_bs, "ckpt_ratio": ckpt_ratio}
# suppose the user given minimum bs is always executable
logger.info("Evaluating inital config...")
logger.info("- Evaluating %s", str(cfg_dict))
thrpt = eval_fn(cfg_dict)
logger.info("\tThroughput: %.2f", thrpt)
curr_best = (cfg_dict.copy(), thrpt)
if thrpt == 0: # OOM
mid, thrpt, curr_best = binary_search(
bs_range, cfg_dict, "batch_size", curr_best
)
max_bs = bs_range[mid]
else:
mid = 0
logger.info("Maximum batch size without OOM: %d", max_bs)
if (
"slapo-megatron" in training_script_args
or "slapo-deepspeed" in training_script_args
):
for bs in reversed(list(range(min_bs, max_bs + 1, step))):
cfg_dict["batch_size"] = bs
mid, thrpt, curr_best = binary_search(
ckpt_ratio_range, cfg_dict, "ckpt_ratio", curr_best, lt=mid
)
if thrpt is None: # early stopping
break
return curr_best
logger.info("Start tuning...")
curr_best = _run(min_bs, max_bs, step)
logger.info("Tuning done!")
return curr_best[0]
def load_config(config_file):
"""Load required functions from the tuning config."""
path = pathlib.Path(config_file).absolute()
sys.path.append(str(path.parent))
module = importlib.import_module(path.stem)
if not hasattr(module, "get_bs_range"):
raise ValueError("Missing 'get_bs' function in config file")
return module.get_bs_range
def parse_log(args, log_file):
with open(log_file, "r", encoding="utf-8") as f:
text = f.read()
if "slapo-megatron" in args or "megatron" in args:
parser = get_dialect_cls("log_parser", "megatron")
_, samples_per_sec, _, error_code = parser.parse_log(log_file)
elif "slapo-deepspeed" in args or "deepspeed" in args:
parser = get_dialect_cls("log_parser", "deepspeed")
_, samples_per_sec, _, error_code = parser.parse_log(log_file)
else:
raise RuntimeError("Please provide correct `impl`")
return (error_code, samples_per_sec, text)
def main():
"""Entry point."""
args = parse_args()
get_bs_range = load_config(args.config)
db = Database(args.db)
def eval_fn(cfg):
log_file = run_training_script(args, cfg)
error_code, thrpt, memo = parse_log(
convert_nargs_to_dict(args.training_script_args), "log.txt"
)
with open(log_file, "r", encoding="utf-8") as filep:
log_file_ctx = filep.read()
db.commit(
Space.cfg_dict_to_str(cfg),
{
"error_code": error_code,
"thrpt": thrpt,
"log": log_file_ctx,
"memo": memo,
},
)
if args.error_stop == "all" and error_code != 0:
raise ValueError("Stop tuning due to error. Check log for details")
return thrpt if error_code == 0 else 0
curr_best = tune(args, get_bs_range, eval_fn)
logger.info("Best config: %s", curr_best)
if __name__ == "__main__":
main()