diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index ab6ba2d..a9c590f 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -229,13 +229,13 @@ def run( if rank == 0: logger.info("loaded pretrained %s" % (hps.pretrainG)) if hasattr(net_g, "module"): - logger.info( + print( net_g.module.load_state_dict( torch.load(hps.pretrainG, map_location="cpu")["model"] ) ) ##测试不加载优化器 else: - logger.info( + print( net_g.load_state_dict( torch.load(hps.pretrainG, map_location="cpu")["model"] ) @@ -244,13 +244,13 @@ def run( if rank == 0: logger.info("loaded pretrained %s" % (hps.pretrainD)) if hasattr(net_d, "module"): - logger.info( + print( net_d.module.load_state_dict( torch.load(hps.pretrainD, map_location="cpu")["model"] ) ) else: - logger.info( + print( net_d.load_state_dict( torch.load(hps.pretrainD, map_location="cpu")["model"] )