From 9602ea649c2f99f623d2553efc5356cb98660270 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Fri, 26 Jan 2024 16:03:00 +0800 Subject: [PATCH] Update train.py --- infer/modules/train/train.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index a9c590f..38a5678 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -104,10 +104,11 @@ def main(): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) children = [] + logger = utils.get_logger(hps.model_dir) for i in range(n_gpus): subproc = mp.Process( target=run, - args=(i, n_gpus, hps), + args=(i, n_gpus, hps, logger), ) children.append(subproc) subproc.start() @@ -116,14 +117,10 @@ def main(): children[i].join() -def run( - rank, - n_gpus, - hps, -): +def run(rank, n_gpus, hps, logger: logging.Logger): global global_step if rank == 0: - logger = utils.get_logger(hps.model_dir) + # logger = utils.get_logger(hps.model_dir) logger.info(hps) # utils.check_git_hash(hps.model_dir) writer = SummaryWriter(log_dir=hps.model_dir) @@ -229,13 +226,13 @@ def run( if rank == 0: logger.info("loaded pretrained %s" % (hps.pretrainG)) if hasattr(net_g, "module"): - print( + logger.info( net_g.module.load_state_dict( torch.load(hps.pretrainG, map_location="cpu")["model"] ) ) ##测试不加载优化器 else: - print( + logger.info( net_g.load_state_dict( torch.load(hps.pretrainG, map_location="cpu")["model"] ) @@ -244,13 +241,13 @@ def run( if rank == 0: logger.info("loaded pretrained %s" % (hps.pretrainD)) if hasattr(net_d, "module"): - print( + logger.info( net_d.module.load_state_dict( torch.load(hps.pretrainD, map_location="cpu")["model"] ) ) else: - print( + logger.info( net_d.load_state_dict( torch.load(hps.pretrainD, map_location="cpu")["model"] )