Add files via upload

This commit is contained in:
Anjok07 2022-12-18 21:18:56 -06:00 committed by GitHub
parent 9f1652fdf3
commit a58c26520d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 14473 additions and 2 deletions

4830
UVR.py

File diff suppressed because it is too large Load Diff

View File

@ -1 +1,2 @@
VERSION = '5.3.0' VERSION = 'v5.5.0'
PATCH = 'UVR_Patch_12_16_22_3_30'

5
demucs/__init__.py Normal file
View File

@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

272
demucs/__main__.py Normal file
View File

@ -0,0 +1,272 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import sys
import time
from dataclasses import dataclass, field
from fractions import Fraction
import torch as th
from torch import distributed, nn
from torch.nn.parallel.distributed import DistributedDataParallel
from .augment import FlipChannels, FlipSign, Remix, Shift
from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
from .model import Demucs
from .parser import get_name, get_parser
from .raw import Rawset
from .tasnet import ConvTasNet
from .test import evaluate
from .train import train_model, validate_model
from .utils import human_seconds, load_model, save_model, sizeof_fmt
@dataclass
class SavedState:
metrics: list = field(default_factory=list)
last_state: dict = None
best_state: dict = None
optimizer: dict = None
def main():
parser = get_parser()
args = parser.parse_args()
name = get_name(parser, args)
print(f"Experiment {name}")
if args.musdb is None and args.rank == 0:
print(
"You must provide the path to the MusDB dataset with the --musdb flag. "
"To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
file=sys.stderr)
sys.exit(1)
eval_folder = args.evals / name
eval_folder.mkdir(exist_ok=True, parents=True)
args.logs.mkdir(exist_ok=True)
metrics_path = args.logs / f"{name}.json"
eval_folder.mkdir(exist_ok=True, parents=True)
args.checkpoints.mkdir(exist_ok=True, parents=True)
args.models.mkdir(exist_ok=True, parents=True)
if args.device is None:
device = "cpu"
if th.cuda.is_available():
device = "cuda"
else:
device = args.device
th.manual_seed(args.seed)
# Prevents too many threads to be started when running `museval` as it can be quite
# inefficient on NUMA architectures.
os.environ["OMP_NUM_THREADS"] = "1"
if args.world_size > 1:
if device != "cuda" and args.rank == 0:
print("Error: distributed training is only available with cuda device", file=sys.stderr)
sys.exit(1)
th.cuda.set_device(args.rank % th.cuda.device_count())
distributed.init_process_group(backend="nccl",
init_method="tcp://" + args.master,
rank=args.rank,
world_size=args.world_size)
checkpoint = args.checkpoints / f"{name}.th"
checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
if args.restart and checkpoint.exists():
checkpoint.unlink()
if args.test:
args.epochs = 1
args.repeat = 0
model = load_model(args.models / args.test)
elif args.tasnet:
model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
else:
model = Demucs(
audio_channels=args.audio_channels,
channels=args.channels,
context=args.context,
depth=args.depth,
glu=args.glu,
growth=args.growth,
kernel_size=args.kernel_size,
lstm_layers=args.lstm_layers,
rescale=args.rescale,
rewrite=args.rewrite,
sources=4,
stride=args.conv_stride,
upsample=args.upsample,
samplerate=args.samplerate
)
model.to(device)
if args.show:
print(model)
size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")
return
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
try:
saved = th.load(checkpoint, map_location='cpu')
except IOError:
saved = SavedState()
else:
model.load_state_dict(saved.last_state)
optimizer.load_state_dict(saved.optimizer)
if args.save_model:
if args.rank == 0:
model.to("cpu")
model.load_state_dict(saved.best_state)
save_model(model, args.models / f"{name}.th")
return
if args.rank == 0:
done = args.logs / f"{name}.done"
if done.exists():
done.unlink()
if args.augment:
augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride),
Remix(group_size=args.remix_group_size)).to(device)
else:
augment = Shift(args.data_stride)
if args.mse:
criterion = nn.MSELoss()
else:
criterion = nn.L1Loss()
# Setting number of samples so that all convolution windows are full.
# Prevents hard to debug mistake with the prediction being shifted compared
# to the input mixture.
samples = model.valid_length(args.samples)
print(f"Number of training samples adjusted to {samples}")
if args.raw:
train_set = Rawset(args.raw / "train",
samples=samples + args.data_stride,
channels=args.audio_channels,
streams=[0, 1, 2, 3, 4],
stride=args.data_stride)
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
else:
if not args.metadata.is_file() and args.rank == 0:
build_musdb_metadata(args.metadata, args.musdb, args.workers)
if args.world_size > 1:
distributed.barrier()
metadata = json.load(open(args.metadata))
duration = Fraction(samples + args.data_stride, args.samplerate)
stride = Fraction(args.data_stride, args.samplerate)
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
metadata,
duration=duration,
stride=stride,
samplerate=args.samplerate,
channels=args.audio_channels)
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
metadata,
samplerate=args.samplerate,
channels=args.audio_channels)
best_loss = float("inf")
for epoch, metrics in enumerate(saved.metrics):
print(f"Epoch {epoch:03d}: "
f"train={metrics['train']:.8f} "
f"valid={metrics['valid']:.8f} "
f"best={metrics['best']:.4f} "
f"duration={human_seconds(metrics['duration'])}")
best_loss = metrics['best']
if args.world_size > 1:
dmodel = DistributedDataParallel(model,
device_ids=[th.cuda.current_device()],
output_device=th.cuda.current_device())
else:
dmodel = model
for epoch in range(len(saved.metrics), args.epochs):
begin = time.time()
model.train()
train_loss = train_model(epoch,
train_set,
dmodel,
criterion,
optimizer,
augment,
batch_size=args.batch_size,
device=device,
repeat=args.repeat,
seed=args.seed,
workers=args.workers,
world_size=args.world_size)
model.eval()
valid_loss = validate_model(epoch,
valid_set,
model,
criterion,
device=device,
rank=args.rank,
split=args.split_valid,
world_size=args.world_size)
duration = time.time() - begin
if valid_loss < best_loss:
best_loss = valid_loss
saved.best_state = {
key: value.to("cpu").clone()
for key, value in model.state_dict().items()
}
saved.metrics.append({
"train": train_loss,
"valid": valid_loss,
"best": best_loss,
"duration": duration
})
if args.rank == 0:
json.dump(saved.metrics, open(metrics_path, "w"))
saved.last_state = model.state_dict()
saved.optimizer = optimizer.state_dict()
if args.rank == 0 and not args.test:
th.save(saved, checkpoint_tmp)
checkpoint_tmp.rename(checkpoint)
print(f"Epoch {epoch:03d}: "
f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} "
f"duration={human_seconds(duration)}")
del dmodel
model.load_state_dict(saved.best_state)
if args.eval_cpu:
device = "cpu"
model.to(device)
model.eval()
evaluate(model,
args.musdb,
eval_folder,
rank=args.rank,
world_size=args.world_size,
device=device,
save=args.save,
split=args.split_valid,
shifts=args.shifts,
workers=args.eval_workers)
model.to("cpu")
save_model(model, args.models / f"{name}.th")
if args.rank == 0:
print("done")
done.write_text("done")
if __name__ == "__main__":
main()

294
demucs/apply.py Normal file
View File

@ -0,0 +1,294 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Code to apply a model to a mix. It will handle chunking with overlaps and
inteprolation between chunks, as well as the "shift trick".
"""
from concurrent.futures import ThreadPoolExecutor
import random
import typing as tp
from multiprocessing import Process,Queue,Pipe
import torch as th
from torch import nn
from torch.nn import functional as F
import tqdm
import tkinter as tk
from .demucs import Demucs
from .hdemucs import HDemucs
from .utils import center_trim, DummyPoolExecutor
Model = tp.Union[Demucs, HDemucs]
progress_bar_num = 0
class BagOfModels(nn.Module):
def __init__(self, models: tp.List[Model],
weights: tp.Optional[tp.List[tp.List[float]]] = None,
segment: tp.Optional[float] = None):
"""
Represents a bag of models with specific weights.
You should call `apply_model` rather than calling directly the forward here for
optimal performance.
Args:
models (list[nn.Module]): list of Demucs/HDemucs models.
weights (list[list[float]]): list of weights. If None, assumed to
be all ones, otherwise it should be a list of N list (N number of models),
each containing S floats (S number of sources).
segment (None or float): overrides the `segment` attribute of each model
(this is performed inplace, be careful if you reuse the models passed).
"""
super().__init__()
assert len(models) > 0
first = models[0]
for other in models:
assert other.sources == first.sources
assert other.samplerate == first.samplerate
assert other.audio_channels == first.audio_channels
if segment is not None:
other.segment = segment
self.audio_channels = first.audio_channels
self.samplerate = first.samplerate
self.sources = first.sources
self.models = nn.ModuleList(models)
if weights is None:
weights = [[1. for _ in first.sources] for _ in models]
else:
assert len(weights) == len(models)
for weight in weights:
assert len(weight) == len(first.sources)
self.weights = weights
def forward(self, x):
raise NotImplementedError("Call `apply_model` on this.")
class TensorChunk:
def __init__(self, tensor, offset=0, length=None):
total_length = tensor.shape[-1]
assert offset >= 0
assert offset < total_length
if length is None:
length = total_length - offset
else:
length = min(total_length - offset, length)
if isinstance(tensor, TensorChunk):
self.tensor = tensor.tensor
self.offset = offset + tensor.offset
else:
self.tensor = tensor
self.offset = offset
self.length = length
self.device = tensor.device
@property
def shape(self):
shape = list(self.tensor.shape)
shape[-1] = self.length
return shape
def padded(self, target_length):
delta = target_length - self.length
total_length = self.tensor.shape[-1]
assert delta >= 0
start = self.offset - delta // 2
end = start + target_length
correct_start = max(0, start)
correct_end = min(total_length, end)
pad_left = correct_start - start
pad_right = end - correct_end
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
assert out.shape[-1] == target_length
return out
def tensor_chunk(tensor_or_chunk):
if isinstance(tensor_or_chunk, TensorChunk):
return tensor_or_chunk
else:
assert isinstance(tensor_or_chunk, th.Tensor)
return TensorChunk(tensor_or_chunk)
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
"""
Apply model to a given mixture.
Args:
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
and apply the oppositve shift to the output. This is repeated `shifts` time and
all predictions are averaged. This effectively makes the model time equivariant
and improves SDR by up to 0.2 points.
split (bool): if True, the input will be broken down in 8 seconds extracts
and predictions will be performed individually on each and concatenated.
Useful for model with large memory footprint like Tasnet.
progress (bool): if True, show a progress bar (requires split=True)
device (torch.device, str, or None): if provided, device on which to
execute the computation, otherwise `mix.device` is assumed.
When `device` is different from `mix.device`, only local computations will
be on `device`, while the entire tracks will be stored on `mix.device`.
"""
global fut_length
global bag_num
global prog_bar
if device is None:
device = mix.device
else:
device = th.device(device)
if pool is None:
if num_workers > 0 and device.type == 'cpu':
pool = ThreadPoolExecutor(num_workers)
else:
pool = DummyPoolExecutor()
kwargs = {
'shifts': shifts,
'split': split,
'overlap': overlap,
'transition_power': transition_power,
'progress': progress,
'device': device,
'pool': pool,
'set_progress_bar': set_progress_bar,
'static_shifts': static_shifts,
}
if isinstance(model, BagOfModels):
# Special treatment for bag of model.
# We explicitely apply multiple times `apply_model` so that the random shifts
# are different for each model.
estimates = 0
totals = [0] * len(model.sources)
bag_num = len(model.models)
fut_length = 0
prog_bar = 0
current_model = 0 #(bag_num + 1)
for sub_model, weight in zip(model.models, model.weights):
original_model_device = next(iter(sub_model.parameters())).device
sub_model.to(device)
fut_length += fut_length
current_model += 1
out = apply_model(sub_model, mix, **kwargs)
sub_model.to(original_model_device)
for k, inst_weight in enumerate(weight):
out[:, k, :, :] *= inst_weight
totals[k] += inst_weight
estimates += out
del out
for k in range(estimates.shape[1]):
estimates[:, k, :, :] /= totals[k]
return estimates
model.to(device)
model.eval()
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
batch, channels, length = mix.shape
if shifts:
kwargs['shifts'] = 0
max_shift = int(0.5 * model.samplerate)
mix = tensor_chunk(mix)
padded_mix = mix.padded(length + 2 * max_shift)
out = 0
for _ in range(shifts):
offset = random.randint(0, max_shift)
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
shifted_out = apply_model(model, shifted, **kwargs)
out += shifted_out[..., max_shift - offset:]
out /= shifts
return out
elif split:
kwargs['split'] = False
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
sum_weight = th.zeros(length, device=mix.device)
segment = int(model.samplerate * model.segment)
stride = int((1 - overlap) * segment)
offsets = range(0, length, stride)
scale = float(format(stride / model.samplerate, ".2f"))
# We start from a triangle shaped weight, with maximal weight in the middle
# of the segment. Then we normalize and take to the power `transition_power`.
# Large values of transition power will lead to sharper transitions.
weight = th.cat([th.arange(1, segment // 2 + 1, device=device),
th.arange(segment - segment // 2, 0, -1, device=device)])
assert len(weight) == segment
# If the overlap < 50%, this will translate to linear transition when
# transition_power is 1.
weight = (weight / weight.max())**transition_power
futures = []
for offset in offsets:
chunk = TensorChunk(mix, offset, segment)
future = pool.submit(apply_model, model, chunk, **kwargs)
futures.append((future, offset))
offset += segment
if progress:
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
for future, offset in futures:
if set_progress_bar:
fut_length = (len(futures) * bag_num * static_shifts)
prog_bar += 1
set_progress_bar(0.1, (0.8/fut_length*prog_bar))
chunk_out = future.result()
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
assert sum_weight.min() > 0
out /= sum_weight
return out
else:
if hasattr(model, 'valid_length'):
valid_length = model.valid_length(length)
else:
valid_length = length
mix = tensor_chunk(mix)
padded_mix = mix.padded(valid_length).to(device)
with th.no_grad():
out = model(padded_mix)
return center_trim(out, length)
def demucs_segments(demucs_segment, demucs_model):
if demucs_segment == 'Default':
segment = None
if isinstance(demucs_model, BagOfModels):
if segment is not None:
for sub in demucs_model.models:
sub.segment = segment
else:
if segment is not None:
sub.segment = segment
else:
try:
segment = int(demucs_segment)
if isinstance(demucs_model, BagOfModels):
if segment is not None:
for sub in demucs_model.models:
sub.segment = segment
else:
if segment is not None:
sub.segment = segment
except:
segment = None
if isinstance(demucs_model, BagOfModels):
if segment is not None:
for sub in demucs_model.models:
sub.segment = segment
else:
if segment is not None:
sub.segment = segment
return demucs_model

459
demucs/demucs.py Normal file
View File

@ -0,0 +1,459 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import julius
import torch
from torch import nn
from torch.nn import functional as F
from .states import capture_init
from .utils import center_trim, unfold
class BLSTM(nn.Module):
"""
BiLSTM with same hidden units as input dim.
If `max_steps` is not None, input will be splitting in overlapping
chunks and the LSTM applied separately on each chunk.
"""
def __init__(self, dim, layers=1, max_steps=None, skip=False):
super().__init__()
assert max_steps is None or max_steps % 4 == 0
self.max_steps = max_steps
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
self.skip = skip
def forward(self, x):
B, C, T = x.shape
y = x
framed = False
if self.max_steps is not None and T > self.max_steps:
width = self.max_steps
stride = width // 2
frames = unfold(x, width, stride)
nframes = frames.shape[2]
framed = True
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
if framed:
out = []
frames = x.reshape(B, -1, C, width)
limit = stride // 2
for k in range(nframes):
if k == 0:
out.append(frames[:, k, :, :-limit])
elif k == nframes - 1:
out.append(frames[:, k, :, limit:])
else:
out.append(frames[:, k, :, limit:-limit])
out = torch.cat(out, -1)
out = out[..., :T]
x = out
if self.skip:
x = x + y
return x
def rescale_conv(conv, reference):
"""Rescale initial weight scale. It is unclear why it helps but it certainly does.
"""
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
rescale_conv(sub, reference)
class LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonaly residual outputs close to 0 initially, then learnt.
"""
def __init__(self, channels: int, init: float = 0):
super().__init__()
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
self.scale.data[:] = init
def forward(self, x):
return self.scale[:, None] * x
class DConv(nn.Module):
"""
New residual branches in each encoder layer.
This alternates dilated convolutions, potentially with LSTMs and attention.
Also before entering each residual branch, dimension is projected on a smaller subspace,
e.g. of dim `channels // compress`.
"""
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
kernel=3, dilate=True):
"""
Args:
channels: input/output channels for residual branch.
compress: amount of channel compression inside the branch.
depth: number of layers in the residual branch. Each layer has its own
projection, and potentially LSTM and attention.
init: initial scale for LayerNorm.
norm: use GroupNorm.
attn: use LocalAttention.
heads: number of heads for the LocalAttention.
ndecay: number of decay controls in the LocalAttention.
lstm: use LSTM.
gelu: Use GELU activation.
kernel: kernel size for the (dilated) convolutions.
dilate: if true, use dilation, increasing with the depth.
"""
super().__init__()
assert kernel % 2 == 1
self.channels = channels
self.compress = compress
self.depth = abs(depth)
dilate = depth > 0
norm_fn: tp.Callable[[int], nn.Module]
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
hidden = int(channels / compress)
act: tp.Type[nn.Module]
if gelu:
act = nn.GELU
else:
act = nn.ReLU
self.layers = nn.ModuleList([])
for d in range(self.depth):
dilation = 2 ** d if dilate else 1
padding = dilation * (kernel // 2)
mods = [
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
norm_fn(hidden), act(),
nn.Conv1d(hidden, 2 * channels, 1),
norm_fn(2 * channels), nn.GLU(1),
LayerScale(channels, init),
]
if attn:
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
if lstm:
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
layer = nn.Sequential(*mods)
self.layers.append(layer)
def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return x
class LocalState(nn.Module):
"""Local state allows to have attention based only on data (no positional embedding),
but while setting a constraint on the time window (e.g. decaying penalty term).
Also a failed experiments with trying to provide some frequency based attention.
"""
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
super().__init__()
assert channels % heads == 0, (channels, heads)
self.heads = heads
self.nfreqs = nfreqs
self.ndecay = ndecay
self.content = nn.Conv1d(channels, channels, 1)
self.query = nn.Conv1d(channels, channels, 1)
self.key = nn.Conv1d(channels, channels, 1)
if nfreqs:
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
if ndecay:
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
self.query_decay.weight.data *= 0.01
assert self.query_decay.bias is not None # stupid type checker
self.query_decay.bias.data[:] = -2
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
def forward(self, x):
B, C, T = x.shape
heads = self.heads
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
# left index are keys, right index are queries
delta = indexes[:, None] - indexes[None, :]
queries = self.query(x).view(B, heads, -1, T)
keys = self.key(x).view(B, heads, -1, T)
# t are keys, s are queries
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
dots /= keys.shape[2]**0.5
if self.nfreqs:
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
if self.ndecay:
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
decay_q = self.query_decay(x).view(B, heads, -1, T)
decay_q = torch.sigmoid(decay_q) / 2
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
# Kill self reference.
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
weights = torch.softmax(dots, dim=2)
content = self.content(x).view(B, heads, -1, T)
result = torch.einsum("bhts,bhct->bhcs", weights, content)
if self.nfreqs:
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
result = torch.cat([result, time_sig], 2)
result = result.reshape(B, -1, T)
return x + self.proj(result)
class Demucs(nn.Module):
@capture_init
def __init__(self,
sources,
# Channels
audio_channels=2,
channels=64,
growth=2.,
# Main structure
depth=6,
rewrite=True,
lstm_layers=0,
# Convolutions
kernel_size=8,
stride=4,
context=1,
# Activations
gelu=True,
glu=True,
# Normalization
norm_starts=4,
norm_groups=4,
# DConv residual branch
dconv_mode=1,
dconv_depth=2,
dconv_comp=4,
dconv_attn=4,
dconv_lstm=4,
dconv_init=1e-4,
# Pre/post processing
normalize=True,
resample=True,
# Weight init
rescale=0.1,
# Metadata
samplerate=44100,
segment=4 * 10):
"""
Args:
sources (list[str]): list of source names
audio_channels (int): stereo or mono
channels (int): first convolution channels
depth (int): number of encoder/decoder layers
growth (float): multiply (resp divide) number of channels by that
for each layer of the encoder (resp decoder)
depth (int): number of layers in the encoder and in the decoder.
rewrite (bool): add 1x1 convolution to each layer.
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
by default, as this is now replaced by the smaller and faster small LSTMs
in the DConv branches.
kernel_size (int): kernel size for convolutions
stride (int): stride for convolutions
context (int): kernel size of the convolution in the
decoder before the transposed convolution. If > 1,
will provide some context from neighboring time steps.
gelu: use GELU activation function.
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
norm_starts: layer at which group norm starts being used.
decoder layers are numbered in reverse order.
norm_groups: number of groups for group norm.
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
dconv_depth: depth of residual DConv branch.
dconv_comp: compression of DConv branch.
dconv_attn: adds attention layers in DConv branch starting at this layer.
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
dconv_init: initial scale for the DConv branch LayerScale.
normalize (bool): normalizes the input audio on the fly, and scales back
the output by the same amount.
resample (bool): upsample x2 the input and downsample /2 the output.
rescale (int): rescale initial weights of convolutions
to get their standard deviation closer to `rescale`.
samplerate (int): stored as meta information for easing
future evaluations of the model.
segment (float): duration of the chunks of audio to ideally evaluate the model on.
This is used by `demucs.apply.apply_model`.
"""
super().__init__()
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.resample = resample
self.channels = channels
self.normalize = normalize
self.samplerate = samplerate
self.segment = segment
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.skip_scales = nn.ModuleList()
if glu:
activation = nn.GLU(dim=1)
ch_scale = 2
else:
activation = nn.ReLU()
ch_scale = 1
if gelu:
act2 = nn.GELU
else:
act2 = nn.ReLU
in_channels = audio_channels
padding = 0
for index in range(depth):
norm_fn = lambda d: nn.Identity() # noqa
if index >= norm_starts:
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
encode = []
encode += [
nn.Conv1d(in_channels, channels, kernel_size, stride),
norm_fn(channels),
act2(),
]
attn = index >= dconv_attn
lstm = index >= dconv_lstm
if dconv_mode & 1:
encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
compress=dconv_comp, attn=attn, lstm=lstm)]
if rewrite:
encode += [
nn.Conv1d(channels, ch_scale * channels, 1),
norm_fn(ch_scale * channels), activation]
self.encoder.append(nn.Sequential(*encode))
decode = []
if index > 0:
out_channels = in_channels
else:
out_channels = len(self.sources) * audio_channels
if rewrite:
decode += [
nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
norm_fn(ch_scale * channels), activation]
if dconv_mode & 2:
decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
compress=dconv_comp, attn=attn, lstm=lstm)]
decode += [nn.ConvTranspose1d(channels, out_channels,
kernel_size, stride, padding=padding)]
if index > 0:
decode += [norm_fn(out_channels), act2()]
self.decoder.insert(0, nn.Sequential(*decode))
in_channels = channels
channels = int(growth * channels)
channels = in_channels
if lstm_layers:
self.lstm = BLSTM(channels, lstm_layers)
else:
self.lstm = None
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolution, e.g. for all
layers, size of the input - kernel_size % stride = 0.
Note that input are automatically padded if necessary to ensure that the output
has the same length as the input.
"""
if self.resample:
length *= 2
for _ in range(self.depth):
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(1, length)
for idx in range(self.depth):
length = (length - 1) * self.stride + self.kernel_size
if self.resample:
length = math.ceil(length / 2)
return int(length)
def forward(self, mix):
x = mix
length = x.shape[-1]
if self.normalize:
mono = mix.mean(dim=1, keepdim=True)
mean = mono.mean(dim=-1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
x = (x - mean) / (1e-5 + std)
else:
mean = 0
std = 1
delta = self.valid_length(length) - length
x = F.pad(x, (delta // 2, delta - delta // 2))
if self.resample:
x = julius.resample_frac(x, 1, 2)
saved = []
for encode in self.encoder:
x = encode(x)
saved.append(x)
if self.lstm:
x = self.lstm(x)
for decode in self.decoder:
skip = saved.pop(-1)
skip = center_trim(skip, x)
x = decode(x + skip)
if self.resample:
x = julius.resample_frac(x, 2, 1)
x = x * std + mean
x = center_trim(x, length)
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
return x
def load_state_dict(self, state, strict=True):
# fix a mismatch with previous generation Demucs models.
for idx in range(self.depth):
for a in ['encoder', 'decoder']:
for b in ['bias', 'weight']:
new = f'{a}.{idx}.3.{b}'
old = f'{a}.{idx}.2.{b}'
if old in state and new not in state:
state[new] = state.pop(old)
super().load_state_dict(state, strict=strict)

502
demucs/filtering.py Normal file
View File

@ -0,0 +1,502 @@
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
def atan2(y, x):
r"""Element-wise arctangent function of y/x.
Returns a new tensor with signed angles in radians.
It is an alternative implementation of torch.atan2
Args:
y (Tensor): First input tensor
x (Tensor): Second input tensor [shape=y.shape]
Returns:
Tensor: [shape=y.shape].
"""
pi = 2 * torch.asin(torch.tensor(1.0))
x += ((x == 0) & (y == 0)) * 1.0
out = torch.atan(y / x)
out += ((y >= 0) & (x < 0)) * pi
out -= ((y < 0) & (x < 0)) * pi
out *= 1 - ((y > 0) & (x == 0)) * 1.0
out += ((y > 0) & (x == 0)) * (pi / 2)
out *= 1 - ((y < 0) & (x == 0)) * 1.0
out += ((y < 0) & (x == 0)) * (-pi / 2)
return out
# Define basic complex operations on torch.Tensor objects whose last dimension
# consists in the concatenation of the real and imaginary parts.
def _norm(x: torch.Tensor) -> torch.Tensor:
r"""Computes the norm value of a torch Tensor, assuming that it
comes as real and imaginary part in its last dimension.
Args:
x (Tensor): Input Tensor of shape [shape=(..., 2)]
Returns:
Tensor: shape as x excluding the last dimension.
"""
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Element-wise multiplication of two complex Tensors described
through their real and imaginary parts.
The result is added to the `out` tensor"""
# check `out` and allocate it if needed
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
if out is None or out.shape != target_shape:
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
if out is a:
real_a = a[..., 0]
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
else:
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
return out
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Element-wise multiplication of two complex Tensors described
through their real and imaginary parts
can work in place in case out is a only"""
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
if out is None or out.shape != target_shape:
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
if out is a:
real_a = a[..., 0]
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
else:
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
return out
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Element-wise multiplicative inverse of a Tensor with complex
entries described through their real and imaginary parts.
can work in place in case out is z"""
ez = _norm(z)
if out is None or out.shape != z.shape:
out = torch.zeros_like(z)
out[..., 0] = z[..., 0] / ez
out[..., 1] = -z[..., 1] / ez
return out
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Element-wise complex conjugate of a Tensor with complex entries
described through their real and imaginary parts.
can work in place in case out is z"""
if out is None or out.shape != z.shape:
out = torch.zeros_like(z)
out[..., 0] = z[..., 0]
out[..., 1] = -z[..., 1]
return out
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Invert 1x1 or 2x2 matrices
Will generate errors if the matrices are singular: user must handle this
through his own regularization schemes.
Args:
M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
matrices to invert: must be square along dimensions -3 and -2
Returns:
invM (Tensor): [shape=M.shape]
inverses of M
"""
nb_channels = M.shape[-2]
if out is None or out.shape != M.shape:
out = torch.empty_like(M)
if nb_channels == 1:
# scalar case
out = _inv(M, out)
elif nb_channels == 2:
# two channels case: analytical expression
# first compute the determinent
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
# invert it
invDet = _inv(det)
# then fill out the matrix with the inverse
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
else:
raise Exception("Only 2 channels are supported for the torch version.")
return out
# Now define the signal-processing low-level functions used by the Separator
def expectation_maximization(
y: torch.Tensor,
x: torch.Tensor,
iterations: int = 2,
eps: float = 1e-10,
batch_size: int = 200,
):
r"""Expectation maximization algorithm, for refining source separation
estimates.
This algorithm allows to make source separation results better by
enforcing multichannel consistency for the estimates. This usually means
a better perceptual quality in terms of spatial artifacts.
The implementation follows the details presented in [1]_, taking
inspiration from the original EM algorithm proposed in [2]_ and its
weighted refinement proposed in [3]_, [4]_.
It works by iteratively:
* Re-estimate source parameters (power spectral densities and spatial
covariance matrices) through :func:`get_local_gaussian_model`.
* Separate again the mixture with the new parameters by first computing
the new modelled mixture covariance matrices with :func:`get_mix_model`,
prepare the Wiener filters through :func:`wiener_gain` and apply them
with :func:`apply_filter``.
References
----------
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
on deep neural networks through data augmentation and network
blending." 2017 IEEE International Conference on Acoustics, Speech
and Signal Processing (ICASSP). IEEE, 2017.
.. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
reverberant audio source separation using a full-rank spatial
covariance model." IEEE Transactions on Audio, Speech, and Language
Processing 18.7 (2010): 1830-1840.
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
separation with deep neural networks." IEEE/ACM Transactions on Audio,
Speech, and Language Processing 24.9 (2016): 1652-1664.
.. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
separation with deep neural networks." 2016 24th European Signal
Processing Conference (EUSIPCO). IEEE, 2016.
.. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
source separation." IEEE Transactions on Signal Processing
62.16 (2014): 4298-4310.
Args:
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
initial estimates for the sources
x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
complex STFT of the mixture signal
iterations (int): [scalar]
number of iterations for the EM algorithm.
eps (float or None): [scalar]
The epsilon value to use for regularization and filters.
Returns:
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
estimated sources after iterations
v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
estimated power spectral densities
R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
estimated spatial covariance matrices
Notes:
* You need an initial estimate for the sources to apply this
algorithm. This is precisely what the :func:`wiener` function does.
* This algorithm *is not* an implementation of the "exact" EM
proposed in [1]_. In particular, it does compute the posterior
covariance matrices the same (exact) way. Instead, it uses the
simplified approximate scheme initially proposed in [5]_ and further
refined in [3]_, [4]_, that boils down to just take the empirical
covariance of the recent source estimates, followed by a weighted
average for the update of the spatial covariance matrix. It has been
empirically demonstrated that this simplified algorithm is more
robust for music separation.
Warning:
It is *very* important to make sure `x.dtype` is `torch.float64`
if you want double precision, because this function will **not**
do such conversion for you from `torch.complex32`, in case you want the
smaller RAM usage on purpose.
It is usually always better in terms of quality to have double
precision, by e.g. calling :func:`expectation_maximization`
with ``x.to(torch.float64)``.
"""
# dimensions
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
nb_sources = y.shape[-1]
regularization = torch.cat(
(
torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None],
torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device),
),
dim=2,
)
regularization = torch.sqrt(torch.as_tensor(eps)) * (
regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1))
)
# allocate the spatial covariance matrices
R = [
torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device)
for j in range(nb_sources)
]
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
for it in range(iterations):
# constructing the mixture covariance matrix. Doing it with a loop
# to avoid storing anytime in RAM the whole 6D tensor
# update the PSD as the average spectrogram over channels
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
# update spatial covariance matrices (weighted update)
for j in range(nb_sources):
R[j] = torch.tensor(0.0, device=x.device)
weight = torch.tensor(eps, device=x.device)
pos: int = 0
batch_size = batch_size if batch_size else nb_frames
while pos < nb_frames:
t = torch.arange(pos, min(nb_frames, pos + batch_size))
pos = int(t[-1]) + 1
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
weight = weight + torch.sum(v[t, ..., j], dim=0)
R[j] = R[j] / weight[..., None, None, None]
weight = torch.zeros_like(weight)
# cloning y if we track gradient, because we're going to update it
if y.requires_grad:
y = y.clone()
pos = 0
while pos < nb_frames:
t = torch.arange(pos, min(nb_frames, pos + batch_size))
pos = int(t[-1]) + 1
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
# compute mix covariance matrix
Cxx = regularization
for j in range(nb_sources):
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
# invert it
inv_Cxx = _invert(Cxx)
# separate the sources
for j in range(nb_sources):
# create a wiener gain for this source
gain = torch.zeros_like(inv_Cxx)
# computes multichannel Wiener gain as v_j R_j inv_Cxx
indices = torch.cartesian_prod(
torch.arange(nb_channels),
torch.arange(nb_channels),
torch.arange(nb_channels),
)
for index in indices:
gain[:, :, index[0], index[1], :] = _mul_add(
R[j][None, :, index[0], index[2], :].clone(),
inv_Cxx[:, :, index[2], index[1], :],
gain[:, :, index[0], index[1], :],
)
gain = gain * v[t, ..., None, None, None, j]
# apply it to the mixture
for i in range(nb_channels):
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
return y, v, R
def wiener(
targets_spectrograms: torch.Tensor,
mix_stft: torch.Tensor,
iterations: int = 1,
softmask: bool = False,
residual: bool = False,
scale_factor: float = 10.0,
eps: float = 1e-10,
):
"""Wiener-based separation for multichannel audio.
The method uses the (possibly multichannel) spectrograms of the
sources to separate the (complex) Short Term Fourier Transform of the
mix. Separation is done in a sequential way by:
* Getting an initial estimate. This can be done in two ways: either by
directly using the spectrograms with the mixture phase, or
by using a softmasking strategy. This initial phase is controlled
by the `softmask` flag.
* If required, adding an additional residual target as the mix minus
all targets.
* Refinining these initial estimates through a call to
:func:`expectation_maximization` if the number of iterations is nonzero.
This implementation also allows to specify the epsilon value used for
regularization. It is based on [1]_, [2]_, [3]_, [4]_.
References
----------
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
on deep neural networks through data augmentation and network
blending." 2017 IEEE International Conference on Acoustics, Speech
and Signal Processing (ICASSP). IEEE, 2017.
.. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
separation with deep neural networks." IEEE/ACM Transactions on Audio,
Speech, and Language Processing 24.9 (2016): 1652-1664.
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
separation with deep neural networks." 2016 24th European Signal
Processing Conference (EUSIPCO). IEEE, 2016.
.. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
source separation." IEEE Transactions on Signal Processing
62.16 (2014): 4298-4310.
Args:
targets_spectrograms (Tensor): spectrograms of the sources
[shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
This is a nonnegative tensor that is
usually the output of the actual separation method of the user. The
spectrograms may be mono, but they need to be 4-dimensional in all
cases.
mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
STFT of the mixture signal.
iterations (int): [scalar]
number of iterations for the EM algorithm
softmask (bool): Describes how the initial estimates are obtained.
* if `False`, then the mixture phase will directly be used with the
spectrogram as initial estimates.
* if `True`, initial estimates are obtained by multiplying the
complex mix element-wise with the ratio of each target spectrogram
with the sum of them all. This strategy is better if the model are
not really good, and worse otherwise.
residual (bool): if `True`, an additional target is created, which is
equal to the mixture minus the other targets, before application of
expectation maximization
eps (float): Epsilon value to use for computing the separations.
This is used whenever division with a model energy is
performed, i.e. when softmasking and when iterating the EM.
It can be understood as the energy of the additional white noise
that is taken out when separating.
Returns:
Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
STFT of estimated sources
Notes:
* Be careful that you need *magnitude spectrogram estimates* for the
case `softmask==False`.
* `softmask=False` is recommended
* The epsilon value will have a huge impact on performance. If it's
large, only the parts of the signal with a significant energy will
be kept in the sources. This epsilon then directly controls the
energy of the reconstruction error.
Warning:
As in :func:`expectation_maximization`, we recommend converting the
mixture `x` to double precision `torch.float64` *before* calling
:func:`wiener`.
"""
if softmask:
# if we use softmask, we compute the ratio mask for all targets and
# multiply by the mix stft
y = (
mix_stft[..., None]
* (
targets_spectrograms
/ (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype))
)[..., None, :]
)
else:
# otherwise, we just multiply the targets spectrograms with mix phase
# we tacitly assume that we have magnitude estimates.
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
nb_sources = targets_spectrograms.shape[-1]
y = torch.zeros(
mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device
)
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
if residual:
# if required, adding an additional target as the mix minus
# available targets
y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
if iterations == 0:
return y
# we need to refine the estimates. Scales down the estimates for
# numerical stability
max_abs = torch.max(
torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device),
torch.sqrt(_norm(mix_stft)).max() / scale_factor,
)
mix_stft = mix_stft / max_abs
y = y / max_abs
# call expectation maximization
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
# scale estimates up again
y = y * max_abs
return y
def _covariance(y_j):
"""
Compute the empirical covariance for a source.
Args:
y_j (Tensor): complex stft of the source.
[shape=(nb_frames, nb_bins, nb_channels, 2)].
Returns:
Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
just y_j * conj(y_j.T): empirical covariance for each TF bin.
"""
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
Cj = torch.zeros(
(nb_frames, nb_bins, nb_channels, nb_channels, 2),
dtype=y_j.dtype,
device=y_j.device,
)
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
for index in indices:
Cj[:, :, index[0], index[1], :] = _mul_add(
y_j[:, :, index[0], :],
_conj(y_j[:, :, index[1], :]),
Cj[:, :, index[0], index[1], :],
)
return Cj

