Update train.py
This commit is contained in:
parent
b304564c9e
commit
9602ea649c
@ -104,10 +104,11 @@ def main():
|
|||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||||
children = []
|
children = []
|
||||||
|
logger = utils.get_logger(hps.model_dir)
|
||||||
for i in range(n_gpus):
|
for i in range(n_gpus):
|
||||||
subproc = mp.Process(
|
subproc = mp.Process(
|
||||||
target=run,
|
target=run,
|
||||||
args=(i, n_gpus, hps),
|
args=(i, n_gpus, hps, logger),
|
||||||
)
|
)
|
||||||
children.append(subproc)
|
children.append(subproc)
|
||||||
subproc.start()
|
subproc.start()
|
||||||
@ -116,14 +117,10 @@ def main():
|
|||||||
children[i].join()
|
children[i].join()
|
||||||
|
|
||||||
|
|
||||||
def run(
|
def run(rank, n_gpus, hps, logger: logging.Logger):
|
||||||
rank,
|
|
||||||
n_gpus,
|
|
||||||
hps,
|
|
||||||
):
|
|
||||||
global global_step
|
global global_step
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger = utils.get_logger(hps.model_dir)
|
# logger = utils.get_logger(hps.model_dir)
|
||||||
logger.info(hps)
|
logger.info(hps)
|
||||||
# utils.check_git_hash(hps.model_dir)
|
# utils.check_git_hash(hps.model_dir)
|
||||||
writer = SummaryWriter(log_dir=hps.model_dir)
|
writer = SummaryWriter(log_dir=hps.model_dir)
|
||||||
@ -229,13 +226,13 @@ def run(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % (hps.pretrainG))
|
logger.info("loaded pretrained %s" % (hps.pretrainG))
|
||||||
if hasattr(net_g, "module"):
|
if hasattr(net_g, "module"):
|
||||||
print(
|
logger.info(
|
||||||
net_g.module.load_state_dict(
|
net_g.module.load_state_dict(
|
||||||
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
||||||
)
|
)
|
||||||
) ##测试不加载优化器
|
) ##测试不加载优化器
|
||||||
else:
|
else:
|
||||||
print(
|
logger.info(
|
||||||
net_g.load_state_dict(
|
net_g.load_state_dict(
|
||||||
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
||||||
)
|
)
|
||||||
@ -244,13 +241,13 @@ def run(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % (hps.pretrainD))
|
logger.info("loaded pretrained %s" % (hps.pretrainD))
|
||||||
if hasattr(net_d, "module"):
|
if hasattr(net_d, "module"):
|
||||||
print(
|
logger.info(
|
||||||
net_d.module.load_state_dict(
|
net_d.module.load_state_dict(
|
||||||
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
logger.info(
|
||||||
net_d.load_state_dict(
|
net_d.load_state_dict(
|
||||||
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user