Update train.py
This commit is contained in:
parent
d187257944
commit
b304564c9e
@ -229,13 +229,13 @@ def run(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % (hps.pretrainG))
|
logger.info("loaded pretrained %s" % (hps.pretrainG))
|
||||||
if hasattr(net_g, "module"):
|
if hasattr(net_g, "module"):
|
||||||
logger.info(
|
print(
|
||||||
net_g.module.load_state_dict(
|
net_g.module.load_state_dict(
|
||||||
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
||||||
)
|
)
|
||||||
) ##测试不加载优化器
|
) ##测试不加载优化器
|
||||||
else:
|
else:
|
||||||
logger.info(
|
print(
|
||||||
net_g.load_state_dict(
|
net_g.load_state_dict(
|
||||||
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
||||||
)
|
)
|
||||||
@ -244,13 +244,13 @@ def run(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % (hps.pretrainD))
|
logger.info("loaded pretrained %s" % (hps.pretrainD))
|
||||||
if hasattr(net_d, "module"):
|
if hasattr(net_d, "module"):
|
||||||
logger.info(
|
print(
|
||||||
net_d.module.load_state_dict(
|
net_d.module.load_state_dict(
|
||||||
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
print(
|
||||||
net_d.load_state_dict(
|
net_d.load_state_dict(
|
||||||
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user