replace lib
This commit is contained in:
parent
c25bb6c5d5
commit
6721b81dcf
21
infer/lib/audio.py
Normal file
21
infer/lib/audio.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import ffmpeg
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(file, sr):
|
||||||
|
try:
|
||||||
|
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
|
||||||
|
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||||
|
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||||
|
file = (
|
||||||
|
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
||||||
|
) # 防止小白拷路径头尾带了空格和"和回车
|
||||||
|
out, _ = (
|
||||||
|
ffmpeg.input(file, threads=0)
|
||||||
|
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
|
||||||
|
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e}")
|
||||||
|
|
||||||
|
return np.frombuffer(out, np.float32).flatten()
|
654
infer/lib/rmvpe.py
Normal file
654
infer/lib/rmvpe.py
Normal file
@ -0,0 +1,654 @@
|
|||||||
|
import torch, numpy as np,pdb
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch,pdb
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from scipy.signal import get_window
|
||||||
|
from librosa.util import pad_center, tiny,normalize
|
||||||
|
###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py
|
||||||
|
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
||||||
|
n_fft=800, dtype=np.float32, norm=None):
|
||||||
|
"""
|
||||||
|
# from librosa 0.6
|
||||||
|
Compute the sum-square envelope of a window function at a given hop length.
|
||||||
|
This is used to estimate modulation effects induced by windowing
|
||||||
|
observations in short-time fourier transforms.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
window : string, tuple, number, callable, or list-like
|
||||||
|
Window specification, as in `get_window`
|
||||||
|
n_frames : int > 0
|
||||||
|
The number of analysis frames
|
||||||
|
hop_length : int > 0
|
||||||
|
The number of samples to advance between frames
|
||||||
|
win_length : [optional]
|
||||||
|
The length of the window function. By default, this matches `n_fft`.
|
||||||
|
n_fft : int > 0
|
||||||
|
The length of each analysis frame.
|
||||||
|
dtype : np.dtype
|
||||||
|
The data type of the output
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
||||||
|
The sum-squared envelope of the window function
|
||||||
|
"""
|
||||||
|
if win_length is None:
|
||||||
|
win_length = n_fft
|
||||||
|
|
||||||
|
n = n_fft + hop_length * (n_frames - 1)
|
||||||
|
x = np.zeros(n, dtype=dtype)
|
||||||
|
|
||||||
|
# Compute the squared window at the desired length
|
||||||
|
win_sq = get_window(window, win_length, fftbins=True)
|
||||||
|
win_sq = normalize(win_sq, norm=norm)**2
|
||||||
|
win_sq = pad_center(win_sq, n_fft)
|
||||||
|
|
||||||
|
# Fill the envelope
|
||||||
|
for i in range(n_frames):
|
||||||
|
sample = i * hop_length
|
||||||
|
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||||
|
return x
|
||||||
|
|
||||||
|
class STFT(torch.nn.Module):
|
||||||
|
def __init__(self, filter_length=1024, hop_length=512, win_length=None,
|
||||||
|
window='hann'):
|
||||||
|
"""
|
||||||
|
This module implements an STFT using 1D convolution and 1D transpose convolutions.
|
||||||
|
This is a bit tricky so there are some cases that probably won't work as working
|
||||||
|
out the same sizes before and after in all overlap add setups is tough. Right now,
|
||||||
|
this code should work with hop lengths that are half the filter length (50% overlap
|
||||||
|
between frames).
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
filter_length {int} -- Length of filters used (default: {1024})
|
||||||
|
hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
|
||||||
|
win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
|
||||||
|
equals the filter length). (default: {None})
|
||||||
|
window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
|
||||||
|
(default: {'hann'})
|
||||||
|
"""
|
||||||
|
super(STFT, self).__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length if win_length else filter_length
|
||||||
|
self.window = window
|
||||||
|
self.forward_transform = None
|
||||||
|
self.pad_amount = int(self.filter_length / 2)
|
||||||
|
scale = self.filter_length / self.hop_length
|
||||||
|
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||||
|
|
||||||
|
cutoff = int((self.filter_length / 2 + 1))
|
||||||
|
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),np.imag(fourier_basis[:cutoff, :])])
|
||||||
|
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
||||||
|
inverse_basis = torch.FloatTensor(
|
||||||
|
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
||||||
|
|
||||||
|
assert (filter_length >= self.win_length)
|
||||||
|
# get window and zero center pad it to filter_length
|
||||||
|
fft_window = get_window(window, self.win_length, fftbins=True)
|
||||||
|
fft_window = pad_center(fft_window, size=filter_length)
|
||||||
|
fft_window = torch.from_numpy(fft_window).float()
|
||||||
|
|
||||||
|
# window the bases
|
||||||
|
forward_basis *= fft_window
|
||||||
|
inverse_basis *= fft_window
|
||||||
|
|
||||||
|
self.register_buffer('forward_basis', forward_basis.float())
|
||||||
|
self.register_buffer('inverse_basis', inverse_basis.float())
|
||||||
|
|
||||||
|
def transform(self, input_data):
|
||||||
|
"""Take input data (audio) to STFT domain.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
|
||||||
|
num_frequencies, num_frames)
|
||||||
|
phase {tensor} -- Phase of STFT with shape (num_batch,
|
||||||
|
num_frequencies, num_frames)
|
||||||
|
"""
|
||||||
|
num_batches = input_data.shape[0]
|
||||||
|
num_samples = input_data.shape[-1]
|
||||||
|
|
||||||
|
self.num_samples = num_samples
|
||||||
|
|
||||||
|
# similar to librosa, reflect-pad the input
|
||||||
|
input_data = input_data.view(num_batches, 1, num_samples)
|
||||||
|
# print(1234,input_data.shape)
|
||||||
|
input_data = F.pad(input_data.unsqueeze(1),(self.pad_amount, self.pad_amount, 0, 0,0,0),mode='reflect').squeeze(1)
|
||||||
|
# print(2333,input_data.shape,self.forward_basis.shape,self.hop_length)
|
||||||
|
# pdb.set_trace()
|
||||||
|
forward_transform = F.conv1d(
|
||||||
|
input_data,
|
||||||
|
self.forward_basis,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
cutoff = int((self.filter_length / 2) + 1)
|
||||||
|
real_part = forward_transform[:, :cutoff, :]
|
||||||
|
imag_part = forward_transform[:, cutoff:, :]
|
||||||
|
|
||||||
|
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
|
||||||
|
# phase = torch.atan2(imag_part.data, real_part.data)
|
||||||
|
|
||||||
|
return magnitude#, phase
|
||||||
|
|
||||||
|
def inverse(self, magnitude, phase):
|
||||||
|
"""Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
|
||||||
|
by the ```transform``` function.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
|
||||||
|
num_frequencies, num_frames)
|
||||||
|
phase {tensor} -- Phase of STFT with shape (num_batch,
|
||||||
|
num_frequencies, num_frames)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
|
||||||
|
shape (num_batch, num_samples)
|
||||||
|
"""
|
||||||
|
recombine_magnitude_phase = torch.cat(
|
||||||
|
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)
|
||||||
|
|
||||||
|
inverse_transform = F.conv_transpose1d(
|
||||||
|
recombine_magnitude_phase,
|
||||||
|
self.inverse_basis,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
if self.window is not None:
|
||||||
|
window_sum = window_sumsquare(
|
||||||
|
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length, n_fft=self.filter_length,
|
||||||
|
dtype=np.float32)
|
||||||
|
# remove modulation effects
|
||||||
|
approx_nonzero_indices = torch.from_numpy(
|
||||||
|
np.where(window_sum > tiny(window_sum))[0])
|
||||||
|
window_sum = torch.from_numpy(window_sum).to(inverse_transform.device)
|
||||||
|
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||||
|
|
||||||
|
# scale by hop ratio
|
||||||
|
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||||
|
|
||||||
|
inverse_transform = inverse_transform[..., self.pad_amount:]
|
||||||
|
inverse_transform = inverse_transform[..., :self.num_samples]
|
||||||
|
inverse_transform = inverse_transform.squeeze(1)
|
||||||
|
|
||||||
|
return inverse_transform
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
"""Take input data (audio) to STFT domain and then back to audio.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
|
||||||
|
shape (num_batch, num_samples)
|
||||||
|
"""
|
||||||
|
self.magnitude, self.phase = self.transform(input_data)
|
||||||
|
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||||
|
return reconstruction
|
||||||
|
from time import time as ttime
|
||||||
|
class BiGRU(nn.Module):
|
||||||
|
def __init__(self, input_features, hidden_features, num_layers):
|
||||||
|
super(BiGRU, self).__init__()
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_features,
|
||||||
|
hidden_features,
|
||||||
|
num_layers=num_layers,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.gru(x)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockRes(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, momentum=0.01):
|
||||||
|
super(ConvBlockRes, self).__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
stride=(1, 1),
|
||||||
|
padding=(1, 1),
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
stride=(1, 1),
|
||||||
|
padding=(1, 1),
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
||||||
|
self.is_shortcut = True
|
||||||
|
else:
|
||||||
|
self.is_shortcut = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.is_shortcut:
|
||||||
|
return self.conv(x) + self.shortcut(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x) + x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
in_size,
|
||||||
|
n_encoders,
|
||||||
|
kernel_size,
|
||||||
|
n_blocks,
|
||||||
|
out_channels=16,
|
||||||
|
momentum=0.01,
|
||||||
|
):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
self.n_encoders = n_encoders
|
||||||
|
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
self.latent_channels = []
|
||||||
|
for i in range(self.n_encoders):
|
||||||
|
self.layers.append(
|
||||||
|
ResEncoderBlock(
|
||||||
|
in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.latent_channels.append([out_channels, in_size])
|
||||||
|
in_channels = out_channels
|
||||||
|
out_channels *= 2
|
||||||
|
in_size //= 2
|
||||||
|
self.out_size = in_size
|
||||||
|
self.out_channel = out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
concat_tensors = []
|
||||||
|
x = self.bn(x)
|
||||||
|
for i in range(self.n_encoders):
|
||||||
|
_, x = self.layers[i](x)
|
||||||
|
concat_tensors.append(_)
|
||||||
|
return x, concat_tensors
|
||||||
|
|
||||||
|
|
||||||
|
class ResEncoderBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
|
||||||
|
):
|
||||||
|
super(ResEncoderBlock, self).__init__()
|
||||||
|
self.n_blocks = n_blocks
|
||||||
|
self.conv = nn.ModuleList()
|
||||||
|
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
||||||
|
for i in range(n_blocks - 1):
|
||||||
|
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
if self.kernel_size is not None:
|
||||||
|
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i in range(self.n_blocks):
|
||||||
|
x = self.conv[i](x)
|
||||||
|
if self.kernel_size is not None:
|
||||||
|
return x, self.pool(x)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Intermediate(nn.Module): #
|
||||||
|
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
||||||
|
super(Intermediate, self).__init__()
|
||||||
|
self.n_inters = n_inters
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
self.layers.append(
|
||||||
|
ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
|
||||||
|
)
|
||||||
|
for i in range(self.n_inters - 1):
|
||||||
|
self.layers.append(
|
||||||
|
ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i in range(self.n_inters):
|
||||||
|
x = self.layers[i](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResDecoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
||||||
|
super(ResDecoderBlock, self).__init__()
|
||||||
|
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
||||||
|
self.n_blocks = n_blocks
|
||||||
|
self.conv1 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
stride=stride,
|
||||||
|
padding=(1, 1),
|
||||||
|
output_padding=out_padding,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.conv2 = nn.ModuleList()
|
||||||
|
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
||||||
|
for i in range(n_blocks - 1):
|
||||||
|
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
||||||
|
|
||||||
|
def forward(self, x, concat_tensor):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = torch.cat((x, concat_tensor), dim=1)
|
||||||
|
for i in range(self.n_blocks):
|
||||||
|
x = self.conv2[i](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
self.n_decoders = n_decoders
|
||||||
|
for i in range(self.n_decoders):
|
||||||
|
out_channels = in_channels // 2
|
||||||
|
self.layers.append(
|
||||||
|
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
|
||||||
|
)
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
def forward(self, x, concat_tensors):
|
||||||
|
for i in range(self.n_decoders):
|
||||||
|
x = self.layers[i](x, concat_tensors[-1 - i])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DeepUnet(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size,
|
||||||
|
n_blocks,
|
||||||
|
en_de_layers=5,
|
||||||
|
inter_layers=4,
|
||||||
|
in_channels=1,
|
||||||
|
en_out_channels=16,
|
||||||
|
):
|
||||||
|
super(DeepUnet, self).__init__()
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
|
||||||
|
)
|
||||||
|
self.intermediate = Intermediate(
|
||||||
|
self.encoder.out_channel // 2,
|
||||||
|
self.encoder.out_channel,
|
||||||
|
inter_layers,
|
||||||
|
n_blocks,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, concat_tensors = self.encoder(x)
|
||||||
|
x = self.intermediate(x)
|
||||||
|
x = self.decoder(x, concat_tensors)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class E2E(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_blocks,
|
||||||
|
n_gru,
|
||||||
|
kernel_size,
|
||||||
|
en_de_layers=5,
|
||||||
|
inter_layers=4,
|
||||||
|
in_channels=1,
|
||||||
|
en_out_channels=16,
|
||||||
|
):
|
||||||
|
super(E2E, self).__init__()
|
||||||
|
self.unet = DeepUnet(
|
||||||
|
kernel_size,
|
||||||
|
n_blocks,
|
||||||
|
en_de_layers,
|
||||||
|
inter_layers,
|
||||||
|
in_channels,
|
||||||
|
en_out_channels,
|
||||||
|
)
|
||||||
|
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
||||||
|
if n_gru:
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
BiGRU(3 * 128, 256, n_gru),
|
||||||
|
nn.Linear(512, 360),
|
||||||
|
nn.Dropout(0.25),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, mel):
|
||||||
|
# print(mel.shape)
|
||||||
|
mel = mel.transpose(-1, -2).unsqueeze(1)
|
||||||
|
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
||||||
|
x = self.fc(x)
|
||||||
|
# print(x.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
from librosa.filters import mel
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpectrogram(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
is_half,
|
||||||
|
n_mel_channels,
|
||||||
|
sampling_rate,
|
||||||
|
win_length,
|
||||||
|
hop_length,
|
||||||
|
n_fft=None,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=None,
|
||||||
|
clamp=1e-5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
n_fft = win_length if n_fft is None else n_fft
|
||||||
|
self.hann_window = {}
|
||||||
|
mel_basis = mel(
|
||||||
|
sr=sampling_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
n_mels=n_mel_channels,
|
||||||
|
fmin=mel_fmin,
|
||||||
|
fmax=mel_fmax,
|
||||||
|
htk=True,
|
||||||
|
)
|
||||||
|
mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
self.register_buffer("mel_basis", mel_basis)
|
||||||
|
self.n_fft = win_length if n_fft is None else n_fft
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.n_mel_channels = n_mel_channels
|
||||||
|
self.clamp = clamp
|
||||||
|
self.is_half = is_half
|
||||||
|
|
||||||
|
def forward(self, audio, keyshift=0, speed=1, center=True):
|
||||||
|
factor = 2 ** (keyshift / 12)
|
||||||
|
n_fft_new = int(np.round(self.n_fft * factor))
|
||||||
|
win_length_new = int(np.round(self.win_length * factor))
|
||||||
|
hop_length_new = int(np.round(self.hop_length * speed))
|
||||||
|
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
||||||
|
if keyshift_key not in self.hann_window:
|
||||||
|
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
|
||||||
|
# "cpu"if(audio.device.type=="privateuseone") else audio.device
|
||||||
|
audio.device
|
||||||
|
)
|
||||||
|
# fft = torch.stft(#doesn't support pytorch_dml
|
||||||
|
# # audio.cpu() if(audio.device.type=="privateuseone")else audio,
|
||||||
|
# audio,
|
||||||
|
# n_fft=n_fft_new,
|
||||||
|
# hop_length=hop_length_new,
|
||||||
|
# win_length=win_length_new,
|
||||||
|
# window=self.hann_window[keyshift_key],
|
||||||
|
# center=center,
|
||||||
|
# return_complex=True,
|
||||||
|
# )
|
||||||
|
# magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
||||||
|
# print(1111111111)
|
||||||
|
# print(222222222222222,audio.device,self.is_half)
|
||||||
|
if hasattr(self, "stft") == False:
|
||||||
|
# print(n_fft_new,hop_length_new,win_length_new,audio.shape)
|
||||||
|
self.stft=STFT(
|
||||||
|
filter_length=n_fft_new,
|
||||||
|
hop_length=hop_length_new,
|
||||||
|
win_length=win_length_new,
|
||||||
|
window='hann'
|
||||||
|
).to(audio.device)
|
||||||
|
magnitude = self.stft.transform(audio)#phase
|
||||||
|
# if (audio.device.type == "privateuseone"):
|
||||||
|
# magnitude=magnitude.to(audio.device)
|
||||||
|
if keyshift != 0:
|
||||||
|
size = self.n_fft // 2 + 1
|
||||||
|
resize = magnitude.size(1)
|
||||||
|
if resize < size:
|
||||||
|
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
||||||
|
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
||||||
|
mel_output = torch.matmul(self.mel_basis, magnitude)
|
||||||
|
if self.is_half == True:
|
||||||
|
mel_output = mel_output.half()
|
||||||
|
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
||||||
|
# print(log_mel_spec.device.type)
|
||||||
|
return log_mel_spec
|
||||||
|
|
||||||
|
|
||||||
|
class RMVPE:
|
||||||
|
def __init__(self, model_path, is_half, device=None):
|
||||||
|
self.resample_kernel = {}
|
||||||
|
self.resample_kernel = {}
|
||||||
|
self.is_half = is_half
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device = device
|
||||||
|
self.mel_extractor = MelSpectrogram(
|
||||||
|
is_half, 128, 16000, 1024, 160, None, 30, 8000
|
||||||
|
).to(device)
|
||||||
|
if ("privateuseone" in str(device)):
|
||||||
|
import onnxruntime as ort
|
||||||
|
ort_session = ort.InferenceSession("rmvpe.onnx", providers=["DmlExecutionProvider"])
|
||||||
|
self.model=ort_session
|
||||||
|
else:
|
||||||
|
model = E2E(4, 1, (2, 2))
|
||||||
|
ckpt = torch.load(model_path, map_location="cpu")
|
||||||
|
model.load_state_dict(ckpt)
|
||||||
|
model.eval()
|
||||||
|
if is_half == True:
|
||||||
|
model = model.half()
|
||||||
|
self.model = model
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
|
||||||
|
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
|
||||||
|
|
||||||
|
def mel2hidden(self, mel):
|
||||||
|
with torch.no_grad():
|
||||||
|
n_frames = mel.shape[-1]
|
||||||
|
mel = F.pad(
|
||||||
|
mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
|
||||||
|
)
|
||||||
|
if("privateuseone" in str(self.device) ):
|
||||||
|
onnx_input_name = self.model.get_inputs()[0].name
|
||||||
|
onnx_outputs_names = self.model.get_outputs()[0].name
|
||||||
|
hidden = self.model.run([onnx_outputs_names], input_feed={onnx_input_name: mel.cpu().numpy()})[0]
|
||||||
|
else:
|
||||||
|
hidden = self.model(mel)
|
||||||
|
return hidden[:, :n_frames]
|
||||||
|
|
||||||
|
def decode(self, hidden, thred=0.03):
|
||||||
|
cents_pred = self.to_local_average_cents(hidden, thred=thred)
|
||||||
|
f0 = 10 * (2 ** (cents_pred / 1200))
|
||||||
|
f0[f0 == 10] = 0
|
||||||
|
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
|
||||||
|
return f0
|
||||||
|
|
||||||
|
def infer_from_audio(self, audio, thred=0.03):
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
t0=ttime()
|
||||||
|
mel = self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True)
|
||||||
|
# print(123123123,mel.device.type)
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
t1=ttime()
|
||||||
|
hidden = self.mel2hidden(mel)
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
t2=ttime()
|
||||||
|
# print(234234,hidden.device.type)
|
||||||
|
if("privateuseone" not in str(self.device)):
|
||||||
|
hidden = hidden.squeeze(0).cpu().numpy()
|
||||||
|
else:
|
||||||
|
hidden=hidden[0]
|
||||||
|
if self.is_half == True:
|
||||||
|
hidden = hidden.astype("float32")
|
||||||
|
|
||||||
|
f0 = self.decode(hidden, thred=thred)
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
t3=ttime()
|
||||||
|
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
|
||||||
|
return f0
|
||||||
|
|
||||||
|
def to_local_average_cents(self, salience, thred=0.05):
|
||||||
|
# t0 = ttime()
|
||||||
|
center = np.argmax(salience, axis=1) # 帧长#index
|
||||||
|
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
|
||||||
|
# t1 = ttime()
|
||||||
|
center += 4
|
||||||
|
todo_salience = []
|
||||||
|
todo_cents_mapping = []
|
||||||
|
starts = center - 4
|
||||||
|
ends = center + 5
|
||||||
|
for idx in range(salience.shape[0]):
|
||||||
|
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
||||||
|
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
||||||
|
# t2 = ttime()
|
||||||
|
todo_salience = np.array(todo_salience) # 帧长,9
|
||||||
|
todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
|
||||||
|
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
|
||||||
|
weight_sum = np.sum(todo_salience, 1) # 帧长
|
||||||
|
devided = product_sum / weight_sum # 帧长
|
||||||
|
# t3 = ttime()
|
||||||
|
maxx = np.max(salience, axis=1) # 帧长
|
||||||
|
devided[maxx <= thred] = 0
|
||||||
|
# t4 = ttime()
|
||||||
|
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||||
|
return devided
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import soundfile as sf, librosa
|
||||||
|
audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav")
|
||||||
|
if len(audio.shape) > 1:
|
||||||
|
audio = librosa.to_mono(audio.transpose(1, 0))
|
||||||
|
audio_bak = audio.copy()
|
||||||
|
if sampling_rate != 16000:
|
||||||
|
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
|
||||||
|
model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt"
|
||||||
|
thred = 0.03 # 0.01
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
rmvpe = RMVPE(model_path,is_half=False, device=device)
|
||||||
|
t0=ttime()
|
||||||
|
f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||||
|
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||||
|
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||||
|
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||||
|
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||||
|
t1=ttime()
|
||||||
|
print(f0.shape,t1-t0)
|
260
infer/lib/slicer2.py
Normal file
260
infer/lib/slicer2.py
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# This function is obtained from librosa.
|
||||||
|
def get_rms(
|
||||||
|
y,
|
||||||
|
frame_length=2048,
|
||||||
|
hop_length=512,
|
||||||
|
pad_mode="constant",
|
||||||
|
):
|
||||||
|
padding = (int(frame_length // 2), int(frame_length // 2))
|
||||||
|
y = np.pad(y, padding, mode=pad_mode)
|
||||||
|
|
||||||
|
axis = -1
|
||||||
|
# put our new within-frame axis at the end for now
|
||||||
|
out_strides = y.strides + tuple([y.strides[axis]])
|
||||||
|
# Reduce the shape on the framing axis
|
||||||
|
x_shape_trimmed = list(y.shape)
|
||||||
|
x_shape_trimmed[axis] -= frame_length - 1
|
||||||
|
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
|
||||||
|
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
|
||||||
|
if axis < 0:
|
||||||
|
target_axis = axis - 1
|
||||||
|
else:
|
||||||
|
target_axis = axis + 1
|
||||||
|
xw = np.moveaxis(xw, -1, target_axis)
|
||||||
|
# Downsample along the target axis
|
||||||
|
slices = [slice(None)] * xw.ndim
|
||||||
|
slices[axis] = slice(0, None, hop_length)
|
||||||
|
x = xw[tuple(slices)]
|
||||||
|
|
||||||
|
# Calculate power
|
||||||
|
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
|
||||||
|
|
||||||
|
return np.sqrt(power)
|
||||||
|
|
||||||
|
|
||||||
|
class Slicer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sr: int,
|
||||||
|
threshold: float = -40.0,
|
||||||
|
min_length: int = 5000,
|
||||||
|
min_interval: int = 300,
|
||||||
|
hop_size: int = 20,
|
||||||
|
max_sil_kept: int = 5000,
|
||||||
|
):
|
||||||
|
if not min_length >= min_interval >= hop_size:
|
||||||
|
raise ValueError(
|
||||||
|
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
|
||||||
|
)
|
||||||
|
if not max_sil_kept >= hop_size:
|
||||||
|
raise ValueError(
|
||||||
|
"The following condition must be satisfied: max_sil_kept >= hop_size"
|
||||||
|
)
|
||||||
|
min_interval = sr * min_interval / 1000
|
||||||
|
self.threshold = 10 ** (threshold / 20.0)
|
||||||
|
self.hop_size = round(sr * hop_size / 1000)
|
||||||
|
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
||||||
|
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
||||||
|
self.min_interval = round(min_interval / self.hop_size)
|
||||||
|
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
||||||
|
|
||||||
|
def _apply_slice(self, waveform, begin, end):
|
||||||
|
if len(waveform.shape) > 1:
|
||||||
|
return waveform[
|
||||||
|
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return waveform[
|
||||||
|
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
# @timeit
|
||||||
|
def slice(self, waveform):
|
||||||
|
if len(waveform.shape) > 1:
|
||||||
|
samples = waveform.mean(axis=0)
|
||||||
|
else:
|
||||||
|
samples = waveform
|
||||||
|
if samples.shape[0] <= self.min_length:
|
||||||
|
return [waveform]
|
||||||
|
rms_list = get_rms(
|
||||||
|
y=samples, frame_length=self.win_size, hop_length=self.hop_size
|
||||||
|
).squeeze(0)
|
||||||
|
sil_tags = []
|
||||||
|
silence_start = None
|
||||||
|
clip_start = 0
|
||||||
|
for i, rms in enumerate(rms_list):
|
||||||
|
# Keep looping while frame is silent.
|
||||||
|
if rms < self.threshold:
|
||||||
|
# Record start of silent frames.
|
||||||
|
if silence_start is None:
|
||||||
|
silence_start = i
|
||||||
|
continue
|
||||||
|
# Keep looping while frame is not silent and silence start has not been recorded.
|
||||||
|
if silence_start is None:
|
||||||
|
continue
|
||||||
|
# Clear recorded silence start if interval is not enough or clip is too short
|
||||||
|
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||||
|
need_slice_middle = (
|
||||||
|
i - silence_start >= self.min_interval
|
||||||
|
and i - clip_start >= self.min_length
|
||||||
|
)
|
||||||
|
if not is_leading_silence and not need_slice_middle:
|
||||||
|
silence_start = None
|
||||||
|
continue
|
||||||
|
# Need slicing. Record the range of silent frames to be removed.
|
||||||
|
if i - silence_start <= self.max_sil_kept:
|
||||||
|
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
||||||
|
if silence_start == 0:
|
||||||
|
sil_tags.append((0, pos))
|
||||||
|
else:
|
||||||
|
sil_tags.append((pos, pos))
|
||||||
|
clip_start = pos
|
||||||
|
elif i - silence_start <= self.max_sil_kept * 2:
|
||||||
|
pos = rms_list[
|
||||||
|
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
|
||||||
|
].argmin()
|
||||||
|
pos += i - self.max_sil_kept
|
||||||
|
pos_l = (
|
||||||
|
rms_list[
|
||||||
|
silence_start : silence_start + self.max_sil_kept + 1
|
||||||
|
].argmin()
|
||||||
|
+ silence_start
|
||||||
|
)
|
||||||
|
pos_r = (
|
||||||
|
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||||
|
+ i
|
||||||
|
- self.max_sil_kept
|
||||||
|
)
|
||||||
|
if silence_start == 0:
|
||||||
|
sil_tags.append((0, pos_r))
|
||||||
|
clip_start = pos_r
|
||||||
|
else:
|
||||||
|
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||||
|
clip_start = max(pos_r, pos)
|
||||||
|
else:
|
||||||
|
pos_l = (
|
||||||
|
rms_list[
|
||||||
|
silence_start : silence_start + self.max_sil_kept + 1
|
||||||
|
].argmin()
|
||||||
|
+ silence_start
|
||||||
|
)
|
||||||
|
pos_r = (
|
||||||
|
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||||
|
+ i
|
||||||
|
- self.max_sil_kept
|
||||||
|
)
|
||||||
|
if silence_start == 0:
|
||||||
|
sil_tags.append((0, pos_r))
|
||||||
|
else:
|
||||||
|
sil_tags.append((pos_l, pos_r))
|
||||||
|
clip_start = pos_r
|
||||||
|
silence_start = None
|
||||||
|
# Deal with trailing silence.
|
||||||
|
total_frames = rms_list.shape[0]
|
||||||
|
if (
|
||||||
|
silence_start is not None
|
||||||
|
and total_frames - silence_start >= self.min_interval
|
||||||
|
):
|
||||||
|
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||||
|
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||||
|
sil_tags.append((pos, total_frames + 1))
|
||||||
|
# Apply and return slices.
|
||||||
|
if len(sil_tags) == 0:
|
||||||
|
return [waveform]
|
||||||
|
else:
|
||||||
|
chunks = []
|
||||||
|
if sil_tags[0][0] > 0:
|
||||||
|
chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
|
||||||
|
for i in range(len(sil_tags) - 1):
|
||||||
|
chunks.append(
|
||||||
|
self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0])
|
||||||
|
)
|
||||||
|
if sil_tags[-1][1] < total_frames:
|
||||||
|
chunks.append(
|
||||||
|
self._apply_slice(waveform, sil_tags[-1][1], total_frames)
|
||||||
|
)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import os.path
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("audio", type=str, help="The audio to be sliced")
|
||||||
|
parser.add_argument(
|
||||||
|
"--out", type=str, help="Output directory of the sliced audio clips"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--db_thresh",
|
||||||
|
type=float,
|
||||||
|
required=False,
|
||||||
|
default=-40,
|
||||||
|
help="The dB threshold for silence detection",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_length",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=5000,
|
||||||
|
help="The minimum milliseconds required for each sliced audio clip",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_interval",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=300,
|
||||||
|
help="The minimum milliseconds for a silence part to be sliced",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hop_size",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=10,
|
||||||
|
help="Frame length in milliseconds",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_sil_kept",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=500,
|
||||||
|
help="The maximum silence length kept around the sliced clip, presented in milliseconds",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
out = args.out
|
||||||
|
if out is None:
|
||||||
|
out = os.path.dirname(os.path.abspath(args.audio))
|
||||||
|
audio, sr = librosa.load(args.audio, sr=None, mono=False)
|
||||||
|
slicer = Slicer(
|
||||||
|
sr=sr,
|
||||||
|
threshold=args.db_thresh,
|
||||||
|
min_length=args.min_length,
|
||||||
|
min_interval=args.min_interval,
|
||||||
|
hop_size=args.hop_size,
|
||||||
|
max_sil_kept=args.max_sil_kept,
|
||||||
|
)
|
||||||
|
chunks = slicer.slice(audio)
|
||||||
|
if not os.path.exists(out):
|
||||||
|
os.makedirs(out)
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
if len(chunk.shape) > 1:
|
||||||
|
chunk = chunk.T
|
||||||
|
soundfile.write(
|
||||||
|
os.path.join(
|
||||||
|
out,
|
||||||
|
f"%s_%d.wav"
|
||||||
|
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
|
||||||
|
),
|
||||||
|
chunk,
|
||||||
|
sr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
512
infer/lib/train/data_utils.py
Normal file
512
infer/lib/train/data_utils.py
Normal file
@ -0,0 +1,512 @@
|
|||||||
|
import os, traceback
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from infer.lib.train.mel_processing import spectrogram_torch
|
||||||
|
from infer.lib.train.utils import load_wav_to_torch, load_filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
|
class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
1) loads audio, text pairs
|
||||||
|
2) normalizes text and converts them to sequences of integers
|
||||||
|
3) computes spectrograms from audio files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, audiopaths_and_text, hparams):
|
||||||
|
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
|
||||||
|
self.max_wav_value = hparams.max_wav_value
|
||||||
|
self.sampling_rate = hparams.sampling_rate
|
||||||
|
self.filter_length = hparams.filter_length
|
||||||
|
self.hop_length = hparams.hop_length
|
||||||
|
self.win_length = hparams.win_length
|
||||||
|
self.sampling_rate = hparams.sampling_rate
|
||||||
|
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
||||||
|
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
||||||
|
self._filter()
|
||||||
|
|
||||||
|
def _filter(self):
|
||||||
|
"""
|
||||||
|
Filter text & store spec lengths
|
||||||
|
"""
|
||||||
|
# Store spectrogram lengths for Bucketing
|
||||||
|
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
|
||||||
|
# spec_length = wav_length // hop_length
|
||||||
|
audiopaths_and_text_new = []
|
||||||
|
lengths = []
|
||||||
|
for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
|
||||||
|
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
||||||
|
audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
|
||||||
|
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
||||||
|
self.audiopaths_and_text = audiopaths_and_text_new
|
||||||
|
self.lengths = lengths
|
||||||
|
|
||||||
|
def get_sid(self, sid):
|
||||||
|
sid = torch.LongTensor([int(sid)])
|
||||||
|
return sid
|
||||||
|
|
||||||
|
def get_audio_text_pair(self, audiopath_and_text):
|
||||||
|
# separate filename and text
|
||||||
|
file = audiopath_and_text[0]
|
||||||
|
phone = audiopath_and_text[1]
|
||||||
|
pitch = audiopath_and_text[2]
|
||||||
|
pitchf = audiopath_and_text[3]
|
||||||
|
dv = audiopath_and_text[4]
|
||||||
|
|
||||||
|
phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
|
||||||
|
spec, wav = self.get_audio(file)
|
||||||
|
dv = self.get_sid(dv)
|
||||||
|
|
||||||
|
len_phone = phone.size()[0]
|
||||||
|
len_spec = spec.size()[-1]
|
||||||
|
# print(123,phone.shape,pitch.shape,spec.shape)
|
||||||
|
if len_phone != len_spec:
|
||||||
|
len_min = min(len_phone, len_spec)
|
||||||
|
# amor
|
||||||
|
len_wav = len_min * self.hop_length
|
||||||
|
|
||||||
|
spec = spec[:, :len_min]
|
||||||
|
wav = wav[:, :len_wav]
|
||||||
|
|
||||||
|
phone = phone[:len_min, :]
|
||||||
|
pitch = pitch[:len_min]
|
||||||
|
pitchf = pitchf[:len_min]
|
||||||
|
|
||||||
|
return (spec, wav, phone, pitch, pitchf, dv)
|
||||||
|
|
||||||
|
def get_labels(self, phone, pitch, pitchf):
|
||||||
|
phone = np.load(phone)
|
||||||
|
phone = np.repeat(phone, 2, axis=0)
|
||||||
|
pitch = np.load(pitch)
|
||||||
|
pitchf = np.load(pitchf)
|
||||||
|
n_num = min(phone.shape[0], 900) # DistributedBucketSampler
|
||||||
|
# print(234,phone.shape,pitch.shape)
|
||||||
|
phone = phone[:n_num, :]
|
||||||
|
pitch = pitch[:n_num]
|
||||||
|
pitchf = pitchf[:n_num]
|
||||||
|
phone = torch.FloatTensor(phone)
|
||||||
|
pitch = torch.LongTensor(pitch)
|
||||||
|
pitchf = torch.FloatTensor(pitchf)
|
||||||
|
return phone, pitch, pitchf
|
||||||
|
|
||||||
|
def get_audio(self, filename):
|
||||||
|
audio, sampling_rate = load_wav_to_torch(filename)
|
||||||
|
if sampling_rate != self.sampling_rate:
|
||||||
|
raise ValueError(
|
||||||
|
"{} SR doesn't match target {} SR".format(
|
||||||
|
sampling_rate, self.sampling_rate
|
||||||
|
)
|
||||||
|
)
|
||||||
|
audio_norm = audio
|
||||||
|
# audio_norm = audio / self.max_wav_value
|
||||||
|
# audio_norm = audio / np.abs(audio).max()
|
||||||
|
|
||||||
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
|
spec_filename = filename.replace(".wav", ".spec.pt")
|
||||||
|
if os.path.exists(spec_filename):
|
||||||
|
try:
|
||||||
|
spec = torch.load(spec_filename)
|
||||||
|
except:
|
||||||
|
print(spec_filename, traceback.format_exc())
|
||||||
|
spec = spectrogram_torch(
|
||||||
|
audio_norm,
|
||||||
|
self.filter_length,
|
||||||
|
self.sampling_rate,
|
||||||
|
self.hop_length,
|
||||||
|
self.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
spec = torch.squeeze(spec, 0)
|
||||||
|
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||||
|
else:
|
||||||
|
spec = spectrogram_torch(
|
||||||
|
audio_norm,
|
||||||
|
self.filter_length,
|
||||||
|
self.sampling_rate,
|
||||||
|
self.hop_length,
|
||||||
|
self.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
spec = torch.squeeze(spec, 0)
|
||||||
|
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||||
|
return spec, audio_norm
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.audiopaths_and_text)
|
||||||
|
|
||||||
|
|
||||||
|
class TextAudioCollateMultiNSFsid:
|
||||||
|
"""Zero-pads model inputs and targets"""
|
||||||
|
|
||||||
|
def __init__(self, return_ids=False):
|
||||||
|
self.return_ids = return_ids
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
"""Collate's training batch from normalized text and aduio
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
batch: [text_normalized, spec_normalized, wav_normalized]
|
||||||
|
"""
|
||||||
|
# Right zero-pad all one-hot text sequences to max input length
|
||||||
|
_, ids_sorted_decreasing = torch.sort(
|
||||||
|
torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
|
||||||
|
)
|
||||||
|
|
||||||
|
max_spec_len = max([x[0].size(1) for x in batch])
|
||||||
|
max_wave_len = max([x[1].size(1) for x in batch])
|
||||||
|
spec_lengths = torch.LongTensor(len(batch))
|
||||||
|
wave_lengths = torch.LongTensor(len(batch))
|
||||||
|
spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
|
||||||
|
wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
|
||||||
|
spec_padded.zero_()
|
||||||
|
wave_padded.zero_()
|
||||||
|
|
||||||
|
max_phone_len = max([x[2].size(0) for x in batch])
|
||||||
|
phone_lengths = torch.LongTensor(len(batch))
|
||||||
|
phone_padded = torch.FloatTensor(
|
||||||
|
len(batch), max_phone_len, batch[0][2].shape[1]
|
||||||
|
) # (spec, wav, phone, pitch)
|
||||||
|
pitch_padded = torch.LongTensor(len(batch), max_phone_len)
|
||||||
|
pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
|
||||||
|
phone_padded.zero_()
|
||||||
|
pitch_padded.zero_()
|
||||||
|
pitchf_padded.zero_()
|
||||||
|
# dv = torch.FloatTensor(len(batch), 256)#gin=256
|
||||||
|
sid = torch.LongTensor(len(batch))
|
||||||
|
|
||||||
|
for i in range(len(ids_sorted_decreasing)):
|
||||||
|
row = batch[ids_sorted_decreasing[i]]
|
||||||
|
|
||||||
|
spec = row[0]
|
||||||
|
spec_padded[i, :, : spec.size(1)] = spec
|
||||||
|
spec_lengths[i] = spec.size(1)
|
||||||
|
|
||||||
|
wave = row[1]
|
||||||
|
wave_padded[i, :, : wave.size(1)] = wave
|
||||||
|
wave_lengths[i] = wave.size(1)
|
||||||
|
|
||||||
|
phone = row[2]
|
||||||
|
phone_padded[i, : phone.size(0), :] = phone
|
||||||
|
phone_lengths[i] = phone.size(0)
|
||||||
|
|
||||||
|
pitch = row[3]
|
||||||
|
pitch_padded[i, : pitch.size(0)] = pitch
|
||||||
|
pitchf = row[4]
|
||||||
|
pitchf_padded[i, : pitchf.size(0)] = pitchf
|
||||||
|
|
||||||
|
# dv[i] = row[5]
|
||||||
|
sid[i] = row[5]
|
||||||
|
|
||||||
|
return (
|
||||||
|
phone_padded,
|
||||||
|
phone_lengths,
|
||||||
|
pitch_padded,
|
||||||
|
pitchf_padded,
|
||||||
|
spec_padded,
|
||||||
|
spec_lengths,
|
||||||
|
wave_padded,
|
||||||
|
wave_lengths,
|
||||||
|
# dv
|
||||||
|
sid,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextAudioLoader(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
1) loads audio, text pairs
|
||||||
|
2) normalizes text and converts them to sequences of integers
|
||||||
|
3) computes spectrograms from audio files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, audiopaths_and_text, hparams):
|
||||||
|
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
|
||||||
|
self.max_wav_value = hparams.max_wav_value
|
||||||
|
self.sampling_rate = hparams.sampling_rate
|
||||||
|
self.filter_length = hparams.filter_length
|
||||||
|
self.hop_length = hparams.hop_length
|
||||||
|
self.win_length = hparams.win_length
|
||||||
|
self.sampling_rate = hparams.sampling_rate
|
||||||
|
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
||||||
|
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
||||||
|
self._filter()
|
||||||
|
|
||||||
|
def _filter(self):
|
||||||
|
"""
|
||||||
|
Filter text & store spec lengths
|
||||||
|
"""
|
||||||
|
# Store spectrogram lengths for Bucketing
|
||||||
|
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
|
||||||
|
# spec_length = wav_length // hop_length
|
||||||
|
audiopaths_and_text_new = []
|
||||||
|
lengths = []
|
||||||
|
for audiopath, text, dv in self.audiopaths_and_text:
|
||||||
|
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
||||||
|
audiopaths_and_text_new.append([audiopath, text, dv])
|
||||||
|
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
||||||
|
self.audiopaths_and_text = audiopaths_and_text_new
|
||||||
|
self.lengths = lengths
|
||||||
|
|
||||||
|
def get_sid(self, sid):
|
||||||
|
sid = torch.LongTensor([int(sid)])
|
||||||
|
return sid
|
||||||
|
|
||||||
|
def get_audio_text_pair(self, audiopath_and_text):
|
||||||
|
# separate filename and text
|
||||||
|
file = audiopath_and_text[0]
|
||||||
|
phone = audiopath_and_text[1]
|
||||||
|
dv = audiopath_and_text[2]
|
||||||
|
|
||||||
|
phone = self.get_labels(phone)
|
||||||
|
spec, wav = self.get_audio(file)
|
||||||
|
dv = self.get_sid(dv)
|
||||||
|
|
||||||
|
len_phone = phone.size()[0]
|
||||||
|
len_spec = spec.size()[-1]
|
||||||
|
if len_phone != len_spec:
|
||||||
|
len_min = min(len_phone, len_spec)
|
||||||
|
len_wav = len_min * self.hop_length
|
||||||
|
spec = spec[:, :len_min]
|
||||||
|
wav = wav[:, :len_wav]
|
||||||
|
phone = phone[:len_min, :]
|
||||||
|
return (spec, wav, phone, dv)
|
||||||
|
|
||||||
|
def get_labels(self, phone):
|
||||||
|
phone = np.load(phone)
|
||||||
|
phone = np.repeat(phone, 2, axis=0)
|
||||||
|
n_num = min(phone.shape[0], 900) # DistributedBucketSampler
|
||||||
|
phone = phone[:n_num, :]
|
||||||
|
phone = torch.FloatTensor(phone)
|
||||||
|
return phone
|
||||||
|
|
||||||
|
def get_audio(self, filename):
|
||||||
|
audio, sampling_rate = load_wav_to_torch(filename)
|
||||||
|
if sampling_rate != self.sampling_rate:
|
||||||
|
raise ValueError(
|
||||||
|
"{} SR doesn't match target {} SR".format(
|
||||||
|
sampling_rate, self.sampling_rate
|
||||||
|
)
|
||||||
|
)
|
||||||
|
audio_norm = audio
|
||||||
|
# audio_norm = audio / self.max_wav_value
|
||||||
|
# audio_norm = audio / np.abs(audio).max()
|
||||||
|
|
||||||
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
|
spec_filename = filename.replace(".wav", ".spec.pt")
|
||||||
|
if os.path.exists(spec_filename):
|
||||||
|
try:
|
||||||
|
spec = torch.load(spec_filename)
|
||||||
|
except:
|
||||||
|
print(spec_filename, traceback.format_exc())
|
||||||
|
spec = spectrogram_torch(
|
||||||
|
audio_norm,
|
||||||
|
self.filter_length,
|
||||||
|
self.sampling_rate,
|
||||||
|
self.hop_length,
|
||||||
|
self.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
spec = torch.squeeze(spec, 0)
|
||||||
|
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||||
|
else:
|
||||||
|
spec = spectrogram_torch(
|
||||||
|
audio_norm,
|
||||||
|
self.filter_length,
|
||||||
|
self.sampling_rate,
|
||||||
|
self.hop_length,
|
||||||
|
self.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
spec = torch.squeeze(spec, 0)
|
||||||
|
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||||
|
return spec, audio_norm
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.audiopaths_and_text)
|
||||||
|
|
||||||
|
|
||||||
|
class TextAudioCollate:
|
||||||
|
"""Zero-pads model inputs and targets"""
|
||||||
|
|
||||||
|
def __init__(self, return_ids=False):
|
||||||
|
self.return_ids = return_ids
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
"""Collate's training batch from normalized text and aduio
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
batch: [text_normalized, spec_normalized, wav_normalized]
|
||||||
|
"""
|
||||||
|
# Right zero-pad all one-hot text sequences to max input length
|
||||||
|
_, ids_sorted_decreasing = torch.sort(
|
||||||
|
torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
|
||||||
|
)
|
||||||
|
|
||||||
|
max_spec_len = max([x[0].size(1) for x in batch])
|
||||||
|
max_wave_len = max([x[1].size(1) for x in batch])
|
||||||
|
spec_lengths = torch.LongTensor(len(batch))
|
||||||
|
wave_lengths = torch.LongTensor(len(batch))
|
||||||
|
spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
|
||||||
|
wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
|
||||||
|
spec_padded.zero_()
|
||||||
|
wave_padded.zero_()
|
||||||
|
|
||||||
|
max_phone_len = max([x[2].size(0) for x in batch])
|
||||||
|
phone_lengths = torch.LongTensor(len(batch))
|
||||||
|
phone_padded = torch.FloatTensor(
|
||||||
|
len(batch), max_phone_len, batch[0][2].shape[1]
|
||||||
|
)
|
||||||
|
phone_padded.zero_()
|
||||||
|
sid = torch.LongTensor(len(batch))
|
||||||
|
|
||||||
|
for i in range(len(ids_sorted_decreasing)):
|
||||||
|
row = batch[ids_sorted_decreasing[i]]
|
||||||
|
|
||||||
|
spec = row[0]
|
||||||
|
spec_padded[i, :, : spec.size(1)] = spec
|
||||||
|
spec_lengths[i] = spec.size(1)
|
||||||
|
|
||||||
|
wave = row[1]
|
||||||
|
wave_padded[i, :, : wave.size(1)] = wave
|
||||||
|
wave_lengths[i] = wave.size(1)
|
||||||
|
|
||||||
|
phone = row[2]
|
||||||
|
phone_padded[i, : phone.size(0), :] = phone
|
||||||
|
phone_lengths[i] = phone.size(0)
|
||||||
|
|
||||||
|
sid[i] = row[3]
|
||||||
|
|
||||||
|
return (
|
||||||
|
phone_padded,
|
||||||
|
phone_lengths,
|
||||||
|
spec_padded,
|
||||||
|
spec_lengths,
|
||||||
|
wave_padded,
|
||||||
|
wave_lengths,
|
||||||
|
sid,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||||
|
"""
|
||||||
|
Maintain similar input lengths in a batch.
|
||||||
|
Length groups are specified by boundaries.
|
||||||
|
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
|
||||||
|
|
||||||
|
It removes samples which are not included in the boundaries.
|
||||||
|
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
boundaries,
|
||||||
|
num_replicas=None,
|
||||||
|
rank=None,
|
||||||
|
shuffle=True,
|
||||||
|
):
|
||||||
|
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
||||||
|
self.lengths = dataset.lengths
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.boundaries = boundaries
|
||||||
|
|
||||||
|
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
||||||
|
self.total_size = sum(self.num_samples_per_bucket)
|
||||||
|
self.num_samples = self.total_size // self.num_replicas
|
||||||
|
|
||||||
|
def _create_buckets(self):
|
||||||
|
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
||||||
|
for i in range(len(self.lengths)):
|
||||||
|
length = self.lengths[i]
|
||||||
|
idx_bucket = self._bisect(length)
|
||||||
|
if idx_bucket != -1:
|
||||||
|
buckets[idx_bucket].append(i)
|
||||||
|
|
||||||
|
for i in range(len(buckets) - 1, -1, -1): #
|
||||||
|
if len(buckets[i]) == 0:
|
||||||
|
buckets.pop(i)
|
||||||
|
self.boundaries.pop(i + 1)
|
||||||
|
|
||||||
|
num_samples_per_bucket = []
|
||||||
|
for i in range(len(buckets)):
|
||||||
|
len_bucket = len(buckets[i])
|
||||||
|
total_batch_size = self.num_replicas * self.batch_size
|
||||||
|
rem = (
|
||||||
|
total_batch_size - (len_bucket % total_batch_size)
|
||||||
|
) % total_batch_size
|
||||||
|
num_samples_per_bucket.append(len_bucket + rem)
|
||||||
|
return buckets, num_samples_per_bucket
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# deterministically shuffle based on epoch
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.epoch)
|
||||||
|
|
||||||
|
indices = []
|
||||||
|
if self.shuffle:
|
||||||
|
for bucket in self.buckets:
|
||||||
|
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
||||||
|
else:
|
||||||
|
for bucket in self.buckets:
|
||||||
|
indices.append(list(range(len(bucket))))
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
for i in range(len(self.buckets)):
|
||||||
|
bucket = self.buckets[i]
|
||||||
|
len_bucket = len(bucket)
|
||||||
|
ids_bucket = indices[i]
|
||||||
|
num_samples_bucket = self.num_samples_per_bucket[i]
|
||||||
|
|
||||||
|
# add extra samples to make it evenly divisible
|
||||||
|
rem = num_samples_bucket - len_bucket
|
||||||
|
ids_bucket = (
|
||||||
|
ids_bucket
|
||||||
|
+ ids_bucket * (rem // len_bucket)
|
||||||
|
+ ids_bucket[: (rem % len_bucket)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# subsample
|
||||||
|
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
||||||
|
|
||||||
|
# batching
|
||||||
|
for j in range(len(ids_bucket) // self.batch_size):
|
||||||
|
batch = [
|
||||||
|
bucket[idx]
|
||||||
|
for idx in ids_bucket[
|
||||||
|
j * self.batch_size : (j + 1) * self.batch_size
|
||||||
|
]
|
||||||
|
]
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
if self.shuffle:
|
||||||
|
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
||||||
|
batches = [batches[i] for i in batch_ids]
|
||||||
|
self.batches = batches
|
||||||
|
|
||||||
|
assert len(self.batches) * self.batch_size == self.num_samples
|
||||||
|
return iter(self.batches)
|
||||||
|
|
||||||
|
def _bisect(self, x, lo=0, hi=None):
|
||||||
|
if hi is None:
|
||||||
|
hi = len(self.boundaries) - 1
|
||||||
|
|
||||||
|
if hi > lo:
|
||||||
|
mid = (hi + lo) // 2
|
||||||
|
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
|
||||||
|
return mid
|
||||||
|
elif x <= self.boundaries[mid]:
|
||||||
|
return self._bisect(x, lo, mid)
|
||||||
|
else:
|
||||||
|
return self._bisect(x, mid + 1, hi)
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples // self.batch_size
|
58
infer/lib/train/losses.py
Normal file
58
infer/lib/train/losses.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def feature_loss(fmap_r, fmap_g):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(fmap_r, fmap_g):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
rl = rl.float().detach()
|
||||||
|
gl = gl.float()
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
|
return loss * 2
|
||||||
|
|
||||||
|
|
||||||
|
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||||
|
loss = 0
|
||||||
|
r_losses = []
|
||||||
|
g_losses = []
|
||||||
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
|
dr = dr.float()
|
||||||
|
dg = dg.float()
|
||||||
|
r_loss = torch.mean((1 - dr) ** 2)
|
||||||
|
g_loss = torch.mean(dg**2)
|
||||||
|
loss += r_loss + g_loss
|
||||||
|
r_losses.append(r_loss.item())
|
||||||
|
g_losses.append(g_loss.item())
|
||||||
|
|
||||||
|
return loss, r_losses, g_losses
|
||||||
|
|
||||||
|
|
||||||
|
def generator_loss(disc_outputs):
|
||||||
|
loss = 0
|
||||||
|
gen_losses = []
|
||||||
|
for dg in disc_outputs:
|
||||||
|
dg = dg.float()
|
||||||
|
l = torch.mean((1 - dg) ** 2)
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
||||||
|
|
||||||
|
|
||||||
|
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||||
|
"""
|
||||||
|
z_p, logs_q: [b, h, t_t]
|
||||||
|
m_p, logs_p: [b, h, t_t]
|
||||||
|
"""
|
||||||
|
z_p = z_p.float()
|
||||||
|
logs_q = logs_q.float()
|
||||||
|
m_p = m_p.float()
|
||||||
|
logs_p = logs_p.float()
|
||||||
|
z_mask = z_mask.float()
|
||||||
|
|
||||||
|
kl = logs_p - logs_q - 0.5
|
||||||
|
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||||
|
kl = torch.sum(kl * z_mask)
|
||||||
|
l = kl / torch.sum(z_mask)
|
||||||
|
return l
|
130
infer/lib/train/mel_processing.py
Normal file
130
infer/lib/train/mel_processing.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
|
||||||
|
|
||||||
|
MAX_WAV_VALUE = 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||||
|
"""
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
C: compression factor
|
||||||
|
"""
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_decompression_torch(x, C=1):
|
||||||
|
"""
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
C: compression factor used to compress
|
||||||
|
"""
|
||||||
|
return torch.exp(x) / C
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_normalize_torch(magnitudes):
|
||||||
|
return dynamic_range_compression_torch(magnitudes)
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_de_normalize_torch(magnitudes):
|
||||||
|
return dynamic_range_decompression_torch(magnitudes)
|
||||||
|
|
||||||
|
|
||||||
|
# Reusable banks
|
||||||
|
mel_basis = {}
|
||||||
|
hann_window = {}
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||||
|
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y :: (B, T) - Audio waveforms
|
||||||
|
n_fft
|
||||||
|
sampling_rate
|
||||||
|
hop_size
|
||||||
|
win_size
|
||||||
|
center
|
||||||
|
Returns:
|
||||||
|
:: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
|
||||||
|
"""
|
||||||
|
# Validation
|
||||||
|
if torch.min(y) < -1.07:
|
||||||
|
print("min value is ", torch.min(y))
|
||||||
|
if torch.max(y) > 1.07:
|
||||||
|
print("max value is ", torch.max(y))
|
||||||
|
|
||||||
|
# Window - Cache if needed
|
||||||
|
global hann_window
|
||||||
|
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||||
|
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||||
|
if wnsize_dtype_device not in hann_window:
|
||||||
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||||
|
dtype=y.dtype, device=y.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Padding
|
||||||
|
y = torch.nn.functional.pad(
|
||||||
|
y.unsqueeze(1),
|
||||||
|
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||||
|
mode="reflect",
|
||||||
|
)
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
# Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
|
||||||
|
spec = torch.stft(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_size,
|
||||||
|
win_length=win_size,
|
||||||
|
window=hann_window[wnsize_dtype_device],
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||||
|
# MelBasis - Cache if needed
|
||||||
|
global mel_basis
|
||||||
|
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||||
|
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||||
|
if fmax_dtype_device not in mel_basis:
|
||||||
|
mel = librosa_mel_fn(
|
||||||
|
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||||
|
)
|
||||||
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||||
|
dtype=spec.dtype, device=spec.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame)
|
||||||
|
melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||||
|
melspec = spectral_normalize_torch(melspec)
|
||||||
|
return melspec
|
||||||
|
|
||||||
|
|
||||||
|
def mel_spectrogram_torch(
|
||||||
|
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||||
|
):
|
||||||
|
"""Convert waveform into Mel-frequency Log-amplitude spectrogram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y :: (B, T) - Waveforms
|
||||||
|
Returns:
|
||||||
|
melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
|
||||||
|
"""
|
||||||
|
# Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
|
||||||
|
spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
|
||||||
|
|
||||||
|
# Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
|
||||||
|
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
|
||||||
|
|
||||||
|
return melspec
|
259
infer/lib/train/process_ckpt.py
Normal file
259
infer/lib/train/process_ckpt.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
import torch, traceback, os, sys
|
||||||
|
|
||||||
|
now_dir = os.getcwd()
|
||||||
|
sys.path.append(now_dir)
|
||||||
|
from collections import OrderedDict
|
||||||
|
from i18n.i18n import I18nAuto
|
||||||
|
|
||||||
|
i18n = I18nAuto()
|
||||||
|
|
||||||
|
|
||||||
|
def savee(ckpt, sr, if_f0, name, epoch, version, hps, i18n):
|
||||||
|
try:
|
||||||
|
opt = OrderedDict()
|
||||||
|
opt["weight"] = {}
|
||||||
|
for key in ckpt.keys():
|
||||||
|
if "enc_q" in key:
|
||||||
|
continue
|
||||||
|
opt["weight"][key] = ckpt[key].half()
|
||||||
|
opt["config"] = [
|
||||||
|
hps.data.filter_length // 2 + 1,
|
||||||
|
32,
|
||||||
|
hps.model.inter_channels,
|
||||||
|
hps.model.hidden_channels,
|
||||||
|
hps.model.filter_channels,
|
||||||
|
hps.model.n_heads,
|
||||||
|
hps.model.n_layers,
|
||||||
|
hps.model.kernel_size,
|
||||||
|
hps.model.p_dropout,
|
||||||
|
hps.model.resblock,
|
||||||
|
hps.model.resblock_kernel_sizes,
|
||||||
|
hps.model.resblock_dilation_sizes,
|
||||||
|
hps.model.upsample_rates,
|
||||||
|
hps.model.upsample_initial_channel,
|
||||||
|
hps.model.upsample_kernel_sizes,
|
||||||
|
hps.model.spk_embed_dim,
|
||||||
|
hps.model.gin_channels,
|
||||||
|
hps.data.sampling_rate,
|
||||||
|
]
|
||||||
|
opt["info"] = "%sepoch" % epoch
|
||||||
|
opt["sr"] = sr
|
||||||
|
opt["f0"] = if_f0
|
||||||
|
opt["version"] = version
|
||||||
|
torch.save(opt, "weights/%s.pth" % name)
|
||||||
|
return "Success."
|
||||||
|
except:
|
||||||
|
return traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def show_info(path):
|
||||||
|
try:
|
||||||
|
a = torch.load(path, map_location="cpu")
|
||||||
|
return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
|
||||||
|
a.get("info", "None"),
|
||||||
|
a.get("sr", "None"),
|
||||||
|
a.get("f0", "None"),
|
||||||
|
a.get("version", "None"),
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
return traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_small_model(path, name, sr, if_f0, info, version):
|
||||||
|
try:
|
||||||
|
ckpt = torch.load(path, map_location="cpu")
|
||||||
|
if "model" in ckpt:
|
||||||
|
ckpt = ckpt["model"]
|
||||||
|
opt = OrderedDict()
|
||||||
|
opt["weight"] = {}
|
||||||
|
for key in ckpt.keys():
|
||||||
|
if "enc_q" in key:
|
||||||
|
continue
|
||||||
|
opt["weight"][key] = ckpt[key].half()
|
||||||
|
if sr == "40k":
|
||||||
|
opt["config"] = [
|
||||||
|
1025,
|
||||||
|
32,
|
||||||
|
192,
|
||||||
|
192,
|
||||||
|
768,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
"1",
|
||||||
|
[3, 7, 11],
|
||||||
|
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
[10, 10, 2, 2],
|
||||||
|
512,
|
||||||
|
[16, 16, 4, 4],
|
||||||
|
109,
|
||||||
|
256,
|
||||||
|
40000,
|
||||||
|
]
|
||||||
|
elif sr == "48k":
|
||||||
|
if version == "v1":
|
||||||
|
opt["config"] = [
|
||||||
|
1025,
|
||||||
|
32,
|
||||||
|
192,
|
||||||
|
192,
|
||||||
|
768,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
"1",
|
||||||
|
[3, 7, 11],
|
||||||
|
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
[10, 6, 2, 2, 2],
|
||||||
|
512,
|
||||||
|
[16, 16, 4, 4, 4],
|
||||||
|
109,
|
||||||
|
256,
|
||||||
|
48000,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
opt["config"] = [
|
||||||
|
1025,
|
||||||
|
32,
|
||||||
|
192,
|
||||||
|
192,
|
||||||
|
768,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
"1",
|
||||||
|
[3, 7, 11],
|
||||||
|
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
[12, 10, 2, 2],
|
||||||
|
512,
|
||||||
|
[24, 20, 4, 4],
|
||||||
|
109,
|
||||||
|
256,
|
||||||
|
48000,
|
||||||
|
]
|
||||||
|
elif sr == "32k":
|
||||||
|
if version == "v1":
|
||||||
|
opt["config"] = [
|
||||||
|
513,
|
||||||
|
32,
|
||||||
|
192,
|
||||||
|
192,
|
||||||
|
768,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
"1",
|
||||||
|
[3, 7, 11],
|
||||||
|
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
[10, 4, 2, 2, 2],
|
||||||
|
512,
|
||||||
|
[16, 16, 4, 4, 4],
|
||||||
|
109,
|
||||||
|
256,
|
||||||
|
32000,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
opt["config"] = [
|
||||||
|
513,
|
||||||
|
32,
|
||||||
|
192,
|
||||||
|
192,
|
||||||
|
768,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
"1",
|
||||||
|
[3, 7, 11],
|
||||||
|
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
[10, 8, 2, 2],
|
||||||
|
512,
|
||||||
|
[20, 16, 4, 4],
|
||||||
|
109,
|
||||||
|
256,
|
||||||
|
32000,
|
||||||
|
]
|
||||||
|
if info == "":
|
||||||
|
info = "Extracted model."
|
||||||
|
opt["info"] = info
|
||||||
|
opt["version"] = version
|
||||||
|
opt["sr"] = sr
|
||||||
|
opt["f0"] = int(if_f0)
|
||||||
|
torch.save(opt, "weights/%s.pth" % name)
|
||||||
|
return "Success."
|
||||||
|
except:
|
||||||
|
return traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def change_info(path, info, name):
|
||||||
|
try:
|
||||||
|
ckpt = torch.load(path, map_location="cpu")
|
||||||
|
ckpt["info"] = info
|
||||||
|
if name == "":
|
||||||
|
name = os.path.basename(path)
|
||||||
|
torch.save(ckpt, "weights/%s" % name)
|
||||||
|
return "Success."
|
||||||
|
except:
|
||||||
|
return traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def merge(path1, path2, alpha1, sr, f0, info, name, version):
|
||||||
|
try:
|
||||||
|
|
||||||
|
def extract(ckpt):
|
||||||
|
a = ckpt["model"]
|
||||||
|
opt = OrderedDict()
|
||||||
|
opt["weight"] = {}
|
||||||
|
for key in a.keys():
|
||||||
|
if "enc_q" in key:
|
||||||
|
continue
|
||||||
|
opt["weight"][key] = a[key]
|
||||||
|
return opt
|
||||||
|
|
||||||
|
ckpt1 = torch.load(path1, map_location="cpu")
|
||||||
|
ckpt2 = torch.load(path2, map_location="cpu")
|
||||||
|
cfg = ckpt1["config"]
|
||||||
|
if "model" in ckpt1:
|
||||||
|
ckpt1 = extract(ckpt1)
|
||||||
|
else:
|
||||||
|
ckpt1 = ckpt1["weight"]
|
||||||
|
if "model" in ckpt2:
|
||||||
|
ckpt2 = extract(ckpt2)
|
||||||
|
else:
|
||||||
|
ckpt2 = ckpt2["weight"]
|
||||||
|
if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
|
||||||
|
return "Fail to merge the models. The model architectures are not the same."
|
||||||
|
opt = OrderedDict()
|
||||||
|
opt["weight"] = {}
|
||||||
|
for key in ckpt1.keys():
|
||||||
|
# try:
|
||||||
|
if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
|
||||||
|
min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
|
||||||
|
opt["weight"][key] = (
|
||||||
|
alpha1 * (ckpt1[key][:min_shape0].float())
|
||||||
|
+ (1 - alpha1) * (ckpt2[key][:min_shape0].float())
|
||||||
|
).half()
|
||||||
|
else:
|
||||||
|
opt["weight"][key] = (
|
||||||
|
alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
|
||||||
|
).half()
|
||||||
|
# except:
|
||||||
|
# pdb.set_trace()
|
||||||
|
opt["config"] = cfg
|
||||||
|
"""
|
||||||
|
if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
|
||||||
|
elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
|
||||||
|
elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
|
||||||
|
"""
|
||||||
|
opt["sr"] = sr
|
||||||
|
opt["f0"] = 1 if f0 == i18n("是") else 0
|
||||||
|
opt["version"] = version
|
||||||
|
opt["info"] = info
|
||||||
|
torch.save(opt, "weights/%s.pth" % name)
|
||||||
|
return "Success."
|
||||||
|
except:
|
||||||
|
return traceback.format_exc()
|
487
infer/lib/train/utils.py
Normal file
487
infer/lib/train/utils.py
Normal file
@ -0,0 +1,487 @@
|
|||||||
|
import os, traceback
|
||||||
|
import glob
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io.wavfile import read
|
||||||
|
import torch
|
||||||
|
|
||||||
|
MATPLOTLIB_FLAG = False
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||||
|
logger = logging
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
||||||
|
assert os.path.isfile(checkpoint_path)
|
||||||
|
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
|
##################
|
||||||
|
def go(model, bkey):
|
||||||
|
saved_state_dict = checkpoint_dict[bkey]
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
state_dict = model.module.state_dict()
|
||||||
|
else:
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items(): # 模型需要的shape
|
||||||
|
try:
|
||||||
|
new_state_dict[k] = saved_state_dict[k]
|
||||||
|
if saved_state_dict[k].shape != state_dict[k].shape:
|
||||||
|
print(
|
||||||
|
"shape-%s-mismatch|need-%s|get-%s"
|
||||||
|
% (k, state_dict[k].shape, saved_state_dict[k].shape)
|
||||||
|
) #
|
||||||
|
raise KeyError
|
||||||
|
except:
|
||||||
|
# logger.info(traceback.format_exc())
|
||||||
|
logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
|
||||||
|
new_state_dict[k] = v # 模型自带的随机值
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
model.module.load_state_dict(new_state_dict, strict=False)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
go(combd, "combd")
|
||||||
|
model = go(sbd, "sbd")
|
||||||
|
#############
|
||||||
|
logger.info("Loaded model weights")
|
||||||
|
|
||||||
|
iteration = checkpoint_dict["iteration"]
|
||||||
|
learning_rate = checkpoint_dict["learning_rate"]
|
||||||
|
if (
|
||||||
|
optimizer is not None and load_opt == 1
|
||||||
|
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
|
||||||
|
# try:
|
||||||
|
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
||||||
|
# except:
|
||||||
|
# traceback.print_exc()
|
||||||
|
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
|
||||||
|
return model, optimizer, learning_rate, iteration
|
||||||
|
|
||||||
|
|
||||||
|
# def load_checkpoint(checkpoint_path, model, optimizer=None):
|
||||||
|
# assert os.path.isfile(checkpoint_path)
|
||||||
|
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||||
|
# iteration = checkpoint_dict['iteration']
|
||||||
|
# learning_rate = checkpoint_dict['learning_rate']
|
||||||
|
# if optimizer is not None:
|
||||||
|
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||||
|
# # print(1111)
|
||||||
|
# saved_state_dict = checkpoint_dict['model']
|
||||||
|
# # print(1111)
|
||||||
|
#
|
||||||
|
# if hasattr(model, 'module'):
|
||||||
|
# state_dict = model.module.state_dict()
|
||||||
|
# else:
|
||||||
|
# state_dict = model.state_dict()
|
||||||
|
# new_state_dict= {}
|
||||||
|
# for k, v in state_dict.items():
|
||||||
|
# try:
|
||||||
|
# new_state_dict[k] = saved_state_dict[k]
|
||||||
|
# except:
|
||||||
|
# logger.info("%s is not in the checkpoint" % k)
|
||||||
|
# new_state_dict[k] = v
|
||||||
|
# if hasattr(model, 'module'):
|
||||||
|
# model.module.load_state_dict(new_state_dict)
|
||||||
|
# else:
|
||||||
|
# model.load_state_dict(new_state_dict)
|
||||||
|
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
|
||||||
|
# checkpoint_path, iteration))
|
||||||
|
# return model, optimizer, learning_rate, iteration
|
||||||
|
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||||
|
assert os.path.isfile(checkpoint_path)
|
||||||
|
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
|
saved_state_dict = checkpoint_dict["model"]
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
state_dict = model.module.state_dict()
|
||||||
|
else:
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items(): # 模型需要的shape
|
||||||
|
try:
|
||||||
|
new_state_dict[k] = saved_state_dict[k]
|
||||||
|
if saved_state_dict[k].shape != state_dict[k].shape:
|
||||||
|
print(
|
||||||
|
"shape-%s-mismatch|need-%s|get-%s"
|
||||||
|
% (k, state_dict[k].shape, saved_state_dict[k].shape)
|
||||||
|
) #
|
||||||
|
raise KeyError
|
||||||
|
except:
|
||||||
|
# logger.info(traceback.format_exc())
|
||||||
|
logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
|
||||||
|
new_state_dict[k] = v # 模型自带的随机值
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
model.module.load_state_dict(new_state_dict, strict=False)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
logger.info("Loaded model weights")
|
||||||
|
|
||||||
|
iteration = checkpoint_dict["iteration"]
|
||||||
|
learning_rate = checkpoint_dict["learning_rate"]
|
||||||
|
if (
|
||||||
|
optimizer is not None and load_opt == 1
|
||||||
|
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
|
||||||
|
# try:
|
||||||
|
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
||||||
|
# except:
|
||||||
|
# traceback.print_exc()
|
||||||
|
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
|
||||||
|
return model, optimizer, learning_rate, iteration
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
||||||
|
logger.info(
|
||||||
|
"Saving model and optimizer state at epoch {} to {}".format(
|
||||||
|
iteration, checkpoint_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
state_dict = model.module.state_dict()
|
||||||
|
else:
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"model": state_dict,
|
||||||
|
"iteration": iteration,
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
},
|
||||||
|
checkpoint_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
|
||||||
|
logger.info(
|
||||||
|
"Saving model and optimizer state at epoch {} to {}".format(
|
||||||
|
iteration, checkpoint_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if hasattr(combd, "module"):
|
||||||
|
state_dict_combd = combd.module.state_dict()
|
||||||
|
else:
|
||||||
|
state_dict_combd = combd.state_dict()
|
||||||
|
if hasattr(sbd, "module"):
|
||||||
|
state_dict_sbd = sbd.module.state_dict()
|
||||||
|
else:
|
||||||
|
state_dict_sbd = sbd.state_dict()
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"combd": state_dict_combd,
|
||||||
|
"sbd": state_dict_sbd,
|
||||||
|
"iteration": iteration,
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
},
|
||||||
|
checkpoint_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def summarize(
|
||||||
|
writer,
|
||||||
|
global_step,
|
||||||
|
scalars={},
|
||||||
|
histograms={},
|
||||||
|
images={},
|
||||||
|
audios={},
|
||||||
|
audio_sampling_rate=22050,
|
||||||
|
):
|
||||||
|
for k, v in scalars.items():
|
||||||
|
writer.add_scalar(k, v, global_step)
|
||||||
|
for k, v in histograms.items():
|
||||||
|
writer.add_histogram(k, v, global_step)
|
||||||
|
for k, v in images.items():
|
||||||
|
writer.add_image(k, v, global_step, dataformats="HWC")
|
||||||
|
for k, v in audios.items():
|
||||||
|
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
||||||
|
f_list = glob.glob(os.path.join(dir_path, regex))
|
||||||
|
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
|
||||||
|
x = f_list[-1]
|
||||||
|
print(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram_to_numpy(spectrogram):
|
||||||
|
global MATPLOTLIB_FLAG
|
||||||
|
if not MATPLOTLIB_FLAG:
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
MATPLOTLIB_FLAG = True
|
||||||
|
mpl_logger = logging.getLogger("matplotlib")
|
||||||
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
|
import matplotlib.pylab as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
plt.xlabel("Frames")
|
||||||
|
plt.ylabel("Channels")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def plot_alignment_to_numpy(alignment, info=None):
|
||||||
|
global MATPLOTLIB_FLAG
|
||||||
|
if not MATPLOTLIB_FLAG:
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
MATPLOTLIB_FLAG = True
|
||||||
|
mpl_logger = logging.getLogger("matplotlib")
|
||||||
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
|
import matplotlib.pylab as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(6, 4))
|
||||||
|
im = ax.imshow(
|
||||||
|
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
|
||||||
|
)
|
||||||
|
fig.colorbar(im, ax=ax)
|
||||||
|
xlabel = "Decoder timestep"
|
||||||
|
if info is not None:
|
||||||
|
xlabel += "\n\n" + info
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
plt.ylabel("Encoder timestep")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_wav_to_torch(full_path):
|
||||||
|
sampling_rate, data = read(full_path)
|
||||||
|
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
||||||
|
|
||||||
|
|
||||||
|
def load_filepaths_and_text(filename, split="|"):
|
||||||
|
with open(filename, encoding="utf-8") as f:
|
||||||
|
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||||
|
return filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
|
def get_hparams(init=True):
|
||||||
|
"""
|
||||||
|
todo:
|
||||||
|
结尾七人组:
|
||||||
|
保存频率、总epoch done
|
||||||
|
bs done
|
||||||
|
pretrainG、pretrainD done
|
||||||
|
卡号:os.en["CUDA_VISIBLE_DEVICES"] done
|
||||||
|
if_latest done
|
||||||
|
模型:if_f0 done
|
||||||
|
采样率:自动选择config done
|
||||||
|
是否缓存数据集进GPU:if_cache_data_in_gpu done
|
||||||
|
|
||||||
|
-m:
|
||||||
|
自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done
|
||||||
|
-c不要了
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration')
|
||||||
|
parser.add_argument(
|
||||||
|
"-se",
|
||||||
|
"--save_every_epoch",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="checkpoint save frequency (epoch)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-te", "--total_epoch", type=int, required=True, help="total_epoch"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
|
||||||
|
)
|
||||||
|
parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
|
||||||
|
parser.add_argument(
|
||||||
|
"-bs", "--batch_size", type=int, required=True, help="batch size"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-e", "--experiment_dir", type=str, required=True, help="experiment dir"
|
||||||
|
) # -m
|
||||||
|
parser.add_argument(
|
||||||
|
"-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-sw",
|
||||||
|
"--save_every_weights",
|
||||||
|
type=str,
|
||||||
|
default="0",
|
||||||
|
help="save the extracted model in weights directory when saving checkpoints",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-v", "--version", type=str, required=True, help="model version"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-f0",
|
||||||
|
"--if_f0",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="use f0 as one of the inputs of the model, 1 or 0",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--if_latest",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="if only save the latest G/D pth file, 1 or 0",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--if_cache_data_in_gpu",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="if caching the dataset in GPU memory, 1 or 0",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
name = args.experiment_dir
|
||||||
|
experiment_dir = os.path.join("./logs", args.experiment_dir)
|
||||||
|
|
||||||
|
if not os.path.exists(experiment_dir):
|
||||||
|
os.makedirs(experiment_dir)
|
||||||
|
|
||||||
|
if args.version == "v1" or args.sample_rate == "40k":
|
||||||
|
config_path = "configs/%s.json" % args.sample_rate
|
||||||
|
else:
|
||||||
|
config_path = "configs/%s_v2.json" % args.sample_rate
|
||||||
|
config_save_path = os.path.join(experiment_dir, "config.json")
|
||||||
|
if init:
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
data = f.read()
|
||||||
|
with open(config_save_path, "w") as f:
|
||||||
|
f.write(data)
|
||||||
|
else:
|
||||||
|
with open(config_save_path, "r") as f:
|
||||||
|
data = f.read()
|
||||||
|
config = json.loads(data)
|
||||||
|
|
||||||
|
hparams = HParams(**config)
|
||||||
|
hparams.model_dir = hparams.experiment_dir = experiment_dir
|
||||||
|
hparams.save_every_epoch = args.save_every_epoch
|
||||||
|
hparams.name = name
|
||||||
|
hparams.total_epoch = args.total_epoch
|
||||||
|
hparams.pretrainG = args.pretrainG
|
||||||
|
hparams.pretrainD = args.pretrainD
|
||||||
|
hparams.version = args.version
|
||||||
|
hparams.gpus = args.gpus
|
||||||
|
hparams.train.batch_size = args.batch_size
|
||||||
|
hparams.sample_rate = args.sample_rate
|
||||||
|
hparams.if_f0 = args.if_f0
|
||||||
|
hparams.if_latest = args.if_latest
|
||||||
|
hparams.save_every_weights = args.save_every_weights
|
||||||
|
hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
|
||||||
|
hparams.data.training_files = "%s/filelist.txt" % experiment_dir
|
||||||
|
return hparams
|
||||||
|
|
||||||
|
|
||||||
|
def get_hparams_from_dir(model_dir):
|
||||||
|
config_save_path = os.path.join(model_dir, "config.json")
|
||||||
|
with open(config_save_path, "r") as f:
|
||||||
|
data = f.read()
|
||||||
|
config = json.loads(data)
|
||||||
|
|
||||||
|
hparams = HParams(**config)
|
||||||
|
hparams.model_dir = model_dir
|
||||||
|
return hparams
|
||||||
|
|
||||||
|
|
||||||
|
def get_hparams_from_file(config_path):
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
data = f.read()
|
||||||
|
config = json.loads(data)
|
||||||
|
|
||||||
|
hparams = HParams(**config)
|
||||||
|
return hparams
|
||||||
|
|
||||||
|
|
||||||
|
def check_git_hash(model_dir):
|
||||||
|
source_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
if not os.path.exists(os.path.join(source_dir, ".git")):
|
||||||
|
logger.warn(
|
||||||
|
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||||
|
source_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
||||||
|
|
||||||
|
path = os.path.join(model_dir, "githash")
|
||||||
|
if os.path.exists(path):
|
||||||
|
saved_hash = open(path).read()
|
||||||
|
if saved_hash != cur_hash:
|
||||||
|
logger.warn(
|
||||||
|
"git hash values are different. {}(saved) != {}(current)".format(
|
||||||
|
saved_hash[:8], cur_hash[:8]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
open(path, "w").write(cur_hash)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(model_dir, filename="train.log"):
|
||||||
|
global logger
|
||||||
|
logger = logging.getLogger(os.path.basename(model_dir))
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
h = logging.FileHandler(os.path.join(model_dir, filename))
|
||||||
|
h.setLevel(logging.DEBUG)
|
||||||
|
h.setFormatter(formatter)
|
||||||
|
logger.addHandler(h)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class HParams:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if type(v) == dict:
|
||||||
|
v = HParams(**v)
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.__dict__.keys()
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self.__dict__.items()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return self.__dict__.values()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.__dict__)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
return setattr(self, key, value)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self.__dict__
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__dict__.__repr__()
|
Loading…
Reference in New Issue
Block a user