1
0
mirror of synced 2024-11-23 23:21:03 +01:00

Update train.py

This commit is contained in:
RVC-Boss 2024-01-26 16:01:14 +08:00 committed by GitHub
parent d187257944
commit b304564c9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -229,13 +229,13 @@ def run(
if rank == 0: if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainG)) logger.info("loaded pretrained %s" % (hps.pretrainG))
if hasattr(net_g, "module"): if hasattr(net_g, "module"):
logger.info( print(
net_g.module.load_state_dict( net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"] torch.load(hps.pretrainG, map_location="cpu")["model"]
) )
) ##测试不加载优化器 ) ##测试不加载优化器
else: else:
logger.info( print(
net_g.load_state_dict( net_g.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"] torch.load(hps.pretrainG, map_location="cpu")["model"]
) )
@ -244,13 +244,13 @@ def run(
if rank == 0: if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainD)) logger.info("loaded pretrained %s" % (hps.pretrainD))
if hasattr(net_d, "module"): if hasattr(net_d, "module"):
logger.info( print(
net_d.module.load_state_dict( net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"] torch.load(hps.pretrainD, map_location="cpu")["model"]
) )
) )
else: else:
logger.info( print(
net_d.load_state_dict( net_d.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"] torch.load(hps.pretrainD, map_location="cpu")["model"]
) )