diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index ad9a5b5..b03d572 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -23,16 +23,14 @@ try: if torch.xpu.is_available(): from infer.modules.ipex import ipex_init - - ipex_init() - - from torch.xpu.amp import autocast from infer.modules.ipex.gradscaler import gradscaler_init + from torch.xpu.amp import autocast GradScaler = gradscaler_init() + ipex_init() else: from torch.cuda.amp import GradScaler, autocast -except Exception: # pylint: disable=broad-exception-caught +except Exception: from torch.cuda.amp import GradScaler, autocast torch.backends.cudnn.deterministic = False @@ -106,11 +104,10 @@ def main(): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) children = [] - logger = utils.get_logger(hps.model_dir) for i in range(n_gpus): subproc = mp.Process( target=run, - args=(i, n_gpus, hps, logger), + args=(i, n_gpus, hps), ) children.append(subproc) subproc.start() @@ -119,10 +116,10 @@ def main(): children[i].join() -def run(rank, n_gpus, hps, logger: logging.Logger): +def run(rank, n_gpus, hps,): global global_step if rank == 0: - # logger = utils.get_logger(hps.model_dir) + logger = utils.get_logger(hps.model_dir) logger.info(hps) # utils.check_git_hash(hps.model_dir) writer = SummaryWriter(log_dir=hps.model_dir)