782
demucs/hdemucs.py Normal file
View File

@ -0,0 +1,782 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
This code contains the spectrogram and Hybrid version of Demucs.
"""
from copy import deepcopy
import math
import typing as tp
import torch
from torch import nn
from torch.nn import functional as F
from .filtering import wiener
from .demucs import DConv, rescale_module
from .states import capture_init
from .spec import spectro, ispectro
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen."""
x0 = x
length = x.shape[-1]
padding_left, padding_right = paddings
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
if length <= max_pad:
extra_pad = max_pad - length + 1
extra_pad_right = min(padding_right, extra_pad)
extra_pad_left = extra_pad - extra_pad_right
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
x = F.pad(x, (extra_pad_left, extra_pad_right))
out = F.pad(x, paddings, mode, value)
assert out.shape[-1] == length + padding_left + padding_right
assert (out[..., padding_left: padding_left + length] == x0).all()
return out
class ScaledEmbedding(nn.Module):
"""
Boost learning rate for embeddings (with `scale`).
Also, can make embeddings continuous with `smooth`.
"""
def __init__(self, num_embeddings: int, embedding_dim: int,
scale: float = 10., smooth=False):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
if smooth:
weight = torch.cumsum(self.embedding.weight.data, dim=0)
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
self.embedding.weight.data[:] = weight
self.embedding.weight.data /= scale
self.scale = scale
@property
def weight(self):
return self.embedding.weight * self.scale
def forward(self, x):
out = self.embedding(x) * self.scale
return out
class HEncLayer(nn.Module):
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
rewrite=True):
"""Encoder layer. This used both by the time and the frequency branch.
Args:
chin: number of input channels.
chout: number of output channels.
norm_groups: number of groups for group norm.
empty: used to make a layer with just the first conv. this is used
before merging the time and freq. branches.
freq: this is acting on frequencies.
dconv: insert DConv residual branches.
norm: use GroupNorm.
context: context size for the 1x1 conv.
dconv_kw: list of kwargs for the DConv class.
pad: pad the input. Padding is done so that the output size is
always the input size / stride.
rewrite: add 1x1 conv at the end of the layer.
"""
super().__init__()
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
pad = kernel_size // 4
else:
pad = 0
klass = nn.Conv1d
self.freq = freq
self.kernel_size = kernel_size
self.stride = stride
self.empty = empty
self.norm = norm
self.pad = pad
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
pad = [pad, 0]
klass = nn.Conv2d
self.conv = klass(chin, chout, kernel_size, stride, pad)
if self.empty:
return
self.norm1 = norm_fn(chout)
self.rewrite = None
if rewrite:
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
self.norm2 = norm_fn(2 * chout)
self.dconv = None
if dconv:
self.dconv = DConv(chout, **dconv_kw)
def forward(self, x, inject=None):
"""
`inject` is used to inject the result from the time branch into the frequency branch,
when both have the same stride.
"""
if not self.freq and x.dim() == 4:
B, C, Fr, T = x.shape
x = x.view(B, -1, T)
if not self.freq:
le = x.shape[-1]
if not le % self.stride == 0:
x = F.pad(x, (0, self.stride - (le % self.stride)))
y = self.conv(x)
if self.empty:
return y
if inject is not None:
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
if inject.dim() == 3 and y.dim() == 4:
inject = inject[:, :, None]
y = y + inject
y = F.gelu(self.norm1(y))
if self.dconv:
if self.freq:
B, C, Fr, T = y.shape
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
y = self.dconv(y)
if self.freq:
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
if self.rewrite:
z = self.norm2(self.rewrite(y))
z = F.glu(z, dim=1)
else:
z = y
return z
class MultiWrap(nn.Module):
"""
Takes one layer and replicate it N times. each replica will act
on a frequency band. All is done so that if the N replica have the same weights,
then this is exactly equivalent to applying the original module on all frequencies.
This is a bit over-engineered to avoid edge artifacts when splitting
the frequency bands, but it is possible the naive implementation would work as well...
"""
def __init__(self, layer, split_ratios):
"""
Args:
layer: module to clone, must be either HEncLayer or HDecLayer.
split_ratios: list of float indicating which ratio to keep for each band.
"""
super().__init__()
self.split_ratios = split_ratios
self.layers = nn.ModuleList()
self.conv = isinstance(layer, HEncLayer)
assert not layer.norm
assert layer.freq
assert layer.pad
if not self.conv:
assert not layer.context_freq
for k in range(len(split_ratios) + 1):
lay = deepcopy(layer)
if self.conv:
lay.conv.padding = (0, 0)
else:
lay.pad = False
for m in lay.modules():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
self.layers.append(lay)
def forward(self, x, skip=None, length=None):
B, C, Fr, T = x.shape
ratios = list(self.split_ratios) + [1]
start = 0
outs = []
for ratio, layer in zip(ratios, self.layers):
if self.conv:
pad = layer.kernel_size // 4
if ratio == 1:
limit = Fr
frames = -1
else:
limit = int(round(Fr * ratio))
le = limit - start
if start == 0:
le += pad
frames = round((le - layer.kernel_size) / layer.stride + 1)
limit = start + (frames - 1) * layer.stride + layer.kernel_size
if start == 0:
limit -= pad
assert limit - start > 0, (limit, start)
assert limit <= Fr, (limit, Fr)
y = x[:, :, start:limit, :]
if start == 0:
y = F.pad(y, (0, 0, pad, 0))
if ratio == 1:
y = F.pad(y, (0, 0, 0, pad))
outs.append(layer(y))
start = limit - layer.kernel_size + layer.stride
else:
if ratio == 1:
limit = Fr
else:
limit = int(round(Fr * ratio))
last = layer.last
layer.last = True
y = x[:, :, start:limit]
s = skip[:, :, start:limit]
out, _ = layer(y, s, None)
if outs:
outs[-1][:, :, -layer.stride:] += (
out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
out = out[:, :, layer.stride:]
if ratio == 1:
out = out[:, :, :-layer.stride // 2, :]
if start == 0:
out = out[:, :, layer.stride // 2:, :]
outs.append(out)
layer.last = last
start = limit
out = torch.cat(outs, dim=2)
if not self.conv and not last:
out = F.gelu(out)
if self.conv:
return out
else:
return out, None
class HDecLayer(nn.Module):
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
context_freq=True, rewrite=True):
"""
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
"""
super().__init__()
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
pad = kernel_size // 4
else:
pad = 0
self.pad = pad
self.last = last
self.freq = freq
self.chin = chin
self.empty = empty
self.stride = stride
self.kernel_size = kernel_size
self.norm = norm
self.context_freq = context_freq
klass = nn.Conv1d
klass_tr = nn.ConvTranspose1d
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
klass = nn.Conv2d
klass_tr = nn.ConvTranspose2d
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
self.norm2 = norm_fn(chout)
if self.empty:
return
self.rewrite = None
if rewrite:
if context_freq:
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
else:
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
[0, context])
self.norm1 = norm_fn(2 * chin)
self.dconv = None
if dconv:
self.dconv = DConv(chin, **dconv_kw)
def forward(self, x, skip, length):
if self.freq and x.dim() == 3:
B, C, T = x.shape
x = x.view(B, self.chin, -1, T)
if not self.empty:
x = x + skip
if self.rewrite:
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
else:
y = x
if self.dconv:
if self.freq:
B, C, Fr, T = y.shape
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
y = self.dconv(y)
if self.freq:
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
else:
y = x
assert skip is None
z = self.norm2(self.conv_tr(y))
if self.freq:
if self.pad:
z = z[..., self.pad:-self.pad, :]
else:
z = z[..., self.pad:self.pad + length]
assert z.shape[-1] == length, (z.shape[-1], length)
if not self.last:
z = F.gelu(z)
return z, y
class HDemucs(nn.Module):
"""
Spectrogram and hybrid Demucs model.
The spectrogram model has the same structure as Demucs, except the first few layers are over the
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
Frequency layers can still access information across time steps thanks to the DConv residual.
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
as the frequency branch and then the two are combined. The opposite happens in the decoder.
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
Open Unmix implementation [Stoter et al. 2019].
The loss is always on the temporal domain, by backpropagating through the above
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
contribution, without changing the one from the waveform, which will lead to worse performance.
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
CaC on the other hand provides similar performance for hybrid, and works naturally with
hybrid models.
This model also uses frequency embeddings are used to improve efficiency on convolutions
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
"""
@capture_init
def __init__(self,
sources,
# Channels
audio_channels=2,
channels=48,
channels_time=None,
growth=2,
# STFT
nfft=4096,
wiener_iters=0,
end_iters=0,
wiener_residual=False,
cac=True,
# Main structure
depth=6,
rewrite=True,
hybrid=True,
hybrid_old=False,
# Frequency branch
multi_freqs=None,
multi_freqs_depth=2,
freq_emb=0.2,
emb_scale=10,
emb_smooth=True,
# Convolutions
kernel_size=8,
time_stride=2,
stride=4,
context=1,
context_enc=0,
# Normalization
norm_starts=4,
norm_groups=4,
# DConv residual branch
dconv_mode=1,
dconv_depth=2,
dconv_comp=4,
dconv_attn=4,
dconv_lstm=4,
dconv_init=1e-4,
# Weight init
rescale=0.1,
# Metadata
samplerate=44100,
segment=4 * 10):
"""
Args:
sources (list[str]): list of source names.
audio_channels (int): input/output audio channels.
channels (int): initial number of hidden channels.
channels_time: if not None, use a different `channels` value for the time branch.
growth: increase the number of hidden channels by this factor at each layer.
nfft: number of fft bins. Note that changing this require careful computation of
various shape parameters and will not work out of the box for hybrid models.
wiener_iters: when using Wiener filtering, number of iterations at test time.
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
wiener_residual: add residual source before wiener filtering.
cac: uses complex as channels, i.e. complex numbers are 2 channels each
in input and output. no further processing is done before ISTFT.
depth (int): number of layers in the encoder and in the decoder.
rewrite (bool): add 1x1 convolution to each layer.
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
hybrid_old: some models trained for MDX had a padding bug. This replicates
this bug to avoid retraining them.
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
layers will be wrapped.
freq_emb: add frequency embedding after the first frequency layer if > 0,
the actual value controls the weight of the embedding.
emb_scale: equivalent to scaling the embedding learning rate
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
kernel_size: kernel_size for encoder and decoder layers.
stride: stride for encoder and decoder layers.
time_stride: stride for the final time layer, after the merge.
context: context for 1x1 conv in the decoder.
context_enc: context for 1x1 conv in the encoder.
norm_starts: layer at which group norm starts being used.
decoder layers are numbered in reverse order.
norm_groups: number of groups for group norm.
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
dconv_depth: depth of residual DConv branch.
dconv_comp: compression of DConv branch.
dconv_attn: adds attention layers in DConv branch starting at this layer.
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
dconv_init: initial scale for the DConv branch LayerScale.
rescale: weight recaling trick
"""
super().__init__()
self.cac = cac
self.wiener_residual = wiener_residual
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.channels = channels
self.samplerate = samplerate
self.segment = segment
self.nfft = nfft
self.hop_length = nfft // 4
self.wiener_iters = wiener_iters
self.end_iters = end_iters
self.freq_emb = None
self.hybrid = hybrid
self.hybrid_old = hybrid_old
if hybrid_old:
assert hybrid, "hybrid_old must come with hybrid=True"
if hybrid:
assert wiener_iters == end_iters
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
if hybrid:
self.tencoder = nn.ModuleList()
self.tdecoder = nn.ModuleList()
chin = audio_channels
chin_z = chin # number of channels for the freq branch
if self.cac:
chin_z *= 2
chout = channels_time or channels
chout_z = channels
freqs = nfft // 2
for index in range(depth):
lstm = index >= dconv_lstm
attn = index >= dconv_attn
norm = index >= norm_starts
freq = freqs > 1
stri = stride
ker = kernel_size
if not freq:
assert freqs == 1
ker = time_stride * 2
stri = time_stride
pad = True
last_freq = False
if freq and freqs <= kernel_size:
ker = freqs
pad = False
last_freq = True
kw = {
'kernel_size': ker,
'stride': stri,
'freq': freq,
'pad': pad,
'norm': norm,
'rewrite': rewrite,
'norm_groups': norm_groups,
'dconv_kw': {
'lstm': lstm,
'attn': attn,
'depth': dconv_depth,
'compress': dconv_comp,
'init': dconv_init,
'gelu': True,
}
}
kwt = dict(kw)
kwt['freq'] = 0
kwt['kernel_size'] = kernel_size
kwt['stride'] = stride
kwt['pad'] = True
kw_dec = dict(kw)
multi = False
if multi_freqs and index < multi_freqs_depth:
multi = True
kw_dec['context_freq'] = False
if last_freq:
chout_z = max(chout, chout_z)
chout = chout_z
enc = HEncLayer(chin_z, chout_z,
dconv=dconv_mode & 1, context=context_enc, **kw)
if hybrid and freq:
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
empty=last_freq, **kwt)
self.tencoder.append(tenc)
if multi:
enc = MultiWrap(enc, multi_freqs)
self.encoder.append(enc)
if index == 0:
chin = self.audio_channels * len(self.sources)
chin_z = chin
if self.cac:
chin_z *= 2
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
last=index == 0, context=context, **kw_dec)
if multi:
dec = MultiWrap(dec, multi_freqs)
if hybrid and freq:
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
last=index == 0, context=context, **kwt)
self.tdecoder.insert(0, tdec)
self.decoder.insert(0, dec)
chin = chout
chin_z = chout_z
chout = int(growth * chout)
chout_z = int(growth * chout_z)
if freq:
if freqs <= kernel_size:
freqs = 1
else:
freqs //= stride
if index == 0 and freq_emb:
self.freq_emb = ScaledEmbedding(
freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
self.freq_emb_scale = freq_emb
if rescale:
rescale_module(self, reference=rescale)
def _spec(self, x):
hl = self.hop_length
nfft = self.nfft
x0 = x # noqa
if self.hybrid:
# We re-pad the signal in order to keep the property
# that the size of the output is exactly the size of the input
# divided by the stride (here hop_length), when divisible.
# This is achieved by padding by 1/4th of the kernel size (here nfft).
# which is not supported by torch.stft.
# Having all convolution operations follow this convention allow to easily
# align the time and frequency branches later on.
assert hl == nfft // 4
le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3
if not self.hybrid_old:
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
else:
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
z = spectro(x, nfft, hl)[..., :-1, :]
if self.hybrid:
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
z = z[..., 2:2+le]
return z
def _ispec(self, z, length=None, scale=0):
hl = self.hop_length // (4 ** scale)
z = F.pad(z, (0, 0, 0, 1))
if self.hybrid:
z = F.pad(z, (2, 2))
pad = hl // 2 * 3
if not self.hybrid_old:
le = hl * int(math.ceil(length / hl)) + 2 * pad
else:
le = hl * int(math.ceil(length / hl))
x = ispectro(z, hl, length=le)
if not self.hybrid_old:
x = x[..., pad:pad + length]
else:
x = x[..., :length]
else:
x = ispectro(z, hl, length)
return x
def _magnitude(self, z):
# return the magnitude of the spectrogram, except when cac is True,
# in which case we just move the complex dimension to the channel one.
if self.cac:
B, C, Fr, T = z.shape
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
m = m.reshape(B, C * 2, Fr, T)
else:
m = z.abs()
return m
def _mask(self, z, m):
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
niters = self.wiener_iters
if self.cac:
B, S, C, Fr, T = m.shape
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
out = torch.view_as_complex(out.contiguous())
return out
if self.training:
niters = self.end_iters
if niters < 0:
z = z[:, None]
return z / (1e-8 + z.abs()) * m
else:
return self._wiener(m, z, niters)
def _wiener(self, mag_out, mix_stft, niters):
# apply wiener filtering from OpenUnmix.
init = mix_stft.dtype
wiener_win_len = 300
residual = self.wiener_residual
B, S, C, Fq, T = mag_out.shape
mag_out = mag_out.permute(0, 4, 3, 2, 1)
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
outs = []
for sample in range(B):
pos = 0
out = []
for pos in range(0, T, wiener_win_len):
frame = slice(pos, pos + wiener_win_len)
z_out = wiener(
mag_out[sample, frame], mix_stft[sample, frame], niters,
residual=residual)
out.append(z_out.transpose(-1, -2))
outs.append(torch.cat(out, dim=0))
out = torch.view_as_complex(torch.stack(outs, 0))
out = out.permute(0, 4, 3, 2, 1).contiguous()
if residual:
out = out[:, :-1]
assert list(out.shape) == [B, S, C, Fq, T]
return out.to(init)
def forward(self, mix):
x = mix
length = x.shape[-1]
z = self._spec(mix)
mag = self._magnitude(z)
x = mag
B, C, Fq, T = x.shape
# unlike previous Demucs, we always normalize because it is easier.
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True)
x = (x - mean) / (1e-5 + std)
# x will be the freq. branch input.
if self.hybrid:
# Prepare the time branch input.
xt = mix
meant = xt.mean(dim=(1, 2), keepdim=True)
stdt = xt.std(dim=(1, 2), keepdim=True)
xt = (xt - meant) / (1e-5 + stdt)
# okay, this is a giant mess I know...
saved = [] # skip connections, freq.
saved_t = [] # skip connections, time.
lengths = [] # saved lengths to properly remove padding, freq branch.
lengths_t = [] # saved lengths for time branch.
for idx, encode in enumerate(self.encoder):
lengths.append(x.shape[-1])
inject = None
if self.hybrid and idx < len(self.tencoder):
# we have not yet merged branches.
lengths_t.append(xt.shape[-1])
tenc = self.tencoder[idx]
xt = tenc(xt)
if not tenc.empty:
# save for skip connection
saved_t.append(xt)
else:
# tenc contains just the first conv., so that now time and freq.
# branches have the same shape and can be merged.
inject = xt
x = encode(x, inject)
if idx == 0 and self.freq_emb is not None:
# add frequency embedding to allow for non equivariant convolutions
# over the frequency axis.
frs = torch.arange(x.shape[-2], device=x.device)
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
x = x + self.freq_emb_scale * emb
saved.append(x)
x = torch.zeros_like(x)
if self.hybrid:
xt = torch.zeros_like(x)
# initialize everything to zero (signal will go through u-net skips).
for idx, decode in enumerate(self.decoder):
skip = saved.pop(-1)
x, pre = decode(x, skip, lengths.pop(-1))
# `pre` contains the output just before final transposed convolution,
# which is used when the freq. and time branch separate.
if self.hybrid:
offset = self.depth - len(self.tdecoder)
if self.hybrid and idx >= offset:
tdec = self.tdecoder[idx - offset]
length_t = lengths_t.pop(-1)
if tdec.empty:
assert pre.shape[2] == 1, pre.shape
pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t)
else:
skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t)
# Let's make sure we used all stored skip connections.
assert len(saved) == 0
assert len(lengths_t) == 0
assert len(saved_t) == 0
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
zout = self._mask(z, x)
x = self._ispec(zout, length)
if self.hybrid:
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
x = xt + x
return x

648
demucs/htdemucs.py Normal file
View File

