mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-28 09:21:03 +01:00
762 lines
28 KiB
Python
762 lines
28 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""
|
|
This code contains the spectrogram and Hybrid version of Demucs.
|
|
"""
|
|
from copy import deepcopy
|
|
import math
|
|
|
|
from openunmix.filtering import wiener
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from .demucs import DConv, rescale_module
|
|
from .states import capture_init
|
|
from .spec import spectro, ispectro
|
|
|
|
|
|
class ScaledEmbedding(nn.Module):
|
|
"""
|
|
Boost learning rate for embeddings (with `scale`).
|
|
Also, can make embeddings continuous with `smooth`.
|
|
"""
|
|
def __init__(self, num_embeddings: int, embedding_dim: int,
|
|
scale: float = 10., smooth=False):
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
|
if smooth:
|
|
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
|
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
|
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
|
self.embedding.weight.data[:] = weight
|
|
self.embedding.weight.data /= scale
|
|
self.scale = scale
|
|
|
|
@property
|
|
def weight(self):
|
|
return self.embedding.weight * self.scale
|
|
|
|
def forward(self, x):
|
|
out = self.embedding(x) * self.scale
|
|
return out
|
|
|
|
|
|
class HEncLayer(nn.Module):
|
|
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
|
freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
|
|
rewrite=True):
|
|
"""Encoder layer. This used both by the time and the frequency branch.
|
|
|
|
Args:
|
|
chin: number of input channels.
|
|
chout: number of output channels.
|
|
norm_groups: number of groups for group norm.
|
|
empty: used to make a layer with just the first conv. this is used
|
|
before merging the time and freq. branches.
|
|
freq: this is acting on frequencies.
|
|
dconv: insert DConv residual branches.
|
|
norm: use GroupNorm.
|
|
context: context size for the 1x1 conv.
|
|
dconv_kw: list of kwargs for the DConv class.
|
|
pad: pad the input. Padding is done so that the output size is
|
|
always the input size / stride.
|
|
rewrite: add 1x1 conv at the end of the layer.
|
|
"""
|
|
super().__init__()
|
|
norm_fn = lambda d: nn.Identity() # noqa
|
|
if norm:
|
|
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
|
if pad:
|
|
pad = kernel_size // 4
|
|
else:
|
|
pad = 0
|
|
klass = nn.Conv1d
|
|
self.freq = freq
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.empty = empty
|
|
self.norm = norm
|
|
self.pad = pad
|
|
if freq:
|
|
kernel_size = [kernel_size, 1]
|
|
stride = [stride, 1]
|
|
pad = [pad, 0]
|
|
klass = nn.Conv2d
|
|
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
|
if self.empty:
|
|
return
|
|
self.norm1 = norm_fn(chout)
|
|
self.rewrite = None
|
|
if rewrite:
|
|
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
|
self.norm2 = norm_fn(2 * chout)
|
|
|
|
self.dconv = None
|
|
if dconv:
|
|
self.dconv = DConv(chout, **dconv_kw)
|
|
|
|
def forward(self, x, inject=None):
|
|
"""
|
|
`inject` is used to inject the result from the time branch into the frequency branch,
|
|
when both have the same stride.
|
|
"""
|
|
if not self.freq and x.dim() == 4:
|
|
B, C, Fr, T = x.shape
|
|
x = x.view(B, -1, T)
|
|
|
|
if not self.freq:
|
|
le = x.shape[-1]
|
|
if not le % self.stride == 0:
|
|
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
|
y = self.conv(x)
|
|
if self.empty:
|
|
return y
|
|
if inject is not None:
|
|
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
|
if inject.dim() == 3 and y.dim() == 4:
|
|
inject = inject[:, :, None]
|
|
y = y + inject
|
|
y = F.gelu(self.norm1(y))
|
|
if self.dconv:
|
|
if self.freq:
|
|
B, C, Fr, T = y.shape
|
|
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
|
y = self.dconv(y)
|
|
if self.freq:
|
|
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
|
if self.rewrite:
|
|
z = self.norm2(self.rewrite(y))
|
|
z = F.glu(z, dim=1)
|
|
else:
|
|
z = y
|
|
return z
|
|
|
|
|
|
class MultiWrap(nn.Module):
|
|
"""
|
|
Takes one layer and replicate it N times. each replica will act
|
|
on a frequency band. All is done so that if the N replica have the same weights,
|
|
then this is exactly equivalent to applying the original module on all frequencies.
|
|
|
|
This is a bit over-engineered to avoid edge artifacts when splitting
|
|
the frequency bands, but it is possible the naive implementation would work as well...
|
|
"""
|
|
def __init__(self, layer, split_ratios):
|
|
"""
|
|
Args:
|
|
layer: module to clone, must be either HEncLayer or HDecLayer.
|
|
split_ratios: list of float indicating which ratio to keep for each band.
|
|
"""
|
|
super().__init__()
|
|
self.split_ratios = split_ratios
|
|
self.layers = nn.ModuleList()
|
|
self.conv = isinstance(layer, HEncLayer)
|
|
assert not layer.norm
|
|
assert layer.freq
|
|
assert layer.pad
|
|
if not self.conv:
|
|
assert not layer.context_freq
|
|
for k in range(len(split_ratios) + 1):
|
|
lay = deepcopy(layer)
|
|
if self.conv:
|
|
lay.conv.padding = (0, 0)
|
|
else:
|
|
lay.pad = False
|
|
for m in lay.modules():
|
|
if hasattr(m, 'reset_parameters'):
|
|
m.reset_parameters()
|
|
self.layers.append(lay)
|
|
|
|
def forward(self, x, skip=None, length=None):
|
|
B, C, Fr, T = x.shape
|
|
|
|
ratios = list(self.split_ratios) + [1]
|
|
start = 0
|
|
outs = []
|
|
for ratio, layer in zip(ratios, self.layers):
|
|
if self.conv:
|
|
pad = layer.kernel_size // 4
|
|
if ratio == 1:
|
|
limit = Fr
|
|
frames = -1
|
|
else:
|
|
limit = int(round(Fr * ratio))
|
|
le = limit - start
|
|
if start == 0:
|
|
le += pad
|
|
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
|
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
|
if start == 0:
|
|
limit -= pad
|
|
assert limit - start > 0, (limit, start)
|
|
assert limit <= Fr, (limit, Fr)
|
|
y = x[:, :, start:limit, :]
|
|
if start == 0:
|
|
y = F.pad(y, (0, 0, pad, 0))
|
|
if ratio == 1:
|
|
y = F.pad(y, (0, 0, 0, pad))
|
|
outs.append(layer(y))
|
|
start = limit - layer.kernel_size + layer.stride
|
|
else:
|
|
if ratio == 1:
|
|
limit = Fr
|
|
else:
|
|
limit = int(round(Fr * ratio))
|
|
last = layer.last
|
|
layer.last = True
|
|
|
|
y = x[:, :, start:limit]
|
|
s = skip[:, :, start:limit]
|
|
out, _ = layer(y, s, None)
|
|
if outs:
|
|
outs[-1][:, :, -layer.stride:] += (
|
|
out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
|
|
out = out[:, :, layer.stride:]
|
|
if ratio == 1:
|
|
out = out[:, :, :-layer.stride // 2, :]
|
|
if start == 0:
|
|
out = out[:, :, layer.stride // 2:, :]
|
|
outs.append(out)
|
|
layer.last = last
|
|
start = limit
|
|
out = torch.cat(outs, dim=2)
|
|
if not self.conv and not last:
|
|
out = F.gelu(out)
|
|
if self.conv:
|
|
return out
|
|
else:
|
|
return out, None
|
|
|
|
|
|
class HDecLayer(nn.Module):
|
|
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
|
freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
|
|
context_freq=True, rewrite=True):
|
|
"""
|
|
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
|
"""
|
|
super().__init__()
|
|
norm_fn = lambda d: nn.Identity() # noqa
|
|
if norm:
|
|
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
|
if pad:
|
|
pad = kernel_size // 4
|
|
else:
|
|
pad = 0
|
|
self.pad = pad
|
|
self.last = last
|
|
self.freq = freq
|
|
self.chin = chin
|
|
self.empty = empty
|
|
self.stride = stride
|
|
self.kernel_size = kernel_size
|
|
self.norm = norm
|
|
self.context_freq = context_freq
|
|
klass = nn.Conv1d
|
|
klass_tr = nn.ConvTranspose1d
|
|
if freq:
|
|
kernel_size = [kernel_size, 1]
|
|
stride = [stride, 1]
|
|
klass = nn.Conv2d
|
|
klass_tr = nn.ConvTranspose2d
|
|
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
|
self.norm2 = norm_fn(chout)
|
|
if self.empty:
|
|
return
|
|
self.rewrite = None
|
|
if rewrite:
|
|
if context_freq:
|
|
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
|
else:
|
|
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
|
|
[0, context])
|
|
self.norm1 = norm_fn(2 * chin)
|
|
|
|
self.dconv = None
|
|
if dconv:
|
|
self.dconv = DConv(chin, **dconv_kw)
|
|
|
|
def forward(self, x, skip, length):
|
|
if self.freq and x.dim() == 3:
|
|
B, C, T = x.shape
|
|
x = x.view(B, self.chin, -1, T)
|
|
|
|
if not self.empty:
|
|
x = x + skip
|
|
|
|
if self.rewrite:
|
|
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
|
else:
|
|
y = x
|
|
if self.dconv:
|
|
if self.freq:
|
|
B, C, Fr, T = y.shape
|
|
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
|
y = self.dconv(y)
|
|
if self.freq:
|
|
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
|
else:
|
|
y = x
|
|
assert skip is None
|
|
z = self.norm2(self.conv_tr(y))
|
|
if self.freq:
|
|
if self.pad:
|
|
z = z[..., self.pad:-self.pad, :]
|
|
else:
|
|
z = z[..., self.pad:self.pad + length]
|
|
assert z.shape[-1] == length, (z.shape[-1], length)
|
|
if not self.last:
|
|
z = F.gelu(z)
|
|
return z, y
|
|
|
|
|
|
class HDemucs(nn.Module):
|
|
"""
|
|
Spectrogram and hybrid Demucs model.
|
|
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
|
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
|
Frequency layers can still access information across time steps thanks to the DConv residual.
|
|
|
|
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
|
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
|
|
|
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
|
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
|
Open Unmix implementation [Stoter et al. 2019].
|
|
|
|
The loss is always on the temporal domain, by backpropagating through the above
|
|
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
|
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
|
contribution, without changing the one from the waveform, which will lead to worse performance.
|
|
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
|
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
|
hybrid models.
|
|
|
|
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
|
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
|
|
|
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
|
"""
|
|
@capture_init
|
|
def __init__(self,
|
|
sources,
|
|
# Channels
|
|
audio_channels=2,
|
|
channels=48,
|
|
channels_time=None,
|
|
growth=2,
|
|
# STFT
|
|
nfft=4096,
|
|
wiener_iters=0,
|
|
end_iters=0,
|
|
wiener_residual=False,
|
|
cac=True,
|
|
# Main structure
|
|
depth=6,
|
|
rewrite=True,
|
|
hybrid=True,
|
|
hybrid_old=False,
|
|
# Frequency branch
|
|
multi_freqs=None,
|
|
multi_freqs_depth=2,
|
|
freq_emb=0.2,
|
|
emb_scale=10,
|
|
emb_smooth=True,
|
|
# Convolutions
|
|
kernel_size=8,
|
|
time_stride=2,
|
|
stride=4,
|
|
context=1,
|
|
context_enc=0,
|
|
# Normalization
|
|
norm_starts=4,
|
|
norm_groups=4,
|
|
# DConv residual branch
|
|
dconv_mode=1,
|
|
dconv_depth=2,
|
|
dconv_comp=4,
|
|
dconv_attn=4,
|
|
dconv_lstm=4,
|
|
dconv_init=1e-4,
|
|
# Weight init
|
|
rescale=0.1,
|
|
# Metadata
|
|
samplerate=44100,
|
|
segment=4 * 10):
|
|
"""
|
|
Args:
|
|
sources (list[str]): list of source names.
|
|
audio_channels (int): input/output audio channels.
|
|
channels (int): initial number of hidden channels.
|
|
channels_time: if not None, use a different `channels` value for the time branch.
|
|
growth: increase the number of hidden channels by this factor at each layer.
|
|
nfft: number of fft bins. Note that changing this require careful computation of
|
|
various shape parameters and will not work out of the box for hybrid models.
|
|
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
|
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
|
wiener_residual: add residual source before wiener filtering.
|
|
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
|
in input and output. no further processing is done before ISTFT.
|
|
depth (int): number of layers in the encoder and in the decoder.
|
|
rewrite (bool): add 1x1 convolution to each layer.
|
|
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
|
|
hybrid_old: some models trained for MDX had a padding bug. This replicates
|
|
this bug to avoid retraining them.
|
|
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
|
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
|
layers will be wrapped.
|
|
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
|
the actual value controls the weight of the embedding.
|
|
emb_scale: equivalent to scaling the embedding learning rate
|
|
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
|
kernel_size: kernel_size for encoder and decoder layers.
|
|
stride: stride for encoder and decoder layers.
|
|
time_stride: stride for the final time layer, after the merge.
|
|
context: context for 1x1 conv in the decoder.
|
|
context_enc: context for 1x1 conv in the encoder.
|
|
norm_starts: layer at which group norm starts being used.
|
|
decoder layers are numbered in reverse order.
|
|
norm_groups: number of groups for group norm.
|
|
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
|
dconv_depth: depth of residual DConv branch.
|
|
dconv_comp: compression of DConv branch.
|
|
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
|
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
|
dconv_init: initial scale for the DConv branch LayerScale.
|
|
rescale: weight recaling trick
|
|
|
|
"""
|
|
super().__init__()
|
|
self.cac = cac
|
|
self.wiener_residual = wiener_residual
|
|
self.audio_channels = audio_channels
|
|
self.sources = sources
|
|
self.kernel_size = kernel_size
|
|
self.context = context
|
|
self.stride = stride
|
|
self.depth = depth
|
|
self.channels = channels
|
|
self.samplerate = samplerate
|
|
self.segment = segment
|
|
|
|
self.nfft = nfft
|
|
self.hop_length = nfft // 4
|
|
self.wiener_iters = wiener_iters
|
|
self.end_iters = end_iters
|
|
self.freq_emb = None
|
|
self.hybrid = hybrid
|
|
self.hybrid_old = hybrid_old
|
|
if hybrid_old:
|
|
assert hybrid, "hybrid_old must come with hybrid=True"
|
|
if hybrid:
|
|
assert wiener_iters == end_iters
|
|
|
|
self.encoder = nn.ModuleList()
|
|
self.decoder = nn.ModuleList()
|
|
|
|
if hybrid:
|
|
self.tencoder = nn.ModuleList()
|
|
self.tdecoder = nn.ModuleList()
|
|
|
|
chin = audio_channels
|
|
chin_z = chin # number of channels for the freq branch
|
|
if self.cac:
|
|
chin_z *= 2
|
|
chout = channels_time or channels
|
|
chout_z = channels
|
|
freqs = nfft // 2
|
|
|
|
for index in range(depth):
|
|
lstm = index >= dconv_lstm
|
|
attn = index >= dconv_attn
|
|
norm = index >= norm_starts
|
|
freq = freqs > 1
|
|
stri = stride
|
|
ker = kernel_size
|
|
if not freq:
|
|
assert freqs == 1
|
|
ker = time_stride * 2
|
|
stri = time_stride
|
|
|
|
pad = True
|
|
last_freq = False
|
|
if freq and freqs <= kernel_size:
|
|
ker = freqs
|
|
pad = False
|
|
last_freq = True
|
|
|
|
kw = {
|
|
'kernel_size': ker,
|
|
'stride': stri,
|
|
'freq': freq,
|
|
'pad': pad,
|
|
'norm': norm,
|
|
'rewrite': rewrite,
|
|
'norm_groups': norm_groups,
|
|
'dconv_kw': {
|
|
'lstm': lstm,
|
|
'attn': attn,
|
|
'depth': dconv_depth,
|
|
'compress': dconv_comp,
|
|
'init': dconv_init,
|
|
'gelu': True,
|
|
}
|
|
}
|
|
kwt = dict(kw)
|
|
kwt['freq'] = 0
|
|
kwt['kernel_size'] = kernel_size
|
|
kwt['stride'] = stride
|
|
kwt['pad'] = True
|
|
kw_dec = dict(kw)
|
|
multi = False
|
|
if multi_freqs and index < multi_freqs_depth:
|
|
multi = True
|
|
kw_dec['context_freq'] = False
|
|
|
|
if last_freq:
|
|
chout_z = max(chout, chout_z)
|
|
chout = chout_z
|
|
|
|
enc = HEncLayer(chin_z, chout_z,
|
|
dconv=dconv_mode & 1, context=context_enc, **kw)
|
|
if hybrid and freq:
|
|
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
|
|
empty=last_freq, **kwt)
|
|
self.tencoder.append(tenc)
|
|
|
|
if multi:
|
|
enc = MultiWrap(enc, multi_freqs)
|
|
self.encoder.append(enc)
|
|
if index == 0:
|
|
chin = self.audio_channels * len(self.sources)
|
|
chin_z = chin
|
|
if self.cac:
|
|
chin_z *= 2
|
|
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
|
|
last=index == 0, context=context, **kw_dec)
|
|
if multi:
|
|
dec = MultiWrap(dec, multi_freqs)
|
|
if hybrid and freq:
|
|
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
|
|
last=index == 0, context=context, **kwt)
|
|
self.tdecoder.insert(0, tdec)
|
|
self.decoder.insert(0, dec)
|
|
|
|
chin = chout
|
|
chin_z = chout_z
|
|
chout = int(growth * chout)
|
|
chout_z = int(growth * chout_z)
|
|
if freq:
|
|
if freqs <= kernel_size:
|
|
freqs = 1
|
|
else:
|
|
freqs //= stride
|
|
if index == 0 and freq_emb:
|
|
self.freq_emb = ScaledEmbedding(
|
|
freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
|
self.freq_emb_scale = freq_emb
|
|
|
|
if rescale:
|
|
rescale_module(self, reference=rescale)
|
|
|
|
def _spec(self, x):
|
|
hl = self.hop_length
|
|
nfft = self.nfft
|
|
x0 = x # noqa
|
|
|
|
if self.hybrid:
|
|
# We re-pad the signal in order to keep the property
|
|
# that the size of the output is exactly the size of the input
|
|
# divided by the stride (here hop_length), when divisible.
|
|
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
|
# which is not supported by torch.stft.
|
|
# Having all convolution operations follow this convention allow to easily
|
|
# align the time and frequency branches later on.
|
|
assert hl == nfft // 4
|
|
le = int(math.ceil(x.shape[-1] / hl))
|
|
pad = hl // 2 * 3
|
|
if not self.hybrid_old:
|
|
x = F.pad(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
|
|
else:
|
|
x = F.pad(x, (pad, pad + le * hl - x.shape[-1]))
|
|
|
|
z = spectro(x, nfft, hl)[..., :-1, :]
|
|
if self.hybrid:
|
|
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
|
z = z[..., 2:2+le]
|
|
return z
|
|
|
|
def _ispec(self, z, length=None, scale=0):
|
|
hl = self.hop_length // (4 ** scale)
|
|
z = F.pad(z, (0, 0, 0, 1))
|
|
if self.hybrid:
|
|
z = F.pad(z, (2, 2))
|
|
pad = hl // 2 * 3
|
|
if not self.hybrid_old:
|
|
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
|
else:
|
|
le = hl * int(math.ceil(length / hl))
|
|
x = ispectro(z, hl, length=le)
|
|
if not self.hybrid_old:
|
|
x = x[..., pad:pad + length]
|
|
else:
|
|
x = x[..., :length]
|
|
else:
|
|
x = ispectro(z, hl, length)
|
|
return x
|
|
|
|
def _magnitude(self, z):
|
|
# return the magnitude of the spectrogram, except when cac is True,
|
|
# in which case we just move the complex dimension to the channel one.
|
|
if self.cac:
|
|
B, C, Fr, T = z.shape
|
|
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
|
m = m.reshape(B, C * 2, Fr, T)
|
|
else:
|
|
m = z.abs()
|
|
return m
|
|
|
|
def _mask(self, z, m):
|
|
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
|
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
|
niters = self.wiener_iters
|
|
if self.cac:
|
|
B, S, C, Fr, T = m.shape
|
|
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
|
out = torch.view_as_complex(out.contiguous())
|
|
return out
|
|
if self.training:
|
|
niters = self.end_iters
|
|
if niters < 0:
|
|
z = z[:, None]
|
|
return z / (1e-8 + z.abs()) * m
|
|
else:
|
|
return self._wiener(m, z, niters)
|
|
|
|
def _wiener(self, mag_out, mix_stft, niters):
|
|
# apply wiener filtering from OpenUnmix.
|
|
init = mix_stft.dtype
|
|
wiener_win_len = 300
|
|
residual = self.wiener_residual
|
|
|
|
B, S, C, Fq, T = mag_out.shape
|
|
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
|
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
|
|
|
outs = []
|
|
for sample in range(B):
|
|
pos = 0
|
|
out = []
|
|
for pos in range(0, T, wiener_win_len):
|
|
frame = slice(pos, pos + wiener_win_len)
|
|
z_out = wiener(
|
|
mag_out[sample, frame], mix_stft[sample, frame], niters,
|
|
residual=residual)
|
|
out.append(z_out.transpose(-1, -2))
|
|
outs.append(torch.cat(out, dim=0))
|
|
out = torch.view_as_complex(torch.stack(outs, 0))
|
|
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
|
if residual:
|
|
out = out[:, :-1]
|
|
assert list(out.shape) == [B, S, C, Fq, T]
|
|
return out.to(init)
|
|
|
|
def forward(self, mix):
|
|
x = mix
|
|
length = x.shape[-1]
|
|
|
|
z = self._spec(mix)
|
|
mag = self._magnitude(z)
|
|
x = mag
|
|
|
|
B, C, Fq, T = x.shape
|
|
|
|
# unlike previous Demucs, we always normalize because it is easier.
|
|
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
|
std = x.std(dim=(1, 2, 3), keepdim=True)
|
|
x = (x - mean) / (1e-5 + std)
|
|
# x will be the freq. branch input.
|
|
|
|
if self.hybrid:
|
|
# Prepare the time branch input.
|
|
xt = mix
|
|
meant = xt.mean(dim=(1, 2), keepdim=True)
|
|
stdt = xt.std(dim=(1, 2), keepdim=True)
|
|
xt = (xt - meant) / (1e-5 + stdt)
|
|
|
|
# okay, this is a giant mess I know...
|
|
saved = [] # skip connections, freq.
|
|
saved_t = [] # skip connections, time.
|
|
lengths = [] # saved lengths to properly remove padding, freq branch.
|
|
lengths_t = [] # saved lengths for time branch.
|
|
for idx, encode in enumerate(self.encoder):
|
|
lengths.append(x.shape[-1])
|
|
inject = None
|
|
if self.hybrid and idx < len(self.tencoder):
|
|
# we have not yet merged branches.
|
|
lengths_t.append(xt.shape[-1])
|
|
tenc = self.tencoder[idx]
|
|
xt = tenc(xt)
|
|
if not tenc.empty:
|
|
# save for skip connection
|
|
saved_t.append(xt)
|
|
else:
|
|
# tenc contains just the first conv., so that now time and freq.
|
|
# branches have the same shape and can be merged.
|
|
inject = xt
|
|
x = encode(x, inject)
|
|
if idx == 0 and self.freq_emb is not None:
|
|
# add frequency embedding to allow for non equivariant convolutions
|
|
# over the frequency axis.
|
|
frs = torch.arange(x.shape[-2], device=x.device)
|
|
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
|
x = x + self.freq_emb_scale * emb
|
|
|
|
saved.append(x)
|
|
|
|
x = torch.zeros_like(x)
|
|
if self.hybrid:
|
|
xt = torch.zeros_like(x)
|
|
# initialize everything to zero (signal will go through u-net skips).
|
|
|
|
for idx, decode in enumerate(self.decoder):
|
|
skip = saved.pop(-1)
|
|
x, pre = decode(x, skip, lengths.pop(-1))
|
|
# `pre` contains the output just before final transposed convolution,
|
|
# which is used when the freq. and time branch separate.
|
|
|
|
if self.hybrid:
|
|
offset = self.depth - len(self.tdecoder)
|
|
if self.hybrid and idx >= offset:
|
|
tdec = self.tdecoder[idx - offset]
|
|
length_t = lengths_t.pop(-1)
|
|
if tdec.empty:
|
|
assert pre.shape[2] == 1, pre.shape
|
|
pre = pre[:, :, 0]
|
|
xt, _ = tdec(pre, None, length_t)
|
|
else:
|
|
skip = saved_t.pop(-1)
|
|
xt, _ = tdec(xt, skip, length_t)
|
|
|
|
# Let's make sure we used all stored skip connections.
|
|
assert len(saved) == 0
|
|
assert len(lengths_t) == 0
|
|
assert len(saved_t) == 0
|
|
|
|
S = len(self.sources)
|
|
x = x.view(B, S, -1, Fq, T)
|
|
x = x * std[:, None] + mean[:, None]
|
|
|
|
zout = self._mask(z, x)
|
|
x = self._ispec(zout, length)
|
|
|
|
if self.hybrid:
|
|
xt = xt.view(B, S, -1, length)
|
|
xt = xt * stdt[:, None] + meant[:, None]
|
|
x = xt + x
|
|
return x
|