1
0
mirror of synced 2024-12-19 10:56:00 +01:00
Retrieval-based-Voice-Conve.../infer/modules/train/train.py
github-actions[bot] e9dd11bddb
chore(sync): merge dev into main (#1379)
* Optimize latency (#1259)

* add attribute:   configs/config.py
	Optimize latency:   tools/rvc_for_realtime.py

* new file:   assets/Synthesizer_inputs.pth

* fix:   configs/config.py
	fix:   tools/rvc_for_realtime.py

* fix bug:   infer/lib/infer_pack/models.py

* new file:   assets/hubert_inputs.pth
	new file:   assets/rmvpe_inputs.pth
	modified:   configs/config.py
	new features:   infer/lib/rmvpe.py
	new features:   tools/jit_export/__init__.py
	new features:   tools/jit_export/get_hubert.py
	new features:   tools/jit_export/get_rmvpe.py
	new features:   tools/jit_export/get_synthesizer.py
	optimize:   tools/rvc_for_realtime.py

* optimize:   tools/jit_export/get_synthesizer.py
	fix bug:   tools/jit_export/__init__.py

* Fixed a bug caused by using half on the CPU:   infer/lib/rmvpe.py
	Fixed a bug caused by using half on the CPU:   tools/jit_export/__init__.py
	Fixed CIRCULAR IMPORT:   tools/jit_export/get_rmvpe.py
	Fixed CIRCULAR IMPORT:   tools/jit_export/get_synthesizer.py
	Fixed a bug caused by using half on the CPU:   tools/rvc_for_realtime.py

* Remove useless code:   infer/lib/rmvpe.py

* Delete gui_v1 copy.py

* Delete .vscode/launch.json

* Delete jit_export_test.py

* Delete tools/rvc_for_realtime copy.py

* Delete configs/config.json

* Delete .gitignore

* Fix exceptions caused by switching inference devices:   infer/lib/rmvpe.py
	Fix exceptions caused by switching inference devices:   tools/jit_export/__init__.py
	Fix exceptions caused by switching inference devices:   tools/rvc_for_realtime.py

* restore

* replace(you can undo this commit)

* remove debug_print

---------

Co-authored-by: Ftps <ftpsflandre@gmail.com>

* Fixed some bugs when exporting ONNX model (#1254)

* fix import (#1280)

* fix import

* lint

* 🎨 同步 locale (#1242)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Fix jit load and import issue (#1282)

* fix jit model loading :   infer/lib/rmvpe.py

* modified:   assets/hubert/.gitignore
	move file:    assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth
	modified:   assets/rmvpe/.gitignore
	move file:    assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth
	fix import:   gui_v1.py

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* Add input wav and delay time monitor for real-time gui (#1293)

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* 🎨 同步 locale (#1289)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: edit PR template

* add input wav and delay time monitor

---------

Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>

* Optimize latency using scripted jit (#1291)

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* 🎨 同步 locale (#1289)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: edit PR template

* Optimize-latency-using-scripted:   configs/config.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/attentions.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/commons.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/models.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/modules.py
	Optimize-latency-using-scripted:   infer/lib/jit/__init__.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_hubert.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_rmvpe.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_synthesizer.py
	Optimize-latency-using-scripted:   infer/lib/rmvpe.py
	Optimize-latency-using-scripted:   tools/rvc_for_realtime.py

* modified:   infer/lib/infer_pack/models.py

* fix some bug:   configs/config.py
	fix some bug:   infer/lib/infer_pack/models.py
	fix some bug:   infer/lib/rmvpe.py

* Fixed abnormal reference of logger in multiprocessing:   infer/modules/train/train.py

---------

Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Format code (#1298)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* 🎨 同步 locale (#1299)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: optimize actions

* feat(workflow): add sync dev

* feat: optimize actions

* feat: optimize actions

* feat: optimize actions

* feat: optimize actions

* feat: add jit options (#1303)

Delete useless code:   infer/lib/jit/get_synthesizer.py
	Optimized code:   tools/rvc_for_realtime.py

* Code refactor + re-design inference ui (#1304)

* Code refacor + re-design inference ui

* Fix tabname

* i18n jp

---------

Co-authored-by: Ftps <ftpsflandre@gmail.com>

* feat: optimize actions

* feat: optimize actions

* Update README & en_US locale file (#1309)

* critical: some bug fixes (#1322)

* JIT acceleration switch does not support hot update

* fix padding bug of rmvpe in torch-directml

* fix padding bug of rmvpe in torch-directml

* Fix STFT under torch_directml (#1330)

* chore(format): run black on dev (#1318)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* chore(i18n): sync locale on dev (#1317)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: allow for tta to be passed to uvr (#1361)

* chore(format): run black on dev (#1373)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Added script for automatically download all needed models at install (#1366)

* Delete modules.py

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* chore(i18n): sync locale on dev (#1377)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* chore(format): run black on dev (#1376)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Update IPEX library (#1362)

* Update IPEX library

* Update ipex index

* chore(format): run black on dev (#1378)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com>
Co-authored-by: Ftps <ftpsflandre@gmail.com>
Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com>
Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com>
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com>
Co-authored-by: Rice Cake <gak141808@gmail.com>
Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com>
Co-authored-by: Dmitry <nda2911@yandex.ru>
Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
2023-10-06 17:14:33 +08:00

643 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import logging
logger = logging.getLogger(__name__)
now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))
import datetime
from infer.lib.train import utils
hps = utils.get_hparams()
os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
n_gpus = len(hps.gpus.split("-"))
from random import randint, shuffle
import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
ipex_init()
from torch.xpu.amp import autocast
from infer.modules.ipex.gradscaler import gradscaler_init
GradScaler = gradscaler_init()
else:
from torch.cuda.amp import GradScaler, autocast
except Exception: # pylint: disable=broad-exception-caught
from torch.cuda.amp import GradScaler, autocast
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
from time import sleep
from time import time as ttime
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from infer.lib.infer_pack import commons
from infer.lib.train.data_utils import (
DistributedBucketSampler,
TextAudioCollate,
TextAudioCollateMultiNSFsid,
TextAudioLoader,
TextAudioLoaderMultiNSFsid,
)
if hps.version == "v1":
from infer.lib.infer_pack.models import MultiPeriodDiscriminator
from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
from infer.lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
)
else:
from infer.lib.infer_pack.models import (
SynthesizerTrnMs768NSFsid as RVC_Model_f0,
SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
)
from infer.lib.train.losses import (
discriminator_loss,
feature_loss,
generator_loss,
kl_loss,
)
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from infer.lib.train.process_ckpt import savee
global_step = 0
class EpochRecorder:
def __init__(self):
self.last_time = ttime()
def record(self):
now_time = ttime()
elapsed_time = now_time - self.last_time
self.last_time = now_time
elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time))
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
return f"[{current_time}] | ({elapsed_time_str})"
def main():
n_gpus = torch.cuda.device_count()
if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True:
n_gpus = 1
if n_gpus < 1:
# patch to unblock people without gpus. there is probably a better way.
logger.warning("NO GPU DETECTED: falling back to CPU - this may take a while")
n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
children = []
logger = utils.get_logger(hps.model_dir)
for i in range(n_gpus):
subproc = mp.Process(
target=run,
args=(i, n_gpus, hps, logger),
)
children.append(subproc)
subproc.start()
for i in range(n_gpus):
children[i].join()
def run(rank, n_gpus, hps, logger: logging.Logger):
global global_step
if rank == 0:
# logger = utils.get_logger(hps.model_dir)
logger.info(hps)
# utils.check_git_hash(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
dist.init_process_group(
backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
if hps.if_f0 == 1:
train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
else:
train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size * n_gpus,
# [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
[100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
# It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
# num_workers=8 -> num_workers=4
if hps.if_f0 == 1:
collate_fn = TextAudioCollateMultiNSFsid()
else:
collate_fn = TextAudioCollate()
train_loader = DataLoader(
train_dataset,
num_workers=4,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=8,
)
if hps.if_f0 == 1:
net_g = RVC_Model_f0(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
is_half=hps.train.fp16_run,
sr=hps.sample_rate,
)
else:
net_g = RVC_Model_nof0(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
is_half=hps.train.fp16_run,
)
if torch.cuda.is_available():
net_g = net_g.cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
if torch.cuda.is_available():
net_d = net_d.cuda(rank)
optim_g = torch.optim.AdamW(
net_g.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
optim_d = torch.optim.AdamW(
net_d.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
# net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
if hasattr(torch, "xpu") and torch.xpu.is_available():
pass
elif torch.cuda.is_available():
net_g = DDP(net_g, device_ids=[rank])
net_d = DDP(net_d, device_ids=[rank])
else:
net_g = DDP(net_g)
net_d = DDP(net_d)
try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
) # D多半加载没事
if rank == 0:
logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
)
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
except: # 如果首次不能加载加载pretrain
# traceback.print_exc()
epoch_str = 1
global_step = 0
if hps.pretrainG != "":
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainG))
if hasattr(net_g, "module"):
logger.info(
net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
else:
logger.info(
net_g.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
if hps.pretrainD != "":
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainD))
if hasattr(net_d, "module"):
logger.info(
net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
)
else:
logger.info(
net_d.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
)
scaler = GradScaler(enabled=hps.train.fp16_run)
cache = []
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
logger,
[writer, writer_eval],
cache,
)
else:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
None,
None,
cache,
)
scheduler_g.step()
scheduler_d.step()
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
):
net_g, net_d = nets
optim_g, optim_d = optims
train_loader, eval_loader = loaders
if writers is not None:
writer, writer_eval = writers
train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
net_d.train()
# Prepare data iterator
if hps.if_cache_data_in_gpu == True:
# Use Cache
data_iterator = cache
if cache == []:
# Make new cache
for batch_idx, info in enumerate(train_loader):
# Unpack
if hps.if_f0 == 1:
(
phone,
phone_lengths,
pitch,
pitchf,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
else:
(
phone,
phone_lengths,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
# Load on CUDA
if torch.cuda.is_available():
phone = phone.cuda(rank, non_blocking=True)
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
if hps.if_f0 == 1:
pitch = pitch.cuda(rank, non_blocking=True)
pitchf = pitchf.cuda(rank, non_blocking=True)
sid = sid.cuda(rank, non_blocking=True)
spec = spec.cuda(rank, non_blocking=True)
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
wave = wave.cuda(rank, non_blocking=True)
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
# Cache on list
if hps.if_f0 == 1:
cache.append(
(
batch_idx,
(
phone,
phone_lengths,
pitch,
pitchf,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
),
)
)
else:
cache.append(
(
batch_idx,
(
phone,
phone_lengths,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
),
)
)
else:
# Load shuffled cache
shuffle(cache)
else:
# Loader
data_iterator = enumerate(train_loader)
# Run steps
epoch_recorder = EpochRecorder()
for batch_idx, info in data_iterator:
# Data
## Unpack
if hps.if_f0 == 1:
(
phone,
phone_lengths,
pitch,
pitchf,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
else:
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
## Load on CUDA
if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
phone = phone.cuda(rank, non_blocking=True)
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
if hps.if_f0 == 1:
pitch = pitch.cuda(rank, non_blocking=True)
pitchf = pitchf.cuda(rank, non_blocking=True)
sid = sid.cuda(rank, non_blocking=True)
spec = spec.cuda(rank, non_blocking=True)
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
wave = wave.cuda(rank, non_blocking=True)
# wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
# Calculate
with autocast(enabled=hps.train.fp16_run):
if hps.if_f0 == 1:
(
y_hat,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
else:
(
y_hat,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
with autocast(enabled=False):
y_hat_mel = mel_spectrogram_torch(
y_hat.float().squeeze(1),
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
if hps.train.fp16_run == True:
y_hat_mel = y_hat_mel.half()
wave = commons.slice_segments(
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
optim_d.zero_grad()
scaler.scale(loss_disc).backward()
scaler.unscale_(optim_d)
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
with autocast(enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
)
)
# Amor For Tensorboard display
if loss_mel > 75:
loss_mel = 75
if loss_kl > 9:
loss_kl = 9
logger.info([global_step, lr])
logger.info(
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
)
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update(
{
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl": loss_kl,
}
)
scalar_dict.update(
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
)
scalar_dict.update(
{"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
)
scalar_dict.update(
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
)
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)
global_step += 1
# /Run steps
if epoch % hps.save_every_epoch == 0 and rank == 0:
if hps.if_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
)
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
)
if rank == 0 and hps.save_every_weights == "1":
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
logger.info(
"saving ckpt %s_e%s:%s"
% (
hps.name,
epoch,
savee(
ckpt,
hps.sample_rate,
hps.if_f0,
hps.name + "_e%s_s%s" % (epoch, global_step),
epoch,
hps.version,
hps,
),
)
)
if rank == 0:
logger.info("====> Epoch: {} {}".format(epoch, epoch_recorder.record()))
if epoch >= hps.total_epoch and rank == 0:
logger.info("Training is done. The program is closed.")
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
logger.info(
"saving final ckpt:%s"
% (
savee(
ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
)
)
)
sleep(1)
os._exit(2333333)
if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
main()