@ -0,0 +1,648 @@
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# First author is Simon Rouard.
"""
This code contains the spectrogram and Hybrid version of Demucs.
"""
import math
from .filtering import wiener
import torch
from torch import nn
from torch.nn import functional as F
from fractions import Fraction
from einops import rearrange
from .transformer import CrossTransformerEncoder
from .demucs import rescale_module
from .states import capture_init
from .spec import spectro, ispectro
from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
class HTDemucs(nn.Module):
"""
Spectrogram and hybrid Demucs model.
The spectrogram model has the same structure as Demucs, except the first few layers are over the
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
Frequency layers can still access information across time steps thanks to the DConv residual.
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
as the frequency branch and then the two are combined. The opposite happens in the decoder.
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
Open Unmix implementation [Stoter et al. 2019].
The loss is always on the temporal domain, by backpropagating through the above
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
contribution, without changing the one from the waveform, which will lead to worse performance.
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
CaC on the other hand provides similar performance for hybrid, and works naturally with
hybrid models.
This model also uses frequency embeddings are used to improve efficiency on convolutions
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
"""
@capture_init
def __init__(
self,
sources,
# Channels
audio_channels=2,
channels=48,
channels_time=None,
growth=2,
# STFT
nfft=4096,
wiener_iters=0,
end_iters=0,
wiener_residual=False,
cac=True,
# Main structure
depth=4,
rewrite=True,
# Frequency branch
multi_freqs=None,
multi_freqs_depth=3,
freq_emb=0.2,
emb_scale=10,
emb_smooth=True,
# Convolutions
kernel_size=8,
time_stride=2,
stride=4,
context=1,
context_enc=0,
# Normalization
norm_starts=4,
norm_groups=4,
# DConv residual branch
dconv_mode=1,
dconv_depth=2,
dconv_comp=8,
dconv_init=1e-3,
# Before the Transformer
bottom_channels=0,
# Transformer
t_layers=5,
t_emb="sin",
t_hidden_scale=4.0,
t_heads=8,
t_dropout=0.0,
t_max_positions=10000,
t_norm_in=True,
t_norm_in_group=False,
t_group_norm=False,
t_norm_first=True,
t_norm_out=True,
t_max_period=10000.0,
t_weight_decay=0.0,
t_lr=None,
t_layer_scale=True,
t_gelu=True,
t_weight_pos_embed=1.0,
t_sin_random_shift=0,
t_cape_mean_normalize=True,
t_cape_augment=True,
t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
t_sparse_self_attn=False,
t_sparse_cross_attn=False,
t_mask_type="diag",
t_mask_random_seed=42,
t_sparse_attn_window=500,
t_global_window=100,
t_sparsity=0.95,
t_auto_sparsity=False,
# ------ Particuliar parameters
t_cross_first=False,
# Weight init
rescale=0.1,
# Metadata
samplerate=44100,
segment=10,
use_train_segment=True,
):
"""
Args:
sources (list[str]): list of source names.
audio_channels (int): input/output audio channels.
channels (int): initial number of hidden channels.
channels_time: if not None, use a different `channels` value for the time branch.
growth: increase the number of hidden channels by this factor at each layer.
nfft: number of fft bins. Note that changing this require careful computation of
various shape parameters and will not work out of the box for hybrid models.
wiener_iters: when using Wiener filtering, number of iterations at test time.
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
wiener_residual: add residual source before wiener filtering.
cac: uses complex as channels, i.e. complex numbers are 2 channels each
in input and output. no further processing is done before ISTFT.
depth (int): number of layers in the encoder and in the decoder.
rewrite (bool): add 1x1 convolution to each layer.
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
layers will be wrapped.
freq_emb: add frequency embedding after the first frequency layer if > 0,
the actual value controls the weight of the embedding.
emb_scale: equivalent to scaling the embedding learning rate
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
kernel_size: kernel_size for encoder and decoder layers.
stride: stride for encoder and decoder layers.
time_stride: stride for the final time layer, after the merge.
context: context for 1x1 conv in the decoder.
context_enc: context for 1x1 conv in the encoder.
norm_starts: layer at which group norm starts being used.
decoder layers are numbered in reverse order.
norm_groups: number of groups for group norm.
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
dconv_depth: depth of residual DConv branch.
dconv_comp: compression of DConv branch.
dconv_attn: adds attention layers in DConv branch starting at this layer.
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
dconv_init: initial scale for the DConv branch LayerScale.
bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
transformer in order to change the number of channels
t_layers: number of layers in each branch (waveform and spec) of the transformer
t_emb: "sin", "cape" or "scaled"
t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
for instance if C = 384 (the number of channels in the transformer) and
t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
384 * 4 = 1536
t_heads: number of heads for the transformer
t_dropout: dropout in the transformer
t_max_positions: max_positions for the "scaled" positional embedding, only
useful if t_emb="scaled"
t_norm_in: (bool) norm before addinf positional embedding and getting into the
transformer layers
t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
timesteps (GroupNorm with group=1)
t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
timesteps (GroupNorm with group=1)
t_norm_first: (bool) if True the norm is before the attention and before the FFN
t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
t_max_period: (float) denominator in the sinusoidal embedding expression
t_weight_decay: (float) weight decay for the transformer
t_lr: (float) specific learning rate for the transformer
t_layer_scale: (bool) Layer Scale for the transformer
t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
t_weight_pos_embed: (float) weighting of the positional embedding
t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
see: https://arxiv.org/abs/2106.03143
t_cape_augment: (bool) if t_emb="cape", must be True during training and False
during the inference, see: https://arxiv.org/abs/2106.03143
t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
see: https://arxiv.org/abs/2106.03143
t_sparse_self_attn: (bool) if True, the self attentions are sparse
t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
unless you designed really specific masks)
t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
with '_' between: i.e. "diag_jmask_random" (note that this is permutation
invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
that generated the random part of the mask
t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
a key (j), the mask is True id |i-j|<=t_sparse_attn_window
t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
and mask[:, :t_global_window] will be True
t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
level of the random part of the mask.
t_cross_first: (bool) if True cross attention is the first layer of the
transformer (False seems to be better)
rescale: weight rescaling trick
use_train_segment: (bool) if True, the actual size that is used during the
training is used during inference.
"""
super().__init__()
self.cac = cac
self.wiener_residual = wiener_residual
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.bottom_channels = bottom_channels
self.channels = channels
self.samplerate = samplerate
self.segment = segment
self.use_train_segment = use_train_segment
self.nfft = nfft
self.hop_length = nfft // 4
self.wiener_iters = wiener_iters
self.end_iters = end_iters
self.freq_emb = None
assert wiener_iters == end_iters
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.tencoder = nn.ModuleList()
self.tdecoder = nn.ModuleList()
chin = audio_channels
chin_z = chin # number of channels for the freq branch
if self.cac:
chin_z *= 2
chout = channels_time or channels
chout_z = channels
freqs = nfft // 2
for index in range(depth):
norm = index >= norm_starts
freq = freqs > 1
stri = stride
ker = kernel_size
if not freq:
assert freqs == 1
ker = time_stride * 2
stri = time_stride
pad = True
last_freq = False
if freq and freqs <= kernel_size:
ker = freqs
pad = False
last_freq = True
kw = {
"kernel_size": ker,
"stride": stri,
"freq": freq,
"pad": pad,
"norm": norm,
"rewrite": rewrite,
"norm_groups": norm_groups,
"dconv_kw": {
"depth": dconv_depth,
"compress": dconv_comp,
"init": dconv_init,
"gelu": True,
},
}
kwt = dict(kw)
kwt["freq"] = 0
kwt["kernel_size"] = kernel_size
kwt["stride"] = stride
kwt["pad"] = True
kw_dec = dict(kw)
multi = False
if multi_freqs and index < multi_freqs_depth:
multi = True
kw_dec["context_freq"] = False
if last_freq:
chout_z = max(chout, chout_z)
chout = chout_z
enc = HEncLayer(
chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
)
if freq:
tenc = HEncLayer(
chin,
chout,
dconv=dconv_mode & 1,
context=context_enc,
empty=last_freq,
**kwt
)
self.tencoder.append(tenc)
if multi:
enc = MultiWrap(enc, multi_freqs)
self.encoder.append(enc)
if index == 0:
chin = self.audio_channels * len(self.sources)
chin_z = chin
if self.cac:
chin_z *= 2
dec = HDecLayer(
chout_z,
chin_z,
dconv=dconv_mode & 2,
last=index == 0,
context=context,
**kw_dec
)
if multi:
dec = MultiWrap(dec, multi_freqs)
if freq:
tdec = HDecLayer(
chout,
chin,
dconv=dconv_mode & 2,
empty=last_freq,
last=index == 0,
context=context,
**kwt
)
self.tdecoder.insert(0, tdec)
self.decoder.insert(0, dec)
chin = chout
chin_z = chout_z
chout = int(growth * chout)
chout_z = int(growth * chout_z)
if freq:
if freqs <= kernel_size:
freqs = 1
else:
freqs //= stride
if index == 0 and freq_emb:
self.freq_emb = ScaledEmbedding(
freqs, chin_z, smooth=emb_smooth, scale=emb_scale
)
self.freq_emb_scale = freq_emb
if rescale:
rescale_module(self, reference=rescale)
transformer_channels = channels * growth ** (depth - 1)
if bottom_channels:
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
self.channel_downsampler = nn.Conv1d(
bottom_channels, transformer_channels, 1
)
self.channel_upsampler_t = nn.Conv1d(
transformer_channels, bottom_channels, 1
)
self.channel_downsampler_t = nn.Conv1d(
bottom_channels, transformer_channels, 1
)
transformer_channels = bottom_channels
if t_layers > 0:
self.crosstransformer = CrossTransformerEncoder(
dim=transformer_channels,
emb=t_emb,
hidden_scale=t_hidden_scale,
num_heads=t_heads,
num_layers=t_layers,
cross_first=t_cross_first,
dropout=t_dropout,
max_positions=t_max_positions,
norm_in=t_norm_in,
norm_in_group=t_norm_in_group,
group_norm=t_group_norm,
norm_first=t_norm_first,
norm_out=t_norm_out,
max_period=t_max_period,
weight_decay=t_weight_decay,
lr=t_lr,
layer_scale=t_layer_scale,
gelu=t_gelu,
sin_random_shift=t_sin_random_shift,
weight_pos_embed=t_weight_pos_embed,
cape_mean_normalize=t_cape_mean_normalize,
cape_augment=t_cape_augment,
cape_glob_loc_scale=t_cape_glob_loc_scale,
sparse_self_attn=t_sparse_self_attn,
sparse_cross_attn=t_sparse_cross_attn,
mask_type=t_mask_type,
mask_random_seed=t_mask_random_seed,
sparse_attn_window=t_sparse_attn_window,
global_window=t_global_window,
sparsity=t_sparsity,
auto_sparsity=t_auto_sparsity,
)
else:
self.crosstransformer = None
def _spec(self, x):
hl = self.hop_length
nfft = self.nfft
x0 = x # noqa
# We re-pad the signal in order to keep the property
# that the size of the output is exactly the size of the input
# divided by the stride (here hop_length), when divisible.
# This is achieved by padding by 1/4th of the kernel size (here nfft).
# which is not supported by torch.stft.
# Having all convolution operations follow this convention allow to easily
# align the time and frequency branches later on.
assert hl == nfft // 4
le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
z = spectro(x, nfft, hl)[..., :-1, :]
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
z = z[..., 2: 2 + le]
return z
def _ispec(self, z, length=None, scale=0):
hl = self.hop_length // (4**scale)
z = F.pad(z, (0, 0, 0, 1))
z = F.pad(z, (2, 2))
pad = hl // 2 * 3
le = hl * int(math.ceil(length / hl)) + 2 * pad
x = ispectro(z, hl, length=le)
x = x[..., pad: pad + length]
return x
def _magnitude(self, z):
# return the magnitude of the spectrogram, except when cac is True,
# in which case we just move the complex dimension to the channel one.
if self.cac:
B, C, Fr, T = z.shape
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
m = m.reshape(B, C * 2, Fr, T)
else:
m = z.abs()
return m
def _mask(self, z, m):
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
niters = self.wiener_iters
if self.cac:
B, S, C, Fr, T = m.shape
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
out = torch.view_as_complex(out.contiguous())
return out
if self.training:
niters = self.end_iters
if niters < 0:
z = z[:, None]
return z / (1e-8 + z.abs()) * m
else:
return self._wiener(m, z, niters)
def _wiener(self, mag_out, mix_stft, niters):
# apply wiener filtering from OpenUnmix.
init = mix_stft.dtype
wiener_win_len = 300
residual = self.wiener_residual
B, S, C, Fq, T = mag_out.shape
mag_out = mag_out.permute(0, 4, 3, 2, 1)
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
outs = []
for sample in range(B):
pos = 0
out = []
for pos in range(0, T, wiener_win_len):
frame = slice(pos, pos + wiener_win_len)
z_out = wiener(
mag_out[sample, frame],
mix_stft[sample, frame],
niters,
residual=residual,
)
out.append(z_out.transpose(-1, -2))
outs.append(torch.cat(out, dim=0))
out = torch.view_as_complex(torch.stack(outs, 0))
out = out.permute(0, 4, 3, 2, 1).contiguous()
if residual:
out = out[:, :-1]
assert list(out.shape) == [B, S, C, Fq, T]
return out.to(init)
def valid_length(self, length: int):
"""
Return a length that is appropriate for evaluation.
In our case, always return the training length, unless
it is smaller than the given length, in which case this
raises an error.
"""
if not self.use_train_segment:
return length
training_length = int(self.segment * self.samplerate)
if training_length < length:
raise ValueError(
f"Given length {length} is longer than "
f"training length {training_length}")
return training_length
def forward(self, mix):
length = mix.shape[-1]
length_pre_pad = None
if self.use_train_segment:
if self.training:
self.segment = Fraction(mix.shape[-1], self.samplerate)
else:
training_length = int(self.segment * self.samplerate)
if mix.shape[-1] < training_length:
length_pre_pad = mix.shape[-1]
mix = F.pad(mix, (0, training_length - length_pre_pad))
z = self._spec(mix)
mag = self._magnitude(z)
x = mag
B, C, Fq, T = x.shape
# unlike previous Demucs, we always normalize because it is easier.
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True)
x = (x - mean) / (1e-5 + std)
# x will be the freq. branch input.
# Prepare the time branch input.
xt = mix
meant = xt.mean(dim=(1, 2), keepdim=True)
stdt = xt.std(dim=(1, 2), keepdim=True)
xt = (xt - meant) / (1e-5 + stdt)
# okay, this is a giant mess I know...
saved = [] # skip connections, freq.
saved_t = [] # skip connections, time.
lengths = [] # saved lengths to properly remove padding, freq branch.
lengths_t = [] # saved lengths for time branch.
for idx, encode in enumerate(self.encoder):
lengths.append(x.shape[-1])
inject = None
if idx < len(self.tencoder):
# we have not yet merged branches.
lengths_t.append(xt.shape[-1])
tenc = self.tencoder[idx]
xt = tenc(xt)
if not tenc.empty:
# save for skip connection
saved_t.append(xt)
else:
# tenc contains just the first conv., so that now time and freq.
# branches have the same shape and can be merged.
inject = xt
x = encode(x, inject)
if idx == 0 and self.freq_emb is not None:
# add frequency embedding to allow for non equivariant convolutions
# over the frequency axis.
frs = torch.arange(x.shape[-2], device=x.device)
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
x = x + self.freq_emb_scale * emb
saved.append(x)
if self.crosstransformer:
if self.bottom_channels:
b, c, f, t = x.shape
x = rearrange(x, "b c f t-> b c (f t)")
x = self.channel_upsampler(x)
x = rearrange(x, "b c (f t)-> b c f t", f=f)
xt = self.channel_upsampler_t(xt)
x, xt = self.crosstransformer(x, xt)
if self.bottom_channels:
x = rearrange(x, "b c f t-> b c (f t)")
x = self.channel_downsampler(x)
x = rearrange(x, "b c (f t)-> b c f t", f=f)
xt = self.channel_downsampler_t(xt)
for idx, decode in enumerate(self.decoder):
skip = saved.pop(-1)
x, pre = decode(x, skip, lengths.pop(-1))
# `pre` contains the output just before final transposed convolution,
# which is used when the freq. and time branch separate.
offset = self.depth - len(self.tdecoder)
if idx >= offset:
tdec = self.tdecoder[idx - offset]
length_t = lengths_t.pop(-1)
if tdec.empty:
assert pre.shape[2] == 1, pre.shape
pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t)
else:
skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t)
# Let's make sure we used all stored skip connections.
assert len(saved) == 0
assert len(lengths_t) == 0
assert len(saved_t) == 0
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
zout = self._mask(z, x)
if self.use_train_segment:
if self.training:
x = self._ispec(zout, length)
else:
x = self._ispec(zout, training_length)
else:
x = self._ispec(zout, length)
if self.use_train_segment:
if self.training:
xt = xt.view(B, S, -1, length)
else:
xt = xt.view(B, S, -1, training_length)
else:
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
x = xt + x
if length_pre_pad:
x = x[..., :length_pre_pad]
return x

218
demucs/model.py Normal file
View File

@ -0,0 +1,218 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch as th
from torch import nn
from .utils import capture_init, center_trim
class BLSTM(nn.Module):
def __init__(self, dim, layers=1):
super().__init__()
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x):
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
return x
def rescale_conv(conv, reference):
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)
def upsample(x, stride):
"""
Linear upsampling, the output will be `stride` times longer.
"""
batch, channels, time = x.size()
weight = th.arange(stride, device=x.device, dtype=th.float) / stride
x = x.view(batch, channels, time, 1)
out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
return out.reshape(batch, channels, -1)
def downsample(x, stride):
"""
Downsample x by decimation.
"""
return x[:, :, ::stride]
class Demucs(nn.Module):
@capture_init
def __init__(self,
sources=4,
audio_channels=2,
channels=64,
depth=6,
rewrite=True,
glu=True,
upsample=False,
rescale=0.1,
kernel_size=8,
stride=4,
growth=2.,
lstm_layers=2,
context=3,
samplerate=44100):
"""
Args:
sources (int): number of sources to separate
audio_channels (int): stereo or mono
channels (int): first convolution channels
depth (int): number of encoder/decoder layers
rewrite (bool): add 1x1 convolution to each encoder layer
and a convolution to each decoder layer.
For the decoder layer, `context` gives the kernel size.
glu (bool): use glu instead of ReLU
upsample (bool): use linear upsampling with convolutions
Wave-U-Net style, instead of transposed convolutions
rescale (int): rescale initial weights of convolutions
to get their standard deviation closer to `rescale`
kernel_size (int): kernel size for convolutions
stride (int): stride for convolutions
growth (float): multiply (resp divide) number of channels by that
for each layer of the encoder (resp decoder)
lstm_layers (int): number of lstm layers, 0 = no lstm
context (int): kernel size of the convolution in the
decoder before the transposed convolution. If > 1,
will provide some context from neighboring time
steps.
"""
super().__init__()
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.upsample = upsample
self.channels = channels
self.samplerate = samplerate
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.final = None
if upsample:
self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
stride = 1
if glu:
activation = nn.GLU(dim=1)
ch_scale = 2
else:
activation = nn.ReLU()
ch_scale = 1
in_channels = audio_channels
for index in range(depth):
encode = []
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
if rewrite:
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
self.encoder.append(nn.Sequential(*encode))
decode = []
if index > 0:
out_channels = in_channels
else:
if upsample:
out_channels = channels
else:
out_channels = sources * audio_channels
if rewrite:
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
if upsample:
decode += [
nn.Conv1d(channels, out_channels, kernel_size, stride=1),
]
else:
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
if index > 0:
decode.append(nn.ReLU())
self.decoder.insert(0, nn.Sequential(*decode))
in_channels = channels
channels = int(growth * channels)
channels = in_channels
if lstm_layers:
self.lstm = BLSTM(channels, lstm_layers)
else:
self.lstm = None
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length when context = 1. If context > 1,
the two signals can be center trimmed to match.
For training, extracts should have a valid length.For evaluation
on full tracks we recommend passing `pad = True` to :method:`forward`.
"""
for _ in range(self.depth):
if self.upsample:
length = math.ceil(length / self.stride) + self.kernel_size - 1
else:
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(1, length)
length += self.context - 1
for _ in range(self.depth):
if self.upsample:
length = length * self.stride + self.kernel_size - 1
else:
length = (length - 1) * self.stride + self.kernel_size
return int(length)
def forward(self, mix):
x = mix
saved = [x]
for encode in self.encoder:
x = encode(x)
saved.append(x)
if self.upsample:
x = downsample(x, self.stride)
if self.lstm:
x = self.lstm(x)
for decode in self.decoder:
if self.upsample:
x = upsample(x, stride=self.stride)
skip = center_trim(saved.pop(-1), x)
x = x + skip
x = decode(x)
if self.final:
skip = center_trim(saved.pop(-1), x)
x = th.cat([x, skip], dim=1)
x = self.final(x)
x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
return x

218
demucs/model_v2.py Normal file
View File

@ -0,0 +1,218 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import julius
from torch import nn
from .tasnet_v2 import ConvTasNet
from .utils import capture_init, center_trim
class BLSTM(nn.Module):
def __init__(self, dim, layers=1):
super().__init__()
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x):
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
return x
def rescale_conv(conv, reference):
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)
def auto_load_demucs_model_v2(sources, demucs_model_name):
if '48' in demucs_model_name:
channels=48
elif 'unittest' in demucs_model_name:
channels=4
else:
channels=64
if 'tasnet' in demucs_model_name:
init_demucs_model = ConvTasNet(sources, X=10)
else:
init_demucs_model = Demucs(sources, channels=channels)
return init_demucs_model
class Demucs(nn.Module):
@capture_init
def __init__(self,
sources,
audio_channels=2,
channels=64,
depth=6,
rewrite=True,
glu=True,
rescale=0.1,
resample=True,
kernel_size=8,
stride=4,
growth=2.,
lstm_layers=2,
context=3,
normalize=False,
samplerate=44100,
segment_length=4 * 10 * 44100):
"""
Args:
sources (list[str]): list of source names
audio_channels (int): stereo or mono
channels (int): first convolution channels
depth (int): number of encoder/decoder layers
rewrite (bool): add 1x1 convolution to each encoder layer
and a convolution to each decoder layer.
For the decoder layer, `context` gives the kernel size.
glu (bool): use glu instead of ReLU
resample_input (bool): upsample x2 the input and downsample /2 the output.
rescale (int): rescale initial weights of convolutions
to get their standard deviation closer to `rescale`
kernel_size (int): kernel size for convolutions
stride (int): stride for convolutions
growth (float): multiply (resp divide) number of channels by that
for each layer of the encoder (resp decoder)
lstm_layers (int): number of lstm layers, 0 = no lstm
context (int): kernel size of the convolution in the
decoder before the transposed convolution. If > 1,
will provide some context from neighboring time
steps.
samplerate (int): stored as meta information for easing
future evaluations of the model.
segment_length (int): stored as meta information for easing
future evaluations of the model. Length of the segments on which
the model was trained.
"""
super().__init__()
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.resample = resample
self.channels = channels
self.normalize = normalize
self.samplerate = samplerate
self.segment_length = segment_length
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
if glu:
activation = nn.GLU(dim=1)
ch_scale = 2
else:
activation = nn.ReLU()
ch_scale = 1
in_channels = audio_channels
for index in range(depth):
encode = []
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
if rewrite:
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
self.encoder.append(nn.Sequential(*encode))
decode = []
if index > 0:
out_channels = in_channels
else:
out_channels = len(self.sources) * audio_channels
if rewrite:
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
if index > 0:
decode.append(nn.ReLU())
self.decoder.insert(0, nn.Sequential(*decode))
in_channels = channels
channels = int(growth * channels)
channels = in_channels
if lstm_layers:
self.lstm = BLSTM(channels, lstm_layers)
else:
self.lstm = None
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length when context = 1. If context > 1,
the two signals can be center trimmed to match.
For training, extracts should have a valid length.For evaluation
on full tracks we recommend passing `pad = True` to :method:`forward`.
"""
if self.resample:
length *= 2
for _ in range(self.depth):
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(1, length)
length += self.context - 1
for _ in range(self.depth):
length = (length - 1) * self.stride + self.kernel_size
if self.resample:
length = math.ceil(length / 2)
return int(length)
def forward(self, mix):
x = mix
if self.normalize:
mono = mix.mean(dim=1, keepdim=True)
mean = mono.mean(dim=-1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
else:
mean = 0
std = 1
x = (x - mean) / (1e-5 + std)
if self.resample:
x = julius.resample_frac(x, 1, 2)
saved = []
for encode in self.encoder:
x = encode(x)
saved.append(x)
if self.lstm:
x = self.lstm(x)
for decode in self.decoder:
skip = center_trim(saved.pop(-1), x)
x = x + skip
x = decode(x)
if self.resample:
x = julius.resample_frac(x, 2, 1)
x = x * std + mean
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
return x

180
demucs/pretrained.py Normal file
View File

@ -0,0 +1,180 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Loading pretrained models.
"""
import logging
from pathlib import Path
import typing as tp
from dora.log import fatal
import logging
from diffq import DiffQuantizer
import torch.hub
from .model import Demucs
from .tasnet_v2 import ConvTasNet
from .utils import set_state
from .hdemucs import HDemucs
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
logger = logging.getLogger(__name__)
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
REMOTE_ROOT = Path(__file__).parent / 'remote'
SOURCES = ["drums", "bass", "other", "vocals"]
def demucs_unittest():
model = HDemucs(channels=4, sources=SOURCES)
return model
def add_model_flags(parser):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
group.add_argument("-n", "--name", default="mdx_extra_q",
help="Pretrained model name or signature. Default is mdx_extra_q.")
parser.add_argument("--repo", type=Path,
help="Folder containing all pre-trained models for use with -n.")
def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
root: str = ''
models: tp.Dict[str, str] = {}
for line in remote_file_list.read_text().split('\n'):
line = line.strip()
if line.startswith('#'):
continue
elif line.startswith('root:'):
root = line.split(':', 1)[1].strip()
else:
sig = line.split('-', 1)[0]
assert sig not in models
models[sig] = ROOT_URL + root + line
return models
def get_model(name: str,
repo: tp.Optional[Path] = None):
"""`name` must be a bag of models name or a pretrained signature
from the remote AWS model repo or the specified local repo if `repo` is not None.
"""
if name == 'demucs_unittest':
return demucs_unittest()
model_repo: ModelOnlyRepo
if repo is None:
models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
model_repo = RemoteRepo(models)
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
else:
if not repo.is_dir():
fatal(f"{repo} must exist and be a directory.")
model_repo = LocalRepo(repo)
bag_repo = BagOnlyRepo(repo, model_repo)
any_repo = AnyModelRepo(model_repo, bag_repo)
model = any_repo.get_model(name)
model.eval()
return model
def get_model_from_args(args):
"""
Load local model package or pre-trained model.
"""
return get_model(name=args.name, repo=args.repo)
logger = logging.getLogger(__name__)
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
PRETRAINED_MODELS = {
'demucs': 'e07c671f',
'demucs48_hq': '28a1282c',
'demucs_extra': '3646af93',
'demucs_quantized': '07afea75',
'tasnet': 'beb46fac',
'tasnet_extra': 'df3777b2',
'demucs_unittest': '09ebc15f',
}
SOURCES = ["drums", "bass", "other", "vocals"]
def get_url(name):
sig = PRETRAINED_MODELS[name]
return ROOT + name + "-" + sig[:8] + ".th"
def is_pretrained(name):
return name in PRETRAINED_MODELS
def load_pretrained(name):
if name == "demucs":
return demucs(pretrained=True)
elif name == "demucs48_hq":
return demucs(pretrained=True, hq=True, channels=48)
elif name == "demucs_extra":
return demucs(pretrained=True, extra=True)
elif name == "demucs_quantized":
return demucs(pretrained=True, quantized=True)
elif name == "demucs_unittest":
return demucs_unittest(pretrained=True)
elif name == "tasnet":
return tasnet(pretrained=True)
elif name == "tasnet_extra":
return tasnet(pretrained=True, extra=True)
else:
raise ValueError(f"Invalid pretrained name {name}")
def _load_state(name, model, quantizer=None):
url = get_url(name)
state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
set_state(model, quantizer, state)
if quantizer:
quantizer.detach()
def demucs_unittest(pretrained=True):
model = Demucs(channels=4, sources=SOURCES)
if pretrained:
_load_state('demucs_unittest', model)
return model
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
if not pretrained and (extra or quantized or hq):
raise ValueError("if extra or quantized is True, pretrained must be True.")
model = Demucs(sources=SOURCES, channels=channels)
if pretrained:
name = 'demucs'
if channels != 64:
name += str(channels)
quantizer = None
if sum([extra, quantized, hq]) > 1:
raise ValueError("Only one of extra, quantized, hq, can be True.")
if quantized:
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
name += '_quantized'
if extra:
name += '_extra'
if hq:
name += '_hq'
_load_state(name, model, quantizer)
return model
def tasnet(pretrained=True, extra=False):
if not pretrained and extra:
raise ValueError("if extra is True, pretrained must be True.")
model = ConvTasNet(X=10, sources=SOURCES)
if pretrained:
name = 'tasnet'
if extra:
name = 'tasnet_extra'
_load_state(name, model)
return model

148
demucs/repo.py Normal file
View File

@ -0,0 +1,148 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Represents a model repository, including pre-trained models and bags of models.
A repo can either be the main remote repository stored in AWS, or a local repository
with your own models.
"""
from hashlib import sha256
from pathlib import Path
import typing as tp
import torch
import yaml
from .apply import BagOfModels, Model
from .states import load_model
AnyModel = tp.Union[Model, BagOfModels]
class ModelLoadingError(RuntimeError):
pass
def check_checksum(path: Path, checksum: str):
sha = sha256()
with open(path, 'rb') as file:
while True:
buf = file.read(2**20)
if not buf:
break
sha.update(buf)
actual_checksum = sha.hexdigest()[:len(checksum)]
if actual_checksum != checksum:
raise ModelLoadingError(f'Invalid checksum for file {path}, '
f'expected {checksum} but got {actual_checksum}')
class ModelOnlyRepo:
"""Base class for all model only repos.
"""
def has_model(self, sig: str) -> bool:
raise NotImplementedError()
def get_model(self, sig: str) -> Model:
raise NotImplementedError()
class RemoteRepo(ModelOnlyRepo):
def __init__(self, models: tp.Dict[str, str]):
self._models = models
def has_model(self, sig: str) -> bool:
return sig in self._models
def get_model(self, sig: str) -> Model:
try:
url = self._models[sig]
except KeyError:
raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
return load_model(pkg)
class LocalRepo(ModelOnlyRepo):
def __init__(self, root: Path):
self.root = root
self.scan()
def scan(self):
self._models = {}
self._checksums = {}
for file in self.root.iterdir():
if file.suffix == '.th':
if '-' in file.stem:
xp_sig, checksum = file.stem.split('-')
self._checksums[xp_sig] = checksum
else:
xp_sig = file.stem
if xp_sig in self._models:
print('Whats xp? ', xp_sig)
raise ModelLoadingError(
f'Duplicate pre-trained model exist for signature {xp_sig}. '
'Please delete all but one.')
self._models[xp_sig] = file
def has_model(self, sig: str) -> bool:
return sig in self._models
def get_model(self, sig: str) -> Model:
try:
file = self._models[sig]
except KeyError:
raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
if sig in self._checksums:
check_checksum(file, self._checksums[sig])
return load_model(file)
class BagOnlyRepo:
"""Handles only YAML files containing bag of models, leaving the actual
model loading to some Repo.
"""
def __init__(self, root: Path, model_repo: ModelOnlyRepo):
self.root = root
self.model_repo = model_repo
self.scan()
def scan(self):
self._bags = {}
for file in self.root.iterdir():
if file.suffix == '.yaml':
self._bags[file.stem] = file
def has_model(self, name: str) -> bool:
return name in self._bags
def get_model(self, name: str) -> BagOfModels:
try:
yaml_file = self._bags[name]
except KeyError:
raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
'a bag of models.')
bag = yaml.safe_load(open(yaml_file))
signatures = bag['models']
models = [self.model_repo.get_model(sig) for sig in signatures]
weights = bag.get('weights')
segment = bag.get('segment')
return BagOfModels(models, weights, segment)
class AnyModelRepo:
def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
self.model_repo = model_repo
self.bag_repo = bag_repo
def has_model(self, name_or_sig: str) -> bool:
return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
def get_model(self, name_or_sig: str) -> AnyModel:
print('name_or_sig: ', name_or_sig)
if self.model_repo.has_model(name_or_sig):
return self.model_repo.get_model(name_or_sig)
else:
return self.bag_repo.get_model(name_or_sig)

