e9dd11bddb
* Optimize latency (#1259) * add attribute: configs/config.py Optimize latency: tools/rvc_for_realtime.py * new file: assets/Synthesizer_inputs.pth * fix: configs/config.py fix: tools/rvc_for_realtime.py * fix bug: infer/lib/infer_pack/models.py * new file: assets/hubert_inputs.pth new file: assets/rmvpe_inputs.pth modified: configs/config.py new features: infer/lib/rmvpe.py new features: tools/jit_export/__init__.py new features: tools/jit_export/get_hubert.py new features: tools/jit_export/get_rmvpe.py new features: tools/jit_export/get_synthesizer.py optimize: tools/rvc_for_realtime.py * optimize: tools/jit_export/get_synthesizer.py fix bug: tools/jit_export/__init__.py * Fixed a bug caused by using half on the CPU: infer/lib/rmvpe.py Fixed a bug caused by using half on the CPU: tools/jit_export/__init__.py Fixed CIRCULAR IMPORT: tools/jit_export/get_rmvpe.py Fixed CIRCULAR IMPORT: tools/jit_export/get_synthesizer.py Fixed a bug caused by using half on the CPU: tools/rvc_for_realtime.py * Remove useless code: infer/lib/rmvpe.py * Delete gui_v1 copy.py * Delete .vscode/launch.json * Delete jit_export_test.py * Delete tools/rvc_for_realtime copy.py * Delete configs/config.json * Delete .gitignore * Fix exceptions caused by switching inference devices: infer/lib/rmvpe.py Fix exceptions caused by switching inference devices: tools/jit_export/__init__.py Fix exceptions caused by switching inference devices: tools/rvc_for_realtime.py * restore * replace(you can undo this commit) * remove debug_print --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * Fixed some bugs when exporting ONNX model (#1254) * fix import (#1280) * fix import * lint * 🎨 同步 locale (#1242) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Fix jit load and import issue (#1282) * fix jit model loading : infer/lib/rmvpe.py * modified: assets/hubert/.gitignore move file: assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth modified: assets/rmvpe/.gitignore move file: assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth fix import: gui_v1.py * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * Add input wav and delay time monitor for real-time gui (#1293) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * add input wav and delay time monitor --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> * Optimize latency using scripted jit (#1291) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * Optimize-latency-using-scripted: configs/config.py Optimize-latency-using-scripted: infer/lib/infer_pack/attentions.py Optimize-latency-using-scripted: infer/lib/infer_pack/commons.py Optimize-latency-using-scripted: infer/lib/infer_pack/models.py Optimize-latency-using-scripted: infer/lib/infer_pack/modules.py Optimize-latency-using-scripted: infer/lib/jit/__init__.py Optimize-latency-using-scripted: infer/lib/jit/get_hubert.py Optimize-latency-using-scripted: infer/lib/jit/get_rmvpe.py Optimize-latency-using-scripted: infer/lib/jit/get_synthesizer.py Optimize-latency-using-scripted: infer/lib/rmvpe.py Optimize-latency-using-scripted: tools/rvc_for_realtime.py * modified: infer/lib/infer_pack/models.py * fix some bug: configs/config.py fix some bug: infer/lib/infer_pack/models.py fix some bug: infer/lib/rmvpe.py * Fixed abnormal reference of logger in multiprocessing: infer/modules/train/train.py --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Format code (#1298) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * 🎨 同步 locale (#1299) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: optimize actions * feat(workflow): add sync dev * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: add jit options (#1303) Delete useless code: infer/lib/jit/get_synthesizer.py Optimized code: tools/rvc_for_realtime.py * Code refactor + re-design inference ui (#1304) * Code refacor + re-design inference ui * Fix tabname * i18n jp --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * feat: optimize actions * feat: optimize actions * Update README & en_US locale file (#1309) * critical: some bug fixes (#1322) * JIT acceleration switch does not support hot update * fix padding bug of rmvpe in torch-directml * fix padding bug of rmvpe in torch-directml * Fix STFT under torch_directml (#1330) * chore(format): run black on dev (#1318) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(i18n): sync locale on dev (#1317) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: allow for tta to be passed to uvr (#1361) * chore(format): run black on dev (#1373) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Added script for automatically download all needed models at install (#1366) * Delete modules.py * Add files via upload * Add files via upload * Add files via upload * Add files via upload * chore(i18n): sync locale on dev (#1377) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(format): run black on dev (#1376) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Update IPEX library (#1362) * Update IPEX library * Update ipex index * chore(format): run black on dev (#1378) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com> Co-authored-by: Ftps <ftpsflandre@gmail.com> Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com> Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com> Co-authored-by: Rice Cake <gak141808@gmail.com> Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com> Co-authored-by: Dmitry <nda2911@yandex.ru> Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
281 lines
11 KiB
Python
281 lines
11 KiB
Python
import torch
|
|
from infer.lib.rmvpe import STFT
|
|
from torch.nn.functional import conv1d, conv2d
|
|
from typing import Union, Optional
|
|
from .utils import linspace, temperature_sigmoid, amp_to_db
|
|
|
|
|
|
class TorchGate(torch.nn.Module):
|
|
"""
|
|
A PyTorch module that applies a spectral gate to an input signal.
|
|
|
|
Arguments:
|
|
sr {int} -- Sample rate of the input signal.
|
|
nonstationary {bool} -- Whether to use non-stationary or stationary masking (default: {False}).
|
|
n_std_thresh_stationary {float} -- Number of standard deviations above mean to threshold noise for
|
|
stationary masking (default: {1.5}).
|
|
n_thresh_nonstationary {float} -- Number of multiplies above smoothed magnitude spectrogram. for
|
|
non-stationary masking (default: {1.3}).
|
|
temp_coeff_nonstationary {float} -- Temperature coefficient for non-stationary masking (default: {0.1}).
|
|
n_movemean_nonstationary {int} -- Number of samples for moving average smoothing in non-stationary masking
|
|
(default: {20}).
|
|
prop_decrease {float} -- Proportion to decrease signal by where the mask is zero (default: {1.0}).
|
|
n_fft {int} -- Size of FFT for STFT (default: {1024}).
|
|
win_length {[int]} -- Window length for STFT. If None, defaults to `n_fft` (default: {None}).
|
|
hop_length {[int]} -- Hop length for STFT. If None, defaults to `win_length` // 4 (default: {None}).
|
|
freq_mask_smooth_hz {float} -- Frequency smoothing width for mask (in Hz). If None, no smoothing is applied
|
|
(default: {500}).
|
|
time_mask_smooth_ms {float} -- Time smoothing width for mask (in ms). If None, no smoothing is applied
|
|
(default: {50}).
|
|
"""
|
|
|
|
@torch.no_grad()
|
|
def __init__(
|
|
self,
|
|
sr: int,
|
|
nonstationary: bool = False,
|
|
n_std_thresh_stationary: float = 1.5,
|
|
n_thresh_nonstationary: float = 1.3,
|
|
temp_coeff_nonstationary: float = 0.1,
|
|
n_movemean_nonstationary: int = 20,
|
|
prop_decrease: float = 1.0,
|
|
n_fft: int = 1024,
|
|
win_length: bool = None,
|
|
hop_length: int = None,
|
|
freq_mask_smooth_hz: float = 500,
|
|
time_mask_smooth_ms: float = 50,
|
|
):
|
|
super().__init__()
|
|
|
|
# General Params
|
|
self.sr = sr
|
|
self.nonstationary = nonstationary
|
|
assert 0.0 <= prop_decrease <= 1.0
|
|
self.prop_decrease = prop_decrease
|
|
|
|
# STFT Params
|
|
self.n_fft = n_fft
|
|
self.win_length = self.n_fft if win_length is None else win_length
|
|
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
|
|
|
|
# Stationary Params
|
|
self.n_std_thresh_stationary = n_std_thresh_stationary
|
|
|
|
# Non-Stationary Params
|
|
self.temp_coeff_nonstationary = temp_coeff_nonstationary
|
|
self.n_movemean_nonstationary = n_movemean_nonstationary
|
|
self.n_thresh_nonstationary = n_thresh_nonstationary
|
|
|
|
# Smooth Mask Params
|
|
self.freq_mask_smooth_hz = freq_mask_smooth_hz
|
|
self.time_mask_smooth_ms = time_mask_smooth_ms
|
|
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
|
|
|
|
@torch.no_grad()
|
|
def _generate_mask_smoothing_filter(self) -> Union[torch.Tensor, None]:
|
|
"""
|
|
A PyTorch module that applies a spectral gate to an input signal using the STFT.
|
|
|
|
Returns:
|
|
smoothing_filter (torch.Tensor): a 2D tensor representing the smoothing filter,
|
|
with shape (n_grad_freq, n_grad_time), where n_grad_freq is the number of frequency
|
|
bins to smooth and n_grad_time is the number of time frames to smooth.
|
|
If both self.freq_mask_smooth_hz and self.time_mask_smooth_ms are None, returns None.
|
|
"""
|
|
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None:
|
|
return None
|
|
|
|
n_grad_freq = (
|
|
1
|
|
if self.freq_mask_smooth_hz is None
|
|
else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2)))
|
|
)
|
|
if n_grad_freq < 1:
|
|
raise ValueError(
|
|
f"freq_mask_smooth_hz needs to be at least {int((self.sr / (self._n_fft / 2)))} Hz"
|
|
)
|
|
|
|
n_grad_time = (
|
|
1
|
|
if self.time_mask_smooth_ms is None
|
|
else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000))
|
|
)
|
|
if n_grad_time < 1:
|
|
raise ValueError(
|
|
f"time_mask_smooth_ms needs to be at least {int((self.hop_length / self.sr) * 1000)} ms"
|
|
)
|
|
|
|
if n_grad_time == 1 and n_grad_freq == 1:
|
|
return None
|
|
|
|
v_f = torch.cat(
|
|
[
|
|
linspace(0, 1, n_grad_freq + 1, endpoint=False),
|
|
linspace(1, 0, n_grad_freq + 2),
|
|
]
|
|
)[1:-1]
|
|
v_t = torch.cat(
|
|
[
|
|
linspace(0, 1, n_grad_time + 1, endpoint=False),
|
|
linspace(1, 0, n_grad_time + 2),
|
|
]
|
|
)[1:-1]
|
|
smoothing_filter = torch.outer(v_f, v_t).unsqueeze(0).unsqueeze(0)
|
|
|
|
return smoothing_filter / smoothing_filter.sum()
|
|
|
|
@torch.no_grad()
|
|
def _stationary_mask(
|
|
self, X_db: torch.Tensor, xn: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes a stationary binary mask to filter out noise in a log-magnitude spectrogram.
|
|
|
|
Arguments:
|
|
X_db (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the log-magnitude spectrogram.
|
|
xn (torch.Tensor): 1D tensor containing the audio signal corresponding to X_db.
|
|
|
|
Returns:
|
|
sig_mask (torch.Tensor): Binary mask of the same shape as X_db, where values greater than the threshold
|
|
are set to 1, and the rest are set to 0.
|
|
"""
|
|
if xn is not None:
|
|
if "privateuseone" in str(xn.device):
|
|
if not hasattr(self, "stft"):
|
|
self.stft = STFT(
|
|
filter_length=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
win_length=self.win_length,
|
|
window="hann",
|
|
).to(xn.device)
|
|
XN = self.stft.transform(xn)
|
|
else:
|
|
XN = torch.stft(
|
|
xn,
|
|
n_fft=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
win_length=self.win_length,
|
|
return_complex=True,
|
|
pad_mode="constant",
|
|
center=True,
|
|
window=torch.hann_window(self.win_length).to(xn.device),
|
|
)
|
|
XN_db = amp_to_db(XN).to(dtype=X_db.dtype)
|
|
else:
|
|
XN_db = X_db
|
|
|
|
# calculate mean and standard deviation along the frequency axis
|
|
std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
|
|
|
|
# compute noise threshold
|
|
noise_thresh = mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary
|
|
|
|
# create binary mask by thresholding the spectrogram
|
|
sig_mask = X_db > noise_thresh.unsqueeze(2)
|
|
return sig_mask
|
|
|
|
@torch.no_grad()
|
|
def _nonstationary_mask(self, X_abs: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Computes a non-stationary binary mask to filter out noise in a log-magnitude spectrogram.
|
|
|
|
Arguments:
|
|
X_abs (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the magnitude spectrogram.
|
|
|
|
Returns:
|
|
sig_mask (torch.Tensor): Binary mask of the same shape as X_abs, where values greater than the threshold
|
|
are set to 1, and the rest are set to 0.
|
|
"""
|
|
X_smoothed = (
|
|
conv1d(
|
|
X_abs.reshape(-1, 1, X_abs.shape[-1]),
|
|
torch.ones(
|
|
self.n_movemean_nonstationary,
|
|
dtype=X_abs.dtype,
|
|
device=X_abs.device,
|
|
).view(1, 1, -1),
|
|
padding="same",
|
|
).view(X_abs.shape)
|
|
/ self.n_movemean_nonstationary
|
|
)
|
|
|
|
# Compute slowness ratio and apply temperature sigmoid
|
|
slowness_ratio = (X_abs - X_smoothed) / (X_smoothed + 1e-6)
|
|
sig_mask = temperature_sigmoid(
|
|
slowness_ratio, self.n_thresh_nonstationary, self.temp_coeff_nonstationary
|
|
)
|
|
|
|
return sig_mask
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, xn: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Apply the proposed algorithm to the input signal.
|
|
|
|
Arguments:
|
|
x (torch.Tensor): The input audio signal, with shape (batch_size, signal_length).
|
|
xn (Optional[torch.Tensor]): The noise signal used for stationary noise reduction. If `None`, the input
|
|
signal is used as the noise signal. Default: `None`.
|
|
|
|
Returns:
|
|
torch.Tensor: The denoised audio signal, with the same shape as the input signal.
|
|
"""
|
|
|
|
# Compute short-time Fourier transform (STFT)
|
|
if "privateuseone" in str(x.device):
|
|
if not hasattr(self, "stft"):
|
|
self.stft = STFT(
|
|
filter_length=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
win_length=self.win_length,
|
|
window="hann",
|
|
).to(x.device)
|
|
X, phase = self.stft.transform(x, return_phase=True)
|
|
else:
|
|
X = torch.stft(
|
|
x,
|
|
n_fft=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
win_length=self.win_length,
|
|
return_complex=True,
|
|
pad_mode="constant",
|
|
center=True,
|
|
window=torch.hann_window(self.win_length).to(x.device),
|
|
)
|
|
|
|
# Compute signal mask based on stationary or nonstationary assumptions
|
|
if self.nonstationary:
|
|
sig_mask = self._nonstationary_mask(X.abs())
|
|
else:
|
|
sig_mask = self._stationary_mask(amp_to_db(X), xn)
|
|
|
|
# Propagate decrease in signal power
|
|
sig_mask = self.prop_decrease * (sig_mask.float() - 1.0) + 1.0
|
|
|
|
# Smooth signal mask with 2D convolution
|
|
if self.smoothing_filter is not None:
|
|
sig_mask = conv2d(
|
|
sig_mask.unsqueeze(1),
|
|
self.smoothing_filter.to(sig_mask.dtype),
|
|
padding="same",
|
|
)
|
|
|
|
# Apply signal mask to STFT magnitude and phase components
|
|
Y = X * sig_mask.squeeze(1)
|
|
|
|
# Inverse STFT to obtain time-domain signal
|
|
if "privateuseone" in str(Y.device):
|
|
y = self.stft.inverse(Y, phase)
|
|
else:
|
|
y = torch.istft(
|
|
Y,
|
|
n_fft=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
win_length=self.win_length,
|
|
center=True,
|
|
window=torch.hann_window(self.win_length).to(Y.device),
|
|
)
|
|
|
|
return y.to(dtype=x.dtype)
|