Update train_nsf_sim_cache_sid_load_pretrain.py
This commit is contained in:
parent
47c665384f
commit
70c6865bc3
@ -7,7 +7,7 @@ os.environ["CUDA_VISIBLE_DEVICES"]=hps.gpus.replace("-",",")
|
|||||||
n_gpus=len(hps.gpus.split("-"))
|
n_gpus=len(hps.gpus.split("-"))
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
import traceback,json,argparse,itertools,math,torch,pdb
|
import traceback,json,argparse,itertools,math,torch,pdb
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = False
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -311,7 +311,7 @@ def train_and_evaluate(
|
|||||||
)
|
)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
# if global_step % hps.train.eval_interval == 0:
|
# if global_step % hps.train.eval_interval == 0:
|
||||||
if epoch % hps.save_every_epoch == 0:
|
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
||||||
if(hps.if_latest==0):
|
if(hps.if_latest==0):
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
net_g,
|
net_g,
|
||||||
@ -466,7 +466,7 @@ def train_and_evaluate(
|
|||||||
)
|
)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
# if global_step % hps.train.eval_interval == 0:
|
# if global_step % hps.train.eval_interval == 0:
|
||||||
if epoch % hps.save_every_epoch == 0:
|
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
||||||
if(hps.if_latest==0):
|
if(hps.if_latest==0):
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
net_g,
|
net_g,
|
||||||
|
Loading…
Reference in New Issue
Block a user