1
0
mirror of synced 2024-11-30 18:24:32 +01:00

Update train_nsf_sim_cache_sid_load_pretrain.py

This commit is contained in:
liujing04 2023-04-09 15:23:27 +00:00 committed by GitHub
parent 47c665384f
commit 70c6865bc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,7 @@ os.environ["CUDA_VISIBLE_DEVICES"]=hps.gpus.replace("-",",")
n_gpus=len(hps.gpus.split("-")) n_gpus=len(hps.gpus.split("-"))
from random import shuffle from random import shuffle
import traceback,json,argparse,itertools,math,torch,pdb import traceback,json,argparse,itertools,math,torch,pdb
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
from torch import nn, optim from torch import nn, optim
from torch.nn import functional as F from torch.nn import functional as F
@ -311,7 +311,7 @@ def train_and_evaluate(
) )
global_step += 1 global_step += 1
# if global_step % hps.train.eval_interval == 0: # if global_step % hps.train.eval_interval == 0:
if epoch % hps.save_every_epoch == 0: if epoch % hps.save_every_epoch == 0 and rank == 0:
if(hps.if_latest==0): if(hps.if_latest==0):
utils.save_checkpoint( utils.save_checkpoint(
net_g, net_g,
@ -466,7 +466,7 @@ def train_and_evaluate(
) )
global_step += 1 global_step += 1
# if global_step % hps.train.eval_interval == 0: # if global_step % hps.train.eval_interval == 0:
if epoch % hps.save_every_epoch == 0: if epoch % hps.save_every_epoch == 0 and rank == 0:
if(hps.if_latest==0): if(hps.if_latest==0):
utils.save_checkpoint( utils.save_checkpoint(
net_g, net_g,