replace lib
This commit is contained in:
parent
c25bb6c5d5
commit
6721b81dcf
21
infer/lib/audio.py
Normal file
21
infer/lib/audio.py
Normal file
@ -0,0 +1,21 @@
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_audio(file, sr):
|
||||
try:
|
||||
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
|
||||
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||
file = (
|
||||
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
||||
) # 防止小白拷路径头尾带了空格和"和回车
|
||||
out, _ = (
|
||||
ffmpeg.input(file, threads=0)
|
||||
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
|
||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e}")
|
||||
|
||||
return np.frombuffer(out, np.float32).flatten()
|
654
infer/lib/rmvpe.py
Normal file
654
infer/lib/rmvpe.py
Normal file
@ -0,0 +1,654 @@
|
||||
import torch, numpy as np,pdb
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch,pdb
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import get_window
|
||||
from librosa.util import pad_center, tiny,normalize
|
||||
###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py
|
||||
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
||||
n_fft=800, dtype=np.float32, norm=None):
|
||||
"""
|
||||
# from librosa 0.6
|
||||
Compute the sum-square envelope of a window function at a given hop length.
|
||||
This is used to estimate modulation effects induced by windowing
|
||||
observations in short-time fourier transforms.
|
||||
Parameters
|
||||
----------
|
||||
window : string, tuple, number, callable, or list-like
|
||||
Window specification, as in `get_window`
|
||||
n_frames : int > 0
|
||||
The number of analysis frames
|
||||
hop_length : int > 0
|
||||
The number of samples to advance between frames
|
||||
win_length : [optional]
|
||||
The length of the window function. By default, this matches `n_fft`.
|
||||
n_fft : int > 0
|
||||
The length of each analysis frame.
|
||||
dtype : np.dtype
|
||||
The data type of the output
|
||||
Returns
|
||||
-------
|
||||
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
||||
The sum-squared envelope of the window function
|
||||
"""
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
n = n_fft + hop_length * (n_frames - 1)
|
||||
x = np.zeros(n, dtype=dtype)
|
||||
|
||||
# Compute the squared window at the desired length
|
||||
win_sq = get_window(window, win_length, fftbins=True)
|
||||
win_sq = normalize(win_sq, norm=norm)**2
|
||||
win_sq = pad_center(win_sq, n_fft)
|
||||
|
||||
# Fill the envelope
|
||||
for i in range(n_frames):
|
||||
sample = i * hop_length
|
||||
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||
return x
|
||||
|
||||
class STFT(torch.nn.Module):
|
||||
def __init__(self, filter_length=1024, hop_length=512, win_length=None,
|
||||
window='hann'):
|
||||
"""
|
||||
This module implements an STFT using 1D convolution and 1D transpose convolutions.
|
||||
This is a bit tricky so there are some cases that probably won't work as working
|
||||
out the same sizes before and after in all overlap add setups is tough. Right now,
|
||||
this code should work with hop lengths that are half the filter length (50% overlap
|
||||
between frames).
|
||||
|
||||
Keyword Arguments:
|
||||
filter_length {int} -- Length of filters used (default: {1024})
|
||||
hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
|
||||
win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
|
||||
equals the filter length). (default: {None})
|
||||
window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
|
||||
(default: {'hann'})
|
||||
"""
|
||||
super(STFT, self).__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length if win_length else filter_length
|
||||
self.window = window
|
||||
self.forward_transform = None
|
||||
self.pad_amount = int(self.filter_length / 2)
|
||||
scale = self.filter_length / self.hop_length
|
||||
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||
|
||||
cutoff = int((self.filter_length / 2 + 1))
|
||||
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),np.imag(fourier_basis[:cutoff, :])])
|
||||
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
||||
inverse_basis = torch.FloatTensor(
|
||||
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
||||
|
||||
assert (filter_length >= self.win_length)
|
||||
# get window and zero center pad it to filter_length
|
||||
fft_window = get_window(window, self.win_length, fftbins=True)
|
||||
fft_window = pad_center(fft_window, size=filter_length)
|
||||
fft_window = torch.from_numpy(fft_window).float()
|
||||
|
||||
# window the bases
|
||||
forward_basis *= fft_window
|
||||
inverse_basis *= fft_window
|
||||
|
||||
self.register_buffer('forward_basis', forward_basis.float())
|
||||
self.register_buffer('inverse_basis', inverse_basis.float())
|
||||
|
||||
def transform(self, input_data):
|
||||
"""Take input data (audio) to STFT domain.
|
||||
|
||||
Arguments:
|
||||
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
|
||||
|
||||
Returns:
|
||||
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
|
||||
num_frequencies, num_frames)
|
||||
phase {tensor} -- Phase of STFT with shape (num_batch,
|
||||
num_frequencies, num_frames)
|
||||
"""
|
||||
num_batches = input_data.shape[0]
|
||||
num_samples = input_data.shape[-1]
|
||||
|
||||
self.num_samples = num_samples
|
||||
|
||||
# similar to librosa, reflect-pad the input
|
||||
input_data = input_data.view(num_batches, 1, num_samples)
|
||||
# print(1234,input_data.shape)
|
||||
input_data = F.pad(input_data.unsqueeze(1),(self.pad_amount, self.pad_amount, 0, 0,0,0),mode='reflect').squeeze(1)
|
||||
# print(2333,input_data.shape,self.forward_basis.shape,self.hop_length)
|
||||
# pdb.set_trace()
|
||||
forward_transform = F.conv1d(
|
||||
input_data,
|
||||
self.forward_basis,
|
||||
stride=self.hop_length,
|
||||
padding=0)
|
||||
|
||||
cutoff = int((self.filter_length / 2) + 1)
|
||||
real_part = forward_transform[:, :cutoff, :]
|
||||
imag_part = forward_transform[:, cutoff:, :]
|
||||
|
||||
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
|
||||
# phase = torch.atan2(imag_part.data, real_part.data)
|
||||
|
||||
return magnitude#, phase
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
"""Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
|
||||
by the ```transform``` function.
|
||||
|
||||
Arguments:
|
||||
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
|
||||
num_frequencies, num_frames)
|
||||
phase {tensor} -- Phase of STFT with shape (num_batch,
|
||||
num_frequencies, num_frames)
|
||||
|
||||
Returns:
|
||||
inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
|
||||
shape (num_batch, num_samples)
|
||||
"""
|
||||
recombine_magnitude_phase = torch.cat(
|
||||
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)
|
||||
|
||||
inverse_transform = F.conv_transpose1d(
|
||||
recombine_magnitude_phase,
|
||||
self.inverse_basis,
|
||||
stride=self.hop_length,
|
||||
padding=0)
|
||||
|
||||
if self.window is not None:
|
||||
window_sum = window_sumsquare(
|
||||
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
||||
win_length=self.win_length, n_fft=self.filter_length,
|
||||
dtype=np.float32)
|
||||
# remove modulation effects
|
||||
approx_nonzero_indices = torch.from_numpy(
|
||||
np.where(window_sum > tiny(window_sum))[0])
|
||||
window_sum = torch.from_numpy(window_sum).to(inverse_transform.device)
|
||||
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||
|
||||
inverse_transform = inverse_transform[..., self.pad_amount:]
|
||||
inverse_transform = inverse_transform[..., :self.num_samples]
|
||||
inverse_transform = inverse_transform.squeeze(1)
|
||||
|
||||
return inverse_transform
|
||||
|
||||
def forward(self, input_data):
|
||||
"""Take input data (audio) to STFT domain and then back to audio.
|
||||
|
||||
Arguments:
|
||||
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
|
||||
|
||||
Returns:
|
||||
reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
|
||||
shape (num_batch, num_samples)
|
||||
"""
|
||||
self.magnitude, self.phase = self.transform(input_data)
|
||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||
return reconstruction
|
||||
from time import time as ttime
|
||||
class BiGRU(nn.Module):
|
||||
def __init__(self, input_features, hidden_features, num_layers):
|
||||
super(BiGRU, self).__init__()
|
||||
self.gru = nn.GRU(
|
||||
input_features,
|
||||
hidden_features,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.gru(x)[0]
|
||||
|
||||
|
||||
class ConvBlockRes(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, momentum=0.01):
|
||||
super(ConvBlockRes, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(1, 1),
|
||||
padding=(1, 1),
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(1, 1),
|
||||
padding=(1, 1),
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||
nn.ReLU(),
|
||||
)
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
||||
self.is_shortcut = True
|
||||
else:
|
||||
self.is_shortcut = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_shortcut:
|
||||
return self.conv(x) + self.shortcut(x)
|
||||
else:
|
||||
return self.conv(x) + x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
in_size,
|
||||
n_encoders,
|
||||
kernel_size,
|
||||
n_blocks,
|
||||
out_channels=16,
|
||||
momentum=0.01,
|
||||
):
|
||||
super(Encoder, self).__init__()
|
||||
self.n_encoders = n_encoders
|
||||
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
||||
self.layers = nn.ModuleList()
|
||||
self.latent_channels = []
|
||||
for i in range(self.n_encoders):
|
||||
self.layers.append(
|
||||
ResEncoderBlock(
|
||||
in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
|
||||
)
|
||||
)
|
||||
self.latent_channels.append([out_channels, in_size])
|
||||
in_channels = out_channels
|
||||
out_channels *= 2
|
||||
in_size //= 2
|
||||
self.out_size = in_size
|
||||
self.out_channel = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
concat_tensors = []
|
||||
x = self.bn(x)
|
||||
for i in range(self.n_encoders):
|
||||
_, x = self.layers[i](x)
|
||||
concat_tensors.append(_)
|
||||
return x, concat_tensors
|
||||
|
||||
|
||||
class ResEncoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
|
||||
):
|
||||
super(ResEncoderBlock, self).__init__()
|
||||
self.n_blocks = n_blocks
|
||||
self.conv = nn.ModuleList()
|
||||
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
||||
for i in range(n_blocks - 1):
|
||||
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
||||
self.kernel_size = kernel_size
|
||||
if self.kernel_size is not None:
|
||||
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.n_blocks):
|
||||
x = self.conv[i](x)
|
||||
if self.kernel_size is not None:
|
||||
return x, self.pool(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class Intermediate(nn.Module): #
|
||||
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
||||
super(Intermediate, self).__init__()
|
||||
self.n_inters = n_inters
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers.append(
|
||||
ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
|
||||
)
|
||||
for i in range(self.n_inters - 1):
|
||||
self.layers.append(
|
||||
ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.n_inters):
|
||||
x = self.layers[i](x)
|
||||
return x
|
||||
|
||||
|
||||
class ResDecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
||||
super(ResDecoderBlock, self).__init__()
|
||||
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
||||
self.n_blocks = n_blocks
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=stride,
|
||||
padding=(1, 1),
|
||||
output_padding=out_padding,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.conv2 = nn.ModuleList()
|
||||
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
||||
for i in range(n_blocks - 1):
|
||||
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
||||
|
||||
def forward(self, x, concat_tensor):
|
||||
x = self.conv1(x)
|
||||
x = torch.cat((x, concat_tensor), dim=1)
|
||||
for i in range(self.n_blocks):
|
||||
x = self.conv2[i](x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
||||
super(Decoder, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
self.n_decoders = n_decoders
|
||||
for i in range(self.n_decoders):
|
||||
out_channels = in_channels // 2
|
||||
self.layers.append(
|
||||
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
|
||||
)
|
||||
in_channels = out_channels
|
||||
|
||||
def forward(self, x, concat_tensors):
|
||||
for i in range(self.n_decoders):
|
||||
x = self.layers[i](x, concat_tensors[-1 - i])
|
||||
return x
|
||||
|
||||
|
||||
class DeepUnet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size,
|
||||
n_blocks,
|
||||
en_de_layers=5,
|
||||
inter_layers=4,
|
||||
in_channels=1,
|
||||
en_out_channels=16,
|
||||
):
|
||||
super(DeepUnet, self).__init__()
|
||||
self.encoder = Encoder(
|
||||
in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
|
||||
)
|
||||
self.intermediate = Intermediate(
|
||||
self.encoder.out_channel // 2,
|
||||
self.encoder.out_channel,
|
||||
inter_layers,
|
||||
n_blocks,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x, concat_tensors = self.encoder(x)
|
||||
x = self.intermediate(x)
|
||||
x = self.decoder(x, concat_tensors)
|
||||
return x
|
||||
|
||||
|
||||
class E2E(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_blocks,
|
||||
n_gru,
|
||||
kernel_size,
|
||||
en_de_layers=5,
|
||||
inter_layers=4,
|
||||
in_channels=1,
|
||||
en_out_channels=16,
|
||||
):
|
||||
super(E2E, self).__init__()
|
||||
self.unet = DeepUnet(
|
||||
kernel_size,
|
||||
n_blocks,
|
||||
en_de_layers,
|
||||
inter_layers,
|
||||
in_channels,
|
||||
en_out_channels,
|
||||
)
|
||||
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
||||
if n_gru:
|
||||
self.fc = nn.Sequential(
|
||||
BiGRU(3 * 128, 256, n_gru),
|
||||
nn.Linear(512, 360),
|
||||
nn.Dropout(0.25),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
else:
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, mel):
|
||||
# print(mel.shape)
|
||||
mel = mel.transpose(-1, -2).unsqueeze(1)
|
||||
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
||||
x = self.fc(x)
|
||||
# print(x.shape)
|
||||
return x
|
||||
|
||||
|
||||
from librosa.filters import mel
|
||||
|
||||
|
||||
class MelSpectrogram(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
is_half,
|
||||
n_mel_channels,
|
||||
sampling_rate,
|
||||
win_length,
|
||||
hop_length,
|
||||
n_fft=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clamp=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
n_fft = win_length if n_fft is None else n_fft
|
||||
self.hann_window = {}
|
||||
mel_basis = mel(
|
||||
sr=sampling_rate,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mel_channels,
|
||||
fmin=mel_fmin,
|
||||
fmax=mel_fmax,
|
||||
htk=True,
|
||||
)
|
||||
mel_basis = torch.from_numpy(mel_basis).float()
|
||||
self.register_buffer("mel_basis", mel_basis)
|
||||
self.n_fft = win_length if n_fft is None else n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.sampling_rate = sampling_rate
|
||||
self.n_mel_channels = n_mel_channels
|
||||
self.clamp = clamp
|
||||
self.is_half = is_half
|
||||
|
||||
def forward(self, audio, keyshift=0, speed=1, center=True):
|
||||
factor = 2 ** (keyshift / 12)
|
||||
n_fft_new = int(np.round(self.n_fft * factor))
|
||||
win_length_new = int(np.round(self.win_length * factor))
|
||||
hop_length_new = int(np.round(self.hop_length * speed))
|
||||
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
||||
if keyshift_key not in self.hann_window:
|
||||
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
|
||||
# "cpu"if(audio.device.type=="privateuseone") else audio.device
|
||||
audio.device
|
||||
)
|
||||
# fft = torch.stft(#doesn't support pytorch_dml
|
||||
# # audio.cpu() if(audio.device.type=="privateuseone")else audio,
|
||||
# audio,
|
||||
# n_fft=n_fft_new,
|
||||
# hop_length=hop_length_new,
|
||||
# win_length=win_length_new,
|
||||
# window=self.hann_window[keyshift_key],
|
||||
# center=center,
|
||||
# return_complex=True,
|
||||
# )
|
||||
# magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
||||
# print(1111111111)
|
||||
# print(222222222222222,audio.device,self.is_half)
|
||||
if hasattr(self, "stft") == False:
|
||||
# print(n_fft_new,hop_length_new,win_length_new,audio.shape)
|
||||
self.stft=STFT(
|
||||
filter_length=n_fft_new,
|
||||
hop_length=hop_length_new,
|
||||
win_length=win_length_new,
|
||||
window='hann'
|
||||
).to(audio.device)
|
||||
magnitude = self.stft.transform(audio)#phase
|
||||
# if (audio.device.type == "privateuseone"):
|
||||
# magnitude=magnitude.to(audio.device)
|
||||
if keyshift != 0:
|
||||
size = self.n_fft // 2 + 1
|
||||
resize = magnitude.size(1)
|
||||
if resize < size:
|
||||
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
||||
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
||||
mel_output = torch.matmul(self.mel_basis, magnitude)
|
||||
if self.is_half == True:
|
||||
mel_output = mel_output.half()
|
||||
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
||||
# print(log_mel_spec.device.type)
|
||||
return log_mel_spec
|
||||
|
||||
|
||||
class RMVPE:
|
||||
def __init__(self, model_path, is_half, device=None):
|
||||
self.resample_kernel = {}
|
||||
self.resample_kernel = {}
|
||||
self.is_half = is_half
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
self.mel_extractor = MelSpectrogram(
|
||||
is_half, 128, 16000, 1024, 160, None, 30, 8000
|
||||
).to(device)
|
||||
if ("privateuseone" in str(device)):
|
||||
import onnxruntime as ort
|
||||
ort_session = ort.InferenceSession("rmvpe.onnx", providers=["DmlExecutionProvider"])
|
||||
self.model=ort_session
|
||||
else:
|
||||
model = E2E(4, 1, (2, 2))
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(ckpt)
|
||||
model.eval()
|
||||
if is_half == True:
|
||||
model = model.half()
|
||||
self.model = model
|
||||
self.model = self.model.to(device)
|
||||
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
|
||||
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
|
||||
|
||||
def mel2hidden(self, mel):
|
||||
with torch.no_grad():
|
||||
n_frames = mel.shape[-1]
|
||||
mel = F.pad(
|
||||
mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
|
||||
)
|
||||
if("privateuseone" in str(self.device) ):
|
||||
onnx_input_name = self.model.get_inputs()[0].name
|
||||
onnx_outputs_names = self.model.get_outputs()[0].name
|
||||
hidden = self.model.run([onnx_outputs_names], input_feed={onnx_input_name: mel.cpu().numpy()})[0]
|
||||
else:
|
||||
hidden = self.model(mel)
|
||||
return hidden[:, :n_frames]
|
||||
|
||||
def decode(self, hidden, thred=0.03):
|
||||
cents_pred = self.to_local_average_cents(hidden, thred=thred)
|
||||
f0 = 10 * (2 ** (cents_pred / 1200))
|
||||
f0[f0 == 10] = 0
|
||||
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
|
||||
return f0
|
||||
|
||||
def infer_from_audio(self, audio, thred=0.03):
|
||||
# torch.cuda.synchronize()
|
||||
t0=ttime()
|
||||
mel = self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True)
|
||||
# print(123123123,mel.device.type)
|
||||
# torch.cuda.synchronize()
|
||||
t1=ttime()
|
||||
hidden = self.mel2hidden(mel)
|
||||
# torch.cuda.synchronize()
|
||||
t2=ttime()
|
||||
# print(234234,hidden.device.type)
|
||||
if("privateuseone" not in str(self.device)):
|
||||
hidden = hidden.squeeze(0).cpu().numpy()
|
||||
else:
|
||||
hidden=hidden[0]
|
||||
if self.is_half == True:
|
||||
hidden = hidden.astype("float32")
|
||||
|
||||
f0 = self.decode(hidden, thred=thred)
|
||||
# torch.cuda.synchronize()
|
||||
t3=ttime()
|
||||
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
|
||||
return f0
|
||||
|
||||
def to_local_average_cents(self, salience, thred=0.05):
|
||||
# t0 = ttime()
|
||||
center = np.argmax(salience, axis=1) # 帧长#index
|
||||
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
|
||||
# t1 = ttime()
|
||||
center += 4
|
||||
todo_salience = []
|
||||
todo_cents_mapping = []
|
||||
starts = center - 4
|
||||
ends = center + 5
|
||||
for idx in range(salience.shape[0]):
|
||||
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
||||
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
||||
# t2 = ttime()
|
||||
todo_salience = np.array(todo_salience) # 帧长,9
|
||||
todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
|
||||
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
|
||||
weight_sum = np.sum(todo_salience, 1) # 帧长
|
||||
devided = product_sum / weight_sum # 帧长
|
||||
# t3 = ttime()
|
||||
maxx = np.max(salience, axis=1) # 帧长
|
||||
devided[maxx <= thred] = 0
|
||||
# t4 = ttime()
|
||||
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
return devided
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import soundfile as sf, librosa
|
||||
audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav")
|
||||
if len(audio.shape) > 1:
|
||||
audio = librosa.to_mono(audio.transpose(1, 0))
|
||||
audio_bak = audio.copy()
|
||||
if sampling_rate != 16000:
|
||||
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
|
||||
model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt"
|
||||
thred = 0.03 # 0.01
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
rmvpe = RMVPE(model_path,is_half=False, device=device)
|
||||
t0=ttime()
|
||||
f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
|
||||
t1=ttime()
|
||||
print(f0.shape,t1-t0)
|
260
infer/lib/slicer2.py
Normal file
260
infer/lib/slicer2.py
Normal file
@ -0,0 +1,260 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
# This function is obtained from librosa.
|
||||
def get_rms(
|
||||
y,
|
||||
frame_length=2048,
|
||||
hop_length=512,
|
||||
pad_mode="constant",
|
||||
):
|
||||
padding = (int(frame_length // 2), int(frame_length // 2))
|
||||
y = np.pad(y, padding, mode=pad_mode)
|
||||
|
||||
axis = -1
|
||||
# put our new within-frame axis at the end for now
|
||||
out_strides = y.strides + tuple([y.strides[axis]])
|
||||
# Reduce the shape on the framing axis
|
||||
x_shape_trimmed = list(y.shape)
|
||||
x_shape_trimmed[axis] -= frame_length - 1
|
||||
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
|
||||
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
|
||||
if axis < 0:
|
||||
target_axis = axis - 1
|
||||
else:
|
||||
target_axis = axis + 1
|
||||
xw = np.moveaxis(xw, -1, target_axis)
|
||||
# Downsample along the target axis
|
||||
slices = [slice(None)] * xw.ndim
|
||||
slices[axis] = slice(0, None, hop_length)
|
||||
x = xw[tuple(slices)]
|
||||
|
||||
# Calculate power
|
||||
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
|
||||
|
||||
return np.sqrt(power)
|
||||
|
||||
|
||||
class Slicer:
|
||||
def __init__(
|
||||
self,
|
||||
sr: int,
|
||||
threshold: float = -40.0,
|
||||
min_length: int = 5000,
|
||||
min_interval: int = 300,
|
||||
hop_size: int = 20,
|
||||
max_sil_kept: int = 5000,
|
||||
):
|
||||
if not min_length >= min_interval >= hop_size:
|
||||
raise ValueError(
|
||||
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
|
||||
)
|
||||
if not max_sil_kept >= hop_size:
|
||||
raise ValueError(
|
||||
"The following condition must be satisfied: max_sil_kept >= hop_size"
|
||||
)
|
||||
min_interval = sr * min_interval / 1000
|
||||
self.threshold = 10 ** (threshold / 20.0)
|
||||
self.hop_size = round(sr * hop_size / 1000)
|
||||
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
||||
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
||||
self.min_interval = round(min_interval / self.hop_size)
|
||||
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
||||
|
||||
def _apply_slice(self, waveform, begin, end):
|
||||
if len(waveform.shape) > 1:
|
||||
return waveform[
|
||||
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
|
||||
]
|
||||
else:
|
||||
return waveform[
|
||||
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
|
||||
]
|
||||
|
||||
# @timeit
|
||||
def slice(self, waveform):
|
||||
if len(waveform.shape) > 1:
|
||||
samples = waveform.mean(axis=0)
|
||||
else:
|
||||
samples = waveform
|
||||
if samples.shape[0] <= self.min_length:
|
||||
return [waveform]
|
||||
rms_list = get_rms(
|
||||
y=samples, frame_length=self.win_size, hop_length=self.hop_size
|
||||
).squeeze(0)
|
||||
sil_tags = []
|
||||
silence_start = None
|
||||
clip_start = 0
|
||||
for i, rms in enumerate(rms_list):
|
||||
# Keep looping while frame is silent.
|
||||
if rms < self.threshold:
|
||||
# Record start of silent frames.
|
||||
if silence_start is None:
|
||||
silence_start = i
|
||||
continue
|
||||
# Keep looping while frame is not silent and silence start has not been recorded.
|
||||
if silence_start is None:
|
||||
continue
|
||||
# Clear recorded silence start if interval is not enough or clip is too short
|
||||
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||
need_slice_middle = (
|
||||
i - silence_start >= self.min_interval
|
||||
and i - clip_start >= self.min_length
|
||||
)
|
||||
if not is_leading_silence and not need_slice_middle:
|
||||
silence_start = None
|
||||
continue
|
||||
# Need slicing. Record the range of silent frames to be removed.
|
||||
if i - silence_start <= self.max_sil_kept:
|
||||
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos))
|
||||
else:
|
||||
sil_tags.append((pos, pos))
|
||||
clip_start = pos
|
||||
elif i - silence_start <= self.max_sil_kept * 2:
|
||||
pos = rms_list[
|
||||
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
pos += i - self.max_sil_kept
|
||||
pos_l = (
|
||||
rms_list[
|
||||
silence_start : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
+ silence_start
|
||||
)
|
||||
pos_r = (
|
||||
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||
+ i
|
||||
- self.max_sil_kept
|
||||
)
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos_r))
|
||||
clip_start = pos_r
|
||||
else:
|
||||
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||
clip_start = max(pos_r, pos)
|
||||
else:
|
||||
pos_l = (
|
||||
rms_list[
|
||||
silence_start : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
+ silence_start
|
||||
)
|
||||
pos_r = (
|
||||
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||
+ i
|
||||
- self.max_sil_kept
|
||||
)
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos_r))
|
||||
else:
|
||||
sil_tags.append((pos_l, pos_r))
|
||||
clip_start = pos_r
|
||||
silence_start = None
|
||||
# Deal with trailing silence.
|
||||
total_frames = rms_list.shape[0]
|
||||
if (
|
||||
silence_start is not None
|
||||
and total_frames - silence_start >= self.min_interval
|
||||
):
|
||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||
sil_tags.append((pos, total_frames + 1))
|
||||
# Apply and return slices.
|
||||
if len(sil_tags) == 0:
|
||||
return [waveform]
|
||||
else:
|
||||
chunks = []
|
||||
if sil_tags[0][0] > 0:
|
||||
chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
|
||||
for i in range(len(sil_tags) - 1):
|
||||
chunks.append(
|
||||
self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0])
|
||||
)
|
||||
if sil_tags[-1][1] < total_frames:
|
||||
chunks.append(
|
||||
self._apply_slice(waveform, sil_tags[-1][1], total_frames)
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
def main():
|
||||
import os.path
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import librosa
|
||||
import soundfile
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("audio", type=str, help="The audio to be sliced")
|
||||
parser.add_argument(
|
||||
"--out", type=str, help="Output directory of the sliced audio clips"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db_thresh",
|
||||
type=float,
|
||||
required=False,
|
||||
default=-40,
|
||||
help="The dB threshold for silence detection",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
type=int,
|
||||
required=False,
|
||||
default=5000,
|
||||
help="The minimum milliseconds required for each sliced audio clip",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_interval",
|
||||
type=int,
|
||||
required=False,
|
||||
default=300,
|
||||
help="The minimum milliseconds for a silence part to be sliced",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hop_size",
|
||||
type=int,
|
||||
required=False,
|
||||
default=10,
|
||||
help="Frame length in milliseconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_sil_kept",
|
||||
type=int,
|
||||
required=False,
|
||||
default=500,
|
||||
help="The maximum silence length kept around the sliced clip, presented in milliseconds",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
out = args.out
|
||||
if out is None:
|
||||
out = os.path.dirname(os.path.abspath(args.audio))
|
||||
audio, sr = librosa.load(args.audio, sr=None, mono=False)
|
||||
slicer = Slicer(
|
||||
sr=sr,
|
||||
threshold=args.db_thresh,
|
||||
min_length=args.min_length,
|
||||
min_interval=args.min_interval,
|
||||
hop_size=args.hop_size,
|
||||
max_sil_kept=args.max_sil_kept,
|
||||
)
|
||||
chunks = slicer.slice(audio)
|
||||
if not os.path.exists(out):
|
||||
os.makedirs(out)
|
||||
for i, chunk in enumerate(chunks):
|
||||
if len(chunk.shape) > 1:
|
||||
chunk = chunk.T
|
||||
soundfile.write(
|
||||
os.path.join(
|
||||
out,
|
||||
f"%s_%d.wav"
|
||||
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
|
||||
),
|
||||
chunk,
|
||||
sr,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
512
infer/lib/train/data_utils.py
Normal file
512
infer/lib/train/data_utils.py
Normal file
@ -0,0 +1,512 @@
|
||||
import os, traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
from infer.lib.train.mel_processing import spectrogram_torch
|
||||
from infer.lib.train.utils import load_wav_to_torch, load_filepaths_and_text
|
||||
|
||||
|
||||
class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, text pairs
|
||||
2) normalizes text and converts them to sequences of integers
|
||||
3) computes spectrograms from audio files.
|
||||
"""
|
||||
|
||||
def __init__(self, audiopaths_and_text, hparams):
|
||||
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.filter_length = hparams.filter_length
|
||||
self.hop_length = hparams.hop_length
|
||||
self.win_length = hparams.win_length
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
||||
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
||||
self._filter()
|
||||
|
||||
def _filter(self):
|
||||
"""
|
||||
Filter text & store spec lengths
|
||||
"""
|
||||
# Store spectrogram lengths for Bucketing
|
||||
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
|
||||
# spec_length = wav_length // hop_length
|
||||
audiopaths_and_text_new = []
|
||||
lengths = []
|
||||
for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
|
||||
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
||||
audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
|
||||
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
||||
self.audiopaths_and_text = audiopaths_and_text_new
|
||||
self.lengths = lengths
|
||||
|
||||
def get_sid(self, sid):
|
||||
sid = torch.LongTensor([int(sid)])
|
||||
return sid
|
||||
|
||||
def get_audio_text_pair(self, audiopath_and_text):
|
||||
# separate filename and text
|
||||
file = audiopath_and_text[0]
|
||||
phone = audiopath_and_text[1]
|
||||
pitch = audiopath_and_text[2]
|
||||
pitchf = audiopath_and_text[3]
|
||||
dv = audiopath_and_text[4]
|
||||
|
||||
phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
|
||||
spec, wav = self.get_audio(file)
|
||||
dv = self.get_sid(dv)
|
||||
|
||||
len_phone = phone.size()[0]
|
||||
len_spec = spec.size()[-1]
|
||||
# print(123,phone.shape,pitch.shape,spec.shape)
|
||||
if len_phone != len_spec:
|
||||
len_min = min(len_phone, len_spec)
|
||||
# amor
|
||||
len_wav = len_min * self.hop_length
|
||||
|
||||
spec = spec[:, :len_min]
|
||||
wav = wav[:, :len_wav]
|
||||
|
||||
phone = phone[:len_min, :]
|
||||
pitch = pitch[:len_min]
|
||||
pitchf = pitchf[:len_min]
|
||||
|
||||
return (spec, wav, phone, pitch, pitchf, dv)
|
||||
|
||||
def get_labels(self, phone, pitch, pitchf):
|
||||
phone = np.load(phone)
|
||||
phone = np.repeat(phone, 2, axis=0)
|
||||
pitch = np.load(pitch)
|
||||
pitchf = np.load(pitchf)
|
||||
n_num = min(phone.shape[0], 900) # DistributedBucketSampler
|
||||
# print(234,phone.shape,pitch.shape)
|
||||
phone = phone[:n_num, :]
|
||||
pitch = pitch[:n_num]
|
||||
pitchf = pitchf[:n_num]
|
||||
phone = torch.FloatTensor(phone)
|
||||
pitch = torch.LongTensor(pitch)
|
||||
pitchf = torch.FloatTensor(pitchf)
|
||||
return phone, pitch, pitchf
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
"{} SR doesn't match target {} SR".format(
|
||||
sampling_rate, self.sampling_rate
|
||||
)
|
||||
)
|
||||
audio_norm = audio
|
||||
# audio_norm = audio / self.max_wav_value
|
||||
# audio_norm = audio / np.abs(audio).max()
|
||||
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec_filename = filename.replace(".wav", ".spec.pt")
|
||||
if os.path.exists(spec_filename):
|
||||
try:
|
||||
spec = torch.load(spec_filename)
|
||||
except:
|
||||
print(spec_filename, traceback.format_exc())
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||
else:
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||
return spec, audio_norm
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_and_text)
|
||||
|
||||
|
||||
class TextAudioCollateMultiNSFsid:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Collate's training batch from normalized text and aduio
|
||||
PARAMS
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized]
|
||||
"""
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
|
||||
)
|
||||
|
||||
max_spec_len = max([x[0].size(1) for x in batch])
|
||||
max_wave_len = max([x[1].size(1) for x in batch])
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
wave_lengths = torch.LongTensor(len(batch))
|
||||
spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
|
||||
wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
|
||||
spec_padded.zero_()
|
||||
wave_padded.zero_()
|
||||
|
||||
max_phone_len = max([x[2].size(0) for x in batch])
|
||||
phone_lengths = torch.LongTensor(len(batch))
|
||||
phone_padded = torch.FloatTensor(
|
||||
len(batch), max_phone_len, batch[0][2].shape[1]
|
||||
) # (spec, wav, phone, pitch)
|
||||
pitch_padded = torch.LongTensor(len(batch), max_phone_len)
|
||||
pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
|
||||
phone_padded.zero_()
|
||||
pitch_padded.zero_()
|
||||
pitchf_padded.zero_()
|
||||
# dv = torch.FloatTensor(len(batch), 256)#gin=256
|
||||
sid = torch.LongTensor(len(batch))
|
||||
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
|
||||
spec = row[0]
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wave = row[1]
|
||||
wave_padded[i, :, : wave.size(1)] = wave
|
||||
wave_lengths[i] = wave.size(1)
|
||||
|
||||
phone = row[2]
|
||||
phone_padded[i, : phone.size(0), :] = phone
|
||||
phone_lengths[i] = phone.size(0)
|
||||
|
||||
pitch = row[3]
|
||||
pitch_padded[i, : pitch.size(0)] = pitch
|
||||
pitchf = row[4]
|
||||
pitchf_padded[i, : pitchf.size(0)] = pitchf
|
||||
|
||||
# dv[i] = row[5]
|
||||
sid[i] = row[5]
|
||||
|
||||
return (
|
||||
phone_padded,
|
||||
phone_lengths,
|
||||
pitch_padded,
|
||||
pitchf_padded,
|
||||
spec_padded,
|
||||
spec_lengths,
|
||||
wave_padded,
|
||||
wave_lengths,
|
||||
# dv
|
||||
sid,
|
||||
)
|
||||
|
||||
|
||||
class TextAudioLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, text pairs
|
||||
2) normalizes text and converts them to sequences of integers
|
||||
3) computes spectrograms from audio files.
|
||||
"""
|
||||
|
||||
def __init__(self, audiopaths_and_text, hparams):
|
||||
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.filter_length = hparams.filter_length
|
||||
self.hop_length = hparams.hop_length
|
||||
self.win_length = hparams.win_length
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
||||
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
||||
self._filter()
|
||||
|
||||
def _filter(self):
|
||||
"""
|
||||
Filter text & store spec lengths
|
||||
"""
|
||||
# Store spectrogram lengths for Bucketing
|
||||
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
|
||||
# spec_length = wav_length // hop_length
|
||||
audiopaths_and_text_new = []
|
||||
lengths = []
|
||||
for audiopath, text, dv in self.audiopaths_and_text:
|
||||
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
||||
audiopaths_and_text_new.append([audiopath, text, dv])
|
||||
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
||||
self.audiopaths_and_text = audiopaths_and_text_new
|
||||
self.lengths = lengths
|
||||
|
||||
def get_sid(self, sid):
|
||||
sid = torch.LongTensor([int(sid)])
|
||||
return sid
|
||||
|
||||
def get_audio_text_pair(self, audiopath_and_text):
|
||||
# separate filename and text
|
||||
file = audiopath_and_text[0]
|
||||
phone = audiopath_and_text[1]
|
||||
dv = audiopath_and_text[2]
|
||||
|
||||
phone = self.get_labels(phone)
|
||||
spec, wav = self.get_audio(file)
|
||||
dv = self.get_sid(dv)
|
||||
|
||||
len_phone = phone.size()[0]
|
||||
len_spec = spec.size()[-1]
|
||||
if len_phone != len_spec:
|
||||
len_min = min(len_phone, len_spec)
|
||||
len_wav = len_min * self.hop_length
|
||||
spec = spec[:, :len_min]
|
||||
wav = wav[:, :len_wav]
|
||||
phone = phone[:len_min, :]
|
||||
return (spec, wav, phone, dv)
|
||||
|
||||
def get_labels(self, phone):
|
||||
phone = np.load(phone)
|
||||
phone = np.repeat(phone, 2, axis=0)
|
||||
n_num = min(phone.shape[0], 900) # DistributedBucketSampler
|
||||
phone = phone[:n_num, :]
|
||||
phone = torch.FloatTensor(phone)
|
||||
return phone
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
"{} SR doesn't match target {} SR".format(
|
||||
sampling_rate, self.sampling_rate
|
||||
)
|
||||
)
|
||||
audio_norm = audio
|
||||
# audio_norm = audio / self.max_wav_value
|
||||
# audio_norm = audio / np.abs(audio).max()
|
||||
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec_filename = filename.replace(".wav", ".spec.pt")
|
||||
if os.path.exists(spec_filename):
|
||||
try:
|
||||
spec = torch.load(spec_filename)
|
||||
except:
|
||||
print(spec_filename, traceback.format_exc())
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||
else:
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||
return spec, audio_norm
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_and_text)
|
||||
|
||||
|
||||
class TextAudioCollate:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Collate's training batch from normalized text and aduio
|
||||
PARAMS
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized]
|
||||
"""
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
|
||||
)
|
||||
|
||||
max_spec_len = max([x[0].size(1) for x in batch])
|
||||
max_wave_len = max([x[1].size(1) for x in batch])
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
wave_lengths = torch.LongTensor(len(batch))
|
||||
spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
|
||||
wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
|
||||
spec_padded.zero_()
|
||||
wave_padded.zero_()
|
||||
|
||||
max_phone_len = max([x[2].size(0) for x in batch])
|
||||
phone_lengths = torch.LongTensor(len(batch))
|
||||
phone_padded = torch.FloatTensor(
|
||||
len(batch), max_phone_len, batch[0][2].shape[1]
|
||||
)
|
||||
phone_padded.zero_()
|
||||
sid = torch.LongTensor(len(batch))
|
||||
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
|
||||
spec = row[0]
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wave = row[1]
|
||||
wave_padded[i, :, : wave.size(1)] = wave
|
||||
wave_lengths[i] = wave.size(1)
|
||||
|
||||
phone = row[2]
|
||||
phone_padded[i, : phone.size(0), :] = phone
|
||||
phone_lengths[i] = phone.size(0)
|
||||
|
||||
sid[i] = row[3]
|
||||
|
||||
return (
|
||||
phone_padded,
|
||||
phone_lengths,
|
||||
spec_padded,
|
||||
spec_lengths,
|
||||
wave_padded,
|
||||
wave_lengths,
|
||||
sid,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
"""
|
||||
Maintain similar input lengths in a batch.
|
||||
Length groups are specified by boundaries.
|
||||
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
|
||||
|
||||
It removes samples which are not included in the boundaries.
|
||||
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
boundaries,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True,
|
||||
):
|
||||
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
||||
self.lengths = dataset.lengths
|
||||
self.batch_size = batch_size
|
||||
self.boundaries = boundaries
|
||||
|
||||
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
||||
self.total_size = sum(self.num_samples_per_bucket)
|
||||
self.num_samples = self.total_size // self.num_replicas
|
||||
|
||||
def _create_buckets(self):
|
||||
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
||||
for i in range(len(self.lengths)):
|
||||
length = self.lengths[i]
|
||||
idx_bucket = self._bisect(length)
|
||||
if idx_bucket != -1:
|
||||
buckets[idx_bucket].append(i)
|
||||
|
||||
for i in range(len(buckets) - 1, -1, -1): #
|
||||
if len(buckets[i]) == 0:
|
||||
buckets.pop(i)
|
||||
self.boundaries.pop(i + 1)
|
||||
|
||||
num_samples_per_bucket = []
|
||||
for i in range(len(buckets)):
|
||||
len_bucket = len(buckets[i])
|
||||
total_batch_size = self.num_replicas * self.batch_size
|
||||
rem = (
|
||||
total_batch_size - (len_bucket % total_batch_size)
|
||||
) % total_batch_size
|
||||
num_samples_per_bucket.append(len_bucket + rem)
|
||||
return buckets, num_samples_per_bucket
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
|
||||
indices = []
|
||||
if self.shuffle:
|
||||
for bucket in self.buckets:
|
||||
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
||||
else:
|
||||
for bucket in self.buckets:
|
||||
indices.append(list(range(len(bucket))))
|
||||
|
||||
batches = []
|
||||
for i in range(len(self.buckets)):
|
||||
bucket = self.buckets[i]
|
||||
len_bucket = len(bucket)
|
||||
ids_bucket = indices[i]
|
||||
num_samples_bucket = self.num_samples_per_bucket[i]
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
rem = num_samples_bucket - len_bucket
|
||||
ids_bucket = (
|
||||
ids_bucket
|
||||
+ ids_bucket * (rem // len_bucket)
|
||||
+ ids_bucket[: (rem % len_bucket)]
|
||||
)
|
||||
|
||||
# subsample
|
||||
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
||||
|
||||
# batching
|
||||
for j in range(len(ids_bucket) // self.batch_size):
|
||||
batch = [
|
||||
bucket[idx]
|
||||
for idx in ids_bucket[
|
||||
j * self.batch_size : (j + 1) * self.batch_size
|
||||
]
|
||||
]
|
||||
batches.append(batch)
|
||||
|
||||
if self.shuffle:
|
||||
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
||||
batches = [batches[i] for i in batch_ids]
|
||||
self.batches = batches
|
||||
|
||||
assert len(self.batches) * self.batch_size == self.num_samples
|
||||
return iter(self.batches)
|
||||
|
||||
def _bisect(self, x, lo=0, hi=None):
|
||||
if hi is None:
|
||||
hi = len(self.boundaries) - 1
|
||||
|
||||
if hi > lo:
|
||||
mid = (hi + lo) // 2
|
||||
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
|
||||
return mid
|
||||
elif x <= self.boundaries[mid]:
|
||||
return self._bisect(x, lo, mid)
|
||||
else:
|
||||
return self._bisect(x, mid + 1, hi)
|
||||
else:
|
||||
return -1
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples // self.batch_size
|
58
infer/lib/train/losses.py
Normal file
58
infer/lib/train/losses.py
Normal file
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
dr = dr.float()
|
||||
dg = dg.float()
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||
"""
|
||||
z_p, logs_q: [b, h, t_t]
|
||||
m_p, logs_p: [b, h, t_t]
|
||||
"""
|
||||
z_p = z_p.float()
|
||||
logs_q = logs_q.float()
|
||||
m_p = m_p.float()
|
||||
logs_p = logs_p.float()
|
||||
z_mask = z_mask.float()
|
||||
|
||||
kl = logs_p - logs_q - 0.5
|
||||
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||
kl = torch.sum(kl * z_mask)
|
||||
l = kl / torch.sum(z_mask)
|
||||
return l
|
130
infer/lib/train/mel_processing.py
Normal file
130
infer/lib/train/mel_processing.py
Normal file
@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
return dynamic_range_compression_torch(magnitudes)
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
return dynamic_range_decompression_torch(magnitudes)
|
||||
|
||||
|
||||
# Reusable banks
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
y :: (B, T) - Audio waveforms
|
||||
n_fft
|
||||
sampling_rate
|
||||
hop_size
|
||||
win_size
|
||||
center
|
||||
Returns:
|
||||
:: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
|
||||
"""
|
||||
# Validation
|
||||
if torch.min(y) < -1.07:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.07:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
# Window - Cache if needed
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
|
||||
# Padding
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
# Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
# Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
# MelBasis - Cache if needed
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=spec.dtype, device=spec.device
|
||||
)
|
||||
|
||||
# Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame)
|
||||
melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
melspec = spectral_normalize_torch(melspec)
|
||||
return melspec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
"""Convert waveform into Mel-frequency Log-amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
y :: (B, T) - Waveforms
|
||||
Returns:
|
||||
melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
|
||||
"""
|
||||
# Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
|
||||
spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
|
||||
|
||||
# Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
|
||||
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
|
||||
|
||||
return melspec
|
259
infer/lib/train/process_ckpt.py
Normal file
259
infer/lib/train/process_ckpt.py
Normal file
@ -0,0 +1,259 @@
|
||||
import torch, traceback, os, sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from collections import OrderedDict
|
||||
from i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
|
||||
def savee(ckpt, sr, if_f0, name, epoch, version, hps, i18n):
|
||||
try:
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
for key in ckpt.keys():
|
||||
if "enc_q" in key:
|
||||
continue
|
||||
opt["weight"][key] = ckpt[key].half()
|
||||
opt["config"] = [
|
||||
hps.data.filter_length // 2 + 1,
|
||||
32,
|
||||
hps.model.inter_channels,
|
||||
hps.model.hidden_channels,
|
||||
hps.model.filter_channels,
|
||||
hps.model.n_heads,
|
||||
hps.model.n_layers,
|
||||
hps.model.kernel_size,
|
||||
hps.model.p_dropout,
|
||||
hps.model.resblock,
|
||||
hps.model.resblock_kernel_sizes,
|
||||
hps.model.resblock_dilation_sizes,
|
||||
hps.model.upsample_rates,
|
||||
hps.model.upsample_initial_channel,
|
||||
hps.model.upsample_kernel_sizes,
|
||||
hps.model.spk_embed_dim,
|
||||
hps.model.gin_channels,
|
||||
hps.data.sampling_rate,
|
||||
]
|
||||
opt["info"] = "%sepoch" % epoch
|
||||
opt["sr"] = sr
|
||||
opt["f0"] = if_f0
|
||||
opt["version"] = version
|
||||
torch.save(opt, "weights/%s.pth" % name)
|
||||
return "Success."
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
||||
|
||||
def show_info(path):
|
||||
try:
|
||||
a = torch.load(path, map_location="cpu")
|
||||
return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
|
||||
a.get("info", "None"),
|
||||
a.get("sr", "None"),
|
||||
a.get("f0", "None"),
|
||||
a.get("version", "None"),
|
||||
)
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
||||
|
||||
def extract_small_model(path, name, sr, if_f0, info, version):
|
||||
try:
|
||||
ckpt = torch.load(path, map_location="cpu")
|
||||
if "model" in ckpt:
|
||||
ckpt = ckpt["model"]
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
for key in ckpt.keys():
|
||||
if "enc_q" in key:
|
||||
continue
|
||||
opt["weight"][key] = ckpt[key].half()
|
||||
if sr == "40k":
|
||||
opt["config"] = [
|
||||
1025,
|
||||
32,
|
||||
192,
|
||||
192,
|
||||
768,
|
||||
2,
|
||||
6,
|
||||
3,
|
||||
0,
|
||||
"1",
|
||||
[3, 7, 11],
|
||||
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
[10, 10, 2, 2],
|
||||
512,
|
||||
[16, 16, 4, 4],
|
||||
109,
|
||||
256,
|
||||
40000,
|
||||
]
|
||||
elif sr == "48k":
|
||||
if version == "v1":
|
||||
opt["config"] = [
|
||||
1025,
|
||||
32,
|
||||
192,
|
||||
192,
|
||||
768,
|
||||
2,
|
||||
6,
|
||||
3,
|
||||
0,
|
||||
"1",
|
||||
[3, 7, 11],
|
||||
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
[10, 6, 2, 2, 2],
|
||||
512,
|
||||
[16, 16, 4, 4, 4],
|
||||
109,
|
||||
256,
|
||||
48000,
|
||||
]
|
||||
else:
|
||||
opt["config"] = [
|
||||
1025,
|
||||
32,
|
||||
192,
|
||||
192,
|
||||
768,
|
||||
2,
|
||||
6,
|
||||
3,
|
||||
0,
|
||||
"1",
|
||||
[3, 7, 11],
|
||||
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
[12, 10, 2, 2],
|
||||
512,
|
||||
[24, 20, 4, 4],
|
||||
109,
|
||||
256,
|
||||
48000,
|
||||
]
|
||||
elif sr == "32k":
|
||||
if version == "v1":
|
||||
opt["config"] = [
|
||||
513,
|
||||
32,
|
||||
192,
|
||||
192,
|
||||
768,
|
||||
2,
|
||||
6,
|
||||
3,
|
||||
0,
|
||||
"1",
|
||||
[3, 7, 11],
|
||||
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
[10, 4, 2, 2, 2],
|
||||
512,
|
||||
[16, 16, 4, 4, 4],
|
||||
109,
|
||||
256,
|
||||
32000,
|
||||
]
|
||||
else:
|
||||
opt["config"] = [
|
||||
513,
|
||||
32,
|
||||
192,
|
||||
192,
|
||||
768,
|
||||
2,
|
||||
6,
|
||||
3,
|
||||
0,
|
||||
"1",
|
||||
[3, 7, 11],
|
||||
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
[10, 8, 2, 2],
|
||||
512,
|
||||
[20, 16, 4, 4],
|
||||
109,
|
||||
256,
|
||||
32000,
|
||||
]
|
||||
if info == "":
|
||||
info = "Extracted model."
|
||||
opt["info"] = info
|
||||
opt["version"] = version
|
||||
opt["sr"] = sr
|
||||
opt["f0"] = int(if_f0)
|
||||
torch.save(opt, "weights/%s.pth" % name)
|
||||
return "Success."
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
||||
|
||||
def change_info(path, info, name):
|
||||
try:
|
||||
ckpt = torch.load(path, map_location="cpu")
|
||||
ckpt["info"] = info
|
||||
if name == "":
|
||||
name = os.path.basename(path)
|
||||
torch.save(ckpt, "weights/%s" % name)
|
||||
return "Success."
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
||||
|
||||
def merge(path1, path2, alpha1, sr, f0, info, name, version):
|
||||
try:
|
||||
|
||||
def extract(ckpt):
|
||||
a = ckpt["model"]
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
for key in a.keys():
|
||||
if "enc_q" in key:
|
||||
continue
|
||||
opt["weight"][key] = a[key]
|
||||
return opt
|
||||
|
||||
ckpt1 = torch.load(path1, map_location="cpu")
|
||||
ckpt2 = torch.load(path2, map_location="cpu")
|
||||
cfg = ckpt1["config"]
|
||||
if "model" in ckpt1:
|
||||
ckpt1 = extract(ckpt1)
|
||||
else:
|
||||
ckpt1 = ckpt1["weight"]
|
||||
if "model" in ckpt2:
|
||||
ckpt2 = extract(ckpt2)
|
||||
else:
|
||||
ckpt2 = ckpt2["weight"]
|
||||
if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
|
||||
return "Fail to merge the models. The model architectures are not the same."
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
for key in ckpt1.keys():
|
||||
# try:
|
||||
if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
|
||||
min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
|
||||
opt["weight"][key] = (
|
||||
alpha1 * (ckpt1[key][:min_shape0].float())
|
||||
+ (1 - alpha1) * (ckpt2[key][:min_shape0].float())
|
||||
).half()
|
||||
else:
|
||||
opt["weight"][key] = (
|
||||
alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
|
||||
).half()
|
||||
# except:
|
||||
# pdb.set_trace()
|
||||
opt["config"] = cfg
|
||||
"""
|
||||
if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
|
||||
elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
|
||||
elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
|
||||
"""
|
||||
opt["sr"] = sr
|
||||
opt["f0"] = 1 if f0 == i18n("是") else 0
|
||||
opt["version"] = version
|
||||
opt["info"] = info
|
||||
torch.save(opt, "weights/%s.pth" % name)
|
||||
return "Success."
|
||||
except:
|
||||
return traceback.format_exc()
|
487
infer/lib/train/utils.py
Normal file
487
infer/lib/train/utils.py
Normal file
@ -0,0 +1,487 @@
|
||||
import os, traceback
|
||||
import glob
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import subprocess
|
||||
import numpy as np
|
||||
from scipy.io.wavfile import read
|
||||
import torch
|
||||
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
logger = logging
|
||||
|
||||
|
||||
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
##################
|
||||
def go(model, bkey):
|
||||
saved_state_dict = checkpoint_dict[bkey]
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items(): # 模型需要的shape
|
||||
try:
|
||||
new_state_dict[k] = saved_state_dict[k]
|
||||
if saved_state_dict[k].shape != state_dict[k].shape:
|
||||
print(
|
||||
"shape-%s-mismatch|need-%s|get-%s"
|
||||
% (k, state_dict[k].shape, saved_state_dict[k].shape)
|
||||
) #
|
||||
raise KeyError
|
||||
except:
|
||||
# logger.info(traceback.format_exc())
|
||||
logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
|
||||
new_state_dict[k] = v # 模型自带的随机值
|
||||
if hasattr(model, "module"):
|
||||
model.module.load_state_dict(new_state_dict, strict=False)
|
||||
else:
|
||||
model.load_state_dict(new_state_dict, strict=False)
|
||||
return model
|
||||
|
||||
go(combd, "combd")
|
||||
model = go(sbd, "sbd")
|
||||
#############
|
||||
logger.info("Loaded model weights")
|
||||
|
||||
iteration = checkpoint_dict["iteration"]
|
||||
learning_rate = checkpoint_dict["learning_rate"]
|
||||
if (
|
||||
optimizer is not None and load_opt == 1
|
||||
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
|
||||
# try:
|
||||
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
||||
# except:
|
||||
# traceback.print_exc()
|
||||
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
|
||||
return model, optimizer, learning_rate, iteration
|
||||
|
||||
|
||||
# def load_checkpoint(checkpoint_path, model, optimizer=None):
|
||||
# assert os.path.isfile(checkpoint_path)
|
||||
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
# iteration = checkpoint_dict['iteration']
|
||||
# learning_rate = checkpoint_dict['learning_rate']
|
||||
# if optimizer is not None:
|
||||
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
# # print(1111)
|
||||
# saved_state_dict = checkpoint_dict['model']
|
||||
# # print(1111)
|
||||
#
|
||||
# if hasattr(model, 'module'):
|
||||
# state_dict = model.module.state_dict()
|
||||
# else:
|
||||
# state_dict = model.state_dict()
|
||||
# new_state_dict= {}
|
||||
# for k, v in state_dict.items():
|
||||
# try:
|
||||
# new_state_dict[k] = saved_state_dict[k]
|
||||
# except:
|
||||
# logger.info("%s is not in the checkpoint" % k)
|
||||
# new_state_dict[k] = v
|
||||
# if hasattr(model, 'module'):
|
||||
# model.module.load_state_dict(new_state_dict)
|
||||
# else:
|
||||
# model.load_state_dict(new_state_dict)
|
||||
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
|
||||
# checkpoint_path, iteration))
|
||||
# return model, optimizer, learning_rate, iteration
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
saved_state_dict = checkpoint_dict["model"]
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items(): # 模型需要的shape
|
||||
try:
|
||||
new_state_dict[k] = saved_state_dict[k]
|
||||
if saved_state_dict[k].shape != state_dict[k].shape:
|
||||
print(
|
||||
"shape-%s-mismatch|need-%s|get-%s"
|
||||
% (k, state_dict[k].shape, saved_state_dict[k].shape)
|
||||
) #
|
||||
raise KeyError
|
||||
except:
|
||||
# logger.info(traceback.format_exc())
|
||||
logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
|
||||
new_state_dict[k] = v # 模型自带的随机值
|
||||
if hasattr(model, "module"):
|
||||
model.module.load_state_dict(new_state_dict, strict=False)
|
||||
else:
|
||||
model.load_state_dict(new_state_dict, strict=False)
|
||||
logger.info("Loaded model weights")
|
||||
|
||||
iteration = checkpoint_dict["iteration"]
|
||||
learning_rate = checkpoint_dict["learning_rate"]
|
||||
if (
|
||||
optimizer is not None and load_opt == 1
|
||||
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
|
||||
# try:
|
||||
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
||||
# except:
|
||||
# traceback.print_exc()
|
||||
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
|
||||
return model, optimizer, learning_rate, iteration
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
||||
logger.info(
|
||||
"Saving model and optimizer state at epoch {} to {}".format(
|
||||
iteration, checkpoint_path
|
||||
)
|
||||
)
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(
|
||||
{
|
||||
"model": state_dict,
|
||||
"iteration": iteration,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"learning_rate": learning_rate,
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
|
||||
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
|
||||
logger.info(
|
||||
"Saving model and optimizer state at epoch {} to {}".format(
|
||||
iteration, checkpoint_path
|
||||
)
|
||||
)
|
||||
if hasattr(combd, "module"):
|
||||
state_dict_combd = combd.module.state_dict()
|
||||
else:
|
||||
state_dict_combd = combd.state_dict()
|
||||
if hasattr(sbd, "module"):
|
||||
state_dict_sbd = sbd.module.state_dict()
|
||||
else:
|
||||
state_dict_sbd = sbd.state_dict()
|
||||
torch.save(
|
||||
{
|
||||
"combd": state_dict_combd,
|
||||
"sbd": state_dict_sbd,
|
||||
"iteration": iteration,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"learning_rate": learning_rate,
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
|
||||
def summarize(
|
||||
writer,
|
||||
global_step,
|
||||
scalars={},
|
||||
histograms={},
|
||||
images={},
|
||||
audios={},
|
||||
audio_sampling_rate=22050,
|
||||
):
|
||||
for k, v in scalars.items():
|
||||
writer.add_scalar(k, v, global_step)
|
||||
for k, v in histograms.items():
|
||||
writer.add_histogram(k, v, global_step)
|
||||
for k, v in images.items():
|
||||
writer.add_image(k, v, global_step, dataformats="HWC")
|
||||
for k, v in audios.items():
|
||||
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
||||
|
||||
|
||||
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
||||
f_list = glob.glob(os.path.join(dir_path, regex))
|
||||
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
|
||||
x = f_list[-1]
|
||||
print(x)
|
||||
return x
|
||||
|
||||
|
||||
def plot_spectrogram_to_numpy(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def plot_alignment_to_numpy(alignment, info=None):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
im = ax.imshow(
|
||||
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
|
||||
)
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = "Decoder timestep"
|
||||
if info is not None:
|
||||
xlabel += "\n\n" + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel("Encoder timestep")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def load_wav_to_torch(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
||||
|
||||
|
||||
def load_filepaths_and_text(filename, split="|"):
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def get_hparams(init=True):
|
||||
"""
|
||||
todo:
|
||||
结尾七人组:
|
||||
保存频率、总epoch done
|
||||
bs done
|
||||
pretrainG、pretrainD done
|
||||
卡号:os.en["CUDA_VISIBLE_DEVICES"] done
|
||||
if_latest done
|
||||
模型:if_f0 done
|
||||
采样率:自动选择config done
|
||||
是否缓存数据集进GPU:if_cache_data_in_gpu done
|
||||
|
||||
-m:
|
||||
自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done
|
||||
-c不要了
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration')
|
||||
parser.add_argument(
|
||||
"-se",
|
||||
"--save_every_epoch",
|
||||
type=int,
|
||||
required=True,
|
||||
help="checkpoint save frequency (epoch)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-te", "--total_epoch", type=int, required=True, help="total_epoch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
|
||||
)
|
||||
parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
|
||||
parser.add_argument(
|
||||
"-bs", "--batch_size", type=int, required=True, help="batch size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--experiment_dir", type=str, required=True, help="experiment dir"
|
||||
) # -m
|
||||
parser.add_argument(
|
||||
"-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-sw",
|
||||
"--save_every_weights",
|
||||
type=str,
|
||||
default="0",
|
||||
help="save the extracted model in weights directory when saving checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--version", type=str, required=True, help="model version"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f0",
|
||||
"--if_f0",
|
||||
type=int,
|
||||
required=True,
|
||||
help="use f0 as one of the inputs of the model, 1 or 0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--if_latest",
|
||||
type=int,
|
||||
required=True,
|
||||
help="if only save the latest G/D pth file, 1 or 0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--if_cache_data_in_gpu",
|
||||
type=int,
|
||||
required=True,
|
||||
help="if caching the dataset in GPU memory, 1 or 0",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
name = args.experiment_dir
|
||||
experiment_dir = os.path.join("./logs", args.experiment_dir)
|
||||
|
||||
if not os.path.exists(experiment_dir):
|
||||
os.makedirs(experiment_dir)
|
||||
|
||||
if args.version == "v1" or args.sample_rate == "40k":
|
||||
config_path = "configs/%s.json" % args.sample_rate
|
||||
else:
|
||||
config_path = "configs/%s_v2.json" % args.sample_rate
|
||||
config_save_path = os.path.join(experiment_dir, "config.json")
|
||||
if init:
|
||||
with open(config_path, "r") as f:
|
||||
data = f.read()
|
||||
with open(config_save_path, "w") as f:
|
||||
f.write(data)
|
||||
else:
|
||||
with open(config_save_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams = HParams(**config)
|
||||
hparams.model_dir = hparams.experiment_dir = experiment_dir
|
||||
hparams.save_every_epoch = args.save_every_epoch
|
||||
hparams.name = name
|
||||
hparams.total_epoch = args.total_epoch
|
||||
hparams.pretrainG = args.pretrainG
|
||||
hparams.pretrainD = args.pretrainD
|
||||
hparams.version = args.version
|
||||
hparams.gpus = args.gpus
|
||||
hparams.train.batch_size = args.batch_size
|
||||
hparams.sample_rate = args.sample_rate
|
||||
hparams.if_f0 = args.if_f0
|
||||
hparams.if_latest = args.if_latest
|
||||
hparams.save_every_weights = args.save_every_weights
|
||||
hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
|
||||
hparams.data.training_files = "%s/filelist.txt" % experiment_dir
|
||||
return hparams
|
||||
|
||||
|
||||
def get_hparams_from_dir(model_dir):
|
||||
config_save_path = os.path.join(model_dir, "config.json")
|
||||
with open(config_save_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams = HParams(**config)
|
||||
hparams.model_dir = model_dir
|
||||
return hparams
|
||||
|
||||
|
||||
def get_hparams_from_file(config_path):
|
||||
with open(config_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams = HParams(**config)
|
||||
return hparams
|
||||
|
||||
|
||||
def check_git_hash(model_dir):
|
||||
source_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
if not os.path.exists(os.path.join(source_dir, ".git")):
|
||||
logger.warn(
|
||||
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||
source_dir
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
||||
|
||||
path = os.path.join(model_dir, "githash")
|
||||
if os.path.exists(path):
|
||||
saved_hash = open(path).read()
|
||||
if saved_hash != cur_hash:
|
||||
logger.warn(
|
||||
"git hash values are different. {}(saved) != {}(current)".format(
|
||||
saved_hash[:8], cur_hash[:8]
|
||||
)
|
||||
)
|
||||
else:
|
||||
open(path, "w").write(cur_hash)
|
||||
|
||||
|
||||
def get_logger(model_dir, filename="train.log"):
|
||||
global logger
|
||||
logger = logging.getLogger(os.path.basename(model_dir))
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
h = logging.FileHandler(os.path.join(model_dir, filename))
|
||||
h.setLevel(logging.DEBUG)
|
||||
h.setFormatter(formatter)
|
||||
logger.addHandler(h)
|
||||
return logger
|
||||
|
||||
|
||||
class HParams:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if type(v) == dict:
|
||||
v = HParams(**v)
|
||||
self[k] = v
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def values(self):
|
||||
return self.__dict__.values()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__dict__)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return self.__dict__.__repr__()
|
Loading…
Reference in New Issue
Block a user