mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2025-01-19 17:18:39 +01:00
460 lines
17 KiB
Python
460 lines
17 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import math
|
||
|
import typing as tp
|
||
|
|
||
|
import julius
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
from .states import capture_init
|
||
|
from .utils import center_trim, unfold
|
||
|
|
||
|
|
||
|
class BLSTM(nn.Module):
|
||
|
"""
|
||
|
BiLSTM with same hidden units as input dim.
|
||
|
If `max_steps` is not None, input will be splitting in overlapping
|
||
|
chunks and the LSTM applied separately on each chunk.
|
||
|
"""
|
||
|
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
||
|
super().__init__()
|
||
|
assert max_steps is None or max_steps % 4 == 0
|
||
|
self.max_steps = max_steps
|
||
|
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
||
|
self.linear = nn.Linear(2 * dim, dim)
|
||
|
self.skip = skip
|
||
|
|
||
|
def forward(self, x):
|
||
|
B, C, T = x.shape
|
||
|
y = x
|
||
|
framed = False
|
||
|
if self.max_steps is not None and T > self.max_steps:
|
||
|
width = self.max_steps
|
||
|
stride = width // 2
|
||
|
frames = unfold(x, width, stride)
|
||
|
nframes = frames.shape[2]
|
||
|
framed = True
|
||
|
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
||
|
|
||
|
x = x.permute(2, 0, 1)
|
||
|
|
||
|
x = self.lstm(x)[0]
|
||
|
x = self.linear(x)
|
||
|
x = x.permute(1, 2, 0)
|
||
|
if framed:
|
||
|
out = []
|
||
|
frames = x.reshape(B, -1, C, width)
|
||
|
limit = stride // 2
|
||
|
for k in range(nframes):
|
||
|
if k == 0:
|
||
|
out.append(frames[:, k, :, :-limit])
|
||
|
elif k == nframes - 1:
|
||
|
out.append(frames[:, k, :, limit:])
|
||
|
else:
|
||
|
out.append(frames[:, k, :, limit:-limit])
|
||
|
out = torch.cat(out, -1)
|
||
|
out = out[..., :T]
|
||
|
x = out
|
||
|
if self.skip:
|
||
|
x = x + y
|
||
|
return x
|
||
|
|
||
|
|
||
|
def rescale_conv(conv, reference):
|
||
|
"""Rescale initial weight scale. It is unclear why it helps but it certainly does.
|
||
|
"""
|
||
|
std = conv.weight.std().detach()
|
||
|
scale = (std / reference)**0.5
|
||
|
conv.weight.data /= scale
|
||
|
if conv.bias is not None:
|
||
|
conv.bias.data /= scale
|
||
|
|
||
|
|
||
|
def rescale_module(module, reference):
|
||
|
for sub in module.modules():
|
||
|
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
||
|
rescale_conv(sub, reference)
|
||
|
|
||
|
|
||
|
class LayerScale(nn.Module):
|
||
|
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
||
|
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
||
|
"""
|
||
|
def __init__(self, channels: int, init: float = 0):
|
||
|
super().__init__()
|
||
|
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
||
|
self.scale.data[:] = init
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.scale[:, None] * x
|
||
|
|
||
|
|
||
|
class DConv(nn.Module):
|
||
|
"""
|
||
|
New residual branches in each encoder layer.
|
||
|
This alternates dilated convolutions, potentially with LSTMs and attention.
|
||
|
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
||
|
e.g. of dim `channels // compress`.
|
||
|
"""
|
||
|
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
|
||
|
norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
|
||
|
kernel=3, dilate=True):
|
||
|
"""
|
||
|
Args:
|
||
|
channels: input/output channels for residual branch.
|
||
|
compress: amount of channel compression inside the branch.
|
||
|
depth: number of layers in the residual branch. Each layer has its own
|
||
|
projection, and potentially LSTM and attention.
|
||
|
init: initial scale for LayerNorm.
|
||
|
norm: use GroupNorm.
|
||
|
attn: use LocalAttention.
|
||
|
heads: number of heads for the LocalAttention.
|
||
|
ndecay: number of decay controls in the LocalAttention.
|
||
|
lstm: use LSTM.
|
||
|
gelu: Use GELU activation.
|
||
|
kernel: kernel size for the (dilated) convolutions.
|
||
|
dilate: if true, use dilation, increasing with the depth.
|
||
|
"""
|
||
|
|
||
|
super().__init__()
|
||
|
assert kernel % 2 == 1
|
||
|
self.channels = channels
|
||
|
self.compress = compress
|
||
|
self.depth = abs(depth)
|
||
|
dilate = depth > 0
|
||
|
|
||
|
norm_fn: tp.Callable[[int], nn.Module]
|
||
|
norm_fn = lambda d: nn.Identity() # noqa
|
||
|
if norm:
|
||
|
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
||
|
|
||
|
hidden = int(channels / compress)
|
||
|
|
||
|
act: tp.Type[nn.Module]
|
||
|
if gelu:
|
||
|
act = nn.GELU
|
||
|
else:
|
||
|
act = nn.ReLU
|
||
|
|
||
|
self.layers = nn.ModuleList([])
|
||
|
for d in range(self.depth):
|
||
|
dilation = 2 ** d if dilate else 1
|
||
|
padding = dilation * (kernel // 2)
|
||
|
mods = [
|
||
|
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
||
|
norm_fn(hidden), act(),
|
||
|
nn.Conv1d(hidden, 2 * channels, 1),
|
||
|
norm_fn(2 * channels), nn.GLU(1),
|
||
|
LayerScale(channels, init),
|
||
|
]
|
||
|
if attn:
|
||
|
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
||
|
if lstm:
|
||
|
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
||
|
layer = nn.Sequential(*mods)
|
||
|
self.layers.append(layer)
|
||
|
|
||
|
def forward(self, x):
|
||
|
for layer in self.layers:
|
||
|
x = x + layer(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class LocalState(nn.Module):
|
||
|
"""Local state allows to have attention based only on data (no positional embedding),
|
||
|
but while setting a constraint on the time window (e.g. decaying penalty term).
|
||
|
|
||
|
Also a failed experiments with trying to provide some frequency based attention.
|
||
|
"""
|
||
|
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
||
|
super().__init__()
|
||
|
assert channels % heads == 0, (channels, heads)
|
||
|
self.heads = heads
|
||
|
self.nfreqs = nfreqs
|
||
|
self.ndecay = ndecay
|
||
|
self.content = nn.Conv1d(channels, channels, 1)
|
||
|
self.query = nn.Conv1d(channels, channels, 1)
|
||
|
self.key = nn.Conv1d(channels, channels, 1)
|
||
|
if nfreqs:
|
||
|
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
||
|
if ndecay:
|
||
|
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
||
|
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
||
|
self.query_decay.weight.data *= 0.01
|
||
|
assert self.query_decay.bias is not None # stupid type checker
|
||
|
self.query_decay.bias.data[:] = -2
|
||
|
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
||
|
|
||
|
def forward(self, x):
|
||
|
B, C, T = x.shape
|
||
|
heads = self.heads
|
||
|
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
||
|
# left index are keys, right index are queries
|
||
|
delta = indexes[:, None] - indexes[None, :]
|
||
|
|
||
|
queries = self.query(x).view(B, heads, -1, T)
|
||
|
keys = self.key(x).view(B, heads, -1, T)
|
||
|
# t are keys, s are queries
|
||
|
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
||
|
dots /= keys.shape[2]**0.5
|
||
|
if self.nfreqs:
|
||
|
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
||
|
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
||
|
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
|
||
|
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
||
|
if self.ndecay:
|
||
|
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
||
|
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
||
|
decay_q = torch.sigmoid(decay_q) / 2
|
||
|
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
||
|
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
||
|
|
||
|
# Kill self reference.
|
||
|
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
||
|
weights = torch.softmax(dots, dim=2)
|
||
|
|
||
|
content = self.content(x).view(B, heads, -1, T)
|
||
|
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
||
|
if self.nfreqs:
|
||
|
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
||
|
result = torch.cat([result, time_sig], 2)
|
||
|
result = result.reshape(B, -1, T)
|
||
|
return x + self.proj(result)
|
||
|
|
||
|
|
||
|
class Demucs(nn.Module):
|
||
|
@capture_init
|
||
|
def __init__(self,
|
||
|
sources,
|
||
|
# Channels
|
||
|
audio_channels=2,
|
||
|
channels=64,
|
||
|
growth=2.,
|
||
|
# Main structure
|
||
|
depth=6,
|
||
|
rewrite=True,
|
||
|
lstm_layers=0,
|
||
|
# Convolutions
|
||
|
kernel_size=8,
|
||
|
stride=4,
|
||
|
context=1,
|
||
|
# Activations
|
||
|
gelu=True,
|
||
|
glu=True,
|
||
|
# Normalization
|
||
|
norm_starts=4,
|
||
|
norm_groups=4,
|
||
|
# DConv residual branch
|
||
|
dconv_mode=1,
|
||
|
dconv_depth=2,
|
||
|
dconv_comp=4,
|
||
|
dconv_attn=4,
|
||
|
dconv_lstm=4,
|
||
|
dconv_init=1e-4,
|
||
|
# Pre/post processing
|
||
|
normalize=True,
|
||
|
resample=True,
|
||
|
# Weight init
|
||
|
rescale=0.1,
|
||
|
# Metadata
|
||
|
samplerate=44100,
|
||
|
segment=4 * 10):
|
||
|
"""
|
||
|
Args:
|
||
|
sources (list[str]): list of source names
|
||
|
audio_channels (int): stereo or mono
|
||
|
channels (int): first convolution channels
|
||
|
depth (int): number of encoder/decoder layers
|
||
|
growth (float): multiply (resp divide) number of channels by that
|
||
|
for each layer of the encoder (resp decoder)
|
||
|
depth (int): number of layers in the encoder and in the decoder.
|
||
|
rewrite (bool): add 1x1 convolution to each layer.
|
||
|
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
|
||
|
by default, as this is now replaced by the smaller and faster small LSTMs
|
||
|
in the DConv branches.
|
||
|
kernel_size (int): kernel size for convolutions
|
||
|
stride (int): stride for convolutions
|
||
|
context (int): kernel size of the convolution in the
|
||
|
decoder before the transposed convolution. If > 1,
|
||
|
will provide some context from neighboring time steps.
|
||
|
gelu: use GELU activation function.
|
||
|
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
|
||
|
norm_starts: layer at which group norm starts being used.
|
||
|
decoder layers are numbered in reverse order.
|
||
|
norm_groups: number of groups for group norm.
|
||
|
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
||
|
dconv_depth: depth of residual DConv branch.
|
||
|
dconv_comp: compression of DConv branch.
|
||
|
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
||
|
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
||
|
dconv_init: initial scale for the DConv branch LayerScale.
|
||
|
normalize (bool): normalizes the input audio on the fly, and scales back
|
||
|
the output by the same amount.
|
||
|
resample (bool): upsample x2 the input and downsample /2 the output.
|
||
|
rescale (int): rescale initial weights of convolutions
|
||
|
to get their standard deviation closer to `rescale`.
|
||
|
samplerate (int): stored as meta information for easing
|
||
|
future evaluations of the model.
|
||
|
segment (float): duration of the chunks of audio to ideally evaluate the model on.
|
||
|
This is used by `demucs.apply.apply_model`.
|
||
|
"""
|
||
|
|
||
|
super().__init__()
|
||
|
self.audio_channels = audio_channels
|
||
|
self.sources = sources
|
||
|
self.kernel_size = kernel_size
|
||
|
self.context = context
|
||
|
self.stride = stride
|
||
|
self.depth = depth
|
||
|
self.resample = resample
|
||
|
self.channels = channels
|
||
|
self.normalize = normalize
|
||
|
self.samplerate = samplerate
|
||
|
self.segment = segment
|
||
|
self.encoder = nn.ModuleList()
|
||
|
self.decoder = nn.ModuleList()
|
||
|
self.skip_scales = nn.ModuleList()
|
||
|
|
||
|
if glu:
|
||
|
activation = nn.GLU(dim=1)
|
||
|
ch_scale = 2
|
||
|
else:
|
||
|
activation = nn.ReLU()
|
||
|
ch_scale = 1
|
||
|
if gelu:
|
||
|
act2 = nn.GELU
|
||
|
else:
|
||
|
act2 = nn.ReLU
|
||
|
|
||
|
in_channels = audio_channels
|
||
|
padding = 0
|
||
|
for index in range(depth):
|
||
|
norm_fn = lambda d: nn.Identity() # noqa
|
||
|
if index >= norm_starts:
|
||
|
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
||
|
|
||
|
encode = []
|
||
|
encode += [
|
||
|
nn.Conv1d(in_channels, channels, kernel_size, stride),
|
||
|
norm_fn(channels),
|
||
|
act2(),
|
||
|
]
|
||
|
attn = index >= dconv_attn
|
||
|
lstm = index >= dconv_lstm
|
||
|
if dconv_mode & 1:
|
||
|
encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
||
|
compress=dconv_comp, attn=attn, lstm=lstm)]
|
||
|
if rewrite:
|
||
|
encode += [
|
||
|
nn.Conv1d(channels, ch_scale * channels, 1),
|
||
|
norm_fn(ch_scale * channels), activation]
|
||
|
self.encoder.append(nn.Sequential(*encode))
|
||
|
|
||
|
decode = []
|
||
|
if index > 0:
|
||
|
out_channels = in_channels
|
||
|
else:
|
||
|
out_channels = len(self.sources) * audio_channels
|
||
|
if rewrite:
|
||
|
decode += [
|
||
|
nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
|
||
|
norm_fn(ch_scale * channels), activation]
|
||
|
if dconv_mode & 2:
|
||
|
decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
||
|
compress=dconv_comp, attn=attn, lstm=lstm)]
|
||
|
decode += [nn.ConvTranspose1d(channels, out_channels,
|
||
|
kernel_size, stride, padding=padding)]
|
||
|
if index > 0:
|
||
|
decode += [norm_fn(out_channels), act2()]
|
||
|
self.decoder.insert(0, nn.Sequential(*decode))
|
||
|
in_channels = channels
|
||
|
channels = int(growth * channels)
|
||
|
|
||
|
channels = in_channels
|
||
|
if lstm_layers:
|
||
|
self.lstm = BLSTM(channels, lstm_layers)
|
||
|
else:
|
||
|
self.lstm = None
|
||
|
|
||
|
if rescale:
|
||
|
rescale_module(self, reference=rescale)
|
||
|
|
||
|
def valid_length(self, length):
|
||
|
"""
|
||
|
Return the nearest valid length to use with the model so that
|
||
|
there is no time steps left over in a convolution, e.g. for all
|
||
|
layers, size of the input - kernel_size % stride = 0.
|
||
|
|
||
|
Note that input are automatically padded if necessary to ensure that the output
|
||
|
has the same length as the input.
|
||
|
"""
|
||
|
if self.resample:
|
||
|
length *= 2
|
||
|
|
||
|
for _ in range(self.depth):
|
||
|
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
||
|
length = max(1, length)
|
||
|
|
||
|
for idx in range(self.depth):
|
||
|
length = (length - 1) * self.stride + self.kernel_size
|
||
|
|
||
|
if self.resample:
|
||
|
length = math.ceil(length / 2)
|
||
|
return int(length)
|
||
|
|
||
|
def forward(self, mix):
|
||
|
x = mix
|
||
|
length = x.shape[-1]
|
||
|
|
||
|
if self.normalize:
|
||
|
mono = mix.mean(dim=1, keepdim=True)
|
||
|
mean = mono.mean(dim=-1, keepdim=True)
|
||
|
std = mono.std(dim=-1, keepdim=True)
|
||
|
x = (x - mean) / (1e-5 + std)
|
||
|
else:
|
||
|
mean = 0
|
||
|
std = 1
|
||
|
|
||
|
delta = self.valid_length(length) - length
|
||
|
x = F.pad(x, (delta // 2, delta - delta // 2))
|
||
|
|
||
|
if self.resample:
|
||
|
x = julius.resample_frac(x, 1, 2)
|
||
|
|
||
|
saved = []
|
||
|
for encode in self.encoder:
|
||
|
x = encode(x)
|
||
|
saved.append(x)
|
||
|
|
||
|
if self.lstm:
|
||
|
x = self.lstm(x)
|
||
|
|
||
|
for decode in self.decoder:
|
||
|
skip = saved.pop(-1)
|
||
|
skip = center_trim(skip, x)
|
||
|
x = decode(x + skip)
|
||
|
|
||
|
if self.resample:
|
||
|
x = julius.resample_frac(x, 2, 1)
|
||
|
x = x * std + mean
|
||
|
x = center_trim(x, length)
|
||
|
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
||
|
return x
|
||
|
|
||
|
def load_state_dict(self, state, strict=True):
|
||
|
# fix a mismatch with previous generation Demucs models.
|
||
|
for idx in range(self.depth):
|
||
|
for a in ['encoder', 'decoder']:
|
||
|
for b in ['bias', 'weight']:
|
||
|
new = f'{a}.{idx}.3.{b}'
|
||
|
old = f'{a}.{idx}.2.{b}'
|
||
|
if old in state and new not in state:
|
||
|
state[new] = state.pop(old)
|
||
|
super().load_state_dict(state, strict=strict)
|