2022-12-18 21:18:56 -06:00

219 lines
7.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch as th
from torch import nn
from .utils import capture_init, center_trim
class BLSTM(nn.Module):
def __init__(self, dim, layers=1):
super().__init__()
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x):
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
return x
def rescale_conv(conv, reference):
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)
def upsample(x, stride):
"""
Linear upsampling, the output will be `stride` times longer.
"""
batch, channels, time = x.size()
weight = th.arange(stride, device=x.device, dtype=th.float) / stride
x = x.view(batch, channels, time, 1)
out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
return out.reshape(batch, channels, -1)
def downsample(x, stride):
"""
Downsample x by decimation.
"""
return x[:, :, ::stride]
class Demucs(nn.Module):
@capture_init
def __init__(self,
sources=4,
audio_channels=2,
channels=64,
depth=6,
rewrite=True,
glu=True,
upsample=False,
rescale=0.1,
kernel_size=8,
stride=4,
growth=2.,
lstm_layers=2,
context=3,
samplerate=44100):
"""
Args:
sources (int): number of sources to separate
audio_channels (int): stereo or mono
channels (int): first convolution channels
depth (int): number of encoder/decoder layers
rewrite (bool): add 1x1 convolution to each encoder layer
and a convolution to each decoder layer.
For the decoder layer, `context` gives the kernel size.
glu (bool): use glu instead of ReLU
upsample (bool): use linear upsampling with convolutions
Wave-U-Net style, instead of transposed convolutions
rescale (int): rescale initial weights of convolutions
to get their standard deviation closer to `rescale`
kernel_size (int): kernel size for convolutions
stride (int): stride for convolutions
growth (float): multiply (resp divide) number of channels by that
for each layer of the encoder (resp decoder)
lstm_layers (int): number of lstm layers, 0 = no lstm
context (int): kernel size of the convolution in the
decoder before the transposed convolution. If > 1,
will provide some context from neighboring time
steps.
"""
super().__init__()
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.upsample = upsample
self.channels = channels
self.samplerate = samplerate
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.final = None
if upsample:
self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
stride = 1
if glu:
activation = nn.GLU(dim=1)
ch_scale = 2
else:
activation = nn.ReLU()
ch_scale = 1
in_channels = audio_channels
for index in range(depth):
encode = []
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
if rewrite:
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
self.encoder.append(nn.Sequential(*encode))
decode = []
if index > 0:
out_channels = in_channels
else:
if upsample:
out_channels = channels
else:
out_channels = sources * audio_channels
if rewrite:
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
if upsample:
decode += [
nn.Conv1d(channels, out_channels, kernel_size, stride=1),
]
else:
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
if index > 0:
decode.append(nn.ReLU())
self.decoder.insert(0, nn.Sequential(*decode))
in_channels = channels
channels = int(growth * channels)
channels = in_channels
if lstm_layers:
self.lstm = BLSTM(channels, lstm_layers)
else:
self.lstm = None
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length when context = 1. If context > 1,
the two signals can be center trimmed to match.
For training, extracts should have a valid length.For evaluation
on full tracks we recommend passing `pad = True` to :method:`forward`.
"""
for _ in range(self.depth):
if self.upsample:
length = math.ceil(length / self.stride) + self.kernel_size - 1
else:
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(1, length)
length += self.context - 1
for _ in range(self.depth):
if self.upsample:
length = length * self.stride + self.kernel_size - 1
else:
length = (length - 1) * self.stride + self.kernel_size
return int(length)
def forward(self, mix):
x = mix
saved = [x]
for encode in self.encoder:
x = encode(x)
saved.append(x)
if self.upsample:
x = downsample(x, self.stride)
if self.lstm:
x = self.lstm(x)
for decode in self.decoder:
if self.upsample:
x = upsample(x, stride=self.stride)
skip = center_trim(saved.pop(-1), x)
x = x + skip
x = decode(x)
if self.final:
skip = center_trim(saved.pop(-1), x)
x = th.cat([x, skip], dim=1)
x = self.final(x)
x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
return x