41
demucs/spec.py Normal file
View File

@ -0,0 +1,41 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Conveniance wrapper to perform STFT and iSTFT"""
import torch as th
def spectro(x, n_fft=512, hop_length=None, pad=0):
*other, length = x.shape
x = x.reshape(-1, length)
z = th.stft(x,
n_fft * (1 + pad),
hop_length or n_fft // 4,
window=th.hann_window(n_fft).to(x),
win_length=n_fft,
normalized=True,
center=True,
return_complex=True,
pad_mode='reflect')
_, freqs, frame = z.shape
return z.view(*other, freqs, frame)
def ispectro(z, hop_length=None, length=None, pad=0):
*other, freqs, frames = z.shape
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
x = th.istft(z,
n_fft,
hop_length,
window=th.hann_window(win_length).to(z.real),
win_length=win_length,
normalized=True,
length=length,
center=True)
_, length = x.shape
return x.view(*other, length)

148
demucs/states.py Normal file
View File

@ -0,0 +1,148 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utilities to save and load models.
"""
from contextlib import contextmanager
import functools
import hashlib
import inspect
import io
from pathlib import Path
import warnings
from omegaconf import OmegaConf
from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
import torch
def get_quantizer(model, args, optimizer=None):
"""Return the quantizer given the XP quantization args."""
quantizer = None
if args.diffq:
quantizer = DiffQuantizer(
model, min_size=args.min_size, group_size=args.group_size)
if optimizer is not None:
quantizer.setup_optimizer(optimizer)
elif args.qat:
quantizer = UniformQuantizer(
model, bits=args.qat, min_size=args.min_size)
return quantizer
def load_model(path_or_package, strict=False):
"""Load a model from the given serialized model, either given as a dict (already loaded)
or a path to a file on disk."""
if isinstance(path_or_package, dict):
package = path_or_package
elif isinstance(path_or_package, (str, Path)):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
path = path_or_package
package = torch.load(path, 'cpu')
else:
raise ValueError(f"Invalid type for {path_or_package}.")
klass = package["klass"]
args = package["args"]
kwargs = package["kwargs"]
if strict:
model = klass(*args, **kwargs)
else:
sig = inspect.signature(klass)
for key in list(kwargs):
if key not in sig.parameters:
warnings.warn("Dropping inexistant parameter " + key)
del kwargs[key]
model = klass(*args, **kwargs)
state = package["state"]
set_state(model, state)
return model
def get_state(model, quantizer, half=False):
"""Get the state from a model, potentially with quantization applied.
If `half` is True, model are stored as half precision, which shouldn't impact performance
but half the state size."""
if quantizer is None:
dtype = torch.half if half else None
state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
else:
state = quantizer.get_quantized_state()
state['__quantized'] = True
return state
def set_state(model, state, quantizer=None):
"""Set the state on a given model."""
if state.get('__quantized'):
if quantizer is not None:
quantizer.restore_quantized_state(model, state['quantized'])
else:
restore_quantized_state(model, state)
else:
model.load_state_dict(state)
return state
def save_with_checksum(content, path):
"""Save the given value on disk, along with a sha256 hash.
Should be used with the output of either `serialize_model` or `get_state`."""
buf = io.BytesIO()
torch.save(content, buf)
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
path = path.parent / (path.stem + "-" + sig + path.suffix)
path.write_bytes(buf.getvalue())
def serialize_model(model, training_args, quantizer=None, half=True):
args, kwargs = model._init_args_kwargs
klass = model.__class__
state = get_state(model, quantizer, half)
return {
'klass': klass,
'args': args,
'kwargs': kwargs,
'state': state,
'training_args': OmegaConf.to_container(training_args, resolve=True),
}
def copy_state(state):
return {k: v.cpu().clone() for k, v in state.items()}
@contextmanager
def swap_state(model, state):
"""
Context manager that swaps the state of a model, e.g:
# model is in old state
with swap_state(model, new_state):
# model in new state
# model back to old state
"""
old_state = copy_state(model.state_dict())
model.load_state_dict(state, strict=False)
try:
yield
finally:
model.load_state_dict(old_state)
def capture_init(init):
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__

447
demucs/tasnet.py Normal file
View File

@ -0,0 +1,447 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Created on 2018/12
# Author: Kaituo XU
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
# Here is the original license:
# The MIT License (MIT)
#
# Copyright (c) 2018 Kaituo XU
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import capture_init
EPS = 1e-8
def overlap_and_add(signal, frame_step):
outer_dimensions = signal.size()[:-2]
frames, frame_length = signal.size()[-2:]
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
subframe_step = frame_step // subframe_length
subframes_per_frame = frame_length // subframe_length
output_size = frame_step * (frames - 1) + frame_length
output_subframes = output_size // subframe_length
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
frame = torch.arange(0, output_subframes,
device=signal.device).unfold(0, subframes_per_frame, subframe_step)
frame = frame.long() # signal may in GPU or CPU
frame = frame.contiguous().view(-1)
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
result.index_add_(-2, frame, subframe_signal)
result = result.view(*outer_dimensions, -1)
return result
class ConvTasNet(nn.Module):
@capture_init
def __init__(self,
N=256,
L=20,
B=256,
H=512,
P=3,
X=8,
R=4,
C=4,
audio_channels=1,
samplerate=44100,
norm_type="gLN",
causal=False,
mask_nonlinear='relu'):
"""
Args:
N: Number of filters in autoencoder
L: Length of the filters (in samples)
B: Number of channels in bottleneck 1 × 1-conv block
H: Number of channels in convolutional blocks
P: Kernel size in convolutional blocks
X: Number of convolutional blocks in each repeat
R: Number of repeats
C: Number of speakers
norm_type: BN, gLN, cLN
causal: causal or non-causal
mask_nonlinear: use which non-linear function to generate mask
"""
super(ConvTasNet, self).__init__()
# Hyper-parameter
self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
self.norm_type = norm_type
self.causal = causal
self.mask_nonlinear = mask_nonlinear
self.audio_channels = audio_channels
self.samplerate = samplerate
# Components
self.encoder = Encoder(L, N, audio_channels)
self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
self.decoder = Decoder(N, L, audio_channels)
# init
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def valid_length(self, length):
return length
def forward(self, mixture):
"""
Args:
mixture: [M, T], M is batch size, T is #samples
Returns:
est_source: [M, C, T]
"""
mixture_w = self.encoder(mixture)
est_mask = self.separator(mixture_w)
est_source = self.decoder(mixture_w, est_mask)
# T changed after conv1d in encoder, fix it here
T_origin = mixture.size(-1)
T_conv = est_source.size(-1)
est_source = F.pad(est_source, (0, T_origin - T_conv))
return est_source
class Encoder(nn.Module):
"""Estimation of the nonnegative mixture weight by a 1-D conv layer.
"""
def __init__(self, L, N, audio_channels):
super(Encoder, self).__init__()
# Hyper-parameter
self.L, self.N = L, N
# Components
# 50% overlap
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
def forward(self, mixture):
"""
Args:
mixture: [M, T], M is batch size, T is #samples
Returns:
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
"""
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
return mixture_w
class Decoder(nn.Module):
def __init__(self, N, L, audio_channels):
super(Decoder, self).__init__()
# Hyper-parameter
self.N, self.L = N, L
self.audio_channels = audio_channels
# Components
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
def forward(self, mixture_w, est_mask):
"""
Args:
mixture_w: [M, N, K]
est_mask: [M, C, N, K]
Returns:
est_source: [M, C, T]
"""
# D = W * M
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
# S = DV
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
m, c, k, _ = est_source.size()
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
return est_source
class TemporalConvNet(nn.Module):
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
"""
Args:
N: Number of filters in autoencoder
B: Number of channels in bottleneck 1 × 1-conv block
H: Number of channels in convolutional blocks
P: Kernel size in convolutional blocks
X: Number of convolutional blocks in each repeat
R: Number of repeats
C: Number of speakers
norm_type: BN, gLN, cLN
causal: causal or non-causal
mask_nonlinear: use which non-linear function to generate mask
"""
super(TemporalConvNet, self).__init__()
# Hyper-parameter
self.C = C
self.mask_nonlinear = mask_nonlinear
# Components
# [M, N, K] -> [M, N, K]
layer_norm = ChannelwiseLayerNorm(N)
# [M, N, K] -> [M, B, K]
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
# [M, B, K] -> [M, B, K]
repeats = []
for r in range(R):
blocks = []
for x in range(X):
dilation = 2**x
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
blocks += [
TemporalBlock(B,
H,
P,
stride=1,
padding=padding,
dilation=dilation,
norm_type=norm_type,
causal=causal)
]
repeats += [nn.Sequential(*blocks)]
temporal_conv_net = nn.Sequential(*repeats)
# [M, B, K] -> [M, C*N, K]
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
# Put together
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
mask_conv1x1)
def forward(self, mixture_w):
"""
Keep this API same with TasNet
Args:
mixture_w: [M, N, K], M is batch size
returns:
est_mask: [M, C, N, K]
"""
M, N, K = mixture_w.size()
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
if self.mask_nonlinear == 'softmax':
est_mask = F.softmax(score, dim=1)
elif self.mask_nonlinear == 'relu':
est_mask = F.relu(score)
else:
raise ValueError("Unsupported mask non-linear function")
return est_mask
class TemporalBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
norm_type="gLN",
causal=False):
super(TemporalBlock, self).__init__()
# [M, B, K] -> [M, H, K]
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
prelu = nn.PReLU()
norm = chose_norm(norm_type, out_channels)
# [M, H, K] -> [M, B, K]
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
dilation, norm_type, causal)
# Put together
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
def forward(self, x):
"""
Args:
x: [M, B, K]
Returns:
[M, B, K]
"""
residual = x
out = self.net(x)
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
return out + residual # look like w/o F.relu is better than w/ F.relu
# return F.relu(out + residual)
class DepthwiseSeparableConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
norm_type="gLN",
causal=False):
super(DepthwiseSeparableConv, self).__init__()
# Use `groups` option to implement depthwise convolution
# [M, H, K] -> [M, H, K]
depthwise_conv = nn.Conv1d(in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=False)
if causal:
chomp = Chomp1d(padding)
prelu = nn.PReLU()
norm = chose_norm(norm_type, in_channels)
# [M, H, K] -> [M, B, K]
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
# Put together
if causal:
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
else:
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
def forward(self, x):
"""
Args:
x: [M, H, K]
Returns:
result: [M, B, K]
"""
return self.net(x)
class Chomp1d(nn.Module):
"""To ensure the output length is the same as the input.
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
"""
Args:
x: [M, H, Kpad]
Returns:
[M, H, K]
"""
return x[:, :, :-self.chomp_size].contiguous()
def chose_norm(norm_type, channel_size):
"""The input of normlization will be (M, C, K), where M is batch size,
C is channel size and K is sequence length.
"""
if norm_type == "gLN":
return GlobalLayerNorm(channel_size)
elif norm_type == "cLN":
return ChannelwiseLayerNorm(channel_size)
elif norm_type == "id":
return nn.Identity()
else: # norm_type == "BN":
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
# along M and K, so this BN usage is right.
return nn.BatchNorm1d(channel_size)
# TODO: Use nn.LayerNorm to impl cLN to speed up
class ChannelwiseLayerNorm(nn.Module):
"""Channel-wise Layer Normalization (cLN)"""
def __init__(self, channel_size):
super(ChannelwiseLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.reset_parameters()
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, y):
"""
Args:
y: [M, N, K], M is batch size, N is channel size, K is length
Returns:
cLN_y: [M, N, K]
"""
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return cLN_y
class GlobalLayerNorm(nn.Module):
"""Global Layer Normalization (gLN)"""
def __init__(self, channel_size):
super(GlobalLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.reset_parameters()
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, y):
"""
Args:
y: [M, N, K], M is batch size, N is channel size, K is length
Returns:
gLN_y: [M, N, K]
"""
# TODO: in torch 1.0, torch.mean() support dim list
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return gLN_y
if __name__ == "__main__":
torch.manual_seed(123)
M, N, L, T = 2, 3, 4, 12
K = 2 * T // L - 1
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
mixture = torch.randint(3, (M, T))
# test Encoder
encoder = Encoder(L, N)
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
mixture_w = encoder(mixture)
print('mixture', mixture)
print('U', encoder.conv1d_U.weight)
print('mixture_w', mixture_w)
print('mixture_w size', mixture_w.size())
# test TemporalConvNet
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
est_mask = separator(mixture_w)
print('est_mask', est_mask)
# test Decoder
decoder = Decoder(N, L)
est_mask = torch.randint(2, (B, K, C, N))
est_source = decoder(mixture_w, est_mask)
print('est_source', est_source)
# test Conv-TasNet
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
est_source = conv_tasnet(mixture)
print('est_source', est_source)
print('est_source size', est_source.size())

452
demucs/tasnet_v2.py Normal file
View File

@ -0,0 +1,452 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Created on 2018/12
# Author: Kaituo XU
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
# Here is the original license:
# The MIT License (MIT)
#
# Copyright (c) 2018 Kaituo XU
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import capture_init
EPS = 1e-8
def overlap_and_add(signal, frame_step):
outer_dimensions = signal.size()[:-2]
frames, frame_length = signal.size()[-2:]
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
subframe_step = frame_step // subframe_length
subframes_per_frame = frame_length // subframe_length
output_size = frame_step * (frames - 1) + frame_length
output_subframes = output_size // subframe_length
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
frame = torch.arange(0, output_subframes,
device=signal.device).unfold(0, subframes_per_frame, subframe_step)
frame = frame.long() # signal may in GPU or CPU
frame = frame.contiguous().view(-1)
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
result.index_add_(-2, frame, subframe_signal)
result = result.view(*outer_dimensions, -1)
return result
class ConvTasNet(nn.Module):
@capture_init
def __init__(self,
sources,
N=256,
L=20,
B=256,
H=512,
P=3,
X=8,
R=4,
audio_channels=2,
norm_type="gLN",
causal=False,
mask_nonlinear='relu',
samplerate=44100,
segment_length=44100 * 2 * 4):
"""
Args:
sources: list of sources
N: Number of filters in autoencoder
L: Length of the filters (in samples)
B: Number of channels in bottleneck 1 × 1-conv block
H: Number of channels in convolutional blocks
P: Kernel size in convolutional blocks
X: Number of convolutional blocks in each repeat
R: Number of repeats
norm_type: BN, gLN, cLN
causal: causal or non-causal
mask_nonlinear: use which non-linear function to generate mask
"""
super(ConvTasNet, self).__init__()
# Hyper-parameter
self.sources = sources
self.C = len(sources)
self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
self.norm_type = norm_type
self.causal = causal
self.mask_nonlinear = mask_nonlinear
self.audio_channels = audio_channels
self.samplerate = samplerate
self.segment_length = segment_length
# Components
self.encoder = Encoder(L, N, audio_channels)
self.separator = TemporalConvNet(
N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
self.decoder = Decoder(N, L, audio_channels)
# init
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def valid_length(self, length):
return length
def forward(self, mixture):
"""
Args:
mixture: [M, T], M is batch size, T is #samples
Returns:
est_source: [M, C, T]
"""
mixture_w = self.encoder(mixture)
est_mask = self.separator(mixture_w)
est_source = self.decoder(mixture_w, est_mask)
# T changed after conv1d in encoder, fix it here
T_origin = mixture.size(-1)
T_conv = est_source.size(-1)
est_source = F.pad(est_source, (0, T_origin - T_conv))
return est_source
class Encoder(nn.Module):
"""Estimation of the nonnegative mixture weight by a 1-D conv layer.
"""
def __init__(self, L, N, audio_channels):
super(Encoder, self).__init__()
# Hyper-parameter
self.L, self.N = L, N
# Components
# 50% overlap
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
def forward(self, mixture):
"""
Args:
mixture: [M, T], M is batch size, T is #samples
Returns:
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
"""
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
return mixture_w
class Decoder(nn.Module):
def __init__(self, N, L, audio_channels):
super(Decoder, self).__init__()
# Hyper-parameter
self.N, self.L = N, L
self.audio_channels = audio_channels
# Components
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
def forward(self, mixture_w, est_mask):
"""
Args:
mixture_w: [M, N, K]
est_mask: [M, C, N, K]
Returns:
est_source: [M, C, T]
"""
# D = W * M
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
# S = DV
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
m, c, k, _ = est_source.size()
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
return est_source
class TemporalConvNet(nn.Module):
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
"""
Args:
N: Number of filters in autoencoder
B: Number of channels in bottleneck 1 × 1-conv block
H: Number of channels in convolutional blocks
P: Kernel size in convolutional blocks
X: Number of convolutional blocks in each repeat
R: Number of repeats
C: Number of speakers
norm_type: BN, gLN, cLN
causal: causal or non-causal
mask_nonlinear: use which non-linear function to generate mask
"""
super(TemporalConvNet, self).__init__()
# Hyper-parameter
self.C = C
self.mask_nonlinear = mask_nonlinear
# Components
# [M, N, K] -> [M, N, K]
layer_norm = ChannelwiseLayerNorm(N)
# [M, N, K] -> [M, B, K]
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
# [M, B, K] -> [M, B, K]
repeats = []
for r in range(R):
blocks = []
for x in range(X):
dilation = 2**x
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
blocks += [
TemporalBlock(B,
H,
P,
stride=1,
padding=padding,
dilation=dilation,
norm_type=norm_type,
causal=causal)
]
repeats += [nn.Sequential(*blocks)]
temporal_conv_net = nn.Sequential(*repeats)
# [M, B, K] -> [M, C*N, K]
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
# Put together
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
mask_conv1x1)
def forward(self, mixture_w):
"""
Keep this API same with TasNet
Args:
mixture_w: [M, N, K], M is batch size
returns:
est_mask: [M, C, N, K]
"""
M, N, K = mixture_w.size()
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
if self.mask_nonlinear == 'softmax':
est_mask = F.softmax(score, dim=1)
elif self.mask_nonlinear == 'relu':
est_mask = F.relu(score)
else:
raise ValueError("Unsupported mask non-linear function")
return est_mask
class TemporalBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
norm_type="gLN",
causal=False):
super(TemporalBlock, self).__init__()
# [M, B, K] -> [M, H, K]
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
prelu = nn.PReLU()
norm = chose_norm(norm_type, out_channels)
# [M, H, K] -> [M, B, K]
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
dilation, norm_type, causal)
# Put together
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
def forward(self, x):
"""
Args:
x: [M, B, K]
Returns:
[M, B, K]
"""
residual = x
out = self.net(x)
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
return out + residual # look like w/o F.relu is better than w/ F.relu
# return F.relu(out + residual)
class DepthwiseSeparableConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
norm_type="gLN",
causal=False):
super(DepthwiseSeparableConv, self).__init__()
# Use `groups` option to implement depthwise convolution
# [M, H, K] -> [M, H, K]
depthwise_conv = nn.Conv1d(in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=False)
if causal:
chomp = Chomp1d(padding)
prelu = nn.PReLU()
norm = chose_norm(norm_type, in_channels)
# [M, H, K] -> [M, B, K]
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
# Put together
if causal:
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
else:
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
def forward(self, x):
"""
Args:
x: [M, H, K]
Returns:
result: [M, B, K]
"""
return self.net(x)
class Chomp1d(nn.Module):
"""To ensure the output length is the same as the input.
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
"""
Args:
x: [M, H, Kpad]
Returns:
[M, H, K]
"""
return x[:, :, :-self.chomp_size].contiguous()
def chose_norm(norm_type, channel_size):
"""The input of normlization will be (M, C, K), where M is batch size,
C is channel size and K is sequence length.
"""
if norm_type == "gLN":
return GlobalLayerNorm(channel_size)
elif norm_type == "cLN":
return ChannelwiseLayerNorm(channel_size)
elif norm_type == "id":
return nn.Identity()
else: # norm_type == "BN":
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
# along M and K, so this BN usage is right.
return nn.BatchNorm1d(channel_size)
# TODO: Use nn.LayerNorm to impl cLN to speed up
class ChannelwiseLayerNorm(nn.Module):
"""Channel-wise Layer Normalization (cLN)"""
def __init__(self, channel_size):
super(ChannelwiseLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.reset_parameters()
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, y):
"""
Args:
y: [M, N, K], M is batch size, N is channel size, K is length
Returns:
cLN_y: [M, N, K]
"""
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return cLN_y
class GlobalLayerNorm(nn.Module):
"""Global Layer Normalization (gLN)"""
def __init__(self, channel_size):
super(GlobalLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.reset_parameters()
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, y):
"""
Args:
y: [M, N, K], M is batch size, N is channel size, K is length
Returns:
gLN_y: [M, N, K]
"""
# TODO: in torch 1.0, torch.mean() support dim list
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return gLN_y
if __name__ == "__main__":
torch.manual_seed(123)
M, N, L, T = 2, 3, 4, 12
K = 2 * T // L - 1
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
mixture = torch.randint(3, (M, T))
# test Encoder
encoder = Encoder(L, N)
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
mixture_w = encoder(mixture)
print('mixture', mixture)
print('U', encoder.conv1d_U.weight)
print('mixture_w', mixture_w)
print('mixture_w size', mixture_w.size())
# test TemporalConvNet
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
est_mask = separator(mixture_w)
print('est_mask', est_mask)
# test Decoder
decoder = Decoder(N, L)
est_mask = torch.randint(2, (B, K, C, N))
est_source = decoder(mixture_w, est_mask)
print('est_source', est_source)
# test Conv-TasNet
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
est_source = conv_tasnet(mixture)
print('est_source', est_source)
print('est_source size', est_source.size())

839
demucs/transformer.py Normal file
View File

@ -0,0 +1,839 @@
# Copyright (c) 2019-present, Meta, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# First author is Simon Rouard.
import random
import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from einops import rearrange
def create_sin_embedding(
length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
):
# We aim for TBC format
assert dim % 2 == 0
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
half_dim = dim // 2
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
phase = pos / (max_period ** (adim / (half_dim - 1)))
return torch.cat(
[
torch.cos(phase),
torch.sin(phase),
],
dim=-1,
)
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model)
)
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(
torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
)
pos_w = torch.arange(0.0, width).unsqueeze(1)
pos_h = torch.arange(0.0, height).unsqueeze(1)
pe[0:d_model:2, :, :] = (
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pe[1:d_model:2, :, :] = (
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pe[d_model::2, :, :] = (
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
pe[d_model + 1:: 2, :, :] = (
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
return pe[None, :].to(device)
def create_sin_embedding_cape(
length: int,
dim: int,
batch_size: int,
mean_normalize: bool,
augment: bool, # True during training
max_global_shift: float = 0.0, # delta max
max_local_shift: float = 0.0, # epsilon max
max_scale: float = 1.0,
device: str = "cpu",
max_period: float = 10000.0,
):
# We aim for TBC format
assert dim % 2 == 0
pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
if mean_normalize:
pos -= torch.nanmean(pos, dim=0, keepdim=True)
if augment:
delta = np.random.uniform(
-max_global_shift, +max_global_shift, size=[1, batch_size, 1]
)
delta_local = np.random.uniform(
-max_local_shift, +max_local_shift, size=[length, batch_size, 1]
)
log_lambdas = np.random.uniform(
-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
)
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
pos = pos.to(device)
half_dim = dim // 2
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
phase = pos / (max_period ** (adim / (half_dim - 1)))
return torch.cat(
[
torch.cos(phase),
torch.sin(phase),
],
dim=-1,
).float()
def get_causal_mask(length):
pos = torch.arange(length)
return pos > pos[:, None]
def get_elementary_mask(
T1,
T2,
mask_type,
sparse_attn_window,
global_window,
mask_random_seed,
sparsity,
device,
):
"""
When the input of the Decoder has length T1 and the output T2
The mask matrix has shape (T2, T1)
"""
assert mask_type in ["diag", "jmask", "random", "global"]
if mask_type == "global":
mask = torch.zeros(T2, T1, dtype=torch.bool)
mask[:, :global_window] = True
line_window = int(global_window * T2 / T1)
mask[:line_window, :] = True
if mask_type == "diag":
mask = torch.zeros(T2, T1, dtype=torch.bool)
rows = torch.arange(T2)[:, None]
cols = (
(T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
.long()
.clamp(0, T1 - 1)
)
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
elif mask_type == "jmask":
mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
rows = torch.arange(T2 + 2)[:, None]
t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
t = (t * (t + 1) / 2).int()
t = torch.cat([-t.flip(0)[:-1], t])
cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
mask = mask[1:-1, 1:-1]
elif mask_type == "random":
gene = torch.Generator(device=device)
gene.manual_seed(mask_random_seed)
mask = (
torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
> sparsity
)
mask = mask.to(device)
return mask
def get_mask(
T1,
T2,
mask_type,
sparse_attn_window,
global_window,
mask_random_seed,
sparsity,
device,
):
"""
Return a SparseCSRTensor mask that is a combination of elementary masks
mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
"""
from xformers.sparse import SparseCSRTensor
# create a list
mask_types = mask_type.split("_")
all_masks = [
get_elementary_mask(
T1,
T2,
mask,
sparse_attn_window,
global_window,
mask_random_seed,
sparsity,
device,
)
for mask in mask_types
]
final_mask = torch.stack(all_masks).sum(axis=0) > 0
return SparseCSRTensor.from_dense(final_mask[None])
class ScaledEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
scale: float = 1.0,
boost: float = 3.0,
):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data *= scale / boost
self.boost = boost
@property
def weight(self):
return self.embedding.weight * self.boost
def forward(self, x):
return self.embedding(x) * self.boost
class LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonaly residual outputs close to 0 initially, then learnt.
"""
def __init__(self, channels: int, init: float = 0, channel_last=False):
"""
channel_last = False corresponds to (B, C, T) tensors
channel_last = True corresponds to (T, B, C) tensors
"""
super().__init__()
self.channel_last = channel_last
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
self.scale.data[:] = init
def forward(self, x):
if self.channel_last:
return self.scale * x
else:
return self.scale[:, None] * x
class MyGroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
"""
x: (B, T, C)
if num_groups=1: Normalisation on all T and C together for each B
"""
x = x.transpose(1, 2)
return super().forward(x).transpose(1, 2)
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation=F.relu,
group_norm=0,
norm_first=False,
norm_out=False,
layer_norm_eps=1e-5,
layer_scale=False,
init_values=1e-4,
device=None,
dtype=None,
sparse=False,
mask_type="diag",
mask_random_seed=42,
sparse_attn_window=500,
global_window=50,
auto_sparsity=False,
sparsity=0.95,
batch_first=False,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=batch_first,
norm_first=norm_first,
device=device,
dtype=dtype,
)
self.sparse = sparse
self.auto_sparsity = auto_sparsity
if sparse:
if not auto_sparsity:
self.mask_type = mask_type
self.sparse_attn_window = sparse_attn_window
self.global_window = global_window
self.sparsity = sparsity
if group_norm:
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm_out = None
if self.norm_first & norm_out:
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
self.gamma_1 = (
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
)
self.gamma_2 = (
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
)
if sparse:
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first,
auto_sparsity=sparsity if auto_sparsity else 0,
)
self.__setattr__("src_mask", torch.zeros(1, 1))
self.mask_random_seed = mask_random_seed
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
if batch_first = False, src shape is (T, B, C)
the case where batch_first=True is not covered
"""
device = src.device
x = src
T, B, C = x.shape
if self.sparse and not self.auto_sparsity:
assert src_mask is None
src_mask = self.src_mask
if src_mask.shape[-1] != T:
src_mask = get_mask(
T,
T,
self.mask_type,
self.sparse_attn_window,
self.global_window,
self.mask_random_seed,
self.sparsity,
device,
)
self.__setattr__("src_mask", src_mask)
if self.norm_first:
x = x + self.gamma_1(
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
)
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
if self.norm_out:
x = self.norm_out(x)
else:
x = self.norm1(
x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
)
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
return x
class CrossTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation=F.relu,
layer_norm_eps: float = 1e-5,
layer_scale: bool = False,
init_values: float = 1e-4,
norm_first: bool = False,
group_norm: bool = False,
norm_out: bool = False,
sparse=False,
mask_type="diag",
mask_random_seed=42,
sparse_attn_window=500,
global_window=50,
sparsity=0.95,
auto_sparsity=None,
device=None,
dtype=None,
batch_first=False,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.sparse = sparse
self.auto_sparsity = auto_sparsity
if sparse:
if not auto_sparsity:
self.mask_type = mask_type
self.sparse_attn_window = sparse_attn_window
self.global_window = global_window
self.sparsity = sparsity
self.cross_attn: nn.Module
self.cross_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.norm1: nn.Module
self.norm2: nn.Module
self.norm3: nn.Module
if group_norm:
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
else:
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm_out = None
if self.norm_first & norm_out:
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
self.gamma_1 = (
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
)
self.gamma_2 = (
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = self._get_activation_fn(activation)
else:
self.activation = activation
if sparse:
self.cross_attn = MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first,
auto_sparsity=sparsity if auto_sparsity else 0)
if not auto_sparsity:
self.__setattr__("mask", torch.zeros(1, 1))
self.mask_random_seed = mask_random_seed
def forward(self, q, k, mask=None):
"""
Args:
q: tensor of shape (T, B, C)
k: tensor of shape (S, B, C)
mask: tensor of shape (T, S)
"""
device = q.device
T, B, C = q.shape
S, B, C = k.shape
if self.sparse and not self.auto_sparsity:
assert mask is None
mask = self.mask
if mask.shape[-1] != S or mask.shape[-2] != T:
mask = get_mask(
S,
T,
self.mask_type,
self.sparse_attn_window,
self.global_window,
self.mask_random_seed,
self.sparsity,
device,
)
self.__setattr__("mask", mask)
if self.norm_first:
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
if self.norm_out:
x = self.norm_out(x)
else:
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
return x
# self-attention block
def _ca_block(self, q, k, attn_mask=None):
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x):
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
def _get_activation_fn(self, activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
# ----------------- MULTI-BLOCKS MODELS: -----------------------
class CrossTransformerEncoder(nn.Module):
def __init__(
self,
dim: int,
emb: str = "sin",
hidden_scale: float = 4.0,
num_heads: int = 8,
num_layers: int = 6,
cross_first: bool = False,
dropout: float = 0.0,
max_positions: int = 1000,
norm_in: bool = True,
norm_in_group: bool = False,
group_norm: int = False,
norm_first: bool = False,
norm_out: bool = False,
max_period: float = 10000.0,
weight_decay: float = 0.0,
lr: tp.Optional[float] = None,
layer_scale: bool = False,
gelu: bool = True,
sin_random_shift: int = 0,
weight_pos_embed: float = 1.0,
cape_mean_normalize: bool = True,
cape_augment: bool = True,
cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
sparse_self_attn: bool = False,
sparse_cross_attn: bool = False,
mask_type: str = "diag",
mask_random_seed: int = 42,
sparse_attn_window: int = 500,
global_window: int = 50,
auto_sparsity: bool = False,
sparsity: float = 0.95,
):
super().__init__()
"""
"""
assert dim % num_heads == 0
hidden_dim = int(dim * hidden_scale)
self.num_layers = num_layers
# classic parity = 1 means that if idx%2 == 1 there is a
# classical encoder else there is a cross encoder
self.classic_parity = 1 if cross_first else 0
self.emb = emb
self.max_period = max_period
self.weight_decay = weight_decay
self.weight_pos_embed = weight_pos_embed
self.sin_random_shift = sin_random_shift
if emb == "cape":
self.cape_mean_normalize = cape_mean_normalize
self.cape_augment = cape_augment
self.cape_glob_loc_scale = cape_glob_loc_scale
if emb == "scaled":
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
self.lr = lr
activation: tp.Any = F.gelu if gelu else F.relu
self.norm_in: nn.Module
self.norm_in_t: nn.Module
if norm_in:
self.norm_in = nn.LayerNorm(dim)
self.norm_in_t = nn.LayerNorm(dim)
elif norm_in_group:
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
else:
self.norm_in = nn.Identity()
self.norm_in_t = nn.Identity()
# spectrogram layers
self.layers = nn.ModuleList()
# temporal layers
self.layers_t = nn.ModuleList()
kwargs_common = {
"d_model": dim,
"nhead": num_heads,
"dim_feedforward": hidden_dim,
"dropout": dropout,
"activation": activation,
"group_norm": group_norm,
"norm_first": norm_first,
"norm_out": norm_out,
"layer_scale": layer_scale,
"mask_type": mask_type,
"mask_random_seed": mask_random_seed,
"sparse_attn_window": sparse_attn_window,
"global_window": global_window,
"sparsity": sparsity,
"auto_sparsity": auto_sparsity,
"batch_first": True,
}
kwargs_classic_encoder = dict(kwargs_common)
kwargs_classic_encoder.update({
"sparse": sparse_self_attn,
})
kwargs_cross_encoder = dict(kwargs_common)
kwargs_cross_encoder.update({
"sparse": sparse_cross_attn,
})
for idx in range(num_layers):
if idx % 2 == self.classic_parity:
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
self.layers_t.append(
MyTransformerEncoderLayer(**kwargs_classic_encoder)
)
else:
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
self.layers_t.append(
CrossTransformerEncoderLayer(**kwargs_cross_encoder)
)
def forward(self, x, xt):
B, C, Fr, T1 = x.shape
pos_emb_2d = create_2d_sin_embedding(
C, Fr, T1, x.device, self.max_period
) # (1, C, Fr, T1)
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
x = self.norm_in(x)
x = x + self.weight_pos_embed * pos_emb_2d
B, C, T2 = xt.shape
xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
pos_emb = self._get_pos_embedding(T2, B, C, x.device)
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
xt = self.norm_in_t(xt)
xt = xt + self.weight_pos_embed * pos_emb
for idx in range(self.num_layers):
if idx % 2 == self.classic_parity:
x = self.layers[idx](x)
xt = self.layers_t[idx](xt)
else:
old_x = x
x = self.layers[idx](x, xt)
xt = self.layers_t[idx](xt, old_x)
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
xt = rearrange(xt, "b t2 c -> b c t2")
return x, xt
def _get_pos_embedding(self, T, B, C, device):
if self.emb == "sin":
shift = random.randrange(self.sin_random_shift + 1)
pos_emb = create_sin_embedding(
T, C, shift=shift, device=device, max_period=self.max_period
)
elif self.emb == "cape":
if self.training:
pos_emb = create_sin_embedding_cape(
T,
C,
B,
device=device,
max_period=self.max_period,
mean_normalize=self.cape_mean_normalize,
augment=self.cape_augment,
max_global_shift=self.cape_glob_loc_scale[0],
max_local_shift=self.cape_glob_loc_scale[1],
max_scale=self.cape_glob_loc_scale[2],
)
else:
pos_emb = create_sin_embedding_cape(
T,
C,
B,
device=device,
max_period=self.max_period,
mean_normalize=self.cape_mean_normalize,
augment=False,
)
elif self.emb == "scaled":
pos = torch.arange(T, device=device)
pos_emb = self.position_embeddings(pos)[:, None]
return pos_emb
def make_optim_group(self):
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
if self.lr is not None:
group["lr"] = self.lr
return group
# Attention Modules
class MultiheadAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
auto_sparsity=None,
):
super().__init__()
assert auto_sparsity is not None, "sanity check"
self.num_heads = num_heads
self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
self.attn_drop = torch.nn.Dropout(dropout)
self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
self.proj_drop = torch.nn.Dropout(dropout)
self.batch_first = batch_first
self.auto_sparsity = auto_sparsity
def forward(
self,
query,
key,
value,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
average_attn_weights=True,
):
if not self.batch_first: # N, B, C
query = query.permute(1, 0, 2) # B, N_q, C
key = key.permute(1, 0, 2) # B, N_k, C
value = value.permute(1, 0, 2) # B, N_k, C
B, N_q, C = query.shape
B, N_k, C = key.shape
q = (
self.q(query)
.reshape(B, N_q, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
q = q.flatten(0, 1)
k = (
self.k(key)
.reshape(B, N_k, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = k.flatten(0, 1)
v = (
self.v(value)
.reshape(B, N_k, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v = v.flatten(0, 1)
if self.auto_sparsity:
assert attn_mask is None
x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
else:
x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
x = x.transpose(1, 2).reshape(B, N_q, C)
x = self.proj(x)
x = self.proj_drop(x)
if not self.batch_first:
x = x.permute(1, 0, 2)
return x, None
def scaled_query_key_softmax(q, k, att_mask):
from xformers.ops import masked_matmul
q = q / (k.size(-1)) ** 0.5
att = masked_matmul(q, k.transpose(-2, -1), att_mask)
att = torch.nn.functional.softmax(att, -1)
return att
def scaled_dot_product_attention(q, k, v, att_mask, dropout):
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
att = dropout(att)
y = att @ v
return y
def _compute_buckets(x, R):
qq = torch.einsum('btf,bfhi->bhti', x, R)
qq = torch.cat([qq, -qq], dim=-1)
buckets = qq.argmax(dim=-1)
return buckets.permute(0, 2, 1).byte().contiguous()
def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
# assert False, "The code for the custom sparse kernel is not ready for release yet."
from xformers.ops import find_locations, sparse_memory_efficient_attention
n_hashes = 32
proj_size = 4
query, key, value = [x.contiguous() for x in [query, key, value]]
with torch.no_grad():
R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
bucket_query = _compute_buckets(query, R)
bucket_key = _compute_buckets(key, R)
row_offsets, column_indices = find_locations(
bucket_query, bucket_key, sparsity, infer_sparsity)
return sparse_memory_efficient_attention(
query, key, value, row_offsets, column_indices, attn_bias)

502
demucs/utils.py Normal file
View File

@ -0,0 +1,502 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from contextlib import contextmanager
import math
import os
import tempfile
import typing as tp
import errno
import functools
import hashlib
import inspect
import io
import os
import random
import socket
import tempfile
import warnings
import zlib
import tkinter as tk
from diffq import UniformQuantizer, DiffQuantizer
import torch as th
import tqdm
from torch import distributed
from torch.nn import functional as F
import torch
def unfold(a, kernel_size, stride):
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
"""
*shape, length = a.shape
n_frames = math.ceil(length / stride)
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(a, (0, tgt_length - length))
strides = list(a.stride())
assert strides[-1] == 1, 'data should be contiguous'
strides = strides[:-1] + [stride, 1]
return a.as_strided([*shape, n_frames, kernel_size], strides)
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
"""
Center trim `tensor` with respect to `reference`, along the last dimension.
`reference` can also be a number, representing the length to trim to.
If the size difference != 0 mod 2, the extra sample is removed on the right side.
"""
ref_size: int
if isinstance(reference, torch.Tensor):
ref_size = reference.size(-1)
else:
ref_size = reference
delta = tensor.size(-1) - ref_size
if delta < 0:
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
if delta:
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
return tensor
def pull_metric(history: tp.List[dict], name: str):
out = []
for metrics in history:
metric = metrics
for part in name.split("."):
metric = metric[part]
out.append(metric)
return out
def EMA(beta: float = 1):
"""
Exponential Moving Average callback.
Returns a single function that can be called to repeatidly update the EMA
with a dict of metrics. The callback will return
the new averaged dict of metrics.
Note that for `beta=1`, this is just plain averaging.
"""
fix: tp.Dict[str, float] = defaultdict(float)
total: tp.Dict[str, float] = defaultdict(float)
def _update(metrics: dict, weight: float = 1) -> dict:
nonlocal total, fix
for key, value in metrics.items():
total[key] = total[key] * beta + weight * float(value)
fix[key] = fix[key] * beta + weight
return {key: tot / fix[key] for key, tot in total.items()}
return _update
def sizeof_fmt(num: float, suffix: str = 'B'):
"""
Given `num` bytes, return human readable size.
Taken from https://stackoverflow.com/a/1094933
"""
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
if abs(num) < 1024.0:
return "%3.1f%s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f%s%s" % (num, 'Yi', suffix)
@contextmanager
def temp_filenames(count: int, delete=True):
names = []
try:
for _ in range(count):
names.append(tempfile.NamedTemporaryFile(delete=False).name)
yield names
finally:
if delete:
for name in names:
os.unlink(name)
def average_metric(metric, count=1.):
"""
Average `metric` which should be a float across all hosts. `count` should be
the weight for this particular host (i.e. number of examples).
"""
metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda')
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
return metric[1].item() / metric[0].item()
def free_port(host='', low=20000, high=40000):
"""
Return a port number that is most likely free.
This could suffer from a race condition although
it should be quite rare.
"""
sock = socket.socket()
while True:
port = random.randint(low, high)
try:
sock.bind((host, port))
except OSError as error:
if error.errno == errno.EADDRINUSE:
continue
raise
return port
def sizeof_fmt(num, suffix='B'):
"""
Given `num` bytes, return human readable size.
Taken from https://stackoverflow.com/a/1094933
"""
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
if abs(num) < 1024.0:
return "%3.1f%s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f%s%s" % (num, 'Yi', suffix)
def human_seconds(seconds, display='.2f'):
"""
Given `seconds` seconds, return human readable duration.
"""
value = seconds * 1e6
ratios = [1e3, 1e3, 60, 60, 24]
names = ['us', 'ms', 's', 'min', 'hrs', 'days']
last = names.pop(0)
for name, ratio in zip(names, ratios):
if value / ratio < 0.3:
break
value /= ratio
last = name
return f"{format(value, display)} {last}"
class TensorChunk:
def __init__(self, tensor, offset=0, length=None):
total_length = tensor.shape[-1]
assert offset >= 0
assert offset < total_length
if length is None:
length = total_length - offset
else:
length = min(total_length - offset, length)
self.tensor = tensor
self.offset = offset
self.length = length
self.device = tensor.device
@property
def shape(self):
shape = list(self.tensor.shape)
shape[-1] = self.length
return shape
def padded(self, target_length):
delta = target_length - self.length
total_length = self.tensor.shape[-1]
assert delta >= 0
start = self.offset - delta // 2
end = start + target_length
correct_start = max(0, start)
correct_end = min(total_length, end)
pad_left = correct_start - start
pad_right = end - correct_end
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
assert out.shape[-1] == target_length
return out
def tensor_chunk(tensor_or_chunk):
if isinstance(tensor_or_chunk, TensorChunk):
return tensor_or_chunk
else:
assert isinstance(tensor_or_chunk, th.Tensor)
return TensorChunk(tensor_or_chunk)
def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None):
"""
Apply model to a given mixture.
Args:
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
and apply the oppositve shift to the output. This is repeated `shifts` time and
all predictions are averaged. This effectively makes the model time equivariant
and improves SDR by up to 0.2 points.
split (bool): if True, the input will be broken down in 8 seconds extracts
and predictions will be performed individually on each and concatenated.
Useful for model with large memory footprint like Tasnet.
progress (bool): if True, show a progress bar (requires split=True)
"""
channels, length = mix.size()
device = mix.device
progress_value = 0
if split:
out = th.zeros(4, channels, length, device=device)
shift = model.samplerate * 10
offsets = range(0, length, shift)
scale = 10
if progress:
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
for offset in offsets:
chunk = mix[..., offset:offset + shift]
if set_progress_bar:
progress_value += 1
set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
else:
chunk_out = apply_model_v1(model, chunk, shifts=shifts)
out[..., offset:offset + shift] = chunk_out
offset += shift
return out
elif shifts:
max_shift = int(model.samplerate / 2)
mix = F.pad(mix, (max_shift, max_shift))
offsets = list(range(max_shift))
random.shuffle(offsets)
out = 0
for offset in offsets[:shifts]:
shifted = mix[..., offset:offset + length + max_shift]
if set_progress_bar:
shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar)
else:
shifted_out = apply_model_v1(model, shifted)
out += shifted_out[..., max_shift - offset:max_shift - offset + length]
out /= shifts
return out
else:
valid_length = model.valid_length(length)
delta = valid_length - length
padded = F.pad(mix, (delta // 2, delta - delta // 2))
with th.no_grad():
out = model(padded.unsqueeze(0))[0]
return center_trim(out, mix)
def apply_model_v2(model, mix, shifts=None, split=False,
overlap=0.25, transition_power=1., progress=False, set_progress_bar=None):
"""
Apply model to a given mixture.
Args:
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
and apply the oppositve shift to the output. This is repeated `shifts` time and
all predictions are averaged. This effectively makes the model time equivariant
and improves SDR by up to 0.2 points.
split (bool): if True, the input will be broken down in 8 seconds extracts
and predictions will be performed individually on each and concatenated.
Useful for model with large memory footprint like Tasnet.
progress (bool): if True, show a progress bar (requires split=True)
"""
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
device = mix.device
channels, length = mix.shape
progress_value = 0
if split:
out = th.zeros(len(model.sources), channels, length, device=device)
sum_weight = th.zeros(length, device=device)
segment = model.segment_length
stride = int((1 - overlap) * segment)
offsets = range(0, length, stride)
scale = stride / model.samplerate
if progress:
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
# We start from a triangle shaped weight, with maximal weight in the middle
# of the segment. Then we normalize and take to the power `transition_power`.
# Large values of transition power will lead to sharper transitions.
weight = th.cat([th.arange(1, segment // 2 + 1),
th.arange(segment - segment // 2, 0, -1)]).to(device)
assert len(weight) == segment
# If the overlap < 50%, this will translate to linear transition when
# transition_power is 1.
weight = (weight / weight.max())**transition_power
for offset in offsets:
chunk = TensorChunk(mix, offset, segment)
if set_progress_bar:
progress_value += 1
set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
else:
chunk_out = apply_model_v2(model, chunk, shifts=shifts)
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out
sum_weight[offset:offset + segment] += weight[:chunk_length]
offset += segment
assert sum_weight.min() > 0
out /= sum_weight
return out
elif shifts:
max_shift = int(0.5 * model.samplerate)
mix = tensor_chunk(mix)
padded_mix = mix.padded(length + 2 * max_shift)
out = 0
for _ in range(shifts):
offset = random.randint(0, max_shift)
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
if set_progress_bar:
progress_value += 1
shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar)
else:
shifted_out = apply_model_v2(model, shifted)
out += shifted_out[..., max_shift - offset:]
out /= shifts
return out
else:
valid_length = model.valid_length(length)
mix = tensor_chunk(mix)
padded_mix = mix.padded(valid_length)
with th.no_grad():
out = model(padded_mix.unsqueeze(0))[0]
return center_trim(out, length)
@contextmanager
def temp_filenames(count, delete=True):
names = []
try:
for _ in range(count):
names.append(tempfile.NamedTemporaryFile(delete=False).name)
yield names
finally:
if delete:
for name in names:
os.unlink(name)
def get_quantizer(model, args, optimizer=None):
quantizer = None
if args.diffq:
quantizer = DiffQuantizer(
model, min_size=args.q_min_size, group_size=8)
if optimizer is not None:
quantizer.setup_optimizer(optimizer)
elif args.qat:
quantizer = UniformQuantizer(
model, bits=args.qat, min_size=args.q_min_size)
return quantizer
def load_model(path, strict=False):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
load_from = path
package = th.load(load_from, 'cpu')
klass = package["klass"]
args = package["args"]
kwargs = package["kwargs"]
if strict:
model = klass(*args, **kwargs)
else:
sig = inspect.signature(klass)
for key in list(kwargs):
if key not in sig.parameters:
warnings.warn("Dropping inexistant parameter " + key)
del kwargs[key]
model = klass(*args, **kwargs)
state = package["state"]
training_args = package["training_args"]
quantizer = get_quantizer(model, training_args)
set_state(model, quantizer, state)
return model
def get_state(model, quantizer):
if quantizer is None:
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
else:
state = quantizer.get_quantized_state()
buf = io.BytesIO()
th.save(state, buf)
state = {'compressed': zlib.compress(buf.getvalue())}
return state
def set_state(model, quantizer, state):
if quantizer is None:
model.load_state_dict(state)
else:
buf = io.BytesIO(zlib.decompress(state["compressed"]))
state = th.load(buf, "cpu")
quantizer.restore_quantized_state(state)
return state
def save_state(state, path):
buf = io.BytesIO()
th.save(state, buf)
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
path = path.parent / (path.stem + "-" + sig + path.suffix)
path.write_bytes(buf.getvalue())
def save_model(model, quantizer, training_args, path):
args, kwargs = model._init_args_kwargs
klass = model.__class__
state = get_state(model, quantizer)
save_to = path
package = {
'klass': klass,
'args': args,
'kwargs': kwargs,
'state': state,
'training_args': training_args,
}
th.save(package, save_to)
def capture_init(init):
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__
class DummyPoolExecutor:
class DummyResult:
def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
def result(self):
return self.func(*self.args, **self.kwargs)
def __init__(self, workers=0):
pass
def submit(self, func, *args, **kwargs):
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
return

736
lib_v5/spec_utils.py Normal file
View File

@ -0,0 +1,736 @@
import librosa
import numpy as np
import soundfile as sf
import math
import random
import pyrubberband
import math
#import noisereduce as nr
MAX_SPEC = 'Max Spec'
MIN_SPEC = 'Min Spec'
AVERAGE = 'Average'
def crop_center(h1, h2):
h1_shape = h1.size()
h2_shape = h2.size()
if h1_shape[3] == h2_shape[3]:
return h1
elif h1_shape[3] < h2_shape[3]:
raise ValueError('h1_shape[3] must be greater than h2_shape[3]')
# s_freq = (h2_shape[2] - h1_shape[2]) // 2
# e_freq = s_freq + h1_shape[2]
s_time = (h1_shape[3] - h2_shape[3]) // 2
e_time = s_time + h2_shape[3]
h1 = h1[:, :, :, s_time:e_time]
return h1
def preprocess(X_spec):
X_mag = np.abs(X_spec)
X_phase = np.angle(X_spec)
return X_mag, X_phase
def make_padding(width, cropsize, offset):
left = offset
roi_size = cropsize - offset * 2
if roi_size == 0:
roi_size = cropsize
right = roi_size - (width % roi_size) + left
return left, right, roi_size
def wave_to_spectrogram(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False):
if reverse:
wave_left = np.flip(np.asfortranarray(wave[0]))
wave_right = np.flip(np.asfortranarray(wave[1]))
elif mid_side:
wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
elif mid_side_b2:
wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5))
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5))
else:
wave_left = np.asfortranarray(wave[0])
wave_right = np.asfortranarray(wave[1])
spec_left = librosa.stft(wave_left, n_fft, hop_length=hop_length)
spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length)
spec = np.asfortranarray([spec_left, spec_right])
return spec
def wave_to_spectrogram_mt(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False):
import threading
if reverse:
wave_left = np.flip(np.asfortranarray(wave[0]))
wave_right = np.flip(np.asfortranarray(wave[1]))
elif mid_side:
wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
elif mid_side_b2:
wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5))
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5))
else:
wave_left = np.asfortranarray(wave[0])
wave_right = np.asfortranarray(wave[1])
def run_thread(**kwargs):
global spec_left
spec_left = librosa.stft(**kwargs)
thread = threading.Thread(target=run_thread, kwargs={'y': wave_left, 'n_fft': n_fft, 'hop_length': hop_length})
thread.start()
spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length)
thread.join()
spec = np.asfortranarray([spec_left, spec_right])
return spec
def normalize(wave, is_normalize=False):
"""Save output music files"""
maxv = np.abs(wave).max()
if maxv > 1.0:
print(f"\nNormalization Set {is_normalize}: Input above threshold for clipping. Max:{maxv}")
if is_normalize:
print(f"The result was normalized.")
wave /= maxv
else:
print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}")
return wave
def normalize_two_stem(wave, mix, is_normalize=False):
"""Save output music files"""
maxv = np.abs(wave).max()
max_mix = np.abs(mix).max()
if maxv > 1.0:
print(f"\nNormalization Set {is_normalize}: Primary source above threshold for clipping. The result was normalized. Max:{maxv}")
print(f"\nNormalization Set {is_normalize}: Mixture above threshold for clipping. The result was normalized. Max:{max_mix}")
if is_normalize:
wave /= maxv
mix /= maxv
else:
print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}")
print(f"\nNormalization Set {is_normalize}: Primary source - Max:{np.abs(wave).max()}")
print(f"\nNormalization Set {is_normalize}: Mixture - Max:{np.abs(mix).max()}")
return wave, mix
def combine_spectrograms(specs, mp):
l = min([specs[i].shape[2] for i in specs])
spec_c = np.zeros(shape=(2, mp.param['bins'] + 1, l), dtype=np.complex64)
offset = 0
bands_n = len(mp.param['band'])
for d in range(1, bands_n + 1):
h = mp.param['band'][d]['crop_stop'] - mp.param['band'][d]['crop_start']
spec_c[:, offset:offset+h, :l] = specs[d][:, mp.param['band'][d]['crop_start']:mp.param['band'][d]['crop_stop'], :l]
offset += h
if offset > mp.param['bins']:
raise ValueError('Too much bins')
# lowpass fiter
if mp.param['pre_filter_start'] > 0: # and mp.param['band'][bands_n]['res_type'] in ['scipy', 'polyphase']:
if bands_n == 1:
spec_c = fft_lp_filter(spec_c, mp.param['pre_filter_start'], mp.param['pre_filter_stop'])
else:
gp = 1
for b in range(mp.param['pre_filter_start'] + 1, mp.param['pre_filter_stop']):
g = math.pow(10, -(b - mp.param['pre_filter_start']) * (3.5 - gp) / 20.0)
gp = g
spec_c[:, b, :] *= g
return np.asfortranarray(spec_c)
def spectrogram_to_image(spec, mode='magnitude'):
if mode == 'magnitude':
if np.iscomplexobj(spec):
y = np.abs(spec)
else:
y = spec
y = np.log10(y ** 2 + 1e-8)
elif mode == 'phase':
if np.iscomplexobj(spec):
y = np.angle(spec)
else:
y = spec
y -= y.min()
y *= 255 / y.max()
img = np.uint8(y)
if y.ndim == 3:
img = img.transpose(1, 2, 0)
img = np.concatenate([
np.max(img, axis=2, keepdims=True), img
], axis=2)
return img
def reduce_vocal_aggressively(X, y, softmask):
v = X - y
y_mag_tmp = np.abs(y)
v_mag_tmp = np.abs(v)
v_mask = v_mag_tmp > y_mag_tmp
y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf)
return y_mag * np.exp(1.j * np.angle(y))
def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32):
if min_range < fade_size * 2:
raise ValueError('min_range must be >= fade_size * 2')
idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
artifact_idx = np.where(end_idx - start_idx > min_range)[0]
weight = np.zeros_like(y_mask)
if len(artifact_idx) > 0:
start_idx = start_idx[artifact_idx]
end_idx = end_idx[artifact_idx]
old_e = None
for s, e in zip(start_idx, end_idx):
if old_e is not None and s - old_e < fade_size:
s = old_e - fade_size * 2
if s != 0:
weight[:, :, s:s + fade_size] = np.linspace(0, 1, fade_size)
else:
s -= fade_size
if e != y_mask.shape[2]:
weight[:, :, e - fade_size:e] = np.linspace(1, 0, fade_size)
else:
e += fade_size
weight[:, :, s + fade_size:e - fade_size] = 1
old_e = e
v_mask = 1 - y_mask
y_mask += weight * v_mask
return y_mask
def mask_silence(mag, ref, thres=0.2, min_range=64, fade_size=32):
if min_range < fade_size * 2:
raise ValueError('min_range must be >= fade_area * 2')
mag = mag.copy()
idx = np.where(ref.mean(axis=(0, 1)) < thres)[0]
starts = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
ends = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
uninformative = np.where(ends - starts > min_range)[0]
if len(uninformative) > 0:
starts = starts[uninformative]
ends = ends[uninformative]
old_e = None
for s, e in zip(starts, ends):
if old_e is not None and s - old_e < fade_size:
s = old_e - fade_size * 2
if s != 0:
weight = np.linspace(0, 1, fade_size)
mag[:, :, s:s + fade_size] += weight * ref[:, :, s:s + fade_size]
else:
s -= fade_size
if e != mag.shape[2]:
weight = np.linspace(1, 0, fade_size)
mag[:, :, e - fade_size:e] += weight * ref[:, :, e - fade_size:e]
else:
e += fade_size
mag[:, :, s + fade_size:e - fade_size] += ref[:, :, s + fade_size:e - fade_size]
old_e = e
return mag
def align_wave_head_and_tail(a, b):
l = min([a[0].size, b[0].size])
return a[:l,:l], b[:l,:l]
def spectrogram_to_wave(spec, hop_length, mid_side, mid_side_b2, reverse, clamp=False):
spec_left = np.asfortranarray(spec[0])
spec_right = np.asfortranarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hop_length)
wave_right = librosa.istft(spec_right, hop_length=hop_length)
if reverse:
return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
elif mid_side:
return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
elif mid_side_b2:
return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)])
else:
return np.asfortranarray([wave_left, wave_right])
def spectrogram_to_wave_mt(spec, hop_length, mid_side, reverse, mid_side_b2):
import threading
spec_left = np.asfortranarray(spec[0])
spec_right = np.asfortranarray(spec[1])
def run_thread(**kwargs):
global wave_left
wave_left = librosa.istft(**kwargs)
thread = threading.Thread(target=run_thread, kwargs={'stft_matrix': spec_left, 'hop_length': hop_length})
thread.start()
wave_right = librosa.istft(spec_right, hop_length=hop_length)
thread.join()
if reverse:
return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
elif mid_side:
return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
elif mid_side_b2:
return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)])
else:
return np.asfortranarray([wave_left, wave_right])
def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None):
bands_n = len(mp.param['band'])
offset = 0
for d in range(1, bands_n + 1):
bp = mp.param['band'][d]
spec_s = np.ndarray(shape=(2, bp['n_fft'] // 2 + 1, spec_m.shape[2]), dtype=complex)
h = bp['crop_stop'] - bp['crop_start']
spec_s[:, bp['crop_start']:bp['crop_stop'], :] = spec_m[:, offset:offset+h, :]
offset += h
if d == bands_n: # higher
if extra_bins_h: # if --high_end_process bypass
max_bin = bp['n_fft'] // 2
spec_s[:, max_bin-extra_bins_h:max_bin, :] = extra_bins[:, :extra_bins_h, :]
if bp['hpf_start'] > 0:
spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1)
if bands_n == 1:
wave = spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse'])
else:
wave = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']))
else:
sr = mp.param['band'][d+1]['sr']
if d == 1: # lower
spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop'])
wave = librosa.resample(spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']), bp['sr'], sr, res_type="sinc_fastest")
else: # mid
spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1)
spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop'])
wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']))
wave = librosa.resample(wave2, bp['sr'], sr, res_type="sinc_fastest")
return wave
def fft_lp_filter(spec, bin_start, bin_stop):
g = 1.0
for b in range(bin_start, bin_stop):
g -= 1 / (bin_stop - bin_start)
spec[:, b, :] = g * spec[:, b, :]
spec[:, bin_stop:, :] *= 0
return spec
def fft_hp_filter(spec, bin_start, bin_stop):
g = 1.0
for b in range(bin_start, bin_stop, -1):
g -= 1 / (bin_start - bin_stop)
spec[:, b, :] = g * spec[:, b, :]
spec[:, 0:bin_stop+1, :] *= 0
return spec
def mirroring(a, spec_m, input_high_end, mp):
if 'mirroring' == a:
mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1)
mirror = mirror * np.exp(1.j * np.angle(input_high_end))
return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
if 'mirroring2' == a:
mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1)
mi = np.multiply(mirror, input_high_end * 1.7)
return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
def adjust_aggr(mask, is_vocal_model, aggressiveness):
aggr = aggressiveness.get('value', 0.0) * 4
if aggr != 0:
if is_vocal_model:
aggr = 1 - aggr
aggr = [aggr, aggr]
if aggressiveness['aggr_correction'] is not None:
aggr[0] += aggressiveness['aggr_correction']['left']
aggr[1] += aggressiveness['aggr_correction']['right']
for ch in range(2):
mask[ch, :aggressiveness['split_bin']] = np.power(mask[ch, :aggressiveness['split_bin']], 1 + aggr[ch] / 3)
mask[ch, aggressiveness['split_bin']:] = np.power(mask[ch, aggressiveness['split_bin']:], 1 + aggr[ch])
return mask
def stft(wave, nfft, hl):
wave_left = np.asfortranarray(wave[0])
wave_right = np.asfortranarray(wave[1])
spec_left = librosa.stft(wave_left, nfft, hop_length=hl)
spec_right = librosa.stft(wave_right, nfft, hop_length=hl)
spec = np.asfortranarray([spec_left, spec_right])
return spec
def istft(spec, hl):
spec_left = np.asfortranarray(spec[0])
spec_right = np.asfortranarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hl)
wave_right = librosa.istft(spec_right, hop_length=hl)
wave = np.asfortranarray([wave_left, wave_right])
return wave
def spec_effects(wave, algorithm='Default', value=None):
spec = [stft(wave[0],2048,1024), stft(wave[1],2048,1024)]
if algorithm == 'Min_Mag':
v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
wave = istft(v_spec_m,1024)
elif algorithm == 'Max_Mag':
v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
wave = istft(v_spec_m,1024)
elif algorithm == 'Default':
wave = (wave[1] * value) + (wave[0] * (1-value))
elif algorithm == 'Invert_p':
X_mag = np.abs(spec[0])
y_mag = np.abs(spec[1])
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
v_spec = spec[1] - max_mag * np.exp(1.j * np.angle(spec[0]))
wave = istft(v_spec,1024)
return wave
def spectrogram_to_wave_bare(spec, hop_length=1024):
spec_left = np.asfortranarray(spec[0])
spec_right = np.asfortranarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hop_length)
wave_right = librosa.istft(spec_right, hop_length=hop_length)
wave = np.asfortranarray([wave_left, wave_right])
return wave
def spectrogram_to_wave_no_mp(spec, hop_length=1024):
if spec.ndim == 2:
wave = librosa.istft(spec, hop_length=hop_length)
elif spec.ndim == 3:
spec_left = np.asfortranarray(spec[0])
spec_right = np.asfortranarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hop_length)
wave_right = librosa.istft(spec_right, hop_length=hop_length)
wave = np.asfortranarray([wave_left, wave_right])
return wave
def wave_to_spectrogram_no_mp(wave):
wave_left = np.asfortranarray(wave[0])
wave_right = np.asfortranarray(wave[1])
spec_left = librosa.stft(wave_left, n_fft=2048, hop_length=1024)
spec_right = librosa.stft(wave_right, n_fft=2048, hop_length=1024)
spec = np.asfortranarray([spec_left, spec_right])
return spec
# def noise_reduction(audio_file):
# noise_pro = 'noise_pro.wav'
# wav, sr = librosa.load(audio_file, sr=44100, mono=False)
# wav_noise, noise_rate = librosa.load(noise_pro, sr=44100, mono=False)
# if wav.ndim == 1:
# wav = np.asfortranarray([wav,wav])
# wav_1 = nr.reduce_noise(audio_clip=wav[0], noise_clip=wav_noise, verbose=True)
# wav_2 = nr.reduce_noise(audio_clip=wav[1], noise_clip=wav_noise, verbose=True)
# if wav_1.shape > wav_2.shape:
# wav_2 = to_shape(wav_2, wav_1.shape)
# if wav_1.shape < wav_2.shape:
# wav_1 = to_shape(wav_1, wav_2.shape)
# #print('wav_1.shape: ', wav_1.shape)
# wav_mix = np.asfortranarray([wav_1, wav_2])
# return wav_mix, sr
def invert_audio(specs, invert_p=True):
ln = min([specs[0].shape[2], specs[1].shape[2]])
specs[0] = specs[0][:,:,:ln]
specs[1] = specs[1][:,:,:ln]
if invert_p:
X_mag = np.abs(specs[0])
y_mag = np.abs(specs[1])
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
v_spec = specs[1] - max_mag * np.exp(1.j * np.angle(specs[0]))
else:
specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
v_spec = specs[0] - specs[1]
return v_spec
def invert_stem(mixture, stem):
mixture = wave_to_spectrogram_no_mp(mixture)
stem = wave_to_spectrogram_no_mp(stem)
output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem]))
return -output.T
def ensembling(a, specs):
for i in range(1, len(specs)):
if i == 1:
spec = specs[0]
ln = min([spec.shape[2], specs[i].shape[2]])
spec = spec[:,:,:ln]
specs[i] = specs[i][:,:,:ln]
#print('spec: ', a)
if MIN_SPEC == a:
spec = np.where(np.abs(specs[i]) <= np.abs(spec), specs[i], spec)
if MAX_SPEC == a:
spec = np.where(np.abs(specs[i]) >= np.abs(spec), specs[i], spec)
if AVERAGE == a:
spec = np.where(np.abs(specs[i]) == np.abs(spec), specs[i], spec)
return spec
def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path):
#print(algorithm)
if algorithm == AVERAGE:
output = average_audio(audio_input)
samplerate = 44100
else:
specs = []
for i in range(len(audio_input)):
wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
spec = wave_to_spectrogram_no_mp(wave)
specs.append(spec)
#print('output size: ', sys.getsizeof(spec))
#print('output size: ', sys.getsizeof(specs))
output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
def to_shape(x, target_shape):
padding_list = []
for x_dim, target_dim in zip(x.shape, target_shape):
pad_value = (target_dim - x_dim)
pad_tuple = ((0, pad_value))
padding_list.append(pad_tuple)
return np.pad(x, tuple(padding_list), mode='constant')
def to_shape_minimize(x: np.ndarray, target_shape):
padding_list = []
for x_dim, target_dim in zip(x.shape, target_shape):
pad_value = (target_dim - x_dim)
pad_tuple = ((0, pad_value))
padding_list.append(pad_tuple)
return np.pad(x, tuple(padding_list), mode='constant')
def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False):
#print(rate)
wav, sr = librosa.load(audio_file, sr=44100, mono=False)
if wav.ndim == 1:
wav = np.asfortranarray([wav,wav])
if is_pitch:
wav_1 = pyrubberband.pyrb.pitch_shift(wav[0], sr, rate, rbargs=None)
wav_2 = pyrubberband.pyrb.pitch_shift(wav[1], sr, rate, rbargs=None)
else:
wav_1 = pyrubberband.pyrb.time_stretch(wav[0], sr, rate, rbargs=None)
wav_2 = pyrubberband.pyrb.time_stretch(wav[1], sr, rate, rbargs=None)
if wav_1.shape > wav_2.shape:
wav_2 = to_shape(wav_2, wav_1.shape)
if wav_1.shape < wav_2.shape:
wav_1 = to_shape(wav_1, wav_2.shape)
wav_mix = np.asfortranarray([wav_1, wav_2])
sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
save_format(export_path)
def average_audio(audio):
waves = []
wave_shapes = []
final_waves = []
for i in range(len(audio)):
wave = librosa.load(audio[i], sr=44100, mono=False)
waves.append(wave[0])
wave_shapes.append(wave[0].shape[1])
wave_shapes_index = wave_shapes.index(max(wave_shapes))
target_shape = waves[wave_shapes_index]
waves.pop(wave_shapes_index)
final_waves.append(target_shape)
for n_array in waves:
wav_target = to_shape(n_array, target_shape.shape)
final_waves.append(wav_target)
waves = sum(final_waves)
waves = waves/len(audio)
return waves
def average_dual_sources(wav_1, wav_2, value):
if wav_1.shape > wav_2.shape:
wav_2 = to_shape(wav_2, wav_1.shape)
if wav_1.shape < wav_2.shape:
wav_1 = to_shape(wav_1, wav_2.shape)
wave = (wav_1 * value) + (wav_2 * (1-value))
return wave
def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray):
if wav_1.shape > wav_2.shape:
wav_2 = to_shape(wav_2, wav_1.shape)
if wav_1.shape < wav_2.shape:
ln = min([wav_1.shape[1], wav_2.shape[1]])
wav_2 = wav_2[:,:ln]
ln = min([wav_1.shape[1], wav_2.shape[1]])
wav_1 = wav_1[:,:ln]
wav_2 = wav_2[:,:ln]
return wav_2
def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_normalization, command_Text, progress_bar_main_var, save_format):
def get_diff(a, b):
corr = np.correlate(a, b, "full")
diff = corr.argmax() - (b.shape[0] - 1)
return diff
progress_bar_main_var.set(10)
# read tracks
wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
wav1 = wav1.transpose()
wav2 = wav2.transpose()
command_Text(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n")
wav2_org = wav2.copy()
progress_bar_main_var.set(20)
command_Text("Processing files... \n")
# pick random position and get diff
counts = {} # counting up for each diff value
progress = 20
check_range = 64
base = (64 / check_range)
for i in range(check_range):
index = int(random.uniform(44100 * 2, min(wav1.shape[0], wav2.shape[0]) - 44100 * 2))
shift = int(random.uniform(-22050,+22050))
samp1 = wav1[index :index +44100, 0] # currently use left channel
samp2 = wav2[index+shift:index+shift+44100, 0]
progress += 1 * base
progress_bar_main_var.set(progress)
diff = get_diff(samp1, samp2)
diff -= shift
if abs(diff) < 22050:
if not diff in counts:
counts[diff] = 0
counts[diff] += 1
# use max counted diff value
max_count = 0
est_diff = 0
for diff in counts.keys():
if counts[diff] > max_count:
max_count = counts[diff]
est_diff = diff
command_Text(f"Estimated difference is {est_diff} (count: {max_count})\n")
progress_bar_main_var.set(90)
audio_files = []
def save_aligned_audio(wav2_aligned):
command_Text(f"Aligned File 2 with File 1.\n")
command_Text(f"Saving files... ")
sf.write(file2_aligned, normalize(wav2_aligned, is_normalization), sr2, subtype=wav_type_set)
save_format(file2_aligned)
min_len = min(wav1.shape[0], wav2_aligned.shape[0])
wav_sub = wav1[:min_len] - wav2_aligned[:min_len]
audio_files.append(file2_aligned)
return min_len, wav_sub
# make aligned track 2
if est_diff > 0:
wav2_aligned = np.append(np.zeros((est_diff, 2)), wav2_org, axis=0)
min_len, wav_sub = save_aligned_audio(wav2_aligned)
elif est_diff < 0:
wav2_aligned = wav2_org[-est_diff:]
min_len, wav_sub = save_aligned_audio(wav2_aligned)
else:
command_Text(f"Audio files already aligned.\n")
command_Text(f"Saving inverted track... ")
min_len = min(wav1.shape[0], wav2.shape[0])
wav_sub = wav1[:min_len] - wav2[:min_len]
wav_sub = np.clip(wav_sub, -1, +1)
sf.write(file_subtracted, normalize(wav_sub, is_normalization), sr1, subtype=wav_type_set)
save_format(file_subtracted)
progress_bar_main_var.set(95)

View File

@ -0,0 +1 @@
# VR init.

143
lib_v5/vr_network/layers.py Normal file
View File

@ -0,0 +1,143 @@
import torch
from torch import nn
import torch.nn.functional as F
from lib_v5 import spec_utils
class Conv2DBNActiv(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
super(Conv2DBNActiv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
nin, nout,
kernel_size=ksize,
stride=stride,
padding=pad,
dilation=dilation,
bias=False),
nn.BatchNorm2d(nout),
activ()
)
def __call__(self, x):
return self.conv(x)
class SeperableConv2DBNActiv(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
super(SeperableConv2DBNActiv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
nin, nin,
kernel_size=ksize,
stride=stride,
padding=pad,
dilation=dilation,
groups=nin,
bias=False),
nn.Conv2d(
nin, nout,
kernel_size=1,
bias=False),
nn.BatchNorm2d(nout),
activ()
)
def __call__(self, x):
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
super(Encoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
def __call__(self, x):
skip = self.conv1(x)
h = self.conv2(skip)
return h, skip
class Decoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
super(Decoder, self).__init__()
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None
def __call__(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
if skip is not None:
skip = spec_utils.crop_center(skip, x)
x = torch.cat([x, skip], dim=1)
h = self.conv(x)
if self.dropout is not None:
h = self.dropout(h)
return h
class ASPPModule(nn.Module):
def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
super(ASPPModule, self).__init__()
self.conv1 = nn.Sequential(
nn.AdaptiveAvgPool2d((1, None)),
Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
)
self.nn_architecture = nn_architecture
self.six_layer = [129605]
self.seven_layer = [537238, 537227, 33966]
extra_conv = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
self.conv3 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
self.conv4 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
self.conv5 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
if self.nn_architecture in self.six_layer:
self.conv6 = extra_conv
nin_x = 6
elif self.nn_architecture in self.seven_layer:
self.conv6 = extra_conv
self.conv7 = extra_conv
nin_x = 7
else:
nin_x = 5
self.bottleneck = nn.Sequential(
Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ),
nn.Dropout2d(0.1)
)
def forward(self, x):
_, _, h, w = x.size()
feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
feat2 = self.conv2(x)
feat3 = self.conv3(x)
feat4 = self.conv4(x)
feat5 = self.conv5(x)
if self.nn_architecture in self.six_layer:
feat6 = self.conv6(x)
out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1)
elif self.nn_architecture in self.seven_layer:
feat6 = self.conv6(x)
feat7 = self.conv7(x)
out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
else:
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
bottle = self.bottleneck(out)
return bottle

View File

@ -0,0 +1,126 @@
import torch
from torch import nn
import torch.nn.functional as F
from lib_v5 import spec_utils
class Conv2DBNActiv(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
super(Conv2DBNActiv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
nin, nout,
kernel_size=ksize,
stride=stride,
padding=pad,
dilation=dilation,
bias=False),
nn.BatchNorm2d(nout),
activ()
)
def __call__(self, x):
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
super(Encoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
def __call__(self, x):
h = self.conv1(x)
h = self.conv2(h)
return h
class Decoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
super(Decoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
# self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None
def __call__(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
if skip is not None:
skip = spec_utils.crop_center(skip, x)
x = torch.cat([x, skip], dim=1)
h = self.conv1(x)
# h = self.conv2(h)
if self.dropout is not None:
h = self.dropout(h)
return h
class ASPPModule(nn.Module):
def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
super(ASPPModule, self).__init__()
self.conv1 = nn.Sequential(
nn.AdaptiveAvgPool2d((1, None)),
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
)
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
self.conv3 = Conv2DBNActiv(
nin, nout, 3, 1, dilations[0], dilations[0], activ=activ
)
self.conv4 = Conv2DBNActiv(
nin, nout, 3, 1, dilations[1], dilations[1], activ=activ
)
self.conv5 = Conv2DBNActiv(
nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None
def forward(self, x):
_, _, h, w = x.size()
feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
feat2 = self.conv2(x)
feat3 = self.conv3(x)
feat4 = self.conv4(x)
feat5 = self.conv5(x)
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
out = self.bottleneck(out)
if self.dropout is not None:
out = self.dropout(out)
return out
class LSTMModule(nn.Module):
def __init__(self, nin_conv, nin_lstm, nout_lstm):
super(LSTMModule, self).__init__()
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
self.lstm = nn.LSTM(
input_size=nin_lstm,
hidden_size=nout_lstm // 2,
bidirectional=True
)
self.dense = nn.Sequential(
nn.Linear(nout_lstm, nin_lstm),
nn.BatchNorm1d(nin_lstm),
nn.ReLU()
)
def forward(self, x):
N, _, nbins, nframes = x.size()
h = self.conv(x)[:, 0] # N, nbins, nframes
h = h.permute(2, 0, 1) # nframes, N, nbins
h, _ = self.lstm(h)
h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins
h = h.reshape(nframes, N, 1, nbins)
h = h.permute(1, 2, 3, 0)
return h

View File

@ -0,0 +1,59 @@
import json
import pathlib
default_param = {}
default_param['bins'] = 768
default_param['unstable_bins'] = 9 # training only
default_param['reduction_bins'] = 762 # training only
default_param['sr'] = 44100
default_param['pre_filter_start'] = 757
default_param['pre_filter_stop'] = 768
default_param['band'] = {}
default_param['band'][1] = {
'sr': 11025,
'hl': 128,
'n_fft': 960,
'crop_start': 0,
'crop_stop': 245,
'lpf_start': 61, # inference only
'res_type': 'polyphase'
}
default_param['band'][2] = {
'sr': 44100,
'hl': 512,
'n_fft': 1536,
'crop_start': 24,
'crop_stop': 547,
'hpf_start': 81, # inference only
'res_type': 'sinc_best'
}
def int_keys(d):
r = {}
for k, v in d:
if k.isdigit():
k = int(k)
r[k] = v
return r
class ModelParameters(object):
def __init__(self, config_path=''):
if '.pth' == pathlib.Path(config_path).suffix:
import zipfile
with zipfile.ZipFile(config_path, 'r') as zip:
self.param = json.loads(zip.read('param.json'), object_pairs_hook=int_keys)
elif '.json' == pathlib.Path(config_path).suffix:
with open(config_path, 'r') as f:
self.param = json.loads(f.read(), object_pairs_hook=int_keys)
else:
self.param = default_param
for k in ['mid_side', 'mid_side_b', 'mid_side_b2', 'stereo_w', 'stereo_n', 'reverse']:
if not k in self.param:
self.param[k] = False

View File

@ -0,0 +1,19 @@
{
"bins": 1024,
"unstable_bins": 0,
"reduction_bins": 0,
"band": {
"1": {
"sr": 16000,
"hl": 512,
"n_fft": 2048,
"crop_start": 0,
"crop_stop": 1024,
"hpf_start": -1,
"res_type": "sinc_best"
}
},
"sr": 16000,
"pre_filter_start": 1023,
"pre_filter_stop": 1024
}

View File

@ -0,0 +1,19 @@
{
"bins": 1024,
"unstable_bins": 0,
"reduction_bins": 0,
"band": {
"1": {
"sr": 32000,
"hl": 512,
"n_fft": 2048,
"crop_start": 0,
"crop_stop": 1024,
"hpf_start": -1,
"res_type": "kaiser_fast"
}
},
"sr": 32000,
"pre_filter_start": 1000,
"pre_filter_stop": 1021
}

View File

@ -0,0 +1,19 @@
{
"bins": 1024,
"unstable_bins": 0,
"reduction_bins": 0,
"band": {
"1": {
"sr": 33075,
"hl": 384,
"n_fft": 2048,
"crop_start": 0,
"crop_stop": 1024,
"hpf_start": -1,
"res_type": "sinc_best"
}
},
"sr": 33075,
"pre_filter_start": 1000,
"pre_filter_stop": 1021
}

View File

@ -0,0 +1,19 @@
{
"bins": 1024,
"unstable_bins": 0,
"reduction_bins": 0,
"band": {
"1": {
"sr": 44100,
"hl": 1024,
"n_fft": 2048,
"crop_start": 0,
"crop_stop": 1024,
"hpf_start": -1,
"res_type": "sinc_best"
}
},
"sr": 44100,
"pre_filter_start": 1023,
"pre_filter_stop": 1024
}

View File

@ -0,0 +1,19 @@
{
"bins": 256,
"unstable_bins": 0,
"reduction_bins": 0,
"band": {
"1": {
"sr": 44100,
"hl": 256,
"n_fft": 512,
"crop_start": 0,
"crop_stop": 256,
"hpf_start": -1,
"res_type": "sinc_best"
}
},
"sr": 44100,
"pre_filter_start": 256,
"pre_filter_stop": 256
}

View File

@ -0,0 +1,19 @@
{
"bins": 1024,
"unstable_bins": 0,
"reduction_bins": 0,
"band": {
"1": {
"sr": 44100,
"hl": 512,
"n_fft": 2048,
"crop_start": 0,
"crop_stop": 1024,
"hpf_start": -1,
"res_type": "sinc_best"
}
},
"sr": 44100,
"pre_filter_start": 1023,
"pre_filter_stop": 1024
}

View File

@ -0,0 +1,19 @@
{
"bins": 1024,
"unstable_bins": 0,
"reduction_bins": 0,
"band": {
"1": {
"sr": 44100,
"hl": 512,
"n_fft": 2048,
"crop_start": 0,
"crop_stop": 700,
"hpf_start": -1,
"res_type": "sinc_best"
}
},
"sr": 44100,
"pre_filter_start": 1023,
"pre_filter_stop": 700
}

View File

@ -0,0 +1,19 @@
{
"bins": 1024,
"unstable_bins": 0,
"reduction_bins": 0,
"band": {
"1": {
"sr": 44100,
"hl": 512,
"n_fft": 1024,
"crop_start": 0,
"crop_stop": 1024,
"hpf_start": -1,
"res_type": "sinc_best"
}
},
"sr": 44100,
"pre_filter_start": 1023,
"pre_filter_stop": 1024
}

View File

@ -0,0 +1,30 @@
{
"bins": 768,
"unstable_bins": 7,
"reduction_bins": 705,
"band": {
"1": {
"sr": 6000,
"hl": 66,
"n_fft": 512,
"crop_start": 0,
"crop_stop": 240,
"lpf_start": 60,
"lpf_stop": 118,
"res_type": "sinc_fastest"
},
"2": {
"sr": 32000,
"hl": 352,
"n_fft": 1024,
"crop_start": 22,
"crop_stop": 505,
"hpf_start": 44,
"hpf_stop": 23,
"res_type": "sinc_medium"
}
},
"sr": 32000,
"pre_filter_start": 710,
"pre_filter_stop": 731
}

View File

@ -0,0 +1,30 @@
{
"bins": 512,
"unstable_bins": 7,
"reduction_bins": 510,
"band": {
"1": {
"sr": 11025,
"hl": 160,
"n_fft": 768,
"crop_start": 0,
"crop_stop": 192,
"lpf_start": 41,
"lpf_stop": 139,
"res_type": "sinc_fastest"
},
"2": {
"sr": 44100,
"hl": 640,
"n_fft": 1024,
"crop_start": 10,
"crop_stop": 320,
"hpf_start": 47,
"hpf_stop": 15,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 510,
"pre_filter_stop": 512
}

View File

@ -0,0 +1,30 @@
{
"bins": 768,
"unstable_bins": 7,
"reduction_bins": 705,
"band": {
"1": {
"sr": 6000,
"hl": 66,
"n_fft": 512,
"crop_start": 0,
"crop_stop": 240,
"lpf_start": 60,
"lpf_stop": 240,
"res_type": "sinc_fastest"
},
"2": {
"sr": 48000,
"hl": 528,
"n_fft": 1536,
"crop_start": 22,
"crop_stop": 505,
"hpf_start": 82,
"hpf_stop": 22,
"res_type": "sinc_medium"
}
},
"sr": 48000,
"pre_filter_start": 710,
"pre_filter_stop": 731
}

View File

@ -0,0 +1,42 @@
{
"bins": 768,
"unstable_bins": 5,
"reduction_bins": 733,
"band": {
"1": {
"sr": 11025,
"hl": 128,
"n_fft": 768,
"crop_start": 0,
"crop_stop": 278,
"lpf_start": 28,
"lpf_stop": 140,
"res_type": "polyphase"
},
"2": {
"sr": 22050,
"hl": 256,
"n_fft": 768,
"crop_start": 14,
"crop_stop": 322,
"hpf_start": 70,
"hpf_stop": 14,
"lpf_start": 283,
"lpf_stop": 314,
"res_type": "polyphase"
},
"3": {
"sr": 44100,
"hl": 512,
"n_fft": 768,
"crop_start": 131,
"crop_stop": 313,
"hpf_start": 154,
"hpf_stop": 141,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 757,
"pre_filter_stop": 768
}

View File

@ -0,0 +1,43 @@
{
"mid_side": true,
"bins": 768,
"unstable_bins": 5,
"reduction_bins": 733,
"band": {
"1": {
"sr": 11025,
"hl": 128,
"n_fft": 768,
"crop_start": 0,
"crop_stop": 278,
"lpf_start": 28,
"lpf_stop": 140,
"res_type": "polyphase"
},
"2": {
"sr": 22050,
"hl": 256,
"n_fft": 768,
"crop_start": 14,
"crop_stop": 322,
"hpf_start": 70,
"hpf_stop": 14,
"lpf_start": 283,
"lpf_stop": 314,
"res_type": "polyphase"
},
"3": {
"sr": 44100,
"hl": 512,
"n_fft": 768,
"crop_start": 131,
"crop_stop": 313,
"hpf_start": 154,
"hpf_stop": 141,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 757,
"pre_filter_stop": 768
}

View File

@ -0,0 +1,43 @@
{
"mid_side_b2": true,
"bins": 640,
"unstable_bins": 7,
"reduction_bins": 565,
"band": {
"1": {
"sr": 11025,
"hl": 108,
"n_fft": 1024,
"crop_start": 0,
"crop_stop": 187,
"lpf_start": 92,
"lpf_stop": 186,
"res_type": "polyphase"
},
"2": {
"sr": 22050,
"hl": 216,
"n_fft": 768,
"crop_start": 0,
"crop_stop": 212,
"hpf_start": 68,
"hpf_stop": 34,
"lpf_start": 174,
"lpf_stop": 209,
"res_type": "polyphase"
},
"3": {
"sr": 44100,
"hl": 432,
"n_fft": 640,
"crop_start": 66,
"crop_stop": 307,
"hpf_start": 86,
"hpf_stop": 72,
"res_type": "kaiser_fast"
}
},
"sr": 44100,
"pre_filter_start": 639,
"pre_filter_stop": 640
}

View File

@ -0,0 +1,54 @@
{
"bins": 768,
"unstable_bins": 7,
"reduction_bins": 668,
"band": {
"1": {
"sr": 11025,
"hl": 128,
"n_fft": 1024,
"crop_start": 0,
"crop_stop": 186,
"lpf_start": 37,
"lpf_stop": 73,
"res_type": "polyphase"
},
"2": {
"sr": 11025,
"hl": 128,
"n_fft": 512,
"crop_start": 4,
"crop_stop": 185,
"hpf_start": 36,
"hpf_stop": 18,
"lpf_start": 93,
"lpf_stop": 185,
"res_type": "polyphase"
},
"3": {
"sr": 22050,
"hl": 256,
"n_fft": 512,
"crop_start": 46,
"crop_stop": 186,
"hpf_start": 93,
"hpf_stop": 46,
"lpf_start": 164,
"lpf_stop": 186,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 512,
"n_fft": 768,
"crop_start": 121,
"crop_stop": 382,
"hpf_start": 138,
"hpf_stop": 123,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 740,
"pre_filter_stop": 768
}

View File

@ -0,0 +1,55 @@
{
"bins": 768,
"unstable_bins": 7,
"mid_side": true,
"reduction_bins": 668,
"band": {
"1": {
"sr": 11025,
"hl": 128,
"n_fft": 1024,
"crop_start": 0,
"crop_stop": 186,
"lpf_start": 37,
"lpf_stop": 73,
"res_type": "polyphase"
},
"2": {
"sr": 11025,
"hl": 128,
"n_fft": 512,
"crop_start": 4,
"crop_stop": 185,
"hpf_start": 36,
"hpf_stop": 18,
"lpf_start": 93,
"lpf_stop": 185,
"res_type": "polyphase"
},
"3": {
"sr": 22050,
"hl": 256,
"n_fft": 512,
"crop_start": 46,
"crop_stop": 186,
"hpf_start": 93,
"hpf_stop": 46,
"lpf_start": 164,
"lpf_stop": 186,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 512,
"n_fft": 768,
"crop_start": 121,
"crop_stop": 382,
"hpf_start": 138,
"hpf_stop": 123,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 740,
"pre_filter_stop": 768
}

View File

@ -0,0 +1,55 @@
{
"mid_side_b": true,
"bins": 768,
"unstable_bins": 7,
"reduction_bins": 668,
"band": {
"1": {
"sr": 11025,
"hl": 128,
"n_fft": 1024,
"crop_start": 0,
"crop_stop": 186,
"lpf_start": 37,
"lpf_stop": 73,
"res_type": "polyphase"
},
"2": {
"sr": 11025,
"hl": 128,
"n_fft": 512,
"crop_start": 4,
"crop_stop": 185,
"hpf_start": 36,
"hpf_stop": 18,
"lpf_start": 93,
"lpf_stop": 185,
"res_type": "polyphase"
},
"3": {
"sr": 22050,
"hl": 256,
"n_fft": 512,
"crop_start": 46,
"crop_stop": 186,
"hpf_start": 93,
"hpf_stop": 46,
"lpf_start": 164,
"lpf_stop": 186,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 512,
"n_fft": 768,
"crop_start": 121,
"crop_stop": 382,
"hpf_start": 138,
"hpf_stop": 123,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 740,
"pre_filter_stop": 768
}

View File

@ -0,0 +1,55 @@
{
"mid_side_b": true,
"bins": 768,
"unstable_bins": 7,
"reduction_bins": 668,
"band": {
"1": {
"sr": 11025,
"hl": 128,
"n_fft": 1024,
"crop_start": 0,
"crop_stop": 186,
"lpf_start": 37,
"lpf_stop": 73,
"res_type": "polyphase"
},
"2": {
"sr": 11025,
"hl": 128,
"n_fft": 512,
"crop_start": 4,
"crop_stop": 185,
"hpf_start": 36,
"hpf_stop": 18,
"lpf_start": 93,
"lpf_stop": 185,
"res_type": "polyphase"
},
"3": {
"sr": 22050,
"hl": 256,
"n_fft": 512,
"crop_start": 46,
"crop_stop": 186,
"hpf_start": 93,
"hpf_stop": 46,
"lpf_start": 164,
"lpf_stop": 186,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 512,
"n_fft": 768,
"crop_start": 121,
"crop_stop": 382,
"hpf_start": 138,
"hpf_stop": 123,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 740,
"pre_filter_stop": 768
}

View File

@ -0,0 +1,55 @@
{
"reverse": true,
"bins": 768,
"unstable_bins": 7,
"reduction_bins": 668,
"band": {
"1": {
"sr": 11025,
"hl": 128,
"n_fft": 1024,
"crop_start": 0,
"crop_stop": 186,
"lpf_start": 37,
"lpf_stop": 73,
"res_type": "polyphase"
},
"2": {
"sr": 11025,
"hl": 128,
"n_fft": 512,
"crop_start": 4,
"crop_stop": 185,
"hpf_start": 36,
"hpf_stop": 18,
"lpf_start": 93,
"lpf_stop": 185,
"res_type": "polyphase"
},
"3": {
"sr": 22050,
"hl": 256,
"n_fft": 512,
"crop_start": 46,
"crop_stop": 186,
"hpf_start": 93,
"hpf_stop": 46,
"lpf_start": 164,
"lpf_stop": 186,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 512,
"n_fft": 768,
"crop_start": 121,
"crop_stop": 382,
"hpf_start": 138,
"hpf_stop": 123,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 740,
"pre_filter_stop": 768
}

View File

@ -0,0 +1,55 @@
{
"stereo_w": true,
"bins": 768,
"unstable_bins": 7,
"reduction_bins": 668,
"band": {
"1": {
"sr": 11025,
"hl": 128,
"n_fft": 1024,
"crop_start": 0,
"crop_stop": 186,
"lpf_start": 37,
"lpf_stop": 73,
"res_type": "polyphase"
},
"2": {
"sr": 11025,
"hl": 128,
"n_fft": 512,
"crop_start": 4,
"crop_stop": 185,
"hpf_start": 36,
"hpf_stop": 18,
"lpf_start": 93,
"lpf_stop": 185,
"res_type": "polyphase"
},
"3": {
"sr": 22050,
"hl": 256,
"n_fft": 512,
"crop_start": 46,
"crop_stop": 186,
"hpf_start": 93,
"hpf_stop": 46,
"lpf_start": 164,
"lpf_stop": 186,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 512,
"n_fft": 768,
"crop_start": 121,
"crop_stop": 382,
"hpf_start": 138,
"hpf_stop": 123,
"res_type": "sinc_medium"
}
},
"sr": 44100,
"pre_filter_start": 740,
"pre_filter_stop": 768
}

View File

@ -0,0 +1,54 @@
{
"bins": 672,
"unstable_bins": 8,
"reduction_bins": 637,
"band": {
"1": {
"sr": 7350,
"hl": 80,
"n_fft": 640,
"crop_start": 0,
"crop_stop": 85,
"lpf_start": 25,
"lpf_stop": 53,
"res_type": "polyphase"
},
"2": {
"sr": 7350,
"hl": 80,
"n_fft": 320,
"crop_start": 4,
"crop_stop": 87,
"hpf_start": 25,
"hpf_stop": 12,
"lpf_start": 31,
"lpf_stop": 62,
"res_type": "polyphase"
},
"3": {
"sr": 14700,
"hl": 160,
"n_fft": 512,
"crop_start": 17,
"crop_stop": 216,
"hpf_start": 48,
"hpf_stop": 24,
"lpf_start": 139,
"lpf_stop": 210,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 480,
"n_fft": 960,
"crop_start": 78,
"crop_stop": 383,
"hpf_start": 130,
"hpf_stop": 86,
"res_type": "kaiser_fast"
}
},
"sr": 44100,
"pre_filter_start": 668,
"pre_filter_stop": 672
}

View File

@ -0,0 +1,55 @@
{
"bins": 672,
"unstable_bins": 8,
"reduction_bins": 637,
"band": {
"1": {
"sr": 7350,
"hl": 80,
"n_fft": 640,
"crop_start": 0,
"crop_stop": 85,
"lpf_start": 25,
"lpf_stop": 53,
"res_type": "polyphase"
},
"2": {
"sr": 7350,
"hl": 80,
"n_fft": 320,
"crop_start": 4,
"crop_stop": 87,
"hpf_start": 25,
"hpf_stop": 12,
"lpf_start": 31,
"lpf_stop": 62,
"res_type": "polyphase"
},
"3": {
"sr": 14700,
"hl": 160,
"n_fft": 512,
"crop_start": 17,
"crop_stop": 216,
"hpf_start": 48,
"hpf_stop": 24,
"lpf_start": 139,
"lpf_stop": 210,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 480,
"n_fft": 960,
"crop_start": 78,
"crop_stop": 383,
"hpf_start": 130,
"hpf_stop": 86,
"convert_channels": "stereo_n",
"res_type": "kaiser_fast"
}
},
"sr": 44100,
"pre_filter_start": 668,
"pre_filter_stop": 672
}

View File

@ -0,0 +1,54 @@
{
"bins": 672,
"unstable_bins": 8,
"reduction_bins": 530,
"band": {
"1": {
"sr": 7350,
"hl": 80,
"n_fft": 640,
"crop_start": 0,
"crop_stop": 85,
"lpf_start": 25,
"lpf_stop": 53,
"res_type": "polyphase"
},
"2": {
"sr": 7350,
"hl": 80,
"n_fft": 320,
"crop_start": 4,
"crop_stop": 87,
"hpf_start": 25,
"hpf_stop": 12,
"lpf_start": 31,
"lpf_stop": 62,
"res_type": "polyphase"
},
"3": {
"sr": 14700,
"hl": 160,
"n_fft": 512,
"crop_start": 17,
"crop_stop": 216,
"hpf_start": 48,
"hpf_stop": 24,
"lpf_start": 139,
"lpf_stop": 210,
"res_type": "polyphase"
},
"4": {
"sr": 44100,
"hl": 480,
"n_fft": 960,
"crop_start": 78,
"crop_stop": 383,
"hpf_start": 130,
"hpf_stop": 86,
"res_type": "kaiser_fast"
}
},
"sr": 44100,
"pre_filter_start": 668,
"pre_filter_stop": 672
}

View File

@ -0,0 +1,43 @@
{
"mid_side_b2": true,
"bins": 1280,
"unstable_bins": 7,
"reduction_bins": 565,
"band": {
"1": {
"sr": 11025,
"hl": 108,
"n_fft": 2048,
"crop_start": 0,
"crop_stop": 374,
"lpf_start": 92,
"lpf_stop": 186,
"res_type": "polyphase"
},
"2": {
"sr": 22050,
"hl": 216,
"n_fft": 1536,
"crop_start": 0,
"crop_stop": 424,
"hpf_start": 68,
"hpf_stop": 34,
"lpf_start": 348,
"lpf_stop": 418,
"res_type": "polyphase"
},
"3": {
"sr": 44100,
"hl": 432,
"n_fft": 1280,
"crop_start": 132,
"crop_stop": 614,
"hpf_start": 172,
"hpf_stop": 144,
"res_type": "polyphase"
}
},
"sr": 44100,
"pre_filter_start": 1280,
"pre_filter_stop": 1280
}

171
lib_v5/vr_network/nets.py Normal file
View File

@ -0,0 +1,171 @@
import torch
from torch import nn
import torch.nn.functional as F
from . import layers
class BaseASPPNet(nn.Module):
def __init__(self, nn_architecture, nin, ch, dilations=(4, 8, 16)):
super(BaseASPPNet, self).__init__()
self.nn_architecture = nn_architecture
self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
if self.nn_architecture == 129605:
self.enc5 = layers.Encoder(ch * 8, ch * 16, 3, 2, 1)
self.aspp = layers.ASPPModule(nn_architecture, ch * 16, ch * 32, dilations)
self.dec5 = layers.Decoder(ch * (16 + 32), ch * 16, 3, 1, 1)
else:
self.aspp = layers.ASPPModule(nn_architecture, ch * 8, ch * 16, dilations)
self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
def __call__(self, x):
h, e1 = self.enc1(x)
h, e2 = self.enc2(h)
h, e3 = self.enc3(h)
h, e4 = self.enc4(h)
if self.nn_architecture == 129605:
h, e5 = self.enc5(h)
h = self.aspp(h)
h = self.dec5(h, e5)
else:
h = self.aspp(h)
h = self.dec4(h, e4)
h = self.dec3(h, e3)
h = self.dec2(h, e2)
h = self.dec1(h, e1)
return h
def determine_model_capacity(n_fft_bins, nn_architecture):
sp_model_arch = [31191, 33966, 129605]
hp_model_arch = [123821, 123812]
hp2_model_arch = [537238, 537227]
if nn_architecture in sp_model_arch:
model_capacity_data = [
(2, 16),
(2, 16),
(18, 8, 1, 1, 0),
(8, 16),
(34, 16, 1, 1, 0),
(16, 32),
(32, 2, 1),
(16, 2, 1),
(16, 2, 1),
]
if nn_architecture in hp_model_arch:
model_capacity_data = [
(2, 32),
(2, 32),
(34, 16, 1, 1, 0),
(16, 32),
(66, 32, 1, 1, 0),
(32, 64),
(64, 2, 1),
(32, 2, 1),
(32, 2, 1),
]
if nn_architecture in hp2_model_arch:
model_capacity_data = [
(2, 64),
(2, 64),
(66, 32, 1, 1, 0),
(32, 64),
(130, 64, 1, 1, 0),
(64, 128),
(128, 2, 1),
(64, 2, 1),
(64, 2, 1),
]
cascaded = CascadedASPPNet
model = cascaded(n_fft_bins, model_capacity_data, nn_architecture)
return model
class CascadedASPPNet(nn.Module):
def __init__(self, n_fft, model_capacity_data, nn_architecture):
super(CascadedASPPNet, self).__init__()
self.stg1_low_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[0])
self.stg1_high_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[1])
self.stg2_bridge = layers.Conv2DBNActiv(*model_capacity_data[2])
self.stg2_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[3])
self.stg3_bridge = layers.Conv2DBNActiv(*model_capacity_data[4])
self.stg3_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[5])
self.out = nn.Conv2d(*model_capacity_data[6], bias=False)
self.aux1_out = nn.Conv2d(*model_capacity_data[7], bias=False)
self.aux2_out = nn.Conv2d(*model_capacity_data[8], bias=False)
self.max_bin = n_fft // 2
self.output_bin = n_fft // 2 + 1
self.offset = 128
def forward(self, x, aggressiveness=None):
mix = x.detach()
x = x.clone()
x = x[:, :, :self.max_bin]
bandw = x.size()[2] // 2
aux1 = torch.cat([
self.stg1_low_band_net(x[:, :, :bandw]),
self.stg1_high_band_net(x[:, :, bandw:])
], dim=2)
h = torch.cat([x, aux1], dim=1)
aux2 = self.stg2_full_band_net(self.stg2_bridge(h))
h = torch.cat([x, aux1, aux2], dim=1)
h = self.stg3_full_band_net(self.stg3_bridge(h))
mask = torch.sigmoid(self.out(h))
mask = F.pad(
input=mask,
pad=(0, 0, 0, self.output_bin - mask.size()[2]),
mode='replicate')
if self.training:
aux1 = torch.sigmoid(self.aux1_out(aux1))
aux1 = F.pad(
input=aux1,
pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
mode='replicate')
aux2 = torch.sigmoid(self.aux2_out(aux2))
aux2 = F.pad(
input=aux2,
pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
mode='replicate')
return mask * mix, aux1 * mix, aux2 * mix
else:
if aggressiveness:
mask[:, :, :aggressiveness['split_bin']] = torch.pow(mask[:, :, :aggressiveness['split_bin']], 1 + aggressiveness['value'] / 3)
mask[:, :, aggressiveness['split_bin']:] = torch.pow(mask[:, :, aggressiveness['split_bin']:], 1 + aggressiveness['value'])
return mask * mix
def predict(self, x_mag, aggressiveness=None):
h = self.forward(x_mag, aggressiveness)
if self.offset > 0:
h = h[:, :, :, self.offset:-self.offset]
assert h.size()[3] > 0
return h

View File

@ -0,0 +1,143 @@
import torch
from torch import nn
import torch.nn.functional as F
from . import layers_new as layers
class BaseNet(nn.Module):
def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))):
super(BaseNet, self).__init__()
self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1)
self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1)
self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1)
self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1)
self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1)
self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True)
self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1)
self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1)
self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1)
self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm)
self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1)
def __call__(self, x):
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
e4 = self.enc4(e3)
e5 = self.enc5(e4)
h = self.aspp(e5)
h = self.dec4(h, e4)
h = self.dec3(h, e3)
h = self.dec2(h, e2)
h = torch.cat([h, self.lstm_dec2(h)], dim=1)
h = self.dec1(h, e1)
return h
class CascadedNet(nn.Module):
def __init__(self, n_fft, nn_architecture):
super(CascadedNet, self).__init__()
self.max_bin = n_fft // 2
self.output_bin = n_fft // 2 + 1
self.nin_lstm = self.max_bin // 2
self.offset = 64
self.nn_architecture = nn_architecture
print('ARC SIZE: ', nn_architecture)
if nn_architecture == 218409:
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, 32, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(2, 16, self.nin_lstm // 2, 64)
self.stg2_low_band_net = nn.Sequential(
BaseNet(18, 64, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(64, 32, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(18, 32, self.nin_lstm // 2, 64)
self.stg3_full_band_net = BaseNet(50, 64, self.nin_lstm, 128)
self.out = nn.Conv2d(64, 2, 1, bias=False)
self.aux_out = nn.Conv2d(48, 2, 1, bias=False)
else:
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, 16, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(16, 8, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(2, 8, self.nin_lstm // 2, 64)
self.stg2_low_band_net = nn.Sequential(
BaseNet(10, 32, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(10, 16, self.nin_lstm // 2, 64)
self.stg3_full_band_net = BaseNet(26, 32, self.nin_lstm, 128)
self.out = nn.Conv2d(32, 2, 1, bias=False)
self.aux_out = nn.Conv2d(24, 2, 1, bias=False)
def forward(self, x):
x = x[:, :, :self.max_bin]
bandw = x.size()[2] // 2
l1_in = x[:, :, :bandw]
h1_in = x[:, :, bandw:]
l1 = self.stg1_low_band_net(l1_in)
h1 = self.stg1_high_band_net(h1_in)
aux1 = torch.cat([l1, h1], dim=2)
l2_in = torch.cat([l1_in, l1], dim=1)
h2_in = torch.cat([h1_in, h1], dim=1)
l2 = self.stg2_low_band_net(l2_in)
h2 = self.stg2_high_band_net(h2_in)
aux2 = torch.cat([l2, h2], dim=2)
f3_in = torch.cat([x, aux1, aux2], dim=1)
f3 = self.stg3_full_band_net(f3_in)
mask = torch.sigmoid(self.out(f3))
mask = F.pad(
input=mask,
pad=(0, 0, 0, self.output_bin - mask.size()[2]),
mode='replicate'
)
if self.training:
aux = torch.cat([aux1, aux2], dim=1)
aux = torch.sigmoid(self.aux_out(aux))
aux = F.pad(
input=aux,
pad=(0, 0, 0, self.output_bin - aux.size()[2]),
mode='replicate'
)
return mask, aux
else:
return mask
def predict_mask(self, x):
mask = self.forward(x)
if self.offset > 0:
mask = mask[:, :, :, self.offset:-self.offset]
assert mask.size()[3] > 0
return mask
def predict(self, x):
mask = self.forward(x)
pred_mag = x * mask
if self.offset > 0:
pred_mag = pred_mag[:, :, :, self.offset:-self.offset]
assert pred_mag.size()[3] > 0
return pred_mag

View File

@ -0,0 +1 @@
Demucs v3 and v4 models go here.

View File

@ -0,0 +1,184 @@
{
"0ddfc0eb5792638ad5dc27850236c246": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 6144,
"primary_stem": "Vocals"
},
"26d308f91f3423a67dc69a6d12a8793d": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 9,
"mdx_n_fft_scale_set": 8192,
"primary_stem": "Other"
},
"2cdd429caac38f0194b133884160f2c6": {
"compensate": 1.035,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Instrumental"
},
"2f5501189a2f6db6349916fabe8c90de": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Instrumental"
},
"398580b6d5d973af3120df54cee6759d": {
"compensate": 1.75,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Vocals"
},
"488b3e6f8bd3717d9d7c428476be2d75": {
"compensate": 1.035,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Instrumental"
},
"4910e7827f335048bdac11fa967772f9": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 7,
"mdx_n_fft_scale_set": 4096,
"primary_stem": "Drums"
},
"53c4baf4d12c3e6c3831bb8f5b532b93": {
"compensate": 1.035,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Vocals"
},
"5d343409ef0df48c7d78cce9f0106781": {
"compensate": 1.075,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Vocals"
},
"5f6483271e1efb9bfb59e4a3e6d4d098": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 9,
"mdx_n_fft_scale_set": 6144,
"primary_stem": "Vocals"
},
"65ab5919372a128e4167f5e01a8fda85": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 8192,
"primary_stem": "Other"
},
"6703e39f36f18aa7855ee1047765621d": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 9,
"mdx_n_fft_scale_set": 16384,
"primary_stem": "Bass"
},
"6b31de20e84392859a3d09d43f089515": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 6144,
"primary_stem": "Vocals"
},
"867595e9de46f6ab699008295df62798": {
"compensate": 1.075,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Vocals"
},
"a3cd63058945e777505c01d2507daf37": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 6144,
"primary_stem": "Vocals"
},
"b33d9b3950b6cbf5fe90a32608924700": {
"compensate": 1.075,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Vocals"
},
"c3b29bdce8c4fa17ec609e16220330ab": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 16384,
"primary_stem": "Bass"
},
"ceed671467c1f64ebdfac8a2490d0d52": {
"compensate": 1.035,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Instrumental"
},
"d2a1376f310e4f7fa37fb9b5774eb701": {
"compensate": 1.035,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Instrumental"
},
"d7bff498db9324db933d913388cba6be": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 6144,
"primary_stem": "Vocals"
},
"d94058f8c7f1fae4164868ae8ae66b20": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 6144,
"primary_stem": "Vocals"
},
"dc41ede5961d50f277eb846db17f5319": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 9,
"mdx_n_fft_scale_set": 4096,
"primary_stem": "Drums"
},
"e5572e58abf111f80d8241d2e44e7fa4": {
"compensate": 1.035,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Instrumental"
},
"e7324c873b1f615c35c1967f912db92a": {
"compensate": 1.075,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Vocals"
},
"1c56ec0224f1d559c42fd6fd2a67b154": {
"compensate": 1.035,
"mdx_dim_f_set": 2048,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 5120,
"primary_stem": "Instrumental"
},
"f2df6d6863d8f435436d8b561594ff49": {
"compensate": 1.035,
"mdx_dim_f_set": 3072,
"mdx_dim_t_set": 8,
"mdx_n_fft_scale_set": 7680,
"primary_stem": "Instrumental"
}
}

View File

@ -0,0 +1,94 @@
{
"0d0e6d143046b0eecc41a22e60224582": {
"vr_model_param": "3band_44100_mid",
"primary_stem": "Instrumental"
},
"18b52f873021a0af556fb4ecd552bb8e": {
"vr_model_param": "2band_32000",
"primary_stem": "Instrumental"
},
"1fc66027c82b499c7d8f55f79e64cadc": {
"vr_model_param": "2band_32000",
"primary_stem": "Instrumental"
},
"2aa34fbc01f8e6d2bf509726481e7142": {
"vr_model_param": "4band_44100",
"primary_stem": "Other"
},
"3e18f639b11abea7361db1a4a91c2559": {
"vr_model_param": "4band_44100",
"primary_stem": "Instrumental"
},
"570b5f50054609a17741369a35007ddd": {
"vr_model_param": "4band_v3",
"primary_stem": "Instrumental"
},
"5a6e24c1b530f2dab045a522ef89b751": {
"vr_model_param": "1band_sr44100_hl512",
"primary_stem": "Instrumental"
},
"6b5916069a49be3fe29d4397ecfd73fa": {
"vr_model_param": "3band_44100_msb2",
"primary_stem": "Instrumental"
},
"74b3bc5fa2b69f29baf7839b858bc679": {
"vr_model_param": "4band_44100",
"primary_stem": "Instrumental"
},
"827213b316df36b52a1f3d04fec89369": {
"vr_model_param": "4band_44100",
"primary_stem": "Instrumental"
},
"911d4048eee7223eca4ee0efb7d29256": {
"vr_model_param": "4band_44100",
"primary_stem": "Vocals"
},
"941f3f7f0b0341f12087aacdfef644b1": {
"vr_model_param": "4band_v2",
"primary_stem": "Instrumental"
},
"a02827cf69d75781a35c0e8a327f3195": {
"vr_model_param": "1band_sr33075_hl384",
"primary_stem": "Instrumental"
},
"b165fbff113c959dba5303b74c6484bc": {
"vr_model_param": "3band_44100",
"primary_stem": "Instrumental"
},
"b5f988cd3e891dca7253bf5f0f3427c7": {
"vr_model_param": "4band_44100",
"primary_stem": "Instrumental"
},
"b99c35723bc35cb11ed14a4780006a80": {
"vr_model_param": "1band_sr44100_hl1024",
"primary_stem": "Instrumental"
},
"ba02fd25b71d620eebbdb49e18e4c336": {
"vr_model_param": "3band_44100_mid",
"primary_stem": "Instrumental"
},
"c4476ef424d8cba65f38d8d04e8514e2": {
"vr_model_param": "3band_44100_msb2",
"primary_stem": "Instrumental"
},
"da2d37b8be2972e550a409bae08335aa": {
"vr_model_param": "4band_44100",
"primary_stem": "Vocals"
},
"db57205d3133e39df8e050b435a78c80": {
"vr_model_param": "4band_44100",
"primary_stem": "Instrumental"
},
"ea83b08e32ec2303456fe50659035f69": {
"vr_model_param": "4band_v3",
"primary_stem": "Instrumental"
},
"f6ea8473ff86017b5ebd586ccacf156b": {
"vr_model_param": "4band_v2_sn",
"primary_stem": "Instrumental"
},
"fd297a61eafc9d829033f8b987c39a3d": {
"vr_model_param": "1band_sr32000_hl512",
"primary_stem": "Instrumental"
}
}

924
separate.py Normal file
View File

@ -0,0 +1,924 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from demucs.apply import apply_model, demucs_segments
from demucs.hdemucs import HDemucs
from demucs.model_v2 import auto_load_demucs_model_v2
from demucs.pretrained import get_model as _gm
from demucs.utils import apply_model_v1
from demucs.utils import apply_model_v2
from lib_v5 import spec_utils
from lib_v5.vr_network import nets
from lib_v5.vr_network import nets_new
#from lib_v5.vr_network.model_param_init import ModelParameters
from pathlib import Path
from gui_data.constants import *
import gzip
import librosa
import math
import numpy as np
import onnxruntime as ort
import os
import torch
import warnings
import pydub
import soundfile as sf
if TYPE_CHECKING:
from UVR import ModelData
warnings.filterwarnings("ignore")
cpu = torch.device('cpu')
class SeperateAttributes:
def __init__(self, model_data: ModelData, process_data: dict, main_model_primary_stem_4_stem=None, main_process_method=None):
self.list_all_models: list
self.process_data = process_data
self.progress_value = 0
self.set_progress_bar = process_data['set_progress_bar']
self.write_to_console = process_data['write_to_console']
self.audio_file = process_data['audio_file']
self.audio_file_base = process_data['audio_file_base']
self.export_path = process_data['export_path']
self.cached_source_callback = process_data['cached_source_callback']
self.cached_model_source_holder = process_data['cached_model_source_holder']
self.is_4_stem_ensemble = process_data['is_4_stem_ensemble']
self.list_all_models = process_data['list_all_models']
self.process_iteration = process_data['process_iteration']
self.model_samplerate = model_data.model_samplerate
self.is_pre_proc_model = model_data.is_pre_proc_model
self.is_secondary_model_activated = model_data.is_secondary_model_activated if not self.is_pre_proc_model else False
self.is_secondary_model = model_data.is_secondary_model if not self.is_pre_proc_model else True
self.process_method = model_data.process_method
self.model_path = model_data.model_path
self.model_name = model_data.model_name
self.model_basename = model_data.model_basename
self.wav_type_set = model_data.wav_type_set
self.mp3_bit_set = model_data.mp3_bit_set
self.save_format = model_data.save_format
self.is_gpu_conversion = model_data.is_gpu_conversion
self.is_normalization = model_data.is_normalization
self.is_primary_stem_only = model_data.is_primary_stem_only if not self.is_secondary_model else model_data.is_primary_model_primary_stem_only
self.is_secondary_stem_only = model_data.is_secondary_stem_only if not self.is_secondary_model else model_data.is_primary_model_secondary_stem_only
self.is_ensemble_mode = model_data.is_ensemble_mode
self.secondary_model = model_data.secondary_model #VERIFY WHERE
self.primary_model_primary_stem = model_data.primary_model_primary_stem
self.primary_stem = model_data.primary_stem #-
self.secondary_stem = model_data.secondary_stem #-
self.is_invert_spec = model_data.is_invert_spec #
self.secondary_model_scale = model_data.secondary_model_scale #
self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix #
#############################
self.primary_source_map = {}
self.secondary_source_map = {}
self.primary_source = None
self.secondary_source = None
self.secondary_source_primary = None
self.secondary_source_secondary = None
if not model_data.process_method == DEMUCS_ARCH_TYPE:
if process_data['is_ensemble_master'] and not self.is_4_stem_ensemble:
if not model_data.ensemble_primary_stem == self.primary_stem:
self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
if self.is_secondary_model and not process_data['is_ensemble_master']:
if not self.primary_model_primary_stem == self.primary_stem and not main_model_primary_stem_4_stem:
self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
if main_model_primary_stem_4_stem:
self.is_primary_stem_only = True if main_model_primary_stem_4_stem == self.primary_stem else False
self.is_secondary_stem_only = True if not main_model_primary_stem_4_stem == self.primary_stem else False
if self.is_pre_proc_model:
self.is_primary_stem_only = True if self.primary_stem == INST_STEM else False
self.is_secondary_stem_only = True if self.secondary_stem == INST_STEM else False
if model_data.process_method == MDX_ARCH_TYPE:
self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename)
self.is_denoise = model_data.is_denoise
self.compensate = model_data.compensate
self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
self.n_fft = model_data.mdx_n_fft_scale_set
self.chunks = model_data.chunks
self.margin = model_data.margin
self.hop = 1024
self.n_bins = self.n_fft//2+1
self.chunk_size = self.hop * (self.dim_t-1)
self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(cpu)
self.dim_c = 4
out_c = self.dim_c
self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(cpu)
if model_data.process_method == DEMUCS_ARCH_TYPE:
self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None
self.secondary_model_4_stem = model_data.secondary_model_4_stem
self.secondary_model_4_stem_scale = model_data.secondary_model_4_stem_scale
self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
self.is_chunk_demucs = model_data.is_chunk_demucs
self.segment = model_data.segment
self.demucs_version = model_data.demucs_version
self.demucs_source_list = model_data.demucs_source_list
self.demucs_source_map = model_data.demucs_source_map
self.is_demucs_combine_stems = model_data.is_demucs_combine_stems
self.demucs_stem_count = model_data.demucs_stem_count
self.pre_proc_model = model_data.pre_proc_model
if self.is_secondary_model and not process_data['is_ensemble_master']:
if not self.demucs_stem_count == 2 and model_data.primary_model_primary_stem == INST_STEM:
self.primary_stem = VOCAL_STEM
self.secondary_stem = INST_STEM
else:
self.primary_stem = model_data.primary_model_primary_stem
self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
if self.is_chunk_demucs:
self.chunks_demucs = model_data.chunks_demucs
self.margin_demucs = model_data.margin_demucs
else:
self.chunks_demucs = 0
self.margin_demucs = 44100
self.shifts = model_data.shifts
self.is_split_mode = model_data.is_split_mode if not self.demucs_version == DEMUCS_V4 else True
self.overlap = model_data.overlap
self.primary_model_name, self.primary_sources = self.cached_source_callback(DEMUCS_ARCH_TYPE, model_name=self.model_basename)
if model_data.process_method == VR_ARCH_TYPE:
self.primary_model_name, self.primary_sources = self.cached_source_callback(VR_ARCH_TYPE, model_name=self.model_basename)
self.mp = model_data.vr_model_param
self.high_end_process = model_data.is_high_end_process
self.is_tta = model_data.is_tta
self.is_post_process = model_data.is_post_process
self.is_gpu_conversion = model_data.is_gpu_conversion
self.batch_size = model_data.batch_size
self.crop_size = model_data.crop_size
self.window_size = model_data.window_size
self.input_high_end_h = None
self.aggressiveness = {'value': model_data.aggression_setting,
'split_bin': self.mp.param['band'][1]['crop_stop'],
'aggr_correction': self.mp.param.get('aggr_correction')}
def start_inference(self):
if self.is_secondary_model and not self.is_pre_proc_model:
self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename))
if self.is_pre_proc_model:
self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename))
def running_inference(self, is_no_write=False):
self.write_to_console(DONE, base_text='') if not is_no_write else None
self.set_progress_bar(0.05) if not is_no_write else None
if self.is_secondary_model and not self.is_pre_proc_model:
self.write_to_console(INFERENCE_STEP_1_SEC)
elif self.is_pre_proc_model:
self.write_to_console(INFERENCE_STEP_1_PRE)
else:
self.write_to_console(INFERENCE_STEP_1)
def load_cached_sources(self, is_4_stem_demucs=False):
if self.is_secondary_model and not self.is_pre_proc_model:
self.write_to_console(INFERENCE_STEP_2_SEC_CACHED_MODOEL(self.process_method, self.model_basename))
elif self.is_pre_proc_model:
self.write_to_console(INFERENCE_STEP_2_PRE_CACHED_MODOEL(self.process_method, self.model_basename))
else:
self.write_to_console(INFERENCE_STEP_2_PRIMARY_CACHED)
if not is_4_stem_demucs:
primary_stem, secondary_stem = gather_sources(self.primary_stem, self.secondary_stem, self.primary_sources)
return primary_stem, secondary_stem
def cache_source(self, secondary_sources):
model_occurrences = self.list_all_models.count(self.model_basename)
if not model_occurrences <= 1:
if self.process_method == MDX_ARCH_TYPE:
self.cached_model_source_holder(MDX_ARCH_TYPE, secondary_sources, self.model_basename)
if self.process_method == VR_ARCH_TYPE:
self.cached_model_source_holder(VR_ARCH_TYPE, secondary_sources, self.model_basename)
if self.process_method == DEMUCS_ARCH_TYPE:
self.cached_model_source_holder(DEMUCS_ARCH_TYPE, secondary_sources, self.model_basename)
# if isinstance(secondary_sources, np.ndarray):
# print('\n==================================\n', secondary_sources, f"\n\nMemory size of source for model {self.model_basename}: ", secondary_sources.size * secondary_sources.itemsize, 'BYTES', '\n==================================\n')
# if type(secondary_sources) is dict:
# print('\n==================================\n', secondary_sources, f"\n\nMemory size of source for model {self.model_basename}: ", sum(v.size * v.itemsize for v in secondary_sources.values()), 'BYTES', '\n==================================\n')
def write_audio(self, stem_path, stem_source, samplerate, secondary_model_source=None, model_scale=None):
if not self.is_secondary_model:
if self.is_secondary_model_activated:
if isinstance(secondary_model_source, np.ndarray):
secondary_model_scale = model_scale if model_scale else self.secondary_model_scale
stem_source = spec_utils.average_dual_sources(stem_source, secondary_model_source, secondary_model_scale)
sf.write(stem_path, stem_source, samplerate, subtype=self.wav_type_set)
save_format(stem_path, self.save_format, self.mp3_bit_set) if not self.is_ensemble_mode else None
self.write_to_console(DONE, base_text='')
self.set_progress_bar(0.95)
class SeperateMDX(SeperateAttributes):
def seperate(self):
samplerate = 44100
if self.primary_model_name == self.model_basename and self.primary_sources:
self.primary_source, self.secondary_source = self.load_cached_sources()
else:
self.start_inference()
if self.is_gpu_conversion >= 0:
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
run_type = ['CUDAExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
else:
self.device = torch.device('cpu')
run_type = ['CPUExecutionProvider']
self.onnx_model = ort.InferenceSession(self.model_path, providers=run_type)
self.running_inference()
mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT else False
mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks, self.margin, mdx_net_cut=mdx_net_cut)
source = self.demix_base(mix)
self.write_to_console(DONE, base_text='')
if self.is_secondary_model_activated:
if self.secondary_model:
self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
if not self.is_secondary_stem_only:
self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
if not isinstance(self.primary_source, np.ndarray):
self.primary_source = spec_utils.normalize(source[0], self.is_normalization).T
self.primary_source_map = {self.primary_stem: self.primary_source}
self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary)
if not self.is_primary_stem_only:
self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
if not isinstance(self.secondary_source, np.ndarray):
raw_mix = self.demix_base(raw_mix, is_match_mix=True)[0] if mdx_net_cut else raw_mix
self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source[0]*self.compensate, raw_mix, self.is_normalization)
if self.is_invert_spec:
self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source)
else:
self.secondary_source = (-self.secondary_source.T+raw_mix.T)
self.secondary_source_map = {self.secondary_stem: self.secondary_source}
self.write_audio(secondary_stem_path, self.secondary_source, samplerate, self.secondary_source_secondary)
torch.cuda.empty_cache()
secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
self.cache_source(secondary_sources)
if self.is_secondary_model:
return secondary_sources
def demix_base(self, mix, is_match_mix=False):
chunked_sources = []
for slice in mix:
self.progress_value += 1
self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if not is_match_mix else None
cmix = mix[slice]
sources = []
mix_waves = []
n_sample = cmix.shape[1]
trim = self.n_fft//2
gen_size = self.chunk_size-2*trim
pad = gen_size - n_sample%gen_size
mix_p = np.concatenate((np.zeros((2,trim)), cmix, np.zeros((2,pad)), np.zeros((2,trim))), 1)
i = 0
while i < n_sample + pad:
waves = np.array(mix_p[:, i:i+self.chunk_size])
mix_waves.append(waves)
i += gen_size
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu)
with torch.no_grad():
_ort = self.onnx_model if not is_match_mix else None
adjust = 1
spek = self.stft(mix_waves)*adjust
if not is_match_mix:
if self.is_denoise:
spec_pred = -_ort.run(None, {'input': -spek.cpu().numpy()})[0]*0.5+_ort.run(None, {'input': spek.cpu().numpy()})[0]*0.5
else:
spec_pred = _ort.run(None, {'input': spek.cpu().numpy()})[0]
else:
spec_pred = spek.cpu().numpy()
tar_waves = self.istft(torch.tensor(spec_pred))#.cpu()
tar_signal = tar_waves[:,:,trim:-trim].transpose(0,1).reshape(2, -1).numpy()[:, :-pad]
start = 0 if slice == 0 else self.margin
end = None if slice == list(mix.keys())[::-1][0] else -self.margin
if self.margin == 0:
end = None
sources.append(tar_signal[:,start:end]*(1/adjust))
chunked_sources.append(sources)
sources = np.concatenate(chunked_sources, axis=-1)
if not is_match_mix:
del self.onnx_model
return sources
def stft(self, x):
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
x = x.permute([0,3,1,2])
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t])
return x[:,:,:self.dim_f]
def istft(self, x, freq_pad=None):
freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
c = 2
x = x.reshape([-1,c,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
x = x.permute([0,2,3,1])
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1,c,self.chunk_size])
class SeperateDemucs(SeperateAttributes):
def seperate(self):
samplerate = 44100
source = None
model_scale = None
stem_source = None
stem_source_secondary = None
inst_mix = None
inst_raw_mix = None
raw_mix = None
inst_source = None
is_no_write = False
is_no_piano_guitar = False
if self.primary_model_name == self.model_basename and type(self.primary_sources) is dict and not self.pre_proc_model:
self.primary_source, self.secondary_source = self.load_cached_sources()
elif self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and not self.pre_proc_model:
source = self.primary_sources
self.load_cached_sources(is_4_stem_demucs=True)
else:
self.start_inference()
if self.is_gpu_conversion >= 0:
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device('cpu')
if self.demucs_version == DEMUCS_V1:
if str(self.model_path).endswith(".gz"):
self.model_path = gzip.open(self.model_path, "rb")
klass, args, kwargs, state = torch.load(self.model_path)
self.demucs = klass(*args, **kwargs)
self.demucs.to(self.device)
self.demucs.load_state_dict(state)
elif self.demucs_version == DEMUCS_V2:
self.demucs = auto_load_demucs_model_v2(self.demucs_source_list, self.model_path)
self.demucs.to(self.device)
self.demucs.load_state_dict(torch.load(self.model_path))
self.demucs.eval()
else:
self.demucs = HDemucs(sources=self.demucs_source_list)
self.demucs = _gm(name=os.path.splitext(os.path.basename(self.model_path))[0],
repo=Path(os.path.dirname(self.model_path)))
self.demucs = demucs_segments(self.segment, self.demucs)
self.demucs.to(self.device)
self.demucs.eval()
if self.pre_proc_model:
if self.primary_stem not in [VOCAL_STEM, INST_STEM]:
is_no_write = True
self.write_to_console(DONE, base_text='')
mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True)
inst_mix, inst_raw_mix, inst_samplerate = prepare_mix(mix_no_voc[INST_STEM], self.chunks_demucs, self.margin_demucs)
self.process_iteration()
self.running_inference(is_no_write=is_no_write)
inst_source = self.demix_demucs(inst_mix)
self.process_iteration()
self.running_inference(is_no_write=is_no_write) if not self.pre_proc_model else None
mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs)
if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model:
source = self.primary_sources
else:
source = self.demix_demucs(mix)
self.write_to_console(DONE, base_text='')
del self.demucs
if isinstance(inst_source, np.ndarray):
source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]])
inst_source[self.demucs_source_map[VOCAL_STEM]] = source_reshape
source = inst_source
if isinstance(source, np.ndarray):
if len(source) == 2:
self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
else:
self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER if len(source) == 6 else DEMUCS_4_SOURCE_MAPPER
if len(source) == 6 and self.process_data['is_ensemble_master'] or len(source) == 6 and self.is_secondary_model:
is_no_piano_guitar = True
six_stem_other_source = list(source)
six_stem_other_source = [i for n, i in enumerate(source) if n in [self.demucs_source_map[OTHER_STEM], self.demucs_source_map[GUITAR_STEM], self.demucs_source_map[PIANO_STEM]]]
other_source = np.zeros_like(six_stem_other_source[0])
for i in six_stem_other_source:
other_source += i
source_reshape = spec_utils.reshape_sources(source[self.demucs_source_map[OTHER_STEM]], other_source)
source[self.demucs_source_map[OTHER_STEM]] = source_reshape
if (self.demucs_stems == ALL_STEMS and not self.process_data['is_ensemble_master']) or self.is_4_stem_ensemble:
self.cache_source(source)
for stem_name, stem_value in self.demucs_source_map.items():
if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4:
if self.secondary_model_4_stem[stem_value]:
model_scale = self.secondary_model_4_stem_scale[stem_value]
stem_source_secondary = process_secondary_model(self.secondary_model_4_stem[stem_value], self.process_data, main_model_primary_stem_4_stem=stem_name, is_4_stem_demucs=True)
if isinstance(stem_source_secondary, np.ndarray):
stem_source_secondary = stem_source_secondary[1 if self.secondary_model_4_stem[stem_value].demucs_stem_count == 2 else stem_value]
stem_source_secondary = spec_utils.normalize(stem_source_secondary, self.is_normalization).T
elif type(stem_source_secondary) is dict:
stem_source_secondary = stem_source_secondary[stem_name]
stem_source_secondary = None if stem_value >= 4 else stem_source_secondary
self.write_to_console(f'{SAVING_STEM[0]}{stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None
stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem_name}).wav')
stem_source = spec_utils.normalize(source[stem_value], self.is_normalization).T
self.write_audio(stem_path, stem_source, samplerate, secondary_model_source=stem_source_secondary, model_scale=model_scale)
if self.is_secondary_model:
return source
else:
if self.is_secondary_model_activated:
if self.secondary_model:
self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
if not self.is_secondary_stem_only:
self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
if not isinstance(self.primary_source, np.ndarray):
self.primary_source = spec_utils.normalize(source[self.demucs_source_map[self.primary_stem]], self.is_normalization).T
self.primary_source_map = {self.primary_stem: self.primary_source}
self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary)
if not self.is_primary_stem_only:
def secondary_save(sec_stem_name, source, raw_mixture=None, is_inst_mixture=False):
secondary_source = self.secondary_source if not is_inst_mixture else None
self.write_to_console(f'{SAVING_STEM[0]}{sec_stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None
secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({sec_stem_name}).wav')
secondary_source_secondary = None
if not isinstance(secondary_source, np.ndarray):
if self.is_demucs_combine_stems:
source = list(source)
if is_inst_mixture:
source = [i for n, i in enumerate(source) if not n in [self.demucs_source_map[self.primary_stem], self.demucs_source_map[VOCAL_STEM]]]
else:
source.pop(self.demucs_source_map[self.primary_stem])
source = source[:len(source) - 2] if is_no_piano_guitar else source
secondary_source = np.zeros_like(source[0])
for i in source:
secondary_source += i
secondary_source = spec_utils.normalize(secondary_source, self.is_normalization).T
else:
if not isinstance(raw_mixture, np.ndarray):
raw_mixture = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs, is_missing_mix=True)
secondary_source, raw_mixture = spec_utils.normalize_two_stem(source[self.demucs_source_map[self.primary_stem]], raw_mixture, self.is_normalization)
if self.is_invert_spec:
secondary_source = spec_utils.invert_stem(raw_mixture, secondary_source)
else:
raw_mixture = spec_utils.reshape_sources(secondary_source, raw_mixture)
secondary_source = (-secondary_source.T+raw_mixture.T)
if not is_inst_mixture:
self.secondary_source = secondary_source
secondary_source_secondary = self.secondary_source_secondary
self.secondary_source_map = {self.secondary_stem: self.secondary_source}
self.write_audio(secondary_stem_path, secondary_source, samplerate, secondary_source_secondary)
secondary_save(self.secondary_stem, source, raw_mixture=raw_mix)
if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble:
secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_raw_mix, is_inst_mixture=True)
torch.cuda.empty_cache()
secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
self.cache_source(secondary_sources)
if self.is_secondary_model:
return secondary_sources
def demix_demucs(self, mix):
processed = {}
set_progress_bar = None if self.is_chunk_demucs else self.set_progress_bar
for nmix in mix:
self.progress_value += 1
self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if self.is_chunk_demucs else None
cmix = mix[nmix]
cmix = torch.tensor(cmix, dtype=torch.float32)
ref = cmix.mean(0)
cmix = (cmix - ref.mean()) / ref.std()
mix_infer = cmix
with torch.no_grad():
if self.demucs_version == DEMUCS_V1:
sources = apply_model_v1(self.demucs,
mix_infer.to(self.device),
self.shifts,
self.is_split_mode,
set_progress_bar=set_progress_bar)
elif self.demucs_version == DEMUCS_V2:
sources = apply_model_v2(self.demucs,
mix_infer.to(self.device),
self.shifts,
self.is_split_mode,
self.overlap,
set_progress_bar=set_progress_bar)
else:
sources = apply_model(self.demucs,
mix_infer[None],
self.shifts,
self.is_split_mode,
self.overlap,
static_shifts=1 if self.shifts == 0 else self.shifts,
set_progress_bar=set_progress_bar,
device=self.device)[0]
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
sources[[0,1]] = sources[[1,0]]
start = 0 if nmix == 0 else self.margin_demucs
end = None if nmix == list(mix.keys())[::-1][0] else -self.margin_demucs
if self.margin_demucs == 0:
end = None
processed[nmix] = sources[:,:,start:end].copy()
sources = list(processed.values())
sources = np.concatenate(sources, axis=-1)
return sources
class SeperateVR(SeperateAttributes):
def seperate(self):
if self.primary_model_name == self.model_basename and self.primary_sources:
self.primary_source, self.secondary_source = self.load_cached_sources()
else:
self.start_inference()
if self.is_gpu_conversion >= 0:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device('cpu')
nn_arch_sizes = [
31191, # default
33966, 56817, 218409, 123821, 123812, 129605, 537238, 537227]
vr_5_1_models = [56817, 218409]
model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
nn_architecture = min(nn_arch_sizes, key=lambda x:abs(x-model_size))
#print('ARC SIZE: ', nn_architecture)
if nn_architecture in vr_5_1_models:
model = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_architecture)
inference = self.inference_vr_new
else:
model = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_architecture)
inference = self.inference_vr
model.load_state_dict(torch.load(self.model_path, map_location=device))
model.to(device)
self.running_inference()
y_spec, v_spec = inference(self.loading_mix(), device, model, self.aggressiveness)
self.write_to_console(DONE, base_text='')
del model
if self.is_secondary_model_activated:
if self.secondary_model:
self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
if not self.is_secondary_stem_only:
self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
if not isinstance(self.primary_source, np.ndarray):
self.primary_source = spec_utils.normalize(self.spec_to_wav(y_spec), self.is_normalization).T
if not self.model_samplerate == 44100:
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
self.primary_source_map = {self.primary_stem: self.primary_source}
self.write_audio(primary_stem_path, self.primary_source, 44100, self.secondary_source_primary)
if not self.is_primary_stem_only:
self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
if not isinstance(self.secondary_source, np.ndarray):
self.secondary_source = self.spec_to_wav(v_spec)
self.secondary_source = spec_utils.normalize(self.spec_to_wav(v_spec), self.is_normalization).T
if not self.model_samplerate == 44100:
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
self.secondary_source_map = {self.secondary_stem: self.secondary_source}
self.write_audio(secondary_stem_path, self.secondary_source, 44100, self.secondary_source_secondary)
torch.cuda.empty_cache()
secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
self.cache_source(secondary_sources)
if self.is_secondary_model:
return secondary_sources
def loading_mix(self):
X_wave, X_spec_s = {}, {}
bands_n = len(self.mp.param['band'])
for d in range(bands_n, 0, -1):
bp = self.mp.param['band'][d]
if d == bands_n: # high-end band
X_wave[d], _ = librosa.load(
self.audio_file, bp['sr'], False, dtype=np.float32, res_type=bp['res_type'])
if X_wave[d].ndim == 1:
X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
else: # lower bands
X_wave[d] = librosa.resample(X_wave[d+1], self.mp.param['band'][d+1]['sr'], bp['sr'], res_type=bp['res_type'])
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], self.mp.param['mid_side'],
self.mp.param['mid_side_b2'], self.mp.param['reverse'])
if d == bands_n and self.high_end_process != 'none':
self.input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (self.mp.param['pre_filter_stop'] - self.mp.param['pre_filter_start'])
self.input_high_end = X_spec_s[d][:, bp['n_fft']//2-self.input_high_end_h:bp['n_fft']//2, :]
X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp)
del X_wave, X_spec_s
return X_spec
def inference_vr(self, X_spec, device, model, aggressiveness):
def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness):
model.eval()
total_iterations = sum([n_window]) if not self.is_tta else sum([n_window])*2
with torch.no_grad():
preds = []
for i in range(n_window):
self.progress_value +=1
self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value)
start = i * roi_size
X_mag_window = X_mag_pad[None, :, :, start:start + self.window_size]
X_mag_window = torch.from_numpy(X_mag_window).to(device)
pred = model.predict(X_mag_window, aggressiveness)
pred = pred.detach().cpu().numpy()
preds.append(pred[0])
pred = np.concatenate(preds, axis=2)
return pred
X_mag, X_phase = spec_utils.preprocess(X_spec)
coef = X_mag.max()
X_mag_pre = X_mag / coef
n_frame = X_mag_pre.shape[2]
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, model.offset)
n_window = int(np.ceil(n_frame / roi_size))
X_mag_pad = np.pad(
X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
pred = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness)
pred = pred[:, :, :n_frame]
if self.is_tta:
pad_l += roi_size // 2
pad_r += roi_size // 2
n_window += 1
X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
pred_tta = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness)
pred_tta = pred_tta[:, :, roi_size // 2:]
pred_tta = pred_tta[:, :, :n_frame]
pred, X_mag, X_phase = (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase)
else:
pred, X_mag, X_phase = pred * coef, X_mag, np.exp(1.j * X_phase)
if self.is_post_process:
pred_inv = np.clip(X_mag - pred, 0, np.inf)
pred = spec_utils.mask_silence(pred, pred_inv)
y_spec = pred * X_phase
v_spec = X_spec - y_spec
return y_spec, v_spec
def inference_vr_new(self, X_spec, device, model, aggressiveness):
def _execute(X_mag_pad, roi_size):
X_dataset = []
patches = (X_mag_pad.shape[2] - 2 * model.offset) // roi_size
total_iterations = patches//self.batch_size if not self.is_tta else (patches//self.batch_size)*2
for i in range(patches):
start = i * roi_size
X_mag_crop = X_mag_pad[:, :, start:start + self.crop_size]
X_dataset.append(X_mag_crop)
X_dataset = np.asarray(X_dataset)
model.eval()
with torch.no_grad():
mask = []
# To reduce the overhead, dataloader is not used.
for i in range(0, patches, self.batch_size):
self.progress_value += 1
if self.progress_value >= total_iterations:
self.progress_value = total_iterations
self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value)
X_batch = X_dataset[i: i + self.batch_size]
X_batch = torch.from_numpy(X_batch).to(device)
pred = model.predict_mask(X_batch)
pred = pred.detach().cpu().numpy()
pred = np.concatenate(pred, axis=2)
mask.append(pred)
mask = np.concatenate(mask, axis=2)
return mask
def postprocess(mask, X_mag, X_phase, aggressiveness):
if self.primary_stem == VOCAL_STEM:
mask = (1.0 - spec_utils.adjust_aggr(mask, True, aggressiveness))
else:
mask = spec_utils.adjust_aggr(mask, False, aggressiveness)
if self.is_post_process:
mask = spec_utils.merge_artifacts(mask)
y_spec = mask * X_mag * np.exp(1.j * X_phase)
v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
return y_spec, v_spec
X_mag, X_phase = spec_utils.preprocess(X_spec)
n_frame = X_mag.shape[2]
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.crop_size, model.offset)
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
X_mag_pad /= X_mag_pad.max()
mask = _execute(X_mag_pad, roi_size)
if self.is_tta:
pad_l += roi_size // 2
pad_r += roi_size // 2
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
X_mag_pad /= X_mag_pad.max()
mask_tta = _execute(X_mag_pad, roi_size)
mask_tta = mask_tta[:, :, roi_size // 2:]
mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
else:
mask = mask[:, :, :n_frame]
y_spec, v_spec = postprocess(mask, X_mag, X_phase, aggressiveness)
return y_spec, v_spec
def spec_to_wav(self, spec):
if self.high_end_process.startswith('mirroring'):
input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp)
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_)
else:
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp)
return wav
def process_secondary_model(secondary_model: ModelData, process_data, main_model_primary_stem_4_stem=None, is_4_stem_demucs=False, main_process_method=None, is_pre_proc_model=False):
if not is_pre_proc_model:
process_iteration = process_data['process_iteration']
process_iteration()
if secondary_model.process_method == VR_ARCH_TYPE:
seperator = SeperateVR(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
if secondary_model.process_method == MDX_ARCH_TYPE:
seperator = SeperateMDX(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
if secondary_model.process_method == DEMUCS_ARCH_TYPE:
seperator = SeperateDemucs(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
secondary_sources = seperator.seperate()
if type(secondary_sources) is dict and not is_4_stem_demucs and not is_pre_proc_model:
return gather_sources(secondary_model.primary_model_primary_stem, STEM_PAIR_MAPPER[secondary_model.primary_model_primary_stem], secondary_sources)
else:
return secondary_sources
def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: dict):
source_primary = False
source_secondary = False
for key, value in secondary_sources.items():
if key in primary_stem_name:
source_primary = value
if key in secondary_stem_name:
source_secondary = value
return source_primary, source_secondary
def prepare_mix(mix, chunk_set, margin_set, mdx_net_cut=False, is_missing_mix=False):
samplerate = 44100
if not isinstance(mix, np.ndarray):
mix, samplerate = librosa.load(mix, mono=False, sr=44100)
else:
mix = mix.T
if mix.ndim == 1:
mix = np.asfortranarray([mix,mix])
def get_segmented_mix(chunk_set=chunk_set):
segmented_mix = {}
samples = mix.shape[-1]
margin = margin_set
chunk_size = chunk_set*44100
assert not margin == 0, 'margin cannot be zero!'
if margin > chunk_size:
margin = chunk_size
if chunk_set == 0 or samples < chunk_size:
chunk_size = samples
counter = -1
for skip in range(0, samples, chunk_size):
counter+=1
s_margin = 0 if counter == 0 else margin
end = min(skip+chunk_size+margin, samples)
start = skip-s_margin
segmented_mix[skip] = mix[:,start:end].copy()
if end == samples:
break
return segmented_mix
if is_missing_mix:
return mix
else:
segmented_mix = get_segmented_mix()
raw_mix = get_segmented_mix(chunk_set=0) if mdx_net_cut else mix
return segmented_mix, raw_mix, samplerate
def save_format(audio_path, save_format, mp3_bit_set):
if not save_format == WAV:
musfile = pydub.AudioSegment.from_wav(audio_path)
if save_format == FLAC:
audio_path_flac = audio_path.replace(".wav", ".flac")
musfile.export(audio_path_flac, format="flac")
if save_format == MP3:
audio_path_mp3 = audio_path.replace(".wav", ".mp3")
musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set)
try:
os.remove(audio_path)
except Exception as e:
print(e)