From 508c59d41b336c1b2e291256e3bc465d3a28c4f3 Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Wed, 2 Nov 2022 03:25:50 -0500 Subject: [PATCH] Delete demucs directory --- demucs/__init__.py | 5 - demucs/__main__.py | 272 ---------------- demucs/apply.py | 285 ---------------- demucs/demucs.py | 459 -------------------------- demucs/hdemucs.py | 761 ------------------------------------------- demucs/model.py | 218 ------------- demucs/model_v2.py | 202 ------------ demucs/pretrained.py | 167 ---------- demucs/repo.py | 154 --------- demucs/spec.py | 41 --- demucs/states.py | 148 --------- demucs/tasnet.py | 447 ------------------------- demucs/tasnet_v2.py | 452 ------------------------- demucs/utils-.py | 187 ----------- demucs/utils.py | 533 ------------------------------ 15 files changed, 4331 deletions(-) delete mode 100644 demucs/__init__.py delete mode 100644 demucs/__main__.py delete mode 100644 demucs/apply.py delete mode 100644 demucs/demucs.py delete mode 100644 demucs/hdemucs.py delete mode 100644 demucs/model.py delete mode 100644 demucs/model_v2.py delete mode 100644 demucs/pretrained.py delete mode 100644 demucs/repo.py delete mode 100644 demucs/spec.py delete mode 100644 demucs/states.py delete mode 100644 demucs/tasnet.py delete mode 100644 demucs/tasnet_v2.py delete mode 100644 demucs/utils-.py delete mode 100644 demucs/utils.py diff --git a/demucs/__init__.py b/demucs/__init__.py deleted file mode 100644 index 5656d59..0000000 --- a/demucs/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. diff --git a/demucs/__main__.py b/demucs/__main__.py deleted file mode 100644 index 5de878f..0000000 --- a/demucs/__main__.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -import sys -import time -from dataclasses import dataclass, field -from fractions import Fraction - -import torch as th -from torch import distributed, nn -from torch.nn.parallel.distributed import DistributedDataParallel - -from .augment import FlipChannels, FlipSign, Remix, Shift -from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks -from .model import Demucs -from .parser import get_name, get_parser -from .raw import Rawset -from .tasnet import ConvTasNet -from .test import evaluate -from .train import train_model, validate_model -from .utils import human_seconds, load_model, save_model, sizeof_fmt - - -@dataclass -class SavedState: - metrics: list = field(default_factory=list) - last_state: dict = None - best_state: dict = None - optimizer: dict = None - - -def main(): - parser = get_parser() - args = parser.parse_args() - name = get_name(parser, args) - print(f"Experiment {name}") - - if args.musdb is None and args.rank == 0: - print( - "You must provide the path to the MusDB dataset with the --musdb flag. " - "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", - file=sys.stderr) - sys.exit(1) - - eval_folder = args.evals / name - eval_folder.mkdir(exist_ok=True, parents=True) - args.logs.mkdir(exist_ok=True) - metrics_path = args.logs / f"{name}.json" - eval_folder.mkdir(exist_ok=True, parents=True) - args.checkpoints.mkdir(exist_ok=True, parents=True) - args.models.mkdir(exist_ok=True, parents=True) - - if args.device is None: - device = "cpu" - if th.cuda.is_available(): - device = "cuda" - else: - device = args.device - - th.manual_seed(args.seed) - # Prevents too many threads to be started when running `museval` as it can be quite - # inefficient on NUMA architectures. - os.environ["OMP_NUM_THREADS"] = "1" - - if args.world_size > 1: - if device != "cuda" and args.rank == 0: - print("Error: distributed training is only available with cuda device", file=sys.stderr) - sys.exit(1) - th.cuda.set_device(args.rank % th.cuda.device_count()) - distributed.init_process_group(backend="nccl", - init_method="tcp://" + args.master, - rank=args.rank, - world_size=args.world_size) - - checkpoint = args.checkpoints / f"{name}.th" - checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" - if args.restart and checkpoint.exists(): - checkpoint.unlink() - - if args.test: - args.epochs = 1 - args.repeat = 0 - model = load_model(args.models / args.test) - elif args.tasnet: - model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X) - else: - model = Demucs( - audio_channels=args.audio_channels, - channels=args.channels, - context=args.context, - depth=args.depth, - glu=args.glu, - growth=args.growth, - kernel_size=args.kernel_size, - lstm_layers=args.lstm_layers, - rescale=args.rescale, - rewrite=args.rewrite, - sources=4, - stride=args.conv_stride, - upsample=args.upsample, - samplerate=args.samplerate - ) - model.to(device) - if args.show: - print(model) - size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) - print(f"Model size {size}") - return - - optimizer = th.optim.Adam(model.parameters(), lr=args.lr) - - try: - saved = th.load(checkpoint, map_location='cpu') - except IOError: - saved = SavedState() - else: - model.load_state_dict(saved.last_state) - optimizer.load_state_dict(saved.optimizer) - - if args.save_model: - if args.rank == 0: - model.to("cpu") - model.load_state_dict(saved.best_state) - save_model(model, args.models / f"{name}.th") - return - - if args.rank == 0: - done = args.logs / f"{name}.done" - if done.exists(): - done.unlink() - - if args.augment: - augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), - Remix(group_size=args.remix_group_size)).to(device) - else: - augment = Shift(args.data_stride) - - if args.mse: - criterion = nn.MSELoss() - else: - criterion = nn.L1Loss() - - # Setting number of samples so that all convolution windows are full. - # Prevents hard to debug mistake with the prediction being shifted compared - # to the input mixture. - samples = model.valid_length(args.samples) - print(f"Number of training samples adjusted to {samples}") - - if args.raw: - train_set = Rawset(args.raw / "train", - samples=samples + args.data_stride, - channels=args.audio_channels, - streams=[0, 1, 2, 3, 4], - stride=args.data_stride) - - valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) - else: - if not args.metadata.is_file() and args.rank == 0: - build_musdb_metadata(args.metadata, args.musdb, args.workers) - if args.world_size > 1: - distributed.barrier() - metadata = json.load(open(args.metadata)) - duration = Fraction(samples + args.data_stride, args.samplerate) - stride = Fraction(args.data_stride, args.samplerate) - train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), - metadata, - duration=duration, - stride=stride, - samplerate=args.samplerate, - channels=args.audio_channels) - valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), - metadata, - samplerate=args.samplerate, - channels=args.audio_channels) - - best_loss = float("inf") - for epoch, metrics in enumerate(saved.metrics): - print(f"Epoch {epoch:03d}: " - f"train={metrics['train']:.8f} " - f"valid={metrics['valid']:.8f} " - f"best={metrics['best']:.4f} " - f"duration={human_seconds(metrics['duration'])}") - best_loss = metrics['best'] - - if args.world_size > 1: - dmodel = DistributedDataParallel(model, - device_ids=[th.cuda.current_device()], - output_device=th.cuda.current_device()) - else: - dmodel = model - - for epoch in range(len(saved.metrics), args.epochs): - begin = time.time() - model.train() - train_loss = train_model(epoch, - train_set, - dmodel, - criterion, - optimizer, - augment, - batch_size=args.batch_size, - device=device, - repeat=args.repeat, - seed=args.seed, - workers=args.workers, - world_size=args.world_size) - model.eval() - valid_loss = validate_model(epoch, - valid_set, - model, - criterion, - device=device, - rank=args.rank, - split=args.split_valid, - world_size=args.world_size) - - duration = time.time() - begin - if valid_loss < best_loss: - best_loss = valid_loss - saved.best_state = { - key: value.to("cpu").clone() - for key, value in model.state_dict().items() - } - saved.metrics.append({ - "train": train_loss, - "valid": valid_loss, - "best": best_loss, - "duration": duration - }) - if args.rank == 0: - json.dump(saved.metrics, open(metrics_path, "w")) - - saved.last_state = model.state_dict() - saved.optimizer = optimizer.state_dict() - if args.rank == 0 and not args.test: - th.save(saved, checkpoint_tmp) - checkpoint_tmp.rename(checkpoint) - - print(f"Epoch {epoch:03d}: " - f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} " - f"duration={human_seconds(duration)}") - - del dmodel - model.load_state_dict(saved.best_state) - if args.eval_cpu: - device = "cpu" - model.to(device) - model.eval() - evaluate(model, - args.musdb, - eval_folder, - rank=args.rank, - world_size=args.world_size, - device=device, - save=args.save, - split=args.split_valid, - shifts=args.shifts, - workers=args.eval_workers) - model.to("cpu") - save_model(model, args.models / f"{name}.th") - if args.rank == 0: - print("done") - done.write_text("done") - - -if __name__ == "__main__": - main() diff --git a/demucs/apply.py b/demucs/apply.py deleted file mode 100644 index 7920ad5..0000000 --- a/demucs/apply.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -Code to apply a model to a mix. It will handle chunking with overlaps and -inteprolation between chunks, as well as the "shift trick". -""" -from concurrent.futures import ThreadPoolExecutor -import random -import typing as tp -from multiprocessing import Process,Queue,Pipe - -import torch as th -from torch import nn -from torch.nn import functional as F -import tqdm -import tkinter as tk - -from .demucs import Demucs -from .hdemucs import HDemucs -from .utils import center_trim, DummyPoolExecutor - -Model = tp.Union[Demucs, HDemucs] - -progress_bar_num = 0 - -class BagOfModels(nn.Module): - def __init__(self, models: tp.List[Model], - weights: tp.Optional[tp.List[tp.List[float]]] = None, - segment: tp.Optional[float] = None): - """ - Represents a bag of models with specific weights. - You should call `apply_model` rather than calling directly the forward here for - optimal performance. - - Args: - models (list[nn.Module]): list of Demucs/HDemucs models. - weights (list[list[float]]): list of weights. If None, assumed to - be all ones, otherwise it should be a list of N list (N number of models), - each containing S floats (S number of sources). - segment (None or float): overrides the `segment` attribute of each model - (this is performed inplace, be careful if you reuse the models passed). - """ - super().__init__() - assert len(models) > 0 - first = models[0] - for other in models: - assert other.sources == first.sources - assert other.samplerate == first.samplerate - assert other.audio_channels == first.audio_channels - if segment is not None: - other.segment = segment - - self.audio_channels = first.audio_channels - self.samplerate = first.samplerate - self.sources = first.sources - self.models = nn.ModuleList(models) - - if weights is None: - weights = [[1. for _ in first.sources] for _ in models] - else: - assert len(weights) == len(models) - for weight in weights: - assert len(weight) == len(first.sources) - self.weights = weights - - def forward(self, x): - raise NotImplementedError("Call `apply_model` on this.") - - -class TensorChunk: - def __init__(self, tensor, offset=0, length=None): - total_length = tensor.shape[-1] - assert offset >= 0 - assert offset < total_length - - if length is None: - length = total_length - offset - else: - length = min(total_length - offset, length) - - self.tensor = tensor - self.offset = offset - self.length = length - self.device = tensor.device - - @property - def shape(self): - shape = list(self.tensor.shape) - shape[-1] = self.length - return shape - - def padded(self, target_length): - delta = target_length - self.length - total_length = self.tensor.shape[-1] - assert delta >= 0 - - start = self.offset - delta // 2 - end = start + target_length - - correct_start = max(0, start) - correct_end = min(total_length, end) - - pad_left = correct_start - start - pad_right = end - correct_end - - out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) - assert out.shape[-1] == target_length - return out - -def tensor_chunk(tensor_or_chunk): - if isinstance(tensor_or_chunk, TensorChunk): - return tensor_or_chunk - else: - assert isinstance(tensor_or_chunk, th.Tensor) - return TensorChunk(tensor_or_chunk) - -def apply_model(model, mix, gui_progress_bar: tk.Variable, widget_text: tk.Text, update_prog, total_files, file_num, inference_type, shifts=1, split=True, - overlap=0.25, transition_power=1., progress=True, device=None, - num_workers=0, pool=None, segmen=False): - """ - Apply model to a given mixture. - - Args: - shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec - and apply the oppositve shift to the output. This is repeated `shifts` time and - all predictions are averaged. This effectively makes the model time equivariant - and improves SDR by up to 0.2 points. - split (bool): if True, the input will be broken down in 8 seconds extracts - and predictions will be performed individually on each and concatenated. - Useful for model with large memory footprint like Tasnet. - progress (bool): if True, show a progress bar (requires split=True) - device (torch.device, str, or None): if provided, device on which to - execute the computation, otherwise `mix.device` is assumed. - When `device` is different from `mix.device`, only local computations will - be on `device`, while the entire tracks will be stored on `mix.device`. - """ - - base_text = 'File {file_num}/{total_files} '.format(file_num=file_num, - total_files=total_files) - - global fut_length - - if device is None: - device = mix.device - else: - device = th.device(device) - if pool is None: - if num_workers > 0 and device.type == 'cpu': - pool = ThreadPoolExecutor(num_workers) - else: - pool = DummyPoolExecutor() - - kwargs = { - 'gui_progress_bar': gui_progress_bar, - 'widget_text': widget_text, - 'update_prog': update_prog, - 'segmen': segmen, - 'shifts': shifts, - 'split': split, - 'overlap': overlap, - 'transition_power': transition_power, - 'progress': progress, - 'device': device, - 'pool': pool, - 'total_files': total_files, - 'file_num': file_num, - 'inference_type': inference_type - } - - if isinstance(model, BagOfModels): - # Special treatment for bag of model. - # We explicitely apply multiple times `apply_model` so that the random shifts - # are different for each model. - global bag_num - global current_model - global progress_bar - global prog_bar - #global percent_prog_del - - #percent_prog_del = gui_progress_bar.get() - - progress_bar = 0 - prog_bar = 0 - estimates = 0 - totals = [0] * len(model.sources) - bag_num = len(model.models) - fut_length = 0 - current_model = 0 #(bag_num + 1) - for sub_model, weight in zip(model.models, model.weights): - original_model_device = next(iter(sub_model.parameters())).device - sub_model.to(device) - fut_length += fut_length - current_model += 1 - out = apply_model(sub_model, mix, **kwargs) - sub_model.to(original_model_device) - for k, inst_weight in enumerate(weight): - out[:, k, :, :] *= inst_weight - totals[k] += inst_weight - estimates += out - del out - - for k in range(estimates.shape[1]): - estimates[:, k, :, :] /= totals[k] - return estimates - - model.to(device) - assert transition_power >= 1, "transition_power < 1 leads to weird behavior." - batch, channels, length = mix.shape - - if split: - kwargs['split'] = False - out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) - sum_weight = th.zeros(length, device=mix.device) - segment = int(model.samplerate * model.segment) - stride = int((1 - overlap) * segment) - offsets = range(0, length, stride) - scale = stride / model.samplerate - # We start from a triangle shaped weight, with maximal weight in the middle - # of the segment. Then we normalize and take to the power `transition_power`. - # Large values of transition power will lead to sharper transitions. - weight = th.cat([th.arange(1, segment // 2 + 1, device=device), - th.arange(segment - segment // 2, 0, -1, device=device)]) - assert len(weight) == segment - # If the overlap < 50%, this will translate to linear transition when - # transition_power is 1. - weight = (weight / weight.max())**transition_power - futures = [] - for offset in offsets: - chunk = TensorChunk(mix, offset, segment) - future = pool.submit(apply_model, model, chunk, **kwargs) - futures.append((future, offset)) - offset += segment - for future, offset in futures: - if segmen: - fut_length = len(futures) - full_fut_length = (fut_length * bag_num) - send_back = full_fut_length * 2 - progress_bar += 100 - prog_bar += 1 - full_step = (progress_bar / full_fut_length) - percent_prog = f"{base_text}Demucs Inference Progress: {prog_bar}/{full_fut_length} | {round(full_step)}%" - if inference_type == 'demucs_only': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.1 + (1.7/send_back * prog_bar))) - elif inference_type == 'inference_mdx': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.35 + (1.05/send_back * prog_bar))) - elif inference_type == 'inference_vr': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.6 + (0.7/send_back * prog_bar))) - widget_text.percentage(percent_prog) - #gui_progress_bar.set(step) - chunk_out = future.result() - chunk_length = chunk_out.shape[-1] - out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device) - sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device) - assert sum_weight.min() > 0 - out /= sum_weight - return out - elif shifts: - kwargs['shifts'] = 0 - max_shift = int(0.5 * model.samplerate) - mix = tensor_chunk(mix) - padded_mix = mix.padded(length + 2 * max_shift) - out = 0 - for _ in range(shifts): - offset = random.randint(0, max_shift) - shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) - shifted_out = apply_model(model, shifted, **kwargs) - out += shifted_out[..., max_shift - offset:] - out /= shifts - return out - else: - if hasattr(model, 'valid_length'): - valid_length = model.valid_length(length) - else: - valid_length = length - mix = tensor_chunk(mix) - padded_mix = mix.padded(valid_length).to(device) - with th.no_grad(): - out = model(padded_mix) - return center_trim(out, length) diff --git a/demucs/demucs.py b/demucs/demucs.py deleted file mode 100644 index d2c08e7..0000000 --- a/demucs/demucs.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math -import typing as tp - -import julius -import torch -from torch import nn -from torch.nn import functional as F - -from .states import capture_init -from .utils import center_trim, unfold - - -class BLSTM(nn.Module): - """ - BiLSTM with same hidden units as input dim. - If `max_steps` is not None, input will be splitting in overlapping - chunks and the LSTM applied separately on each chunk. - """ - def __init__(self, dim, layers=1, max_steps=None, skip=False): - super().__init__() - assert max_steps is None or max_steps % 4 == 0 - self.max_steps = max_steps - self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) - self.linear = nn.Linear(2 * dim, dim) - self.skip = skip - - def forward(self, x): - B, C, T = x.shape - y = x - framed = False - if self.max_steps is not None and T > self.max_steps: - width = self.max_steps - stride = width // 2 - frames = unfold(x, width, stride) - nframes = frames.shape[2] - framed = True - x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) - - x = x.permute(2, 0, 1) - - x = self.lstm(x)[0] - x = self.linear(x) - x = x.permute(1, 2, 0) - if framed: - out = [] - frames = x.reshape(B, -1, C, width) - limit = stride // 2 - for k in range(nframes): - if k == 0: - out.append(frames[:, k, :, :-limit]) - elif k == nframes - 1: - out.append(frames[:, k, :, limit:]) - else: - out.append(frames[:, k, :, limit:-limit]) - out = torch.cat(out, -1) - out = out[..., :T] - x = out - if self.skip: - x = x + y - return x - - -def rescale_conv(conv, reference): - """Rescale initial weight scale. It is unclear why it helps but it certainly does. - """ - std = conv.weight.std().detach() - scale = (std / reference)**0.5 - conv.weight.data /= scale - if conv.bias is not None: - conv.bias.data /= scale - - -def rescale_module(module, reference): - for sub in module.modules(): - if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): - rescale_conv(sub, reference) - - -class LayerScale(nn.Module): - """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). - This rescales diagonaly residual outputs close to 0 initially, then learnt. - """ - def __init__(self, channels: int, init: float = 0): - super().__init__() - self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) - self.scale.data[:] = init - - def forward(self, x): - return self.scale[:, None] * x - - -class DConv(nn.Module): - """ - New residual branches in each encoder layer. - This alternates dilated convolutions, potentially with LSTMs and attention. - Also before entering each residual branch, dimension is projected on a smaller subspace, - e.g. of dim `channels // compress`. - """ - def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, - norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, - kernel=3, dilate=True): - """ - Args: - channels: input/output channels for residual branch. - compress: amount of channel compression inside the branch. - depth: number of layers in the residual branch. Each layer has its own - projection, and potentially LSTM and attention. - init: initial scale for LayerNorm. - norm: use GroupNorm. - attn: use LocalAttention. - heads: number of heads for the LocalAttention. - ndecay: number of decay controls in the LocalAttention. - lstm: use LSTM. - gelu: Use GELU activation. - kernel: kernel size for the (dilated) convolutions. - dilate: if true, use dilation, increasing with the depth. - """ - - super().__init__() - assert kernel % 2 == 1 - self.channels = channels - self.compress = compress - self.depth = abs(depth) - dilate = depth > 0 - - norm_fn: tp.Callable[[int], nn.Module] - norm_fn = lambda d: nn.Identity() # noqa - if norm: - norm_fn = lambda d: nn.GroupNorm(1, d) # noqa - - hidden = int(channels / compress) - - act: tp.Type[nn.Module] - if gelu: - act = nn.GELU - else: - act = nn.ReLU - - self.layers = nn.ModuleList([]) - for d in range(self.depth): - dilation = 2 ** d if dilate else 1 - padding = dilation * (kernel // 2) - mods = [ - nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), - norm_fn(hidden), act(), - nn.Conv1d(hidden, 2 * channels, 1), - norm_fn(2 * channels), nn.GLU(1), - LayerScale(channels, init), - ] - if attn: - mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) - if lstm: - mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) - layer = nn.Sequential(*mods) - self.layers.append(layer) - - def forward(self, x): - for layer in self.layers: - x = x + layer(x) - return x - - -class LocalState(nn.Module): - """Local state allows to have attention based only on data (no positional embedding), - but while setting a constraint on the time window (e.g. decaying penalty term). - - Also a failed experiments with trying to provide some frequency based attention. - """ - def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): - super().__init__() - assert channels % heads == 0, (channels, heads) - self.heads = heads - self.nfreqs = nfreqs - self.ndecay = ndecay - self.content = nn.Conv1d(channels, channels, 1) - self.query = nn.Conv1d(channels, channels, 1) - self.key = nn.Conv1d(channels, channels, 1) - if nfreqs: - self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) - if ndecay: - self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) - # Initialize decay close to zero (there is a sigmoid), for maximum initial window. - self.query_decay.weight.data *= 0.01 - assert self.query_decay.bias is not None # stupid type checker - self.query_decay.bias.data[:] = -2 - self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) - - def forward(self, x): - B, C, T = x.shape - heads = self.heads - indexes = torch.arange(T, device=x.device, dtype=x.dtype) - # left index are keys, right index are queries - delta = indexes[:, None] - indexes[None, :] - - queries = self.query(x).view(B, heads, -1, T) - keys = self.key(x).view(B, heads, -1, T) - # t are keys, s are queries - dots = torch.einsum("bhct,bhcs->bhts", keys, queries) - dots /= keys.shape[2]**0.5 - if self.nfreqs: - periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) - freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) - freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 - dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) - if self.ndecay: - decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) - decay_q = self.query_decay(x).view(B, heads, -1, T) - decay_q = torch.sigmoid(decay_q) / 2 - decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 - dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) - - # Kill self reference. - dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) - weights = torch.softmax(dots, dim=2) - - content = self.content(x).view(B, heads, -1, T) - result = torch.einsum("bhts,bhct->bhcs", weights, content) - if self.nfreqs: - time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) - result = torch.cat([result, time_sig], 2) - result = result.reshape(B, -1, T) - return x + self.proj(result) - - -class Demucs(nn.Module): - @capture_init - def __init__(self, - sources, - # Channels - audio_channels=2, - channels=64, - growth=2., - # Main structure - depth=6, - rewrite=True, - lstm_layers=0, - # Convolutions - kernel_size=8, - stride=4, - context=1, - # Activations - gelu=True, - glu=True, - # Normalization - norm_starts=4, - norm_groups=4, - # DConv residual branch - dconv_mode=1, - dconv_depth=2, - dconv_comp=4, - dconv_attn=4, - dconv_lstm=4, - dconv_init=1e-4, - # Pre/post processing - normalize=True, - resample=True, - # Weight init - rescale=0.1, - # Metadata - samplerate=44100, - segment=4 * 10): - """ - Args: - sources (list[str]): list of source names - audio_channels (int): stereo or mono - channels (int): first convolution channels - depth (int): number of encoder/decoder layers - growth (float): multiply (resp divide) number of channels by that - for each layer of the encoder (resp decoder) - depth (int): number of layers in the encoder and in the decoder. - rewrite (bool): add 1x1 convolution to each layer. - lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated - by default, as this is now replaced by the smaller and faster small LSTMs - in the DConv branches. - kernel_size (int): kernel size for convolutions - stride (int): stride for convolutions - context (int): kernel size of the convolution in the - decoder before the transposed convolution. If > 1, - will provide some context from neighboring time steps. - gelu: use GELU activation function. - glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. - norm_starts: layer at which group norm starts being used. - decoder layers are numbered in reverse order. - norm_groups: number of groups for group norm. - dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. - dconv_depth: depth of residual DConv branch. - dconv_comp: compression of DConv branch. - dconv_attn: adds attention layers in DConv branch starting at this layer. - dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. - dconv_init: initial scale for the DConv branch LayerScale. - normalize (bool): normalizes the input audio on the fly, and scales back - the output by the same amount. - resample (bool): upsample x2 the input and downsample /2 the output. - rescale (int): rescale initial weights of convolutions - to get their standard deviation closer to `rescale`. - samplerate (int): stored as meta information for easing - future evaluations of the model. - segment (float): duration of the chunks of audio to ideally evaluate the model on. - This is used by `demucs.apply.apply_model`. - """ - - super().__init__() - self.audio_channels = audio_channels - self.sources = sources - self.kernel_size = kernel_size - self.context = context - self.stride = stride - self.depth = depth - self.resample = resample - self.channels = channels - self.normalize = normalize - self.samplerate = samplerate - self.segment = segment - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - self.skip_scales = nn.ModuleList() - - if glu: - activation = nn.GLU(dim=1) - ch_scale = 2 - else: - activation = nn.ReLU() - ch_scale = 1 - if gelu: - act2 = nn.GELU - else: - act2 = nn.ReLU - - in_channels = audio_channels - padding = 0 - for index in range(depth): - norm_fn = lambda d: nn.Identity() # noqa - if index >= norm_starts: - norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa - - encode = [] - encode += [ - nn.Conv1d(in_channels, channels, kernel_size, stride), - norm_fn(channels), - act2(), - ] - attn = index >= dconv_attn - lstm = index >= dconv_lstm - if dconv_mode & 1: - encode += [DConv(channels, depth=dconv_depth, init=dconv_init, - compress=dconv_comp, attn=attn, lstm=lstm)] - if rewrite: - encode += [ - nn.Conv1d(channels, ch_scale * channels, 1), - norm_fn(ch_scale * channels), activation] - self.encoder.append(nn.Sequential(*encode)) - - decode = [] - if index > 0: - out_channels = in_channels - else: - out_channels = len(self.sources) * audio_channels - if rewrite: - decode += [ - nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), - norm_fn(ch_scale * channels), activation] - if dconv_mode & 2: - decode += [DConv(channels, depth=dconv_depth, init=dconv_init, - compress=dconv_comp, attn=attn, lstm=lstm)] - decode += [nn.ConvTranspose1d(channels, out_channels, - kernel_size, stride, padding=padding)] - if index > 0: - decode += [norm_fn(out_channels), act2()] - self.decoder.insert(0, nn.Sequential(*decode)) - in_channels = channels - channels = int(growth * channels) - - channels = in_channels - if lstm_layers: - self.lstm = BLSTM(channels, lstm_layers) - else: - self.lstm = None - - if rescale: - rescale_module(self, reference=rescale) - - def valid_length(self, length): - """ - Return the nearest valid length to use with the model so that - there is no time steps left over in a convolution, e.g. for all - layers, size of the input - kernel_size % stride = 0. - - Note that input are automatically padded if necessary to ensure that the output - has the same length as the input. - """ - if self.resample: - length *= 2 - - for _ in range(self.depth): - length = math.ceil((length - self.kernel_size) / self.stride) + 1 - length = max(1, length) - - for idx in range(self.depth): - length = (length - 1) * self.stride + self.kernel_size - - if self.resample: - length = math.ceil(length / 2) - return int(length) - - def forward(self, mix): - x = mix - length = x.shape[-1] - - if self.normalize: - mono = mix.mean(dim=1, keepdim=True) - mean = mono.mean(dim=-1, keepdim=True) - std = mono.std(dim=-1, keepdim=True) - x = (x - mean) / (1e-5 + std) - else: - mean = 0 - std = 1 - - delta = self.valid_length(length) - length - x = F.pad(x, (delta // 2, delta - delta // 2)) - - if self.resample: - x = julius.resample_frac(x, 1, 2) - - saved = [] - for encode in self.encoder: - x = encode(x) - saved.append(x) - - if self.lstm: - x = self.lstm(x) - - for decode in self.decoder: - skip = saved.pop(-1) - skip = center_trim(skip, x) - x = decode(x + skip) - - if self.resample: - x = julius.resample_frac(x, 2, 1) - x = x * std + mean - x = center_trim(x, length) - x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) - return x - - def load_state_dict(self, state, strict=True): - # fix a mismatch with previous generation Demucs models. - for idx in range(self.depth): - for a in ['encoder', 'decoder']: - for b in ['bias', 'weight']: - new = f'{a}.{idx}.3.{b}' - old = f'{a}.{idx}.2.{b}' - if old in state and new not in state: - state[new] = state.pop(old) - super().load_state_dict(state, strict=strict) diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py deleted file mode 100644 index 864fd3f..0000000 --- a/demucs/hdemucs.py +++ /dev/null @@ -1,761 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -This code contains the spectrogram and Hybrid version of Demucs. -""" -from copy import deepcopy -import math - -from openunmix.filtering import wiener -import torch -from torch import nn -from torch.nn import functional as F - -from .demucs import DConv, rescale_module -from .states import capture_init -from .spec import spectro, ispectro - - -class ScaledEmbedding(nn.Module): - """ - Boost learning rate for embeddings (with `scale`). - Also, can make embeddings continuous with `smooth`. - """ - def __init__(self, num_embeddings: int, embedding_dim: int, - scale: float = 10., smooth=False): - super().__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim) - if smooth: - weight = torch.cumsum(self.embedding.weight.data, dim=0) - # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. - weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] - self.embedding.weight.data[:] = weight - self.embedding.weight.data /= scale - self.scale = scale - - @property - def weight(self): - return self.embedding.weight * self.scale - - def forward(self, x): - out = self.embedding(x) * self.scale - return out - - -class HEncLayer(nn.Module): - def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, - freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, - rewrite=True): - """Encoder layer. This used both by the time and the frequency branch. - - Args: - chin: number of input channels. - chout: number of output channels. - norm_groups: number of groups for group norm. - empty: used to make a layer with just the first conv. this is used - before merging the time and freq. branches. - freq: this is acting on frequencies. - dconv: insert DConv residual branches. - norm: use GroupNorm. - context: context size for the 1x1 conv. - dconv_kw: list of kwargs for the DConv class. - pad: pad the input. Padding is done so that the output size is - always the input size / stride. - rewrite: add 1x1 conv at the end of the layer. - """ - super().__init__() - norm_fn = lambda d: nn.Identity() # noqa - if norm: - norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa - if pad: - pad = kernel_size // 4 - else: - pad = 0 - klass = nn.Conv1d - self.freq = freq - self.kernel_size = kernel_size - self.stride = stride - self.empty = empty - self.norm = norm - self.pad = pad - if freq: - kernel_size = [kernel_size, 1] - stride = [stride, 1] - pad = [pad, 0] - klass = nn.Conv2d - self.conv = klass(chin, chout, kernel_size, stride, pad) - if self.empty: - return - self.norm1 = norm_fn(chout) - self.rewrite = None - if rewrite: - self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) - self.norm2 = norm_fn(2 * chout) - - self.dconv = None - if dconv: - self.dconv = DConv(chout, **dconv_kw) - - def forward(self, x, inject=None): - """ - `inject` is used to inject the result from the time branch into the frequency branch, - when both have the same stride. - """ - if not self.freq and x.dim() == 4: - B, C, Fr, T = x.shape - x = x.view(B, -1, T) - - if not self.freq: - le = x.shape[-1] - if not le % self.stride == 0: - x = F.pad(x, (0, self.stride - (le % self.stride))) - y = self.conv(x) - if self.empty: - return y - if inject is not None: - assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) - if inject.dim() == 3 and y.dim() == 4: - inject = inject[:, :, None] - y = y + inject - y = F.gelu(self.norm1(y)) - if self.dconv: - if self.freq: - B, C, Fr, T = y.shape - y = y.permute(0, 2, 1, 3).reshape(-1, C, T) - y = self.dconv(y) - if self.freq: - y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) - if self.rewrite: - z = self.norm2(self.rewrite(y)) - z = F.glu(z, dim=1) - else: - z = y - return z - - -class MultiWrap(nn.Module): - """ - Takes one layer and replicate it N times. each replica will act - on a frequency band. All is done so that if the N replica have the same weights, - then this is exactly equivalent to applying the original module on all frequencies. - - This is a bit over-engineered to avoid edge artifacts when splitting - the frequency bands, but it is possible the naive implementation would work as well... - """ - def __init__(self, layer, split_ratios): - """ - Args: - layer: module to clone, must be either HEncLayer or HDecLayer. - split_ratios: list of float indicating which ratio to keep for each band. - """ - super().__init__() - self.split_ratios = split_ratios - self.layers = nn.ModuleList() - self.conv = isinstance(layer, HEncLayer) - assert not layer.norm - assert layer.freq - assert layer.pad - if not self.conv: - assert not layer.context_freq - for k in range(len(split_ratios) + 1): - lay = deepcopy(layer) - if self.conv: - lay.conv.padding = (0, 0) - else: - lay.pad = False - for m in lay.modules(): - if hasattr(m, 'reset_parameters'): - m.reset_parameters() - self.layers.append(lay) - - def forward(self, x, skip=None, length=None): - B, C, Fr, T = x.shape - - ratios = list(self.split_ratios) + [1] - start = 0 - outs = [] - for ratio, layer in zip(ratios, self.layers): - if self.conv: - pad = layer.kernel_size // 4 - if ratio == 1: - limit = Fr - frames = -1 - else: - limit = int(round(Fr * ratio)) - le = limit - start - if start == 0: - le += pad - frames = round((le - layer.kernel_size) / layer.stride + 1) - limit = start + (frames - 1) * layer.stride + layer.kernel_size - if start == 0: - limit -= pad - assert limit - start > 0, (limit, start) - assert limit <= Fr, (limit, Fr) - y = x[:, :, start:limit, :] - if start == 0: - y = F.pad(y, (0, 0, pad, 0)) - if ratio == 1: - y = F.pad(y, (0, 0, 0, pad)) - outs.append(layer(y)) - start = limit - layer.kernel_size + layer.stride - else: - if ratio == 1: - limit = Fr - else: - limit = int(round(Fr * ratio)) - last = layer.last - layer.last = True - - y = x[:, :, start:limit] - s = skip[:, :, start:limit] - out, _ = layer(y, s, None) - if outs: - outs[-1][:, :, -layer.stride:] += ( - out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) - out = out[:, :, layer.stride:] - if ratio == 1: - out = out[:, :, :-layer.stride // 2, :] - if start == 0: - out = out[:, :, layer.stride // 2:, :] - outs.append(out) - layer.last = last - start = limit - out = torch.cat(outs, dim=2) - if not self.conv and not last: - out = F.gelu(out) - if self.conv: - return out - else: - return out, None - - -class HDecLayer(nn.Module): - def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, - freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, - context_freq=True, rewrite=True): - """ - Same as HEncLayer but for decoder. See `HEncLayer` for documentation. - """ - super().__init__() - norm_fn = lambda d: nn.Identity() # noqa - if norm: - norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa - if pad: - pad = kernel_size // 4 - else: - pad = 0 - self.pad = pad - self.last = last - self.freq = freq - self.chin = chin - self.empty = empty - self.stride = stride - self.kernel_size = kernel_size - self.norm = norm - self.context_freq = context_freq - klass = nn.Conv1d - klass_tr = nn.ConvTranspose1d - if freq: - kernel_size = [kernel_size, 1] - stride = [stride, 1] - klass = nn.Conv2d - klass_tr = nn.ConvTranspose2d - self.conv_tr = klass_tr(chin, chout, kernel_size, stride) - self.norm2 = norm_fn(chout) - if self.empty: - return - self.rewrite = None - if rewrite: - if context_freq: - self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) - else: - self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, - [0, context]) - self.norm1 = norm_fn(2 * chin) - - self.dconv = None - if dconv: - self.dconv = DConv(chin, **dconv_kw) - - def forward(self, x, skip, length): - if self.freq and x.dim() == 3: - B, C, T = x.shape - x = x.view(B, self.chin, -1, T) - - if not self.empty: - x = x + skip - - if self.rewrite: - y = F.glu(self.norm1(self.rewrite(x)), dim=1) - else: - y = x - if self.dconv: - if self.freq: - B, C, Fr, T = y.shape - y = y.permute(0, 2, 1, 3).reshape(-1, C, T) - y = self.dconv(y) - if self.freq: - y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) - else: - y = x - assert skip is None - z = self.norm2(self.conv_tr(y)) - if self.freq: - if self.pad: - z = z[..., self.pad:-self.pad, :] - else: - z = z[..., self.pad:self.pad + length] - assert z.shape[-1] == length, (z.shape[-1], length) - if not self.last: - z = F.gelu(z) - return z, y - - -class HDemucs(nn.Module): - """ - Spectrogram and hybrid Demucs model. - The spectrogram model has the same structure as Demucs, except the first few layers are over the - frequency axis, until there is only 1 frequency, and then it moves to time convolutions. - Frequency layers can still access information across time steps thanks to the DConv residual. - - Hybrid model have a parallel time branch. At some layer, the time branch has the same stride - as the frequency branch and then the two are combined. The opposite happens in the decoder. - - Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), - or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on - Open Unmix implementation [Stoter et al. 2019]. - - The loss is always on the temporal domain, by backpropagating through the above - output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks - a bit Wiener filtering, as doing more iteration at test time will change the spectrogram - contribution, without changing the one from the waveform, which will lead to worse performance. - I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. - CaC on the other hand provides similar performance for hybrid, and works naturally with - hybrid models. - - This model also uses frequency embeddings are used to improve efficiency on convolutions - over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). - - Unlike classic Demucs, there is no resampling here, and normalization is always applied. - """ - @capture_init - def __init__(self, - sources, - # Channels - audio_channels=2, - channels=48, - channels_time=None, - growth=2, - # STFT - nfft=4096, - wiener_iters=0, - end_iters=0, - wiener_residual=False, - cac=True, - # Main structure - depth=6, - rewrite=True, - hybrid=True, - hybrid_old=False, - # Frequency branch - multi_freqs=None, - multi_freqs_depth=2, - freq_emb=0.2, - emb_scale=10, - emb_smooth=True, - # Convolutions - kernel_size=8, - time_stride=2, - stride=4, - context=1, - context_enc=0, - # Normalization - norm_starts=4, - norm_groups=4, - # DConv residual branch - dconv_mode=1, - dconv_depth=2, - dconv_comp=4, - dconv_attn=4, - dconv_lstm=4, - dconv_init=1e-4, - # Weight init - rescale=0.1, - # Metadata - samplerate=44100, - segment=4 * 10): - """ - Args: - sources (list[str]): list of source names. - audio_channels (int): input/output audio channels. - channels (int): initial number of hidden channels. - channels_time: if not None, use a different `channels` value for the time branch. - growth: increase the number of hidden channels by this factor at each layer. - nfft: number of fft bins. Note that changing this require careful computation of - various shape parameters and will not work out of the box for hybrid models. - wiener_iters: when using Wiener filtering, number of iterations at test time. - end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. - wiener_residual: add residual source before wiener filtering. - cac: uses complex as channels, i.e. complex numbers are 2 channels each - in input and output. no further processing is done before ISTFT. - depth (int): number of layers in the encoder and in the decoder. - rewrite (bool): add 1x1 convolution to each layer. - hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. - hybrid_old: some models trained for MDX had a padding bug. This replicates - this bug to avoid retraining them. - multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. - multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost - layers will be wrapped. - freq_emb: add frequency embedding after the first frequency layer if > 0, - the actual value controls the weight of the embedding. - emb_scale: equivalent to scaling the embedding learning rate - emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). - kernel_size: kernel_size for encoder and decoder layers. - stride: stride for encoder and decoder layers. - time_stride: stride for the final time layer, after the merge. - context: context for 1x1 conv in the decoder. - context_enc: context for 1x1 conv in the encoder. - norm_starts: layer at which group norm starts being used. - decoder layers are numbered in reverse order. - norm_groups: number of groups for group norm. - dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. - dconv_depth: depth of residual DConv branch. - dconv_comp: compression of DConv branch. - dconv_attn: adds attention layers in DConv branch starting at this layer. - dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. - dconv_init: initial scale for the DConv branch LayerScale. - rescale: weight recaling trick - - """ - super().__init__() - self.cac = cac - self.wiener_residual = wiener_residual - self.audio_channels = audio_channels - self.sources = sources - self.kernel_size = kernel_size - self.context = context - self.stride = stride - self.depth = depth - self.channels = channels - self.samplerate = samplerate - self.segment = segment - - self.nfft = nfft - self.hop_length = nfft // 4 - self.wiener_iters = wiener_iters - self.end_iters = end_iters - self.freq_emb = None - self.hybrid = hybrid - self.hybrid_old = hybrid_old - if hybrid_old: - assert hybrid, "hybrid_old must come with hybrid=True" - if hybrid: - assert wiener_iters == end_iters - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - if hybrid: - self.tencoder = nn.ModuleList() - self.tdecoder = nn.ModuleList() - - chin = audio_channels - chin_z = chin # number of channels for the freq branch - if self.cac: - chin_z *= 2 - chout = channels_time or channels - chout_z = channels - freqs = nfft // 2 - - for index in range(depth): - lstm = index >= dconv_lstm - attn = index >= dconv_attn - norm = index >= norm_starts - freq = freqs > 1 - stri = stride - ker = kernel_size - if not freq: - assert freqs == 1 - ker = time_stride * 2 - stri = time_stride - - pad = True - last_freq = False - if freq and freqs <= kernel_size: - ker = freqs - pad = False - last_freq = True - - kw = { - 'kernel_size': ker, - 'stride': stri, - 'freq': freq, - 'pad': pad, - 'norm': norm, - 'rewrite': rewrite, - 'norm_groups': norm_groups, - 'dconv_kw': { - 'lstm': lstm, - 'attn': attn, - 'depth': dconv_depth, - 'compress': dconv_comp, - 'init': dconv_init, - 'gelu': True, - } - } - kwt = dict(kw) - kwt['freq'] = 0 - kwt['kernel_size'] = kernel_size - kwt['stride'] = stride - kwt['pad'] = True - kw_dec = dict(kw) - multi = False - if multi_freqs and index < multi_freqs_depth: - multi = True - kw_dec['context_freq'] = False - - if last_freq: - chout_z = max(chout, chout_z) - chout = chout_z - - enc = HEncLayer(chin_z, chout_z, - dconv=dconv_mode & 1, context=context_enc, **kw) - if hybrid and freq: - tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, - empty=last_freq, **kwt) - self.tencoder.append(tenc) - - if multi: - enc = MultiWrap(enc, multi_freqs) - self.encoder.append(enc) - if index == 0: - chin = self.audio_channels * len(self.sources) - chin_z = chin - if self.cac: - chin_z *= 2 - dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, - last=index == 0, context=context, **kw_dec) - if multi: - dec = MultiWrap(dec, multi_freqs) - if hybrid and freq: - tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, - last=index == 0, context=context, **kwt) - self.tdecoder.insert(0, tdec) - self.decoder.insert(0, dec) - - chin = chout - chin_z = chout_z - chout = int(growth * chout) - chout_z = int(growth * chout_z) - if freq: - if freqs <= kernel_size: - freqs = 1 - else: - freqs //= stride - if index == 0 and freq_emb: - self.freq_emb = ScaledEmbedding( - freqs, chin_z, smooth=emb_smooth, scale=emb_scale) - self.freq_emb_scale = freq_emb - - if rescale: - rescale_module(self, reference=rescale) - - def _spec(self, x): - hl = self.hop_length - nfft = self.nfft - x0 = x # noqa - - if self.hybrid: - # We re-pad the signal in order to keep the property - # that the size of the output is exactly the size of the input - # divided by the stride (here hop_length), when divisible. - # This is achieved by padding by 1/4th of the kernel size (here nfft). - # which is not supported by torch.stft. - # Having all convolution operations follow this convention allow to easily - # align the time and frequency branches later on. - assert hl == nfft // 4 - le = int(math.ceil(x.shape[-1] / hl)) - pad = hl // 2 * 3 - if not self.hybrid_old: - x = F.pad(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect') - else: - x = F.pad(x, (pad, pad + le * hl - x.shape[-1])) - - z = spectro(x, nfft, hl)[..., :-1, :] - if self.hybrid: - assert z.shape[-1] == le + 4, (z.shape, x.shape, le) - z = z[..., 2:2+le] - return z - - def _ispec(self, z, length=None, scale=0): - hl = self.hop_length // (4 ** scale) - z = F.pad(z, (0, 0, 0, 1)) - if self.hybrid: - z = F.pad(z, (2, 2)) - pad = hl // 2 * 3 - if not self.hybrid_old: - le = hl * int(math.ceil(length / hl)) + 2 * pad - else: - le = hl * int(math.ceil(length / hl)) - x = ispectro(z, hl, length=le) - if not self.hybrid_old: - x = x[..., pad:pad + length] - else: - x = x[..., :length] - else: - x = ispectro(z, hl, length) - return x - - def _magnitude(self, z): - # return the magnitude of the spectrogram, except when cac is True, - # in which case we just move the complex dimension to the channel one. - if self.cac: - B, C, Fr, T = z.shape - m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) - m = m.reshape(B, C * 2, Fr, T) - else: - m = z.abs() - return m - - def _mask(self, z, m): - # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. - # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. - niters = self.wiener_iters - if self.cac: - B, S, C, Fr, T = m.shape - out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) - out = torch.view_as_complex(out.contiguous()) - return out - if self.training: - niters = self.end_iters - if niters < 0: - z = z[:, None] - return z / (1e-8 + z.abs()) * m - else: - return self._wiener(m, z, niters) - - def _wiener(self, mag_out, mix_stft, niters): - # apply wiener filtering from OpenUnmix. - init = mix_stft.dtype - wiener_win_len = 300 - residual = self.wiener_residual - - B, S, C, Fq, T = mag_out.shape - mag_out = mag_out.permute(0, 4, 3, 2, 1) - mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) - - outs = [] - for sample in range(B): - pos = 0 - out = [] - for pos in range(0, T, wiener_win_len): - frame = slice(pos, pos + wiener_win_len) - z_out = wiener( - mag_out[sample, frame], mix_stft[sample, frame], niters, - residual=residual) - out.append(z_out.transpose(-1, -2)) - outs.append(torch.cat(out, dim=0)) - out = torch.view_as_complex(torch.stack(outs, 0)) - out = out.permute(0, 4, 3, 2, 1).contiguous() - if residual: - out = out[:, :-1] - assert list(out.shape) == [B, S, C, Fq, T] - return out.to(init) - - def forward(self, mix): - x = mix - length = x.shape[-1] - - z = self._spec(mix) - mag = self._magnitude(z) - x = mag - - B, C, Fq, T = x.shape - - # unlike previous Demucs, we always normalize because it is easier. - mean = x.mean(dim=(1, 2, 3), keepdim=True) - std = x.std(dim=(1, 2, 3), keepdim=True) - x = (x - mean) / (1e-5 + std) - # x will be the freq. branch input. - - if self.hybrid: - # Prepare the time branch input. - xt = mix - meant = xt.mean(dim=(1, 2), keepdim=True) - stdt = xt.std(dim=(1, 2), keepdim=True) - xt = (xt - meant) / (1e-5 + stdt) - - # okay, this is a giant mess I know... - saved = [] # skip connections, freq. - saved_t = [] # skip connections, time. - lengths = [] # saved lengths to properly remove padding, freq branch. - lengths_t = [] # saved lengths for time branch. - for idx, encode in enumerate(self.encoder): - lengths.append(x.shape[-1]) - inject = None - if self.hybrid and idx < len(self.tencoder): - # we have not yet merged branches. - lengths_t.append(xt.shape[-1]) - tenc = self.tencoder[idx] - xt = tenc(xt) - if not tenc.empty: - # save for skip connection - saved_t.append(xt) - else: - # tenc contains just the first conv., so that now time and freq. - # branches have the same shape and can be merged. - inject = xt - x = encode(x, inject) - if idx == 0 and self.freq_emb is not None: - # add frequency embedding to allow for non equivariant convolutions - # over the frequency axis. - frs = torch.arange(x.shape[-2], device=x.device) - emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) - x = x + self.freq_emb_scale * emb - - saved.append(x) - - x = torch.zeros_like(x) - if self.hybrid: - xt = torch.zeros_like(x) - # initialize everything to zero (signal will go through u-net skips). - - for idx, decode in enumerate(self.decoder): - skip = saved.pop(-1) - x, pre = decode(x, skip, lengths.pop(-1)) - # `pre` contains the output just before final transposed convolution, - # which is used when the freq. and time branch separate. - - if self.hybrid: - offset = self.depth - len(self.tdecoder) - if self.hybrid and idx >= offset: - tdec = self.tdecoder[idx - offset] - length_t = lengths_t.pop(-1) - if tdec.empty: - assert pre.shape[2] == 1, pre.shape - pre = pre[:, :, 0] - xt, _ = tdec(pre, None, length_t) - else: - skip = saved_t.pop(-1) - xt, _ = tdec(xt, skip, length_t) - - # Let's make sure we used all stored skip connections. - assert len(saved) == 0 - assert len(lengths_t) == 0 - assert len(saved_t) == 0 - - S = len(self.sources) - x = x.view(B, S, -1, Fq, T) - x = x * std[:, None] + mean[:, None] - - zout = self._mask(z, x) - x = self._ispec(zout, length) - - if self.hybrid: - xt = xt.view(B, S, -1, length) - xt = xt * stdt[:, None] + meant[:, None] - x = xt + x - return x diff --git a/demucs/model.py b/demucs/model.py deleted file mode 100644 index e2745b8..0000000 --- a/demucs/model.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math - -import torch as th -from torch import nn - -from .utils import capture_init, center_trim - - -class BLSTM(nn.Module): - def __init__(self, dim, layers=1): - super().__init__() - self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) - self.linear = nn.Linear(2 * dim, dim) - - def forward(self, x): - x = x.permute(2, 0, 1) - x = self.lstm(x)[0] - x = self.linear(x) - x = x.permute(1, 2, 0) - return x - - -def rescale_conv(conv, reference): - std = conv.weight.std().detach() - scale = (std / reference)**0.5 - conv.weight.data /= scale - if conv.bias is not None: - conv.bias.data /= scale - - -def rescale_module(module, reference): - for sub in module.modules(): - if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): - rescale_conv(sub, reference) - - -def upsample(x, stride): - """ - Linear upsampling, the output will be `stride` times longer. - """ - batch, channels, time = x.size() - weight = th.arange(stride, device=x.device, dtype=th.float) / stride - x = x.view(batch, channels, time, 1) - out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight - return out.reshape(batch, channels, -1) - - -def downsample(x, stride): - """ - Downsample x by decimation. - """ - return x[:, :, ::stride] - - -class Demucs(nn.Module): - @capture_init - def __init__(self, - sources=4, - audio_channels=2, - channels=64, - depth=6, - rewrite=True, - glu=True, - upsample=False, - rescale=0.1, - kernel_size=8, - stride=4, - growth=2., - lstm_layers=2, - context=3, - samplerate=44100): - """ - Args: - sources (int): number of sources to separate - audio_channels (int): stereo or mono - channels (int): first convolution channels - depth (int): number of encoder/decoder layers - rewrite (bool): add 1x1 convolution to each encoder layer - and a convolution to each decoder layer. - For the decoder layer, `context` gives the kernel size. - glu (bool): use glu instead of ReLU - upsample (bool): use linear upsampling with convolutions - Wave-U-Net style, instead of transposed convolutions - rescale (int): rescale initial weights of convolutions - to get their standard deviation closer to `rescale` - kernel_size (int): kernel size for convolutions - stride (int): stride for convolutions - growth (float): multiply (resp divide) number of channels by that - for each layer of the encoder (resp decoder) - lstm_layers (int): number of lstm layers, 0 = no lstm - context (int): kernel size of the convolution in the - decoder before the transposed convolution. If > 1, - will provide some context from neighboring time - steps. - """ - - super().__init__() - self.audio_channels = audio_channels - self.sources = sources - self.kernel_size = kernel_size - self.context = context - self.stride = stride - self.depth = depth - self.upsample = upsample - self.channels = channels - self.samplerate = samplerate - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - self.final = None - if upsample: - self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1) - stride = 1 - - if glu: - activation = nn.GLU(dim=1) - ch_scale = 2 - else: - activation = nn.ReLU() - ch_scale = 1 - in_channels = audio_channels - for index in range(depth): - encode = [] - encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] - if rewrite: - encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] - self.encoder.append(nn.Sequential(*encode)) - - decode = [] - if index > 0: - out_channels = in_channels - else: - if upsample: - out_channels = channels - else: - out_channels = sources * audio_channels - if rewrite: - decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] - if upsample: - decode += [ - nn.Conv1d(channels, out_channels, kernel_size, stride=1), - ] - else: - decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] - if index > 0: - decode.append(nn.ReLU()) - self.decoder.insert(0, nn.Sequential(*decode)) - in_channels = channels - channels = int(growth * channels) - - channels = in_channels - - if lstm_layers: - self.lstm = BLSTM(channels, lstm_layers) - else: - self.lstm = None - - if rescale: - rescale_module(self, reference=rescale) - - def valid_length(self, length): - """ - Return the nearest valid length to use with the model so that - there is no time steps left over in a convolutions, e.g. for all - layers, size of the input - kernel_size % stride = 0. - - If the mixture has a valid length, the estimated sources - will have exactly the same length when context = 1. If context > 1, - the two signals can be center trimmed to match. - - For training, extracts should have a valid length.For evaluation - on full tracks we recommend passing `pad = True` to :method:`forward`. - """ - for _ in range(self.depth): - if self.upsample: - length = math.ceil(length / self.stride) + self.kernel_size - 1 - else: - length = math.ceil((length - self.kernel_size) / self.stride) + 1 - length = max(1, length) - length += self.context - 1 - for _ in range(self.depth): - if self.upsample: - length = length * self.stride + self.kernel_size - 1 - else: - length = (length - 1) * self.stride + self.kernel_size - - return int(length) - - def forward(self, mix): - x = mix - saved = [x] - for encode in self.encoder: - x = encode(x) - saved.append(x) - if self.upsample: - x = downsample(x, self.stride) - if self.lstm: - x = self.lstm(x) - for decode in self.decoder: - if self.upsample: - x = upsample(x, stride=self.stride) - skip = center_trim(saved.pop(-1), x) - x = x + skip - x = decode(x) - if self.final: - skip = center_trim(saved.pop(-1), x) - x = th.cat([x, skip], dim=1) - x = self.final(x) - - x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1)) - return x diff --git a/demucs/model_v2.py b/demucs/model_v2.py deleted file mode 100644 index e9d932f..0000000 --- a/demucs/model_v2.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math - -import julius -from torch import nn - -from .utils import capture_init, center_trim - - -class BLSTM(nn.Module): - def __init__(self, dim, layers=1): - super().__init__() - self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) - self.linear = nn.Linear(2 * dim, dim) - - def forward(self, x): - x = x.permute(2, 0, 1) - x = self.lstm(x)[0] - x = self.linear(x) - x = x.permute(1, 2, 0) - return x - - -def rescale_conv(conv, reference): - std = conv.weight.std().detach() - scale = (std / reference)**0.5 - conv.weight.data /= scale - if conv.bias is not None: - conv.bias.data /= scale - - -def rescale_module(module, reference): - for sub in module.modules(): - if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): - rescale_conv(sub, reference) - - -class Demucs(nn.Module): - @capture_init - def __init__(self, - sources, - audio_channels=2, - channels=64, - depth=6, - rewrite=True, - glu=True, - rescale=0.1, - resample=True, - kernel_size=8, - stride=4, - growth=2., - lstm_layers=2, - context=3, - normalize=False, - samplerate=44100, - segment_length=4 * 10 * 44100): - """ - Args: - sources (list[str]): list of source names - audio_channels (int): stereo or mono - channels (int): first convolution channels - depth (int): number of encoder/decoder layers - rewrite (bool): add 1x1 convolution to each encoder layer - and a convolution to each decoder layer. - For the decoder layer, `context` gives the kernel size. - glu (bool): use glu instead of ReLU - resample_input (bool): upsample x2 the input and downsample /2 the output. - rescale (int): rescale initial weights of convolutions - to get their standard deviation closer to `rescale` - kernel_size (int): kernel size for convolutions - stride (int): stride for convolutions - growth (float): multiply (resp divide) number of channels by that - for each layer of the encoder (resp decoder) - lstm_layers (int): number of lstm layers, 0 = no lstm - context (int): kernel size of the convolution in the - decoder before the transposed convolution. If > 1, - will provide some context from neighboring time - steps. - samplerate (int): stored as meta information for easing - future evaluations of the model. - segment_length (int): stored as meta information for easing - future evaluations of the model. Length of the segments on which - the model was trained. - """ - - super().__init__() - self.audio_channels = audio_channels - self.sources = sources - self.kernel_size = kernel_size - self.context = context - self.stride = stride - self.depth = depth - self.resample = resample - self.channels = channels - self.normalize = normalize - self.samplerate = samplerate - self.segment_length = segment_length - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - if glu: - activation = nn.GLU(dim=1) - ch_scale = 2 - else: - activation = nn.ReLU() - ch_scale = 1 - in_channels = audio_channels - for index in range(depth): - encode = [] - encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] - if rewrite: - encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] - self.encoder.append(nn.Sequential(*encode)) - - decode = [] - if index > 0: - out_channels = in_channels - else: - out_channels = len(self.sources) * audio_channels - if rewrite: - decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] - decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] - if index > 0: - decode.append(nn.ReLU()) - self.decoder.insert(0, nn.Sequential(*decode)) - in_channels = channels - channels = int(growth * channels) - - channels = in_channels - - if lstm_layers: - self.lstm = BLSTM(channels, lstm_layers) - else: - self.lstm = None - - if rescale: - rescale_module(self, reference=rescale) - - def valid_length(self, length): - """ - Return the nearest valid length to use with the model so that - there is no time steps left over in a convolutions, e.g. for all - layers, size of the input - kernel_size % stride = 0. - - If the mixture has a valid length, the estimated sources - will have exactly the same length when context = 1. If context > 1, - the two signals can be center trimmed to match. - - For training, extracts should have a valid length.For evaluation - on full tracks we recommend passing `pad = True` to :method:`forward`. - """ - if self.resample: - length *= 2 - for _ in range(self.depth): - length = math.ceil((length - self.kernel_size) / self.stride) + 1 - length = max(1, length) - length += self.context - 1 - for _ in range(self.depth): - length = (length - 1) * self.stride + self.kernel_size - - if self.resample: - length = math.ceil(length / 2) - return int(length) - - def forward(self, mix): - x = mix - - if self.normalize: - mono = mix.mean(dim=1, keepdim=True) - mean = mono.mean(dim=-1, keepdim=True) - std = mono.std(dim=-1, keepdim=True) - else: - mean = 0 - std = 1 - - x = (x - mean) / (1e-5 + std) - - if self.resample: - x = julius.resample_frac(x, 1, 2) - - saved = [] - for encode in self.encoder: - x = encode(x) - saved.append(x) - if self.lstm: - x = self.lstm(x) - for decode in self.decoder: - skip = center_trim(saved.pop(-1), x) - x = x + skip - x = decode(x) - - if self.resample: - x = julius.resample_frac(x, 2, 1) - x = x * std + mean - x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) - return x diff --git a/demucs/pretrained.py b/demucs/pretrained.py deleted file mode 100644 index 152e19c..0000000 --- a/demucs/pretrained.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Loading pretrained models. -""" - -import logging -from pathlib import Path -import typing as tp - -from dora.log import fatal - -import logging - -from diffq import DiffQuantizer -import torch.hub - -from .model import Demucs -from .tasnet_v2 import ConvTasNet -from .utils import set_state - -from .hdemucs import HDemucs -from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa - -logger = logging.getLogger(__name__) -ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" -REMOTE_ROOT = Path(__file__).parent / 'remote' - -SOURCES = ["drums", "bass", "other", "vocals"] - - -def demucs_unittest(): - model = HDemucs(channels=4, sources=SOURCES) - return model - - -def add_model_flags(parser): - group = parser.add_mutually_exclusive_group(required=False) - group.add_argument("-s", "--sig", help="Locally trained XP signature.") - group.add_argument("-n", "--name", default="mdx_extra_q", - help="Pretrained model name or signature. Default is mdx_extra_q.") - parser.add_argument("--repo", type=Path, - help="Folder containing all pre-trained models for use with -n.") - - -def get_model(name: str, - repo: tp.Optional[Path] = None): - """`name` must be a bag of models name or a pretrained signature - from the remote AWS model repo or the specified local repo if `repo` is not None. - """ - if name == 'demucs_unittest': - return demucs_unittest() - model_repo: ModelOnlyRepo - if repo is None: - remote_files = [line.strip() - for line in (REMOTE_ROOT / 'files.txt').read_text().split('\n') - if line.strip()] - model_repo = RemoteRepo(ROOT_URL, remote_files) - bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) - else: - if not repo.is_dir(): - fatal(f"{repo} must exist and be a directory.") - model_repo = LocalRepo(repo) - bag_repo = BagOnlyRepo(repo, model_repo) - any_repo = AnyModelRepo(model_repo, bag_repo) - return any_repo.get_model(name) - - -def get_model_from_args(args): - """ - Load local model package or pre-trained model. - """ - return get_model(name=args.name, repo=args.repo) - -logger = logging.getLogger(__name__) -ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/" - -PRETRAINED_MODELS = { - 'demucs': 'e07c671f', - 'demucs48_hq': '28a1282c', - 'demucs_extra': '3646af93', - 'demucs_quantized': '07afea75', - 'tasnet': 'beb46fac', - 'tasnet_extra': 'df3777b2', - 'demucs_unittest': '09ebc15f', -} - -SOURCES = ["drums", "bass", "other", "vocals"] - - -def get_url(name): - sig = PRETRAINED_MODELS[name] - return ROOT + name + "-" + sig[:8] + ".th" - - -def is_pretrained(name): - return name in PRETRAINED_MODELS - - -def load_pretrained(name): - if name == "demucs": - return demucs(pretrained=True) - elif name == "demucs48_hq": - return demucs(pretrained=True, hq=True, channels=48) - elif name == "demucs_extra": - return demucs(pretrained=True, extra=True) - elif name == "demucs_quantized": - return demucs(pretrained=True, quantized=True) - elif name == "demucs_unittest": - return demucs_unittest(pretrained=True) - elif name == "tasnet": - return tasnet(pretrained=True) - elif name == "tasnet_extra": - return tasnet(pretrained=True, extra=True) - else: - raise ValueError(f"Invalid pretrained name {name}") - - -def _load_state(name, model, quantizer=None): - url = get_url(name) - state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) - set_state(model, quantizer, state) - if quantizer: - quantizer.detach() - - -def demucs_unittest(pretrained=True): - model = Demucs(channels=4, sources=SOURCES) - if pretrained: - _load_state('demucs_unittest', model) - return model - - -def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64): - if not pretrained and (extra or quantized or hq): - raise ValueError("if extra or quantized is True, pretrained must be True.") - model = Demucs(sources=SOURCES, channels=channels) - if pretrained: - name = 'demucs' - if channels != 64: - name += str(channels) - quantizer = None - if sum([extra, quantized, hq]) > 1: - raise ValueError("Only one of extra, quantized, hq, can be True.") - if quantized: - quantizer = DiffQuantizer(model, group_size=8, min_size=1) - name += '_quantized' - if extra: - name += '_extra' - if hq: - name += '_hq' - _load_state(name, model, quantizer) - return model - - -def tasnet(pretrained=True, extra=False): - if not pretrained and extra: - raise ValueError("if extra is True, pretrained must be True.") - model = ConvTasNet(X=10, sources=SOURCES) - if pretrained: - name = 'tasnet' - if extra: - name = 'tasnet_extra' - _load_state(name, model) - return model \ No newline at end of file diff --git a/demucs/repo.py b/demucs/repo.py deleted file mode 100644 index 178f348..0000000 --- a/demucs/repo.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Represents a model repository, including pre-trained models and bags of models. -A repo can either be the main remote repository stored in AWS, or a local repository -with your own models. -""" - -from hashlib import sha256 -from pathlib import Path -import typing as tp - -import torch -import yaml - -from .apply import BagOfModels, Model -from .states import load_model - - -AnyModel = tp.Union[Model, BagOfModels] - - -class ModelLoadingError(RuntimeError): - pass - - -def check_checksum(path: Path, checksum: str): - sha = sha256() - with open(path, 'rb') as file: - while True: - buf = file.read(2**20) - if not buf: - break - sha.update(buf) - actual_checksum = sha.hexdigest()[:len(checksum)] - if actual_checksum != checksum: - raise ModelLoadingError(f'Invalid checksum for file {path}, ' - f'expected {checksum} but got {actual_checksum}') - - -class ModelOnlyRepo: - """Base class for all model only repos. - """ - def has_model(self, sig: str) -> bool: - raise NotImplementedError() - - def get_model(self, sig: str) -> Model: - raise NotImplementedError() - - -class RemoteRepo(ModelOnlyRepo): - def __init__(self, root_url: str, remote_files: tp.List[str]): - if not root_url.endswith('/'): - root_url += '/' - self._models: tp.Dict[str, str] = {} - for file in remote_files: - sig, checksum = file.split('.')[0].split('-') - assert sig not in self._models - self._models[sig] = root_url + file - - def has_model(self, sig: str) -> bool: - return sig in self._models - - def get_model(self, sig: str) -> Model: - try: - url = self._models[sig] - except KeyError: - raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') - pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) - return load_model(pkg) - - -class LocalRepo(ModelOnlyRepo): - def __init__(self, root: Path): - self.root = root - self.scan() - - def scan(self): - self._models = {} - self._checksums = {} - for file in self.root.iterdir(): - if file.suffix == '.th': - if '-' in file.stem: - xp_sig, checksum = file.stem.split('-') - self._checksums[xp_sig] = checksum - else: - xp_sig = file.stem - if xp_sig in self._models: - print('Whats xp? ', xp_sig) - raise ModelLoadingError( - f'Duplicate pre-trained model exist for signature {xp_sig}. ' - 'Please delete all but one.') - self._models[xp_sig] = file - - def has_model(self, sig: str) -> bool: - return sig in self._models - - def get_model(self, sig: str) -> Model: - try: - file = self._models[sig] - except KeyError: - raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') - if sig in self._checksums: - check_checksum(file, self._checksums[sig]) - return load_model(file) - - -class BagOnlyRepo: - """Handles only YAML files containing bag of models, leaving the actual - model loading to some Repo. - """ - def __init__(self, root: Path, model_repo: ModelOnlyRepo): - self.root = root - self.model_repo = model_repo - self.scan() - - def scan(self): - self._bags = {} - for file in self.root.iterdir(): - if file.suffix == '.yaml': - self._bags[file.stem] = file - - def has_model(self, name: str) -> bool: - return name in self._bags - - def get_model(self, name: str) -> BagOfModels: - try: - yaml_file = self._bags[name] - except KeyError: - raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' - 'a bag of models.') - bag = yaml.safe_load(open(yaml_file)) - signatures = bag['models'] - models = [self.model_repo.get_model(sig) for sig in signatures] - weights = bag.get('weights') - segment = bag.get('segment') - return BagOfModels(models, weights, segment) - - -class AnyModelRepo: - def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): - self.model_repo = model_repo - self.bag_repo = bag_repo - - def has_model(self, name_or_sig: str) -> bool: - return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) - - def get_model(self, name_or_sig: str) -> AnyModel: - if self.model_repo.has_model(name_or_sig): - return self.model_repo.get_model(name_or_sig) - else: - return self.bag_repo.get_model(name_or_sig) diff --git a/demucs/spec.py b/demucs/spec.py deleted file mode 100644 index 85e5dc9..0000000 --- a/demucs/spec.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Conveniance wrapper to perform STFT and iSTFT""" - -import torch as th - - -def spectro(x, n_fft=512, hop_length=None, pad=0): - *other, length = x.shape - x = x.reshape(-1, length) - z = th.stft(x, - n_fft * (1 + pad), - hop_length or n_fft // 4, - window=th.hann_window(n_fft).to(x), - win_length=n_fft, - normalized=True, - center=True, - return_complex=True, - pad_mode='reflect') - _, freqs, frame = z.shape - return z.view(*other, freqs, frame) - - -def ispectro(z, hop_length=None, length=None, pad=0): - *other, freqs, frames = z.shape - n_fft = 2 * freqs - 2 - z = z.view(-1, freqs, frames) - win_length = n_fft // (1 + pad) - x = th.istft(z, - n_fft, - hop_length, - window=th.hann_window(win_length).to(z.real), - win_length=win_length, - normalized=True, - length=length, - center=True) - _, length = x.shape - return x.view(*other, length) diff --git a/demucs/states.py b/demucs/states.py deleted file mode 100644 index db17a18..0000000 --- a/demucs/states.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -Utilities to save and load models. -""" -from contextlib import contextmanager - -import functools -import hashlib -import inspect -import io -from pathlib import Path -import warnings - -from omegaconf import OmegaConf -from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state -import torch - - -def get_quantizer(model, args, optimizer=None): - """Return the quantizer given the XP quantization args.""" - quantizer = None - if args.diffq: - quantizer = DiffQuantizer( - model, min_size=args.min_size, group_size=args.group_size) - if optimizer is not None: - quantizer.setup_optimizer(optimizer) - elif args.qat: - quantizer = UniformQuantizer( - model, bits=args.qat, min_size=args.min_size) - return quantizer - - -def load_model(path_or_package, strict=False): - """Load a model from the given serialized model, either given as a dict (already loaded) - or a path to a file on disk.""" - if isinstance(path_or_package, dict): - package = path_or_package - elif isinstance(path_or_package, (str, Path)): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - path = path_or_package - package = torch.load(path, 'cpu') - else: - raise ValueError(f"Invalid type for {path_or_package}.") - - klass = package["klass"] - args = package["args"] - kwargs = package["kwargs"] - - if strict: - model = klass(*args, **kwargs) - else: - sig = inspect.signature(klass) - for key in list(kwargs): - if key not in sig.parameters: - warnings.warn("Dropping inexistant parameter " + key) - del kwargs[key] - model = klass(*args, **kwargs) - - state = package["state"] - - set_state(model, state) - return model - - -def get_state(model, quantizer, half=False): - """Get the state from a model, potentially with quantization applied. - If `half` is True, model are stored as half precision, which shouldn't impact performance - but half the state size.""" - if quantizer is None: - dtype = torch.half if half else None - state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} - else: - state = quantizer.get_quantized_state() - state['__quantized'] = True - return state - - -def set_state(model, state, quantizer=None): - """Set the state on a given model.""" - if state.get('__quantized'): - if quantizer is not None: - quantizer.restore_quantized_state(model, state['quantized']) - else: - restore_quantized_state(model, state) - else: - model.load_state_dict(state) - return state - - -def save_with_checksum(content, path): - """Save the given value on disk, along with a sha256 hash. - Should be used with the output of either `serialize_model` or `get_state`.""" - buf = io.BytesIO() - torch.save(content, buf) - sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] - - path = path.parent / (path.stem + "-" + sig + path.suffix) - path.write_bytes(buf.getvalue()) - - -def serialize_model(model, training_args, quantizer=None, half=True): - args, kwargs = model._init_args_kwargs - klass = model.__class__ - - state = get_state(model, quantizer, half) - return { - 'klass': klass, - 'args': args, - 'kwargs': kwargs, - 'state': state, - 'training_args': OmegaConf.to_container(training_args, resolve=True), - } - - -def copy_state(state): - return {k: v.cpu().clone() for k, v in state.items()} - - -@contextmanager -def swap_state(model, state): - """ - Context manager that swaps the state of a model, e.g: - - # model is in old state - with swap_state(model, new_state): - # model in new state - # model back to old state - """ - old_state = copy_state(model.state_dict()) - model.load_state_dict(state, strict=False) - try: - yield - finally: - model.load_state_dict(old_state) - - -def capture_init(init): - @functools.wraps(init) - def __init__(self, *args, **kwargs): - self._init_args_kwargs = (args, kwargs) - init(self, *args, **kwargs) - - return __init__ diff --git a/demucs/tasnet.py b/demucs/tasnet.py deleted file mode 100644 index 9cb7a95..0000000 --- a/demucs/tasnet.py +++ /dev/null @@ -1,447 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# Created on 2018/12 -# Author: Kaituo XU -# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels -# Here is the original license: -# The MIT License (MIT) -# -# Copyright (c) 2018 Kaituo XU -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .utils import capture_init - -EPS = 1e-8 - - -def overlap_and_add(signal, frame_step): - outer_dimensions = signal.size()[:-2] - frames, frame_length = signal.size()[-2:] - - subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor - subframe_step = frame_step // subframe_length - subframes_per_frame = frame_length // subframe_length - output_size = frame_step * (frames - 1) + frame_length - output_subframes = output_size // subframe_length - - subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) - - frame = torch.arange(0, output_subframes, - device=signal.device).unfold(0, subframes_per_frame, subframe_step) - frame = frame.long() # signal may in GPU or CPU - frame = frame.contiguous().view(-1) - - result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) - result.index_add_(-2, frame, subframe_signal) - result = result.view(*outer_dimensions, -1) - return result - - -class ConvTasNet(nn.Module): - @capture_init - def __init__(self, - N=256, - L=20, - B=256, - H=512, - P=3, - X=8, - R=4, - C=4, - audio_channels=1, - samplerate=44100, - norm_type="gLN", - causal=False, - mask_nonlinear='relu'): - """ - Args: - N: Number of filters in autoencoder - L: Length of the filters (in samples) - B: Number of channels in bottleneck 1 × 1-conv block - H: Number of channels in convolutional blocks - P: Kernel size in convolutional blocks - X: Number of convolutional blocks in each repeat - R: Number of repeats - C: Number of speakers - norm_type: BN, gLN, cLN - causal: causal or non-causal - mask_nonlinear: use which non-linear function to generate mask - """ - super(ConvTasNet, self).__init__() - # Hyper-parameter - self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C - self.norm_type = norm_type - self.causal = causal - self.mask_nonlinear = mask_nonlinear - self.audio_channels = audio_channels - self.samplerate = samplerate - # Components - self.encoder = Encoder(L, N, audio_channels) - self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear) - self.decoder = Decoder(N, L, audio_channels) - # init - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_normal_(p) - - def valid_length(self, length): - return length - - def forward(self, mixture): - """ - Args: - mixture: [M, T], M is batch size, T is #samples - Returns: - est_source: [M, C, T] - """ - mixture_w = self.encoder(mixture) - est_mask = self.separator(mixture_w) - est_source = self.decoder(mixture_w, est_mask) - - # T changed after conv1d in encoder, fix it here - T_origin = mixture.size(-1) - T_conv = est_source.size(-1) - est_source = F.pad(est_source, (0, T_origin - T_conv)) - return est_source - - -class Encoder(nn.Module): - """Estimation of the nonnegative mixture weight by a 1-D conv layer. - """ - def __init__(self, L, N, audio_channels): - super(Encoder, self).__init__() - # Hyper-parameter - self.L, self.N = L, N - # Components - # 50% overlap - self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False) - - def forward(self, mixture): - """ - Args: - mixture: [M, T], M is batch size, T is #samples - Returns: - mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 - """ - mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] - return mixture_w - - -class Decoder(nn.Module): - def __init__(self, N, L, audio_channels): - super(Decoder, self).__init__() - # Hyper-parameter - self.N, self.L = N, L - self.audio_channels = audio_channels - # Components - self.basis_signals = nn.Linear(N, audio_channels * L, bias=False) - - def forward(self, mixture_w, est_mask): - """ - Args: - mixture_w: [M, N, K] - est_mask: [M, C, N, K] - Returns: - est_source: [M, C, T] - """ - # D = W * M - source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] - source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] - # S = DV - est_source = self.basis_signals(source_w) # [M, C, K, ac * L] - m, c, k, _ = est_source.size() - est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous() - est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T - return est_source - - -class TemporalConvNet(nn.Module): - def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'): - """ - Args: - N: Number of filters in autoencoder - B: Number of channels in bottleneck 1 × 1-conv block - H: Number of channels in convolutional blocks - P: Kernel size in convolutional blocks - X: Number of convolutional blocks in each repeat - R: Number of repeats - C: Number of speakers - norm_type: BN, gLN, cLN - causal: causal or non-causal - mask_nonlinear: use which non-linear function to generate mask - """ - super(TemporalConvNet, self).__init__() - # Hyper-parameter - self.C = C - self.mask_nonlinear = mask_nonlinear - # Components - # [M, N, K] -> [M, N, K] - layer_norm = ChannelwiseLayerNorm(N) - # [M, N, K] -> [M, B, K] - bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) - # [M, B, K] -> [M, B, K] - repeats = [] - for r in range(R): - blocks = [] - for x in range(X): - dilation = 2**x - padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 - blocks += [ - TemporalBlock(B, - H, - P, - stride=1, - padding=padding, - dilation=dilation, - norm_type=norm_type, - causal=causal) - ] - repeats += [nn.Sequential(*blocks)] - temporal_conv_net = nn.Sequential(*repeats) - # [M, B, K] -> [M, C*N, K] - mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) - # Put together - self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, - mask_conv1x1) - - def forward(self, mixture_w): - """ - Keep this API same with TasNet - Args: - mixture_w: [M, N, K], M is batch size - returns: - est_mask: [M, C, N, K] - """ - M, N, K = mixture_w.size() - score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] - score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] - if self.mask_nonlinear == 'softmax': - est_mask = F.softmax(score, dim=1) - elif self.mask_nonlinear == 'relu': - est_mask = F.relu(score) - else: - raise ValueError("Unsupported mask non-linear function") - return est_mask - - -class TemporalBlock(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - norm_type="gLN", - causal=False): - super(TemporalBlock, self).__init__() - # [M, B, K] -> [M, H, K] - conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) - prelu = nn.PReLU() - norm = chose_norm(norm_type, out_channels) - # [M, H, K] -> [M, B, K] - dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, - dilation, norm_type, causal) - # Put together - self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) - - def forward(self, x): - """ - Args: - x: [M, B, K] - Returns: - [M, B, K] - """ - residual = x - out = self.net(x) - # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? - return out + residual # look like w/o F.relu is better than w/ F.relu - # return F.relu(out + residual) - - -class DepthwiseSeparableConv(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - norm_type="gLN", - causal=False): - super(DepthwiseSeparableConv, self).__init__() - # Use `groups` option to implement depthwise convolution - # [M, H, K] -> [M, H, K] - depthwise_conv = nn.Conv1d(in_channels, - in_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=in_channels, - bias=False) - if causal: - chomp = Chomp1d(padding) - prelu = nn.PReLU() - norm = chose_norm(norm_type, in_channels) - # [M, H, K] -> [M, B, K] - pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) - # Put together - if causal: - self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) - else: - self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) - - def forward(self, x): - """ - Args: - x: [M, H, K] - Returns: - result: [M, B, K] - """ - return self.net(x) - - -class Chomp1d(nn.Module): - """To ensure the output length is the same as the input. - """ - def __init__(self, chomp_size): - super(Chomp1d, self).__init__() - self.chomp_size = chomp_size - - def forward(self, x): - """ - Args: - x: [M, H, Kpad] - Returns: - [M, H, K] - """ - return x[:, :, :-self.chomp_size].contiguous() - - -def chose_norm(norm_type, channel_size): - """The input of normlization will be (M, C, K), where M is batch size, - C is channel size and K is sequence length. - """ - if norm_type == "gLN": - return GlobalLayerNorm(channel_size) - elif norm_type == "cLN": - return ChannelwiseLayerNorm(channel_size) - elif norm_type == "id": - return nn.Identity() - else: # norm_type == "BN": - # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics - # along M and K, so this BN usage is right. - return nn.BatchNorm1d(channel_size) - - -# TODO: Use nn.LayerNorm to impl cLN to speed up -class ChannelwiseLayerNorm(nn.Module): - """Channel-wise Layer Normalization (cLN)""" - def __init__(self, channel_size): - super(ChannelwiseLayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.reset_parameters() - - def reset_parameters(self): - self.gamma.data.fill_(1) - self.beta.data.zero_() - - def forward(self, y): - """ - Args: - y: [M, N, K], M is batch size, N is channel size, K is length - Returns: - cLN_y: [M, N, K] - """ - mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] - var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] - cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta - return cLN_y - - -class GlobalLayerNorm(nn.Module): - """Global Layer Normalization (gLN)""" - def __init__(self, channel_size): - super(GlobalLayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.reset_parameters() - - def reset_parameters(self): - self.gamma.data.fill_(1) - self.beta.data.zero_() - - def forward(self, y): - """ - Args: - y: [M, N, K], M is batch size, N is channel size, K is length - Returns: - gLN_y: [M, N, K] - """ - # TODO: in torch 1.0, torch.mean() support dim list - mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] - var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) - gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta - return gLN_y - - -if __name__ == "__main__": - torch.manual_seed(123) - M, N, L, T = 2, 3, 4, 12 - K = 2 * T // L - 1 - B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False - mixture = torch.randint(3, (M, T)) - # test Encoder - encoder = Encoder(L, N) - encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) - mixture_w = encoder(mixture) - print('mixture', mixture) - print('U', encoder.conv1d_U.weight) - print('mixture_w', mixture_w) - print('mixture_w size', mixture_w.size()) - - # test TemporalConvNet - separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) - est_mask = separator(mixture_w) - print('est_mask', est_mask) - - # test Decoder - decoder = Decoder(N, L) - est_mask = torch.randint(2, (B, K, C, N)) - est_source = decoder(mixture_w, est_mask) - print('est_source', est_source) - - # test Conv-TasNet - conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) - est_source = conv_tasnet(mixture) - print('est_source', est_source) - print('est_source size', est_source.size()) diff --git a/demucs/tasnet_v2.py b/demucs/tasnet_v2.py deleted file mode 100644 index ecc1257..0000000 --- a/demucs/tasnet_v2.py +++ /dev/null @@ -1,452 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# Created on 2018/12 -# Author: Kaituo XU -# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels -# Here is the original license: -# The MIT License (MIT) -# -# Copyright (c) 2018 Kaituo XU -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .utils import capture_init - -EPS = 1e-8 - - -def overlap_and_add(signal, frame_step): - outer_dimensions = signal.size()[:-2] - frames, frame_length = signal.size()[-2:] - - subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor - subframe_step = frame_step // subframe_length - subframes_per_frame = frame_length // subframe_length - output_size = frame_step * (frames - 1) + frame_length - output_subframes = output_size // subframe_length - - subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) - - frame = torch.arange(0, output_subframes, - device=signal.device).unfold(0, subframes_per_frame, subframe_step) - frame = frame.long() # signal may in GPU or CPU - frame = frame.contiguous().view(-1) - - result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) - result.index_add_(-2, frame, subframe_signal) - result = result.view(*outer_dimensions, -1) - return result - - -class ConvTasNet(nn.Module): - @capture_init - def __init__(self, - sources, - N=256, - L=20, - B=256, - H=512, - P=3, - X=8, - R=4, - audio_channels=2, - norm_type="gLN", - causal=False, - mask_nonlinear='relu', - samplerate=44100, - segment_length=44100 * 2 * 4): - """ - Args: - sources: list of sources - N: Number of filters in autoencoder - L: Length of the filters (in samples) - B: Number of channels in bottleneck 1 × 1-conv block - H: Number of channels in convolutional blocks - P: Kernel size in convolutional blocks - X: Number of convolutional blocks in each repeat - R: Number of repeats - norm_type: BN, gLN, cLN - causal: causal or non-causal - mask_nonlinear: use which non-linear function to generate mask - """ - super(ConvTasNet, self).__init__() - # Hyper-parameter - self.sources = sources - self.C = len(sources) - self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R - self.norm_type = norm_type - self.causal = causal - self.mask_nonlinear = mask_nonlinear - self.audio_channels = audio_channels - self.samplerate = samplerate - self.segment_length = segment_length - # Components - self.encoder = Encoder(L, N, audio_channels) - self.separator = TemporalConvNet( - N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear) - self.decoder = Decoder(N, L, audio_channels) - # init - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_normal_(p) - - def valid_length(self, length): - return length - - def forward(self, mixture): - """ - Args: - mixture: [M, T], M is batch size, T is #samples - Returns: - est_source: [M, C, T] - """ - mixture_w = self.encoder(mixture) - est_mask = self.separator(mixture_w) - est_source = self.decoder(mixture_w, est_mask) - - # T changed after conv1d in encoder, fix it here - T_origin = mixture.size(-1) - T_conv = est_source.size(-1) - est_source = F.pad(est_source, (0, T_origin - T_conv)) - return est_source - - -class Encoder(nn.Module): - """Estimation of the nonnegative mixture weight by a 1-D conv layer. - """ - def __init__(self, L, N, audio_channels): - super(Encoder, self).__init__() - # Hyper-parameter - self.L, self.N = L, N - # Components - # 50% overlap - self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False) - - def forward(self, mixture): - """ - Args: - mixture: [M, T], M is batch size, T is #samples - Returns: - mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 - """ - mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] - return mixture_w - - -class Decoder(nn.Module): - def __init__(self, N, L, audio_channels): - super(Decoder, self).__init__() - # Hyper-parameter - self.N, self.L = N, L - self.audio_channels = audio_channels - # Components - self.basis_signals = nn.Linear(N, audio_channels * L, bias=False) - - def forward(self, mixture_w, est_mask): - """ - Args: - mixture_w: [M, N, K] - est_mask: [M, C, N, K] - Returns: - est_source: [M, C, T] - """ - # D = W * M - source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] - source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] - # S = DV - est_source = self.basis_signals(source_w) # [M, C, K, ac * L] - m, c, k, _ = est_source.size() - est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous() - est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T - return est_source - - -class TemporalConvNet(nn.Module): - def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'): - """ - Args: - N: Number of filters in autoencoder - B: Number of channels in bottleneck 1 × 1-conv block - H: Number of channels in convolutional blocks - P: Kernel size in convolutional blocks - X: Number of convolutional blocks in each repeat - R: Number of repeats - C: Number of speakers - norm_type: BN, gLN, cLN - causal: causal or non-causal - mask_nonlinear: use which non-linear function to generate mask - """ - super(TemporalConvNet, self).__init__() - # Hyper-parameter - self.C = C - self.mask_nonlinear = mask_nonlinear - # Components - # [M, N, K] -> [M, N, K] - layer_norm = ChannelwiseLayerNorm(N) - # [M, N, K] -> [M, B, K] - bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) - # [M, B, K] -> [M, B, K] - repeats = [] - for r in range(R): - blocks = [] - for x in range(X): - dilation = 2**x - padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 - blocks += [ - TemporalBlock(B, - H, - P, - stride=1, - padding=padding, - dilation=dilation, - norm_type=norm_type, - causal=causal) - ] - repeats += [nn.Sequential(*blocks)] - temporal_conv_net = nn.Sequential(*repeats) - # [M, B, K] -> [M, C*N, K] - mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) - # Put together - self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, - mask_conv1x1) - - def forward(self, mixture_w): - """ - Keep this API same with TasNet - Args: - mixture_w: [M, N, K], M is batch size - returns: - est_mask: [M, C, N, K] - """ - M, N, K = mixture_w.size() - score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] - score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] - if self.mask_nonlinear == 'softmax': - est_mask = F.softmax(score, dim=1) - elif self.mask_nonlinear == 'relu': - est_mask = F.relu(score) - else: - raise ValueError("Unsupported mask non-linear function") - return est_mask - - -class TemporalBlock(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - norm_type="gLN", - causal=False): - super(TemporalBlock, self).__init__() - # [M, B, K] -> [M, H, K] - conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) - prelu = nn.PReLU() - norm = chose_norm(norm_type, out_channels) - # [M, H, K] -> [M, B, K] - dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, - dilation, norm_type, causal) - # Put together - self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) - - def forward(self, x): - """ - Args: - x: [M, B, K] - Returns: - [M, B, K] - """ - residual = x - out = self.net(x) - # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? - return out + residual # look like w/o F.relu is better than w/ F.relu - # return F.relu(out + residual) - - -class DepthwiseSeparableConv(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - norm_type="gLN", - causal=False): - super(DepthwiseSeparableConv, self).__init__() - # Use `groups` option to implement depthwise convolution - # [M, H, K] -> [M, H, K] - depthwise_conv = nn.Conv1d(in_channels, - in_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=in_channels, - bias=False) - if causal: - chomp = Chomp1d(padding) - prelu = nn.PReLU() - norm = chose_norm(norm_type, in_channels) - # [M, H, K] -> [M, B, K] - pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) - # Put together - if causal: - self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) - else: - self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) - - def forward(self, x): - """ - Args: - x: [M, H, K] - Returns: - result: [M, B, K] - """ - return self.net(x) - - -class Chomp1d(nn.Module): - """To ensure the output length is the same as the input. - """ - def __init__(self, chomp_size): - super(Chomp1d, self).__init__() - self.chomp_size = chomp_size - - def forward(self, x): - """ - Args: - x: [M, H, Kpad] - Returns: - [M, H, K] - """ - return x[:, :, :-self.chomp_size].contiguous() - - -def chose_norm(norm_type, channel_size): - """The input of normlization will be (M, C, K), where M is batch size, - C is channel size and K is sequence length. - """ - if norm_type == "gLN": - return GlobalLayerNorm(channel_size) - elif norm_type == "cLN": - return ChannelwiseLayerNorm(channel_size) - elif norm_type == "id": - return nn.Identity() - else: # norm_type == "BN": - # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics - # along M and K, so this BN usage is right. - return nn.BatchNorm1d(channel_size) - - -# TODO: Use nn.LayerNorm to impl cLN to speed up -class ChannelwiseLayerNorm(nn.Module): - """Channel-wise Layer Normalization (cLN)""" - def __init__(self, channel_size): - super(ChannelwiseLayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.reset_parameters() - - def reset_parameters(self): - self.gamma.data.fill_(1) - self.beta.data.zero_() - - def forward(self, y): - """ - Args: - y: [M, N, K], M is batch size, N is channel size, K is length - Returns: - cLN_y: [M, N, K] - """ - mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] - var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] - cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta - return cLN_y - - -class GlobalLayerNorm(nn.Module): - """Global Layer Normalization (gLN)""" - def __init__(self, channel_size): - super(GlobalLayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.reset_parameters() - - def reset_parameters(self): - self.gamma.data.fill_(1) - self.beta.data.zero_() - - def forward(self, y): - """ - Args: - y: [M, N, K], M is batch size, N is channel size, K is length - Returns: - gLN_y: [M, N, K] - """ - # TODO: in torch 1.0, torch.mean() support dim list - mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] - var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) - gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta - return gLN_y - - -if __name__ == "__main__": - torch.manual_seed(123) - M, N, L, T = 2, 3, 4, 12 - K = 2 * T // L - 1 - B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False - mixture = torch.randint(3, (M, T)) - # test Encoder - encoder = Encoder(L, N) - encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) - mixture_w = encoder(mixture) - print('mixture', mixture) - print('U', encoder.conv1d_U.weight) - print('mixture_w', mixture_w) - print('mixture_w size', mixture_w.size()) - - # test TemporalConvNet - separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) - est_mask = separator(mixture_w) - print('est_mask', est_mask) - - # test Decoder - decoder = Decoder(N, L) - est_mask = torch.randint(2, (B, K, C, N)) - est_source = decoder(mixture_w, est_mask) - print('est_source', est_source) - - # test Conv-TasNet - conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) - est_source = conv_tasnet(mixture) - print('est_source', est_source) - print('est_source size', est_source.size()) diff --git a/demucs/utils-.py b/demucs/utils-.py deleted file mode 100644 index e69c926..0000000 --- a/demucs/utils-.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import errno -import functools -import gzip -import os -import random -import socket -import tempfile -import warnings -from contextlib import contextmanager - -import torch as th -import tqdm -from torch import distributed -from torch.nn import functional as F - - -def center_trim(tensor, reference): - """ - Center trim `tensor` with respect to `reference`, along the last dimension. - `reference` can also be a number, representing the length to trim to. - If the size difference != 0 mod 2, the extra sample is removed on the right side. - """ - if hasattr(reference, "size"): - reference = reference.size(-1) - delta = tensor.size(-1) - reference - if delta < 0: - raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") - if delta: - tensor = tensor[..., delta // 2:-(delta - delta // 2)] - return tensor - - -def average_metric(metric, count=1.): - """ - Average `metric` which should be a float across all hosts. `count` should be - the weight for this particular host (i.e. number of examples). - """ - metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') - distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) - return metric[1].item() / metric[0].item() - - -def free_port(host='', low=20000, high=40000): - """ - Return a port number that is most likely free. - This could suffer from a race condition although - it should be quite rare. - """ - sock = socket.socket() - while True: - port = random.randint(low, high) - try: - sock.bind((host, port)) - except OSError as error: - if error.errno == errno.EADDRINUSE: - continue - raise - return port - - -def sizeof_fmt(num, suffix='B'): - """ - Given `num` bytes, return human readable size. - Taken from https://stackoverflow.com/a/1094933 - """ - for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: - if abs(num) < 1024.0: - return "%3.1f%s%s" % (num, unit, suffix) - num /= 1024.0 - return "%.1f%s%s" % (num, 'Yi', suffix) - - -def human_seconds(seconds, display='.2f'): - """ - Given `seconds` seconds, return human readable duration. - """ - value = seconds * 1e6 - ratios = [1e3, 1e3, 60, 60, 24] - names = ['us', 'ms', 's', 'min', 'hrs', 'days'] - last = names.pop(0) - for name, ratio in zip(names, ratios): - if value / ratio < 0.3: - break - value /= ratio - last = name - return f"{format(value, display)} {last}" - - -def apply_model_v1(model, mix, shifts=None, split=False, progress=False): - """ - Apply model to a given mixture. - - Args: - shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec - and apply the oppositve shift to the output. This is repeated `shifts` time and - all predictions are averaged. This effectively makes the model time equivariant - and improves SDR by up to 0.2 points. - split (bool): if True, the input will be broken down in 8 seconds extracts - and predictions will be performed individually on each and concatenated. - Useful for model with large memory footprint like Tasnet. - progress (bool): if True, show a progress bar (requires split=True) - """ - channels, length = mix.size() - device = mix.device - if split: - out = th.zeros(4, channels, length, device=device) - shift = model.samplerate * 10 - offsets = range(0, length, shift) - scale = 10 - if progress: - offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') - for offset in offsets: - chunk = mix[..., offset:offset + shift] - chunk_out = apply_model_v1(model, chunk, shifts=shifts) - out[..., offset:offset + shift] = chunk_out - offset += shift - return out - elif shifts: - max_shift = int(model.samplerate / 2) - mix = F.pad(mix, (max_shift, max_shift)) - offsets = list(range(max_shift)) - random.shuffle(offsets) - out = 0 - for offset in offsets[:shifts]: - shifted = mix[..., offset:offset + length + max_shift] - shifted_out = apply_model_v1(model, shifted) - out += shifted_out[..., max_shift - offset:max_shift - offset + length] - out /= shifts - return out - else: - valid_length = model.valid_length(length) - print('valid_length: ', valid_length) - delta = valid_length - length - padded = F.pad(mix, (delta // 2, delta - delta // 2)) - with th.no_grad(): - out = model(padded.unsqueeze(0))[0] - return center_trim(out, mix) - - -@contextmanager -def temp_filenames(count, delete=True, **kwargs): - names = [] - try: - for _ in range(count): - names.append(tempfile.NamedTemporaryFile(delete=False).name) - yield names - finally: - if delete: - for name in names: - os.unlink(name) - - -def load_model(path): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - load_from = path - if str(path).endswith(".gz"): - load_from = gzip.open(path, "rb") - klass, args, kwargs, state = th.load(load_from, 'cpu') - model = klass(*args, **kwargs) - model.load_state_dict(state) - return model - - -def save_model(model, path): - args, kwargs = model._init_args_kwargs - klass = model.__class__ - state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} - save_to = path - if str(path).endswith(".gz"): - save_to = gzip.open(path, "wb", compresslevel=5) - th.save((klass, args, kwargs, state), save_to) - - -def capture_init(init): - @functools.wraps(init) - def __init__(self, *args, **kwargs): - self._init_args_kwargs = (args, kwargs) - init(self, *args, **kwargs) - - return __init__ diff --git a/demucs/utils.py b/demucs/utils.py deleted file mode 100644 index f09bc18..0000000 --- a/demucs/utils.py +++ /dev/null @@ -1,533 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from collections import defaultdict -from contextlib import contextmanager -import math -import os -import tempfile -import typing as tp - -import errno -import functools -import hashlib -import inspect -import io -import os -import random -import socket -import tempfile -import warnings -import zlib -import tkinter as tk - -from diffq import UniformQuantizer, DiffQuantizer -import torch as th -import tqdm -from torch import distributed -from torch.nn import functional as F - -import torch - -def unfold(a, kernel_size, stride): - """Given input of size [*OT, T], output Tensor of size [*OT, F, K] - with K the kernel size, by extracting frames with the given stride. - - This will pad the input so that `F = ceil(T / K)`. - - see https://github.com/pytorch/pytorch/issues/60466 - """ - *shape, length = a.shape - n_frames = math.ceil(length / stride) - tgt_length = (n_frames - 1) * stride + kernel_size - a = F.pad(a, (0, tgt_length - length)) - strides = list(a.stride()) - assert strides[-1] == 1, 'data should be contiguous' - strides = strides[:-1] + [stride, 1] - return a.as_strided([*shape, n_frames, kernel_size], strides) - - -def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): - """ - Center trim `tensor` with respect to `reference`, along the last dimension. - `reference` can also be a number, representing the length to trim to. - If the size difference != 0 mod 2, the extra sample is removed on the right side. - """ - ref_size: int - if isinstance(reference, torch.Tensor): - ref_size = reference.size(-1) - else: - ref_size = reference - delta = tensor.size(-1) - ref_size - if delta < 0: - raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") - if delta: - tensor = tensor[..., delta // 2:-(delta - delta // 2)] - return tensor - - -def pull_metric(history: tp.List[dict], name: str): - out = [] - for metrics in history: - metric = metrics - for part in name.split("."): - metric = metric[part] - out.append(metric) - return out - - -def EMA(beta: float = 1): - """ - Exponential Moving Average callback. - Returns a single function that can be called to repeatidly update the EMA - with a dict of metrics. The callback will return - the new averaged dict of metrics. - - Note that for `beta=1`, this is just plain averaging. - """ - fix: tp.Dict[str, float] = defaultdict(float) - total: tp.Dict[str, float] = defaultdict(float) - - def _update(metrics: dict, weight: float = 1) -> dict: - nonlocal total, fix - for key, value in metrics.items(): - total[key] = total[key] * beta + weight * float(value) - fix[key] = fix[key] * beta + weight - return {key: tot / fix[key] for key, tot in total.items()} - return _update - - -def sizeof_fmt(num: float, suffix: str = 'B'): - """ - Given `num` bytes, return human readable size. - Taken from https://stackoverflow.com/a/1094933 - """ - for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: - if abs(num) < 1024.0: - return "%3.1f%s%s" % (num, unit, suffix) - num /= 1024.0 - return "%.1f%s%s" % (num, 'Yi', suffix) - - -@contextmanager -def temp_filenames(count: int, delete=True): - names = [] - try: - for _ in range(count): - names.append(tempfile.NamedTemporaryFile(delete=False).name) - yield names - finally: - if delete: - for name in names: - os.unlink(name) - -def average_metric(metric, count=1.): - """ - Average `metric` which should be a float across all hosts. `count` should be - the weight for this particular host (i.e. number of examples). - """ - metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') - distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) - return metric[1].item() / metric[0].item() - - -def free_port(host='', low=20000, high=40000): - """ - Return a port number that is most likely free. - This could suffer from a race condition although - it should be quite rare. - """ - sock = socket.socket() - while True: - port = random.randint(low, high) - try: - sock.bind((host, port)) - except OSError as error: - if error.errno == errno.EADDRINUSE: - continue - raise - return port - - -def sizeof_fmt(num, suffix='B'): - """ - Given `num` bytes, return human readable size. - Taken from https://stackoverflow.com/a/1094933 - """ - for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: - if abs(num) < 1024.0: - return "%3.1f%s%s" % (num, unit, suffix) - num /= 1024.0 - return "%.1f%s%s" % (num, 'Yi', suffix) - - -def human_seconds(seconds, display='.2f'): - """ - Given `seconds` seconds, return human readable duration. - """ - value = seconds * 1e6 - ratios = [1e3, 1e3, 60, 60, 24] - names = ['us', 'ms', 's', 'min', 'hrs', 'days'] - last = names.pop(0) - for name, ratio in zip(names, ratios): - if value / ratio < 0.3: - break - value /= ratio - last = name - return f"{format(value, display)} {last}" - - -class TensorChunk: - def __init__(self, tensor, offset=0, length=None): - total_length = tensor.shape[-1] - assert offset >= 0 - assert offset < total_length - - if length is None: - length = total_length - offset - else: - length = min(total_length - offset, length) - - self.tensor = tensor - self.offset = offset - self.length = length - self.device = tensor.device - - @property - def shape(self): - shape = list(self.tensor.shape) - shape[-1] = self.length - return shape - - def padded(self, target_length): - delta = target_length - self.length - total_length = self.tensor.shape[-1] - assert delta >= 0 - - start = self.offset - delta // 2 - end = start + target_length - - correct_start = max(0, start) - correct_end = min(total_length, end) - - pad_left = correct_start - start - pad_right = end - correct_end - - out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) - assert out.shape[-1] == target_length - return out - - -def tensor_chunk(tensor_or_chunk): - if isinstance(tensor_or_chunk, TensorChunk): - return tensor_or_chunk - else: - assert isinstance(tensor_or_chunk, th.Tensor) - return TensorChunk(tensor_or_chunk) - - -def apply_model_v1(model, mix, gui_progress_bar: tk.Variable, widget_text: tk.Text, update_prog, total_files, file_num, inference_type, shifts=None, split=False, progress=False, segmen=True): - """ - Apply model to a given mixture. - - Args: - shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec - and apply the oppositve shift to the output. This is repeated `shifts` time and - all predictions are averaged. This effectively makes the model time equivariant - and improves SDR by up to 0.2 points. - split (bool): if True, the input will be broken down in 8 seconds extracts - and predictions will be performed individually on each and concatenated. - Useful for model with large memory footprint like Tasnet. - progress (bool): if True, show a progress bar (requires split=True) - """ - - base_text = 'File {file_num}/{total_files} '.format(file_num=file_num, - total_files=total_files) - - channels, length = mix.size() - device = mix.device - if split: - out = th.zeros(4, channels, length, device=device) - shift = model.samplerate * 10 - offsets = range(0, length, shift) - scale = 10 - progress_bar = 0 - prog_bar = 0 - if progress: - offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') - for offset in offsets: - if segmen: - fut_length = len(offsets) - send_back = fut_length * 2 - progress_bar += 100 - prog_bar += 1 - if inference_type == 'demucs_only': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.1 + (1.7/send_back * prog_bar))) - elif inference_type == 'inference_mdx': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.35 + (1.05/send_back * prog_bar))) - elif inference_type == 'inference_vr': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.6 + (0.7/send_back * prog_bar))) - step = (progress_bar / fut_length) - percent_prog = f"{base_text}Demucs v1 Inference Progress: {prog_bar}/{fut_length} | {round(step)}%" - widget_text.percentage(percent_prog) - #gui_progress_bar.set(step) - chunk = mix[..., offset:offset + shift] - chunk_out = apply_model_v1(model, chunk, gui_progress_bar, widget_text, update_prog, total_files, file_num, inference_type, shifts=shifts) - out[..., offset:offset + shift] = chunk_out - offset += shift - return out - elif shifts: - max_shift = int(model.samplerate / 2) - mix = F.pad(mix, (max_shift, max_shift)) - offsets = list(range(max_shift)) - random.shuffle(offsets) - out = 0 - for offset in offsets[:shifts]: - shifted = mix[..., offset:offset + length + max_shift] - shifted_out = apply_model_v1(model, shifted, gui_progress_bar, widget_text, update_prog, total_files, file_num, inference_type) - out += shifted_out[..., max_shift - offset:max_shift - offset + length] - out /= shifts - return out - else: - valid_length = model.valid_length(length) - delta = valid_length - length - padded = F.pad(mix, (delta // 2, delta - delta // 2)) - with th.no_grad(): - out = model(padded.unsqueeze(0))[0] - return center_trim(out, mix) - -def apply_model_v2(model, mix, gui_progress_bar: tk.Variable, widget_text: tk.Text, update_prog, total_files, file_num, inference_type, shifts=None, split=False, - overlap=0.25, transition_power=1., progress=False, segmen=True): - """ - Apply model to a given mixture. - - Args: - shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec - and apply the oppositve shift to the output. This is repeated `shifts` time and - all predictions are averaged. This effectively makes the model time equivariant - and improves SDR by up to 0.2 points. - split (bool): if True, the input will be broken down in 8 seconds extracts - and predictions will be performed individually on each and concatenated. - Useful for model with large memory footprint like Tasnet. - progress (bool): if True, show a progress bar (requires split=True) - """ - - global prog_space - global percent_prog - - percent_prog = 0 - - base_text = 'File {file_num}/{total_files} '.format(file_num=file_num, - total_files=total_files) - - #widget_text.remove(percent_prog) - assert transition_power >= 1, "transition_power < 1 leads to weird behavior." - device = mix.device - channels, length = mix.shape - if split: - out = th.zeros(len(model.sources), channels, length, device=device) - sum_weight = th.zeros(length, device=device) - segment = model.segment_length - stride = int((1 - overlap) * segment) - offsets = range(0, length, stride) - scale = stride / model.samplerate - if progress: - offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') - # We start from a triangle shaped weight, with maximal weight in the middle - # of the segment. Then we normalize and take to the power `transition_power`. - # Large values of transition power will lead to sharper transitions. - weight = th.cat([th.arange(1, segment // 2 + 1), - th.arange(segment - segment // 2, 0, -1)]).to(device) - assert len(weight) == segment - # If the overlap < 50%, this will translate to linear transition when - # transition_power is 1. - weight = (weight / weight.max())**transition_power - progress_bar = 0 - prog_bar = 0 - for offset in offsets: - if segmen: - fut_length = len(offsets) - send_back = fut_length * 2 - progress_bar += 100 - prog_bar += 1 - if inference_type == 'demucs_only': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.1 + (1.7/send_back * prog_bar))) - elif inference_type == 'inference_mdx': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.35 + (1.05/send_back * prog_bar))) - elif inference_type == 'inference_vr': - update_prog(gui_progress_bar, total_files, file_num, - step=(0.6 + (0.7/send_back * prog_bar))) - step = (progress_bar / fut_length) - percent_prog = f"{base_text}Demucs v2 Inference Progress: {prog_bar}/{fut_length} | {round(step)}%" - prog_space = len(percent_prog) - prog_space = prog_bar*prog_space - widget_text.percentage(percent_prog) - chunk = TensorChunk(mix, offset, segment) - chunk_out = apply_model_v2(model, chunk, gui_progress_bar, widget_text, update_prog, total_files, file_num, inference_type, shifts=shifts) - chunk_length = chunk_out.shape[-1] - out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out - sum_weight[offset:offset + segment] += weight[:chunk_length] - offset += segment - assert sum_weight.min() > 0 - out /= sum_weight - return out - elif shifts: - max_shift = int(0.5 * model.samplerate) - mix = tensor_chunk(mix) - padded_mix = mix.padded(length + 2 * max_shift) - out = 0 - for _ in range(shifts): - offset = random.randint(0, max_shift) - shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) - shifted_out = apply_model_v2(model, shifted, gui_progress_bar, widget_text, update_prog, total_files, file_num, inference_type) - out += shifted_out[..., max_shift - offset:] - out /= shifts - return out - else: - valid_length = model.valid_length(length) - mix = tensor_chunk(mix) - padded_mix = mix.padded(valid_length) - with th.no_grad(): - out = model(padded_mix.unsqueeze(0))[0] - return center_trim(out, length) - - -@contextmanager -def temp_filenames(count, delete=True): - names = [] - try: - for _ in range(count): - names.append(tempfile.NamedTemporaryFile(delete=False).name) - yield names - finally: - if delete: - for name in names: - os.unlink(name) - - -def get_quantizer(model, args, optimizer=None): - quantizer = None - if args.diffq: - quantizer = DiffQuantizer( - model, min_size=args.q_min_size, group_size=8) - if optimizer is not None: - quantizer.setup_optimizer(optimizer) - elif args.qat: - quantizer = UniformQuantizer( - model, bits=args.qat, min_size=args.q_min_size) - return quantizer - - -def load_model(path, strict=False): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - load_from = path - package = th.load(load_from, 'cpu') - - klass = package["klass"] - args = package["args"] - kwargs = package["kwargs"] - - if strict: - model = klass(*args, **kwargs) - else: - sig = inspect.signature(klass) - for key in list(kwargs): - if key not in sig.parameters: - warnings.warn("Dropping inexistant parameter " + key) - del kwargs[key] - model = klass(*args, **kwargs) - - state = package["state"] - training_args = package["training_args"] - quantizer = get_quantizer(model, training_args) - - set_state(model, quantizer, state) - return model - - -def get_state(model, quantizer): - if quantizer is None: - state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} - else: - state = quantizer.get_quantized_state() - buf = io.BytesIO() - th.save(state, buf) - state = {'compressed': zlib.compress(buf.getvalue())} - return state - - -def set_state(model, quantizer, state): - if quantizer is None: - model.load_state_dict(state) - else: - buf = io.BytesIO(zlib.decompress(state["compressed"])) - state = th.load(buf, "cpu") - quantizer.restore_quantized_state(state) - - return state - - -def save_state(state, path): - buf = io.BytesIO() - th.save(state, buf) - sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] - - path = path.parent / (path.stem + "-" + sig + path.suffix) - path.write_bytes(buf.getvalue()) - - -def save_model(model, quantizer, training_args, path): - args, kwargs = model._init_args_kwargs - klass = model.__class__ - - state = get_state(model, quantizer) - - save_to = path - package = { - 'klass': klass, - 'args': args, - 'kwargs': kwargs, - 'state': state, - 'training_args': training_args, - } - th.save(package, save_to) - - -def capture_init(init): - @functools.wraps(init) - def __init__(self, *args, **kwargs): - self._init_args_kwargs = (args, kwargs) - init(self, *args, **kwargs) - - return __init__ - -class DummyPoolExecutor: - class DummyResult: - def __init__(self, func, *args, **kwargs): - self.func = func - self.args = args - self.kwargs = kwargs - - def result(self): - return self.func(*self.args, **self.kwargs) - - def __init__(self, workers=0): - pass - - def submit(self, func, *args, **kwargs): - return DummyPoolExecutor.DummyResult(func, *args, **kwargs) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - return