mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-12-18 10:26:03 +01:00
503 lines
20 KiB
Python
503 lines
20 KiB
Python
from typing import Optional
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.utils.data import DataLoader
|
|
|
|
def atan2(y, x):
|
|
r"""Element-wise arctangent function of y/x.
|
|
Returns a new tensor with signed angles in radians.
|
|
It is an alternative implementation of torch.atan2
|
|
|
|
Args:
|
|
y (Tensor): First input tensor
|
|
x (Tensor): Second input tensor [shape=y.shape]
|
|
|
|
Returns:
|
|
Tensor: [shape=y.shape].
|
|
"""
|
|
pi = 2 * torch.asin(torch.tensor(1.0))
|
|
x += ((x == 0) & (y == 0)) * 1.0
|
|
out = torch.atan(y / x)
|
|
out += ((y >= 0) & (x < 0)) * pi
|
|
out -= ((y < 0) & (x < 0)) * pi
|
|
out *= 1 - ((y > 0) & (x == 0)) * 1.0
|
|
out += ((y > 0) & (x == 0)) * (pi / 2)
|
|
out *= 1 - ((y < 0) & (x == 0)) * 1.0
|
|
out += ((y < 0) & (x == 0)) * (-pi / 2)
|
|
return out
|
|
|
|
|
|
# Define basic complex operations on torch.Tensor objects whose last dimension
|
|
# consists in the concatenation of the real and imaginary parts.
|
|
|
|
|
|
def _norm(x: torch.Tensor) -> torch.Tensor:
|
|
r"""Computes the norm value of a torch Tensor, assuming that it
|
|
comes as real and imaginary part in its last dimension.
|
|
|
|
Args:
|
|
x (Tensor): Input Tensor of shape [shape=(..., 2)]
|
|
|
|
Returns:
|
|
Tensor: shape as x excluding the last dimension.
|
|
"""
|
|
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
|
|
|
|
|
|
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""Element-wise multiplication of two complex Tensors described
|
|
through their real and imaginary parts.
|
|
The result is added to the `out` tensor"""
|
|
|
|
# check `out` and allocate it if needed
|
|
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
|
if out is None or out.shape != target_shape:
|
|
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
|
if out is a:
|
|
real_a = a[..., 0]
|
|
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
|
|
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
|
|
else:
|
|
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
|
|
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
|
|
return out
|
|
|
|
|
|
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""Element-wise multiplication of two complex Tensors described
|
|
through their real and imaginary parts
|
|
can work in place in case out is a only"""
|
|
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
|
if out is None or out.shape != target_shape:
|
|
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
|
if out is a:
|
|
real_a = a[..., 0]
|
|
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
|
|
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
|
|
else:
|
|
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
|
|
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
|
|
return out
|
|
|
|
|
|
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""Element-wise multiplicative inverse of a Tensor with complex
|
|
entries described through their real and imaginary parts.
|
|
can work in place in case out is z"""
|
|
ez = _norm(z)
|
|
if out is None or out.shape != z.shape:
|
|
out = torch.zeros_like(z)
|
|
out[..., 0] = z[..., 0] / ez
|
|
out[..., 1] = -z[..., 1] / ez
|
|
return out
|
|
|
|
|
|
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""Element-wise complex conjugate of a Tensor with complex entries
|
|
described through their real and imaginary parts.
|
|
can work in place in case out is z"""
|
|
if out is None or out.shape != z.shape:
|
|
out = torch.zeros_like(z)
|
|
out[..., 0] = z[..., 0]
|
|
out[..., 1] = -z[..., 1]
|
|
return out
|
|
|
|
|
|
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""
|
|
Invert 1x1 or 2x2 matrices
|
|
|
|
Will generate errors if the matrices are singular: user must handle this
|
|
through his own regularization schemes.
|
|
|
|
Args:
|
|
M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
|
|
matrices to invert: must be square along dimensions -3 and -2
|
|
|
|
Returns:
|
|
invM (Tensor): [shape=M.shape]
|
|
inverses of M
|
|
"""
|
|
nb_channels = M.shape[-2]
|
|
|
|
if out is None or out.shape != M.shape:
|
|
out = torch.empty_like(M)
|
|
|
|
if nb_channels == 1:
|
|
# scalar case
|
|
out = _inv(M, out)
|
|
elif nb_channels == 2:
|
|
# two channels case: analytical expression
|
|
|
|
# first compute the determinent
|
|
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
|
|
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
|
|
# invert it
|
|
invDet = _inv(det)
|
|
|
|
# then fill out the matrix with the inverse
|
|
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
|
|
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
|
|
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
|
|
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
|
|
else:
|
|
raise Exception("Only 2 channels are supported for the torch version.")
|
|
return out
|
|
|
|
|
|
# Now define the signal-processing low-level functions used by the Separator
|
|
|
|
|
|
def expectation_maximization(
|
|
y: torch.Tensor,
|
|
x: torch.Tensor,
|
|
iterations: int = 2,
|
|
eps: float = 1e-10,
|
|
batch_size: int = 200,
|
|
):
|
|
r"""Expectation maximization algorithm, for refining source separation
|
|
estimates.
|
|
|
|
This algorithm allows to make source separation results better by
|
|
enforcing multichannel consistency for the estimates. This usually means
|
|
a better perceptual quality in terms of spatial artifacts.
|
|
|
|
The implementation follows the details presented in [1]_, taking
|
|
inspiration from the original EM algorithm proposed in [2]_ and its
|
|
weighted refinement proposed in [3]_, [4]_.
|
|
It works by iteratively:
|
|
|
|
* Re-estimate source parameters (power spectral densities and spatial
|
|
covariance matrices) through :func:`get_local_gaussian_model`.
|
|
|
|
* Separate again the mixture with the new parameters by first computing
|
|
the new modelled mixture covariance matrices with :func:`get_mix_model`,
|
|
prepare the Wiener filters through :func:`wiener_gain` and apply them
|
|
with :func:`apply_filter``.
|
|
|
|
References
|
|
----------
|
|
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
|
|
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
|
|
on deep neural networks through data augmentation and network
|
|
blending." 2017 IEEE International Conference on Acoustics, Speech
|
|
and Signal Processing (ICASSP). IEEE, 2017.
|
|
|
|
.. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
|
|
reverberant audio source separation using a full-rank spatial
|
|
covariance model." IEEE Transactions on Audio, Speech, and Language
|
|
Processing 18.7 (2010): 1830-1840.
|
|
|
|
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
|
|
separation with deep neural networks." IEEE/ACM Transactions on Audio,
|
|
Speech, and Language Processing 24.9 (2016): 1652-1664.
|
|
|
|
.. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
|
|
separation with deep neural networks." 2016 24th European Signal
|
|
Processing Conference (EUSIPCO). IEEE, 2016.
|
|
|
|
.. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
|
|
source separation." IEEE Transactions on Signal Processing
|
|
62.16 (2014): 4298-4310.
|
|
|
|
Args:
|
|
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
|
initial estimates for the sources
|
|
x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
|
|
complex STFT of the mixture signal
|
|
iterations (int): [scalar]
|
|
number of iterations for the EM algorithm.
|
|
eps (float or None): [scalar]
|
|
The epsilon value to use for regularization and filters.
|
|
|
|
Returns:
|
|
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
|
estimated sources after iterations
|
|
v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
|
|
estimated power spectral densities
|
|
R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
|
|
estimated spatial covariance matrices
|
|
|
|
Notes:
|
|
* You need an initial estimate for the sources to apply this
|
|
algorithm. This is precisely what the :func:`wiener` function does.
|
|
* This algorithm *is not* an implementation of the "exact" EM
|
|
proposed in [1]_. In particular, it does compute the posterior
|
|
covariance matrices the same (exact) way. Instead, it uses the
|
|
simplified approximate scheme initially proposed in [5]_ and further
|
|
refined in [3]_, [4]_, that boils down to just take the empirical
|
|
covariance of the recent source estimates, followed by a weighted
|
|
average for the update of the spatial covariance matrix. It has been
|
|
empirically demonstrated that this simplified algorithm is more
|
|
robust for music separation.
|
|
|
|
Warning:
|
|
It is *very* important to make sure `x.dtype` is `torch.float64`
|
|
if you want double precision, because this function will **not**
|
|
do such conversion for you from `torch.complex32`, in case you want the
|
|
smaller RAM usage on purpose.
|
|
|
|
It is usually always better in terms of quality to have double
|
|
precision, by e.g. calling :func:`expectation_maximization`
|
|
with ``x.to(torch.float64)``.
|
|
"""
|
|
# dimensions
|
|
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
|
|
nb_sources = y.shape[-1]
|
|
|
|
regularization = torch.cat(
|
|
(
|
|
torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None],
|
|
torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device),
|
|
),
|
|
dim=2,
|
|
)
|
|
regularization = torch.sqrt(torch.as_tensor(eps)) * (
|
|
regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1))
|
|
)
|
|
|
|
# allocate the spatial covariance matrices
|
|
R = [
|
|
torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device)
|
|
for j in range(nb_sources)
|
|
]
|
|
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
|
|
|
|
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
|
|
for it in range(iterations):
|
|
# constructing the mixture covariance matrix. Doing it with a loop
|
|
# to avoid storing anytime in RAM the whole 6D tensor
|
|
|
|
# update the PSD as the average spectrogram over channels
|
|
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
|
|
|
|
# update spatial covariance matrices (weighted update)
|
|
for j in range(nb_sources):
|
|
R[j] = torch.tensor(0.0, device=x.device)
|
|
weight = torch.tensor(eps, device=x.device)
|
|
pos: int = 0
|
|
batch_size = batch_size if batch_size else nb_frames
|
|
while pos < nb_frames:
|
|
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
|
pos = int(t[-1]) + 1
|
|
|
|
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
|
|
weight = weight + torch.sum(v[t, ..., j], dim=0)
|
|
R[j] = R[j] / weight[..., None, None, None]
|
|
weight = torch.zeros_like(weight)
|
|
|
|
# cloning y if we track gradient, because we're going to update it
|
|
if y.requires_grad:
|
|
y = y.clone()
|
|
|
|
pos = 0
|
|
while pos < nb_frames:
|
|
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
|
pos = int(t[-1]) + 1
|
|
|
|
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
|
|
|
# compute mix covariance matrix
|
|
Cxx = regularization
|
|
for j in range(nb_sources):
|
|
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
|
|
|
|
# invert it
|
|
inv_Cxx = _invert(Cxx)
|
|
|
|
# separate the sources
|
|
for j in range(nb_sources):
|
|
|
|
# create a wiener gain for this source
|
|
gain = torch.zeros_like(inv_Cxx)
|
|
|
|
# computes multichannel Wiener gain as v_j R_j inv_Cxx
|
|
indices = torch.cartesian_prod(
|
|
torch.arange(nb_channels),
|
|
torch.arange(nb_channels),
|
|
torch.arange(nb_channels),
|
|
)
|
|
for index in indices:
|
|
gain[:, :, index[0], index[1], :] = _mul_add(
|
|
R[j][None, :, index[0], index[2], :].clone(),
|
|
inv_Cxx[:, :, index[2], index[1], :],
|
|
gain[:, :, index[0], index[1], :],
|
|
)
|
|
gain = gain * v[t, ..., None, None, None, j]
|
|
|
|
# apply it to the mixture
|
|
for i in range(nb_channels):
|
|
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
|
|
|
|
return y, v, R
|
|
|
|
|
|
def wiener(
|
|
targets_spectrograms: torch.Tensor,
|
|
mix_stft: torch.Tensor,
|
|
iterations: int = 1,
|
|
softmask: bool = False,
|
|
residual: bool = False,
|
|
scale_factor: float = 10.0,
|
|
eps: float = 1e-10,
|
|
):
|
|
"""Wiener-based separation for multichannel audio.
|
|
|
|
The method uses the (possibly multichannel) spectrograms of the
|
|
sources to separate the (complex) Short Term Fourier Transform of the
|
|
mix. Separation is done in a sequential way by:
|
|
|
|
* Getting an initial estimate. This can be done in two ways: either by
|
|
directly using the spectrograms with the mixture phase, or
|
|
by using a softmasking strategy. This initial phase is controlled
|
|
by the `softmask` flag.
|
|
|
|
* If required, adding an additional residual target as the mix minus
|
|
all targets.
|
|
|
|
* Refinining these initial estimates through a call to
|
|
:func:`expectation_maximization` if the number of iterations is nonzero.
|
|
|
|
This implementation also allows to specify the epsilon value used for
|
|
regularization. It is based on [1]_, [2]_, [3]_, [4]_.
|
|
|
|
References
|
|
----------
|
|
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
|
|
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
|
|
on deep neural networks through data augmentation and network
|
|
blending." 2017 IEEE International Conference on Acoustics, Speech
|
|
and Signal Processing (ICASSP). IEEE, 2017.
|
|
|
|
.. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
|
|
separation with deep neural networks." IEEE/ACM Transactions on Audio,
|
|
Speech, and Language Processing 24.9 (2016): 1652-1664.
|
|
|
|
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
|
|
separation with deep neural networks." 2016 24th European Signal
|
|
Processing Conference (EUSIPCO). IEEE, 2016.
|
|
|
|
.. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
|
|
source separation." IEEE Transactions on Signal Processing
|
|
62.16 (2014): 4298-4310.
|
|
|
|
Args:
|
|
targets_spectrograms (Tensor): spectrograms of the sources
|
|
[shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
|
|
This is a nonnegative tensor that is
|
|
usually the output of the actual separation method of the user. The
|
|
spectrograms may be mono, but they need to be 4-dimensional in all
|
|
cases.
|
|
mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
|
|
STFT of the mixture signal.
|
|
iterations (int): [scalar]
|
|
number of iterations for the EM algorithm
|
|
softmask (bool): Describes how the initial estimates are obtained.
|
|
* if `False`, then the mixture phase will directly be used with the
|
|
spectrogram as initial estimates.
|
|
* if `True`, initial estimates are obtained by multiplying the
|
|
complex mix element-wise with the ratio of each target spectrogram
|
|
with the sum of them all. This strategy is better if the model are
|
|
not really good, and worse otherwise.
|
|
residual (bool): if `True`, an additional target is created, which is
|
|
equal to the mixture minus the other targets, before application of
|
|
expectation maximization
|
|
eps (float): Epsilon value to use for computing the separations.
|
|
This is used whenever division with a model energy is
|
|
performed, i.e. when softmasking and when iterating the EM.
|
|
It can be understood as the energy of the additional white noise
|
|
that is taken out when separating.
|
|
|
|
Returns:
|
|
Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
|
|
STFT of estimated sources
|
|
|
|
Notes:
|
|
* Be careful that you need *magnitude spectrogram estimates* for the
|
|
case `softmask==False`.
|
|
* `softmask=False` is recommended
|
|
* The epsilon value will have a huge impact on performance. If it's
|
|
large, only the parts of the signal with a significant energy will
|
|
be kept in the sources. This epsilon then directly controls the
|
|
energy of the reconstruction error.
|
|
|
|
Warning:
|
|
As in :func:`expectation_maximization`, we recommend converting the
|
|
mixture `x` to double precision `torch.float64` *before* calling
|
|
:func:`wiener`.
|
|
"""
|
|
if softmask:
|
|
# if we use softmask, we compute the ratio mask for all targets and
|
|
# multiply by the mix stft
|
|
y = (
|
|
mix_stft[..., None]
|
|
* (
|
|
targets_spectrograms
|
|
/ (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype))
|
|
)[..., None, :]
|
|
)
|
|
else:
|
|
# otherwise, we just multiply the targets spectrograms with mix phase
|
|
# we tacitly assume that we have magnitude estimates.
|
|
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
|
|
nb_sources = targets_spectrograms.shape[-1]
|
|
y = torch.zeros(
|
|
mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device
|
|
)
|
|
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
|
|
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
|
|
|
|
if residual:
|
|
# if required, adding an additional target as the mix minus
|
|
# available targets
|
|
y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
|
|
|
|
if iterations == 0:
|
|
return y
|
|
|
|
# we need to refine the estimates. Scales down the estimates for
|
|
# numerical stability
|
|
max_abs = torch.max(
|
|
torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device),
|
|
torch.sqrt(_norm(mix_stft)).max() / scale_factor,
|
|
)
|
|
|
|
mix_stft = mix_stft / max_abs
|
|
y = y / max_abs
|
|
|
|
# call expectation maximization
|
|
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
|
|
|
|
# scale estimates up again
|
|
y = y * max_abs
|
|
return y
|
|
|
|
|
|
def _covariance(y_j):
|
|
"""
|
|
Compute the empirical covariance for a source.
|
|
|
|
Args:
|
|
y_j (Tensor): complex stft of the source.
|
|
[shape=(nb_frames, nb_bins, nb_channels, 2)].
|
|
|
|
Returns:
|
|
Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
|
|
just y_j * conj(y_j.T): empirical covariance for each TF bin.
|
|
"""
|
|
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
|
|
Cj = torch.zeros(
|
|
(nb_frames, nb_bins, nb_channels, nb_channels, 2),
|
|
dtype=y_j.dtype,
|
|
device=y_j.device,
|
|
)
|
|
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
|
|
for index in indices:
|
|
Cj[:, :, index[0], index[1], :] = _mul_add(
|
|
y_j[:, :, index[0], :],
|
|
_conj(y_j[:, :, index[1], :]),
|
|
Cj[:, :, index[0], index[1], :],
|
|
)
|
|
return Cj
|