diff --git a/demucs/__init__.py b/demucs/__init__.py new file mode 100644 index 0000000..d4182e3 --- /dev/null +++ b/demucs/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +__version__ = "2.0.3" diff --git a/demucs/__main__.py b/demucs/__main__.py new file mode 100644 index 0000000..5148f20 --- /dev/null +++ b/demucs/__main__.py @@ -0,0 +1,317 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json +import math +import os +import sys +import time +from dataclasses import dataclass, field + +import torch as th +from torch import distributed, nn +from torch.nn.parallel.distributed import DistributedDataParallel + +from .augment import FlipChannels, FlipSign, Remix, Scale, Shift +from .compressed import get_compressed_datasets +from .model import Demucs +from .parser import get_name, get_parser +from .raw import Rawset +from .repitch import RepitchedWrapper +from .pretrained import load_pretrained, SOURCES +from .tasnet import ConvTasNet +from .test import evaluate +from .train import train_model, validate_model +from .utils import (human_seconds, load_model, save_model, get_state, + save_state, sizeof_fmt, get_quantizer) +from .wav import get_wav_datasets, get_musdb_wav_datasets + + +@dataclass +class SavedState: + metrics: list = field(default_factory=list) + last_state: dict = None + best_state: dict = None + optimizer: dict = None + + +def main(): + parser = get_parser() + args = parser.parse_args() + name = get_name(parser, args) + print(f"Experiment {name}") + + if args.musdb is None and args.rank == 0: + print( + "You must provide the path to the MusDB dataset with the --musdb flag. " + "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", + file=sys.stderr) + sys.exit(1) + + eval_folder = args.evals / name + eval_folder.mkdir(exist_ok=True, parents=True) + args.logs.mkdir(exist_ok=True) + metrics_path = args.logs / f"{name}.json" + eval_folder.mkdir(exist_ok=True, parents=True) + args.checkpoints.mkdir(exist_ok=True, parents=True) + args.models.mkdir(exist_ok=True, parents=True) + + if args.device is None: + device = "cpu" + if th.cuda.is_available(): + device = "cuda" + else: + device = args.device + + th.manual_seed(args.seed) + # Prevents too many threads to be started when running `museval` as it can be quite + # inefficient on NUMA architectures. + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + + if args.world_size > 1: + if device != "cuda" and args.rank == 0: + print("Error: distributed training is only available with cuda device", file=sys.stderr) + sys.exit(1) + th.cuda.set_device(args.rank % th.cuda.device_count()) + distributed.init_process_group(backend="nccl", + init_method="tcp://" + args.master, + rank=args.rank, + world_size=args.world_size) + + checkpoint = args.checkpoints / f"{name}.th" + checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" + if args.restart and checkpoint.exists() and args.rank == 0: + checkpoint.unlink() + + if args.test or args.test_pretrained: + args.epochs = 1 + args.repeat = 0 + if args.test: + model = load_model(args.models / args.test) + else: + model = load_pretrained(args.test_pretrained) + elif args.tasnet: + model = ConvTasNet(audio_channels=args.audio_channels, + samplerate=args.samplerate, X=args.X, + segment_length=4 * args.samples, + sources=SOURCES) + else: + model = Demucs( + audio_channels=args.audio_channels, + channels=args.channels, + context=args.context, + depth=args.depth, + glu=args.glu, + growth=args.growth, + kernel_size=args.kernel_size, + lstm_layers=args.lstm_layers, + rescale=args.rescale, + rewrite=args.rewrite, + stride=args.conv_stride, + resample=args.resample, + normalize=args.normalize, + samplerate=args.samplerate, + segment_length=4 * args.samples, + sources=SOURCES, + ) + model.to(device) + if args.init: + model.load_state_dict(load_pretrained(args.init).state_dict()) + + if args.show: + print(model) + size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) + print(f"Model size {size}") + return + + try: + saved = th.load(checkpoint, map_location='cpu') + except IOError: + saved = SavedState() + + optimizer = th.optim.Adam(model.parameters(), lr=args.lr) + + quantizer = None + quantizer = get_quantizer(model, args, optimizer) + + if saved.last_state is not None: + model.load_state_dict(saved.last_state, strict=False) + if saved.optimizer is not None: + optimizer.load_state_dict(saved.optimizer) + + model_name = f"{name}.th" + if args.save_model: + if args.rank == 0: + model.to("cpu") + model.load_state_dict(saved.best_state) + save_model(model, quantizer, args, args.models / model_name) + return + elif args.save_state: + model_name = f"{args.save_state}.th" + if args.rank == 0: + model.to("cpu") + model.load_state_dict(saved.best_state) + state = get_state(model, quantizer) + save_state(state, args.models / model_name) + return + + if args.rank == 0: + done = args.logs / f"{name}.done" + if done.exists(): + done.unlink() + + augment = [Shift(args.data_stride)] + if args.augment: + augment += [FlipSign(), FlipChannels(), Scale(), + Remix(group_size=args.remix_group_size)] + augment = nn.Sequential(*augment).to(device) + print("Agumentation pipeline:", augment) + + if args.mse: + criterion = nn.MSELoss() + else: + criterion = nn.L1Loss() + + # Setting number of samples so that all convolution windows are full. + # Prevents hard to debug mistake with the prediction being shifted compared + # to the input mixture. + samples = model.valid_length(args.samples) + print(f"Number of training samples adjusted to {samples}") + samples = samples + args.data_stride + if args.repitch: + # We need a bit more audio samples, to account for potential + # tempo change. + samples = math.ceil(samples / (1 - 0.01 * args.max_tempo)) + + args.metadata.mkdir(exist_ok=True, parents=True) + if args.raw: + train_set = Rawset(args.raw / "train", + samples=samples, + channels=args.audio_channels, + streams=range(1, len(model.sources) + 1), + stride=args.data_stride) + + valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) + elif args.wav: + train_set, valid_set = get_wav_datasets(args, samples, model.sources) + elif args.is_wav: + train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources) + else: + train_set, valid_set = get_compressed_datasets(args, samples) + + if args.repitch: + train_set = RepitchedWrapper( + train_set, + proba=args.repitch, + max_tempo=args.max_tempo) + + best_loss = float("inf") + for epoch, metrics in enumerate(saved.metrics): + print(f"Epoch {epoch:03d}: " + f"train={metrics['train']:.8f} " + f"valid={metrics['valid']:.8f} " + f"best={metrics['best']:.4f} " + f"ms={metrics.get('true_model_size', 0):.2f}MB " + f"cms={metrics.get('compressed_model_size', 0):.2f}MB " + f"duration={human_seconds(metrics['duration'])}") + best_loss = metrics['best'] + + if args.world_size > 1: + dmodel = DistributedDataParallel(model, + device_ids=[th.cuda.current_device()], + output_device=th.cuda.current_device()) + else: + dmodel = model + + for epoch in range(len(saved.metrics), args.epochs): + begin = time.time() + model.train() + train_loss, model_size = train_model( + epoch, train_set, dmodel, criterion, optimizer, augment, + quantizer=quantizer, + batch_size=args.batch_size, + device=device, + repeat=args.repeat, + seed=args.seed, + diffq=args.diffq, + workers=args.workers, + world_size=args.world_size) + model.eval() + valid_loss = validate_model( + epoch, valid_set, model, criterion, + device=device, + rank=args.rank, + split=args.split_valid, + overlap=args.overlap, + world_size=args.world_size) + + ms = 0 + cms = 0 + if quantizer and args.rank == 0: + ms = quantizer.true_model_size() + cms = quantizer.compressed_model_size(num_workers=min(40, args.world_size * 10)) + + duration = time.time() - begin + if valid_loss < best_loss and ms <= args.ms_target: + best_loss = valid_loss + saved.best_state = { + key: value.to("cpu").clone() + for key, value in model.state_dict().items() + } + + saved.metrics.append({ + "train": train_loss, + "valid": valid_loss, + "best": best_loss, + "duration": duration, + "model_size": model_size, + "true_model_size": ms, + "compressed_model_size": cms, + }) + if args.rank == 0: + json.dump(saved.metrics, open(metrics_path, "w")) + + saved.last_state = model.state_dict() + saved.optimizer = optimizer.state_dict() + if args.rank == 0 and not args.test: + th.save(saved, checkpoint_tmp) + checkpoint_tmp.rename(checkpoint) + + print(f"Epoch {epoch:03d}: " + f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB " + f"cms={cms:.2f}MB " + f"duration={human_seconds(duration)}") + + if args.world_size > 1: + distributed.barrier() + + del dmodel + model.load_state_dict(saved.best_state) + if args.eval_cpu: + device = "cpu" + model.to(device) + model.eval() + evaluate(model, args.musdb, eval_folder, + is_wav=args.is_wav, + rank=args.rank, + world_size=args.world_size, + device=device, + save=args.save, + split=args.split_valid, + shifts=args.shifts, + overlap=args.overlap, + workers=args.eval_workers) + model.to("cpu") + if args.rank == 0: + if not (args.test or args.test_pretrained): + save_model(model, quantizer, args, args.models / model_name) + print("done") + done.write_text("done") + + +if __name__ == "__main__": + main() diff --git a/demucs/audio.py b/demucs/audio.py new file mode 100644 index 0000000..b29f156 --- /dev/null +++ b/demucs/audio.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import json +import subprocess as sp +from pathlib import Path + +import julius +import numpy as np +import torch + +from .utils import temp_filenames + + +def _read_info(path): + stdout_data = sp.check_output([ + 'ffprobe', "-loglevel", "panic", + str(path), '-print_format', 'json', '-show_format', '-show_streams' + ]) + return json.loads(stdout_data.decode('utf-8')) + + +class AudioFile: + """ + Allows to read audio from any format supported by ffmpeg, as well as resampling or + converting to mono on the fly. See :method:`read` for more details. + """ + def __init__(self, path: Path): + self.path = Path(path) + self._info = None + + def __repr__(self): + features = [("path", self.path)] + features.append(("samplerate", self.samplerate())) + features.append(("channels", self.channels())) + features.append(("streams", len(self))) + features_str = ", ".join(f"{name}={value}" for name, value in features) + return f"AudioFile({features_str})" + + @property + def info(self): + if self._info is None: + self._info = _read_info(self.path) + return self._info + + @property + def duration(self): + return float(self.info['format']['duration']) + + @property + def _audio_streams(self): + return [ + index for index, stream in enumerate(self.info["streams"]) + if stream["codec_type"] == "audio" + ] + + def __len__(self): + return len(self._audio_streams) + + def channels(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['channels']) + + def samplerate(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) + + def read(self, + seek_time=None, + duration=None, + streams=slice(None), + samplerate=None, + channels=None, + temp_folder=None): + """ + Slightly more efficient implementation than stempeg, + in particular, this will extract all stems at once + rather than having to loop over one file multiple times + for each stream. + + Args: + seek_time (float): seek time in seconds or None if no seeking is needed. + duration (float): duration in seconds to extract or None to extract until the end. + streams (slice, int or list): streams to extract, can be a single int, a list or + a slice. If it is a slice or list, the output will be of size [S, C, T] + with S the number of streams, C the number of channels and T the number of samples. + If it is an int, the output will be [C, T]. + samplerate (int): if provided, will resample on the fly. If None, no resampling will + be done. Original sampling rate can be obtained with :method:`samplerate`. + channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that + as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. + See https://sound.stackexchange.com/a/42710. + Our definition of mono is simply the average of the two channels. Any other + value will be ignored. + temp_folder (str or Path or None): temporary folder to use for decoding. + + + """ + streams = np.array(range(len(self)))[streams] + single = not isinstance(streams, np.ndarray) + if single: + streams = [streams] + + if duration is None: + target_size = None + query_duration = None + else: + target_size = int((samplerate or self.samplerate()) * duration) + query_duration = float((target_size + 1) / (samplerate or self.samplerate())) + + with temp_filenames(len(streams)) as filenames: + command = ['ffmpeg', '-y'] + command += ['-loglevel', 'panic'] + if seek_time: + command += ['-ss', str(seek_time)] + command += ['-i', str(self.path)] + for stream, filename in zip(streams, filenames): + command += ['-map', f'0:{self._audio_streams[stream]}'] + if query_duration is not None: + command += ['-t', str(query_duration)] + command += ['-threads', '1'] + command += ['-f', 'f32le'] + if samplerate is not None: + command += ['-ar', str(samplerate)] + command += [filename] + + sp.run(command, check=True) + wavs = [] + for filename in filenames: + wav = np.fromfile(filename, dtype=np.float32) + wav = torch.from_numpy(wav) + wav = wav.view(-1, self.channels()).t() + if channels is not None: + wav = convert_audio_channels(wav, channels) + if target_size is not None: + wav = wav[..., :target_size] + wavs.append(wav) + wav = torch.stack(wavs, dim=0) + if single: + wav = wav[0] + return wav + + +def convert_audio_channels(wav, channels=2): + """Convert audio to the given number of channels.""" + *shape, src_channels, length = wav.shape + if src_channels == channels: + pass + elif channels == 1: + # Case 1: + # The caller asked 1-channel audio, but the stream have multiple + # channels, downmix all channels. + wav = wav.mean(dim=-2, keepdim=True) + elif src_channels == 1: + # Case 2: + # The caller asked for multiple channels, but the input file have + # one single channel, replicate the audio over all channels. + wav = wav.expand(*shape, channels, length) + elif src_channels >= channels: + # Case 3: + # The caller asked for multiple channels, and the input file have + # more channels than requested. In that case return the first channels. + wav = wav[..., :channels, :] + else: + # Case 4: What is a reasonable choice here? + raise ValueError('The audio file has less channels than requested but is not mono.') + return wav + + +def convert_audio(wav, from_samplerate, to_samplerate, channels): + wav = convert_audio_channels(wav, channels) + return julius.resample_frac(wav, from_samplerate, to_samplerate) diff --git a/demucs/compressed.py b/demucs/compressed.py new file mode 100644 index 0000000..eb8fbb7 --- /dev/null +++ b/demucs/compressed.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json +from fractions import Fraction +from concurrent import futures + +import musdb +from torch import distributed + +from .audio import AudioFile + + +def get_musdb_tracks(root, *args, **kwargs): + mus = musdb.DB(root, *args, **kwargs) + return {track.name: track.path for track in mus} + + +class StemsSet: + def __init__(self, tracks, metadata, duration=None, stride=1, + samplerate=44100, channels=2, streams=slice(None)): + + self.metadata = [] + for name, path in tracks.items(): + meta = dict(metadata[name]) + meta["path"] = path + meta["name"] = name + self.metadata.append(meta) + if duration is not None and meta["duration"] < duration: + raise ValueError(f"Track {name} duration is too small {meta['duration']}") + self.metadata.sort(key=lambda x: x["name"]) + self.duration = duration + self.stride = stride + self.channels = channels + self.samplerate = samplerate + self.streams = streams + + def __len__(self): + return sum(self._examples_count(m) for m in self.metadata) + + def _examples_count(self, meta): + if self.duration is None: + return 1 + else: + return int((meta["duration"] - self.duration) // self.stride + 1) + + def track_metadata(self, index): + for meta in self.metadata: + examples = self._examples_count(meta) + if index >= examples: + index -= examples + continue + return meta + + def __getitem__(self, index): + for meta in self.metadata: + examples = self._examples_count(meta) + if index >= examples: + index -= examples + continue + streams = AudioFile(meta["path"]).read(seek_time=index * self.stride, + duration=self.duration, + channels=self.channels, + samplerate=self.samplerate, + streams=self.streams) + return (streams - meta["mean"]) / meta["std"] + + +def _get_track_metadata(path): + # use mono at 44kHz as reference. For any other settings data won't be perfectly + # normalized but it should be good enough. + audio = AudioFile(path) + mix = audio.read(streams=0, channels=1, samplerate=44100) + return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()} + + +def _build_metadata(tracks, workers=10): + pendings = [] + with futures.ProcessPoolExecutor(workers) as pool: + for name, path in tracks.items(): + pendings.append((name, pool.submit(_get_track_metadata, path))) + return {name: p.result() for name, p in pendings} + + +def _build_musdb_metadata(path, musdb, workers): + tracks = get_musdb_tracks(musdb) + metadata = _build_metadata(tracks, workers) + path.parent.mkdir(exist_ok=True, parents=True) + json.dump(metadata, open(path, "w")) + + +def get_compressed_datasets(args, samples): + metadata_file = args.metadata / "musdb.json" + if not metadata_file.is_file() and args.rank == 0: + _build_musdb_metadata(metadata_file, args.musdb, args.workers) + if args.world_size > 1: + distributed.barrier() + metadata = json.load(open(metadata_file)) + duration = Fraction(samples, args.samplerate) + stride = Fraction(args.data_stride, args.samplerate) + train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), + metadata, + duration=duration, + stride=stride, + streams=slice(1, None), + samplerate=args.samplerate, + channels=args.audio_channels) + valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), + metadata, + samplerate=args.samplerate, + channels=args.audio_channels) + return train_set, valid_set diff --git a/demucs/model.py b/demucs/model.py new file mode 100644 index 0000000..e9d932f --- /dev/null +++ b/demucs/model.py @@ -0,0 +1,202 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import julius +from torch import nn + +from .utils import capture_init, center_trim + + +class BLSTM(nn.Module): + def __init__(self, dim, layers=1): + super().__init__() + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x): + x = x.permute(2, 0, 1) + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + return x + + +def rescale_conv(conv, reference): + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): + rescale_conv(sub, reference) + + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources, + audio_channels=2, + channels=64, + depth=6, + rewrite=True, + glu=True, + rescale=0.1, + resample=True, + kernel_size=8, + stride=4, + growth=2., + lstm_layers=2, + context=3, + normalize=False, + samplerate=44100, + segment_length=4 * 10 * 44100): + """ + Args: + sources (list[str]): list of source names + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + rewrite (bool): add 1x1 convolution to each encoder layer + and a convolution to each decoder layer. + For the decoder layer, `context` gives the kernel size. + glu (bool): use glu instead of ReLU + resample_input (bool): upsample x2 the input and downsample /2 the output. + rescale (int): rescale initial weights of convolutions + to get their standard deviation closer to `rescale` + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + lstm_layers (int): number of lstm layers, 0 = no lstm + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time + steps. + samplerate (int): stored as meta information for easing + future evaluations of the model. + segment_length (int): stored as meta information for easing + future evaluations of the model. Length of the segments on which + the model was trained. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.resample = resample + self.channels = channels + self.normalize = normalize + self.samplerate = samplerate + self.segment_length = segment_length + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + in_channels = audio_channels + for index in range(depth): + encode = [] + encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] + if rewrite: + encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + out_channels = len(self.sources) * audio_channels + if rewrite: + decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] + decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] + if index > 0: + decode.append(nn.ReLU()) + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolutions, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + If the mixture has a valid length, the estimated sources + will have exactly the same length when context = 1. If context > 1, + the two signals can be center trimmed to match. + + For training, extracts should have a valid length.For evaluation + on full tracks we recommend passing `pad = True` to :method:`forward`. + """ + if self.resample: + length *= 2 + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + length += self.context - 1 + for _ in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + + if self.resample: + length = math.ceil(length / 2) + return int(length) + + def forward(self, mix): + x = mix + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + mean = mono.mean(dim=-1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + else: + mean = 0 + std = 1 + + x = (x - mean) / (1e-5 + std) + + if self.resample: + x = julius.resample_frac(x, 1, 2) + + saved = [] + for encode in self.encoder: + x = encode(x) + saved.append(x) + if self.lstm: + x = self.lstm(x) + for decode in self.decoder: + skip = center_trim(saved.pop(-1), x) + x = x + skip + x = decode(x) + + if self.resample: + x = julius.resample_frac(x, 2, 1) + x = x * std + mean + x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) + return x diff --git a/demucs/parser.py b/demucs/parser.py new file mode 100644 index 0000000..4e8a19c --- /dev/null +++ b/demucs/parser.py @@ -0,0 +1,244 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from pathlib import Path + + +def get_parser(): + parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.") + default_raw = None + default_musdb = None + if 'DEMUCS_RAW' in os.environ: + default_raw = Path(os.environ['DEMUCS_RAW']) + if 'DEMUCS_MUSDB' in os.environ: + default_musdb = Path(os.environ['DEMUCS_MUSDB']) + parser.add_argument( + "--raw", + type=Path, + default=default_raw, + help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.") + parser.add_argument("--no_raw", action="store_const", const=None, dest="raw") + parser.add_argument("-m", + "--musdb", + type=Path, + default=default_musdb, + help="Path to musdb root") + parser.add_argument("--is_wav", action="store_true", + help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).") + parser.add_argument("--metadata", type=Path, default=Path("metadata/"), + help="Folder where metadata information is stored.") + parser.add_argument("--wav", type=Path, + help="Path to a wav dataset. This should contain a 'train' and a 'valid' " + "subfolder.") + parser.add_argument("--samplerate", type=int, default=44100) + parser.add_argument("--audio_channels", type=int, default=2) + parser.add_argument("--samples", + default=44100 * 10, + type=int, + help="number of samples to feed in") + parser.add_argument("--data_stride", + default=44100, + type=int, + help="Stride for chunks, shorter = longer epochs") + parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers") + parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers") + parser.add_argument("-d", + "--device", + help="Device to train on, default is cuda if available else cpu") + parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.") + parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file") + parser.add_argument("--test", help="Just run the test pipeline + one validation. " + "This should be a filename relative to the models/ folder.") + parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, " + "on a pretrained model. ") + + parser.add_argument("--rank", default=0, type=int) + parser.add_argument("--world_size", default=1, type=int) + parser.add_argument("--master") + + parser.add_argument("--checkpoints", + type=Path, + default=Path("checkpoints"), + help="Folder where to store checkpoints etc") + parser.add_argument("--evals", + type=Path, + default=Path("evals"), + help="Folder where to store evals and waveforms") + parser.add_argument("--save", + action="store_true", + help="Save estimated for the test set waveforms") + parser.add_argument("--logs", + type=Path, + default=Path("logs"), + help="Folder where to store logs") + parser.add_argument("--models", + type=Path, + default=Path("models"), + help="Folder where to store trained models") + parser.add_argument("-R", + "--restart", + action='store_true', + help='Restart training, ignoring previous run') + + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs") + parser.add_argument("-r", + "--repeat", + type=int, + default=2, + help="Repeat the train set, longer epochs") + parser.add_argument("-b", "--batch_size", type=int, default=64) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1") + parser.add_argument("--init", help="Initialize from a pre-trained model.") + + # Augmentation options + parser.add_argument("--no_augment", + action="store_false", + dest="augment", + default=True, + help="No basic data augmentation.") + parser.add_argument("--repitch", type=float, default=0.2, + help="Probability to do tempo/pitch change") + parser.add_argument("--max_tempo", type=float, default=12, + help="Maximum relative tempo change in %% when using repitch.") + + parser.add_argument("--remix_group_size", + type=int, + default=4, + help="Shuffle sources using group of this size. Useful to somewhat " + "replicate multi-gpu training " + "on less GPUs.") + parser.add_argument("--shifts", + type=int, + default=10, + help="Number of random shifts used for the shift trick.") + parser.add_argument("--overlap", + type=float, + default=0.25, + help="Overlap when --split_valid is passed.") + + # See model.py for doc + parser.add_argument("--growth", + type=float, + default=2., + help="Number of channels between two layers will increase by this factor") + parser.add_argument("--depth", + type=int, + default=6, + help="Number of layers for the encoder and decoder") + parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM") + parser.add_argument("--channels", + type=int, + default=64, + help="Number of channels for the first encoder layer") + parser.add_argument("--kernel_size", + type=int, + default=8, + help="Kernel size for the (transposed) convolutions") + parser.add_argument("--conv_stride", + type=int, + default=4, + help="Stride for the (transposed) convolutions") + parser.add_argument("--context", + type=int, + default=3, + help="Context size for the decoder convolutions " + "before the transposed convolutions") + parser.add_argument("--rescale", + type=float, + default=0.1, + help="Initial weight rescale reference") + parser.add_argument("--no_resample", action="store_false", + default=True, dest="resample", + help="No Resampling of the input/output x2") + parser.add_argument("--no_glu", + action="store_false", + default=True, + dest="glu", + help="Replace all GLUs by ReLUs") + parser.add_argument("--no_rewrite", + action="store_false", + default=True, + dest="rewrite", + help="No 1x1 rewrite convolutions") + parser.add_argument("--normalize", action="store_true") + parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True) + + # Tasnet options + parser.add_argument("--tasnet", action="store_true") + parser.add_argument("--split_valid", + action="store_true", + help="Predict chunks by chunks for valid and test. Required for tasnet") + parser.add_argument("--X", type=int, default=8) + + # Other options + parser.add_argument("--show", + action="store_true", + help="Show model architecture, size and exit") + parser.add_argument("--save_model", action="store_true", + help="Skip traning, just save final model " + "for the current checkpoint value.") + parser.add_argument("--save_state", + help="Skip training, just save state " + "for the current checkpoint value. You should " + "provide a model name as argument.") + + # Quantization options + parser.add_argument("--q-min-size", type=float, default=1, + help="Only quantize layers over this size (in MB)") + parser.add_argument( + "--qat", type=int, help="If provided, use QAT training with that many bits.") + + parser.add_argument("--diffq", type=float, default=0) + parser.add_argument( + "--ms-target", type=float, default=162, + help="Model size target in MB, when using DiffQ. Best model will be kept " + "only if it is smaller than this target.") + + return parser + + +def get_name(parser, args): + """ + Return the name of an experiment given the args. Some parameters are ignored, + for instance --workers, as they do not impact the final result. + """ + ignore_args = set([ + "checkpoints", + "deterministic", + "eval", + "evals", + "eval_cpu", + "eval_workers", + "logs", + "master", + "rank", + "restart", + "save", + "save_model", + "save_state", + "show", + "workers", + "world_size", + ]) + parts = [] + name_args = dict(args.__dict__) + for name, value in name_args.items(): + if name in ignore_args: + continue + if value != parser.get_default(name): + if isinstance(value, Path): + parts.append(f"{name}={value.name}") + else: + parts.append(f"{name}={value}") + if parts: + name = " ".join(parts) + else: + name = "default" + return name diff --git a/demucs/pretrained.py b/demucs/pretrained.py new file mode 100644 index 0000000..6aac5db --- /dev/null +++ b/demucs/pretrained.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import logging + +from diffq import DiffQuantizer +import torch.hub + +from .model import Demucs +from .tasnet import ConvTasNet +from .utils import set_state + +logger = logging.getLogger(__name__) +ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/" + +PRETRAINED_MODELS = { + 'demucs': 'e07c671f', + 'demucs48_hq': '28a1282c', + 'demucs_extra': '3646af93', + 'demucs_quantized': '07afea75', + 'tasnet': 'beb46fac', + 'tasnet_extra': 'df3777b2', + 'demucs_unittest': '09ebc15f', +} + +SOURCES = ["drums", "bass", "other", "vocals"] + + +def get_url(name): + sig = PRETRAINED_MODELS[name] + return ROOT + name + "-" + sig[:8] + ".th" + + +def is_pretrained(name): + return name in PRETRAINED_MODELS + + +def load_pretrained(name): + if name == "demucs": + return demucs(pretrained=True) + elif name == "demucs48_hq": + return demucs(pretrained=True, hq=True, channels=48) + elif name == "demucs_extra": + return demucs(pretrained=True, extra=True) + elif name == "demucs_quantized": + return demucs(pretrained=True, quantized=True) + elif name == "demucs_unittest": + return demucs_unittest(pretrained=True) + elif name == "tasnet": + return tasnet(pretrained=True) + elif name == "tasnet_extra": + return tasnet(pretrained=True, extra=True) + else: + raise ValueError(f"Invalid pretrained name {name}") + + +def _load_state(name, model, quantizer=None): + url = get_url(name) + state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) + set_state(model, quantizer, state) + if quantizer: + quantizer.detach() + + +def demucs_unittest(pretrained=True): + model = Demucs(channels=4, sources=SOURCES) + if pretrained: + _load_state('demucs_unittest', model) + return model + + +def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64): + if not pretrained and (extra or quantized or hq): + raise ValueError("if extra or quantized is True, pretrained must be True.") + model = Demucs(sources=SOURCES, channels=channels) + if pretrained: + name = 'demucs' + if channels != 64: + name += str(channels) + quantizer = None + if sum([extra, quantized, hq]) > 1: + raise ValueError("Only one of extra, quantized, hq, can be True.") + if quantized: + quantizer = DiffQuantizer(model, group_size=8, min_size=1) + name += '_quantized' + if extra: + name += '_extra' + if hq: + name += '_hq' + _load_state(name, model, quantizer) + return model + + +def tasnet(pretrained=True, extra=False): + if not pretrained and extra: + raise ValueError("if extra is True, pretrained must be True.") + model = ConvTasNet(X=10, sources=SOURCES) + if pretrained: + name = 'tasnet' + if extra: + name = 'tasnet_extra' + _load_state(name, model) + return model diff --git a/demucs/raw.py b/demucs/raw.py new file mode 100644 index 0000000..d4941ad --- /dev/null +++ b/demucs/raw.py @@ -0,0 +1,173 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from collections import defaultdict, namedtuple +from pathlib import Path + +import musdb +import numpy as np +import torch as th +import tqdm +from torch.utils.data import DataLoader + +from .audio import AudioFile + +ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"]) + + +class Rawset: + """ + Dataset of raw, normalized, float32 audio files + """ + def __init__(self, path, samples=None, stride=None, channels=2, streams=None): + self.path = Path(path) + self.channels = channels + self.samples = samples + if stride is None: + stride = samples if samples is not None else 0 + self.stride = stride + entries = defaultdict(list) + for root, folders, files in os.walk(self.path, followlinks=True): + folders.sort() + files.sort() + for file in files: + if file.endswith(".raw"): + path = Path(root) / file + name, stream = path.stem.rsplit('.', 1) + entries[(path.parent.relative_to(self.path), name)].append(int(stream)) + + self._entries = list(entries.keys()) + + sizes = [] + self._lengths = [] + ref_streams = sorted(entries[self._entries[0]]) + assert ref_streams == list(range(len(ref_streams))) + if streams is None: + self.streams = ref_streams + else: + self.streams = streams + for entry in sorted(entries.keys()): + streams = entries[entry] + assert sorted(streams) == ref_streams + file = self._path(*entry) + length = file.stat().st_size // (4 * channels) + if samples is None: + sizes.append(1) + else: + if length < samples: + self._entries.remove(entry) + continue + sizes.append((length - samples) // stride + 1) + self._lengths.append(length) + if not sizes: + raise ValueError(f"Empty dataset {self.path}") + self._cumulative_sizes = np.cumsum(sizes) + self._sizes = sizes + + def __len__(self): + return self._cumulative_sizes[-1] + + @property + def total_length(self): + return sum(self._lengths) + + def chunk_info(self, index): + file_index = np.searchsorted(self._cumulative_sizes, index, side='right') + if file_index == 0: + local_index = index + else: + local_index = index - self._cumulative_sizes[file_index - 1] + return ChunkInfo(offset=local_index * self.stride, + file_index=file_index, + local_index=local_index) + + def _path(self, folder, name, stream=0): + return self.path / folder / (name + f'.{stream}.raw') + + def __getitem__(self, index): + chunk = self.chunk_info(index) + entry = self._entries[chunk.file_index] + + length = self.samples or self._lengths[chunk.file_index] + streams = [] + to_read = length * self.channels * 4 + for stream_index, stream in enumerate(self.streams): + offset = chunk.offset * 4 * self.channels + file = open(self._path(*entry, stream=stream), 'rb') + file.seek(offset) + content = file.read(to_read) + assert len(content) == to_read + content = np.frombuffer(content, dtype=np.float32) + content = content.copy() # make writable + streams.append(th.from_numpy(content).view(length, self.channels).t()) + return th.stack(streams, dim=0) + + def name(self, index): + chunk = self.chunk_info(index) + folder, name = self._entries[chunk.file_index] + return folder / name + + +class MusDBSet: + def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2): + self.mus = mus + self.streams = streams + self.samplerate = samplerate + self.channels = channels + + def __len__(self): + return len(self.mus.tracks) + + def __getitem__(self, index): + track = self.mus.tracks[index] + return (track.name, AudioFile(track.path).read(channels=self.channels, + seek_time=0, + streams=self.streams, + samplerate=self.samplerate)) + + +def build_raw(mus, destination, normalize, workers, samplerate, channels): + destination.mkdir(parents=True, exist_ok=True) + loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate), + batch_size=1, + num_workers=workers, + collate_fn=lambda x: x[0]) + for name, streams in tqdm.tqdm(loader): + if normalize: + ref = streams[0].mean(dim=0) # use mono mixture as reference + streams = (streams - ref.mean()) / ref.std() + for index, stream in enumerate(streams): + open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes()) + + +def main(): + parser = argparse.ArgumentParser('rawset') + parser.add_argument('--workers', type=int, default=10) + parser.add_argument('--samplerate', type=int, default=44100) + parser.add_argument('--channels', type=int, default=2) + parser.add_argument('musdb', type=Path) + parser.add_argument('destination', type=Path) + + args = parser.parse_args() + + build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"), + args.destination / "train", + normalize=True, + channels=args.channels, + samplerate=args.samplerate, + workers=args.workers) + build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"), + args.destination / "valid", + normalize=True, + samplerate=args.samplerate, + channels=args.channels, + workers=args.workers) + + +if __name__ == "__main__": + main() diff --git a/demucs/separate.py b/demucs/separate.py new file mode 100644 index 0000000..3fc7af9 --- /dev/null +++ b/demucs/separate.py @@ -0,0 +1,185 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys +from pathlib import Path +import subprocess + +import julius +import torch as th +import torchaudio as ta + +from .audio import AudioFile, convert_audio_channels +from .pretrained import is_pretrained, load_pretrained +from .utils import apply_model, load_model + + +def load_track(track, device, audio_channels, samplerate): + errors = {} + wav = None + + try: + wav = AudioFile(track).read( + streams=0, + samplerate=samplerate, + channels=audio_channels).to(device) + except FileNotFoundError: + errors['ffmpeg'] = 'Ffmpeg is not installed.' + except subprocess.CalledProcessError: + errors['ffmpeg'] = 'FFmpeg could not read the file.' + + if wav is None: + try: + wav, sr = ta.load(str(track)) + except RuntimeError as err: + errors['torchaudio'] = err.args[0] + else: + wav = convert_audio_channels(wav, audio_channels) + wav = wav.to(device) + wav = julius.resample_frac(wav, sr, samplerate) + + if wav is None: + print(f"Could not load file {track}. " + "Maybe it is not a supported file format? ") + for backend, error in errors.items(): + print(f"When trying to load using {backend}, got the following error: {error}") + sys.exit(1) + return wav + + +def encode_mp3(wav, path, bitrate=320, samplerate=44100, channels=2, verbose=False): + try: + import lameenc + except ImportError: + print("Failed to call lame encoder. Maybe it is not installed? " + "On windows, run `python.exe -m pip install -U lameenc`, " + "on OSX/Linux, run `python3 -m pip install -U lameenc`, " + "then try again.", file=sys.stderr) + sys.exit(1) + encoder = lameenc.Encoder() + encoder.set_bit_rate(bitrate) + encoder.set_in_sample_rate(samplerate) + encoder.set_channels(channels) + encoder.set_quality(2) # 2-highest, 7-fastest + if not verbose: + encoder.silence() + wav = wav.transpose(0, 1).numpy() + mp3_data = encoder.encode(wav.tobytes()) + mp3_data += encoder.flush() + with open(path, "wb") as f: + f.write(mp3_data) + + +def main(): + parser = argparse.ArgumentParser("demucs.separate", + description="Separate the sources for the given tracks") + parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks') + parser.add_argument("-n", + "--name", + default="demucs_quantized", + help="Model name. See README.md for the list of pretrained models. " + "Default is demucs_quantized.") + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("-o", + "--out", + type=Path, + default=Path("separated"), + help="Folder where to put extracted tracks. A subfolder " + "with the model name will be created.") + parser.add_argument("--models", + type=Path, + default=Path("models"), + help="Path to trained models. " + "Also used to store downloaded pretrained models") + parser.add_argument("-d", + "--device", + default="cuda" if th.cuda.is_available() else "cpu", + help="Device to use, default is cuda if available else cpu") + parser.add_argument("--shifts", + default=0, + type=int, + help="Number of random shifts for equivariant stabilization." + "Increase separation time but improves quality for Demucs. 10 was used " + "in the original paper.") + parser.add_argument("--overlap", + default=0.25, + type=float, + help="Overlap between the splits.") + parser.add_argument("--no-split", + action="store_false", + dest="split", + default=True, + help="Doesn't split audio in chunks. This can use large amounts of memory.") + parser.add_argument("--float32", + action="store_true", + help="Convert the output wavefile to use pcm f32 format instead of s16. " + "This should not make a difference if you just plan on listening to the " + "audio but might be needed to compute exactly metrics like SDR etc.") + parser.add_argument("--int16", + action="store_false", + dest="float32", + help="Opposite of --float32, here for compatibility.") + parser.add_argument("--mp3", action="store_true", + help="Convert the output wavs to mp3.") + parser.add_argument("--mp3-bitrate", + default=320, + type=int, + help="Bitrate of converted mp3.") + + args = parser.parse_args() + name = args.name + ".th" + model_path = args.models / name + if model_path.is_file(): + model = load_model(model_path) + else: + if is_pretrained(args.name): + model = load_pretrained(args.name) + else: + print(f"No pre-trained model {args.name}", file=sys.stderr) + sys.exit(1) + model.to(args.device) + + out = args.out / args.name + out.mkdir(parents=True, exist_ok=True) + print(f"Separated tracks will be stored in {out.resolve()}") + for track in args.tracks: + if not track.exists(): + print( + f"File {track} does not exist. If the path contains spaces, " + "please try again after surrounding the entire path with quotes \"\".", + file=sys.stderr) + continue + print(f"Separating track {track}") + wav = load_track(track, args.device, model.audio_channels, model.samplerate) + + ref = wav.mean(0) + wav = (wav - ref.mean()) / ref.std() + sources = apply_model(model, wav, shifts=args.shifts, split=args.split, + overlap=args.overlap, progress=True) + sources = sources * ref.std() + ref.mean() + + track_folder = out / track.name.rsplit(".", 1)[0] + track_folder.mkdir(exist_ok=True) + for source, name in zip(sources, model.sources): + source = source / max(1.01 * source.abs().max(), 1) + if args.mp3 or not args.float32: + source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short() + source = source.cpu() + stem = str(track_folder / name) + if args.mp3: + encode_mp3(source, stem + ".mp3", + bitrate=args.mp3_bitrate, + samplerate=model.samplerate, + channels=model.audio_channels, + verbose=args.verbose) + else: + wavname = str(track_folder / f"{name}.wav") + ta.save(wavname, source, sample_rate=model.samplerate) + + +if __name__ == "__main__": + main() diff --git a/demucs/tasnet.py b/demucs/tasnet.py new file mode 100644 index 0000000..ecc1257 --- /dev/null +++ b/demucs/tasnet.py @@ -0,0 +1,452 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Created on 2018/12 +# Author: Kaituo XU +# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels +# Here is the original license: +# The MIT License (MIT) +# +# Copyright (c) 2018 Kaituo XU +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import capture_init + +EPS = 1e-8 + + +def overlap_and_add(signal, frame_step): + outer_dimensions = signal.size()[:-2] + frames, frame_length = signal.size()[-2:] + + subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + + frame = torch.arange(0, output_subframes, + device=signal.device).unfold(0, subframes_per_frame, subframe_step) + frame = frame.long() # signal may in GPU or CPU + frame = frame.contiguous().view(-1) + + result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) + result.index_add_(-2, frame, subframe_signal) + result = result.view(*outer_dimensions, -1) + return result + + +class ConvTasNet(nn.Module): + @capture_init + def __init__(self, + sources, + N=256, + L=20, + B=256, + H=512, + P=3, + X=8, + R=4, + audio_channels=2, + norm_type="gLN", + causal=False, + mask_nonlinear='relu', + samplerate=44100, + segment_length=44100 * 2 * 4): + """ + Args: + sources: list of sources + N: Number of filters in autoencoder + L: Length of the filters (in samples) + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(ConvTasNet, self).__init__() + # Hyper-parameter + self.sources = sources + self.C = len(sources) + self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R + self.norm_type = norm_type + self.causal = causal + self.mask_nonlinear = mask_nonlinear + self.audio_channels = audio_channels + self.samplerate = samplerate + self.segment_length = segment_length + # Components + self.encoder = Encoder(L, N, audio_channels) + self.separator = TemporalConvNet( + N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear) + self.decoder = Decoder(N, L, audio_channels) + # init + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def valid_length(self, length): + return length + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + est_source: [M, C, T] + """ + mixture_w = self.encoder(mixture) + est_mask = self.separator(mixture_w) + est_source = self.decoder(mixture_w, est_mask) + + # T changed after conv1d in encoder, fix it here + T_origin = mixture.size(-1) + T_conv = est_source.size(-1) + est_source = F.pad(est_source, (0, T_origin - T_conv)) + return est_source + + +class Encoder(nn.Module): + """Estimation of the nonnegative mixture weight by a 1-D conv layer. + """ + def __init__(self, L, N, audio_channels): + super(Encoder, self).__init__() + # Hyper-parameter + self.L, self.N = L, N + # Components + # 50% overlap + self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False) + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 + """ + mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] + return mixture_w + + +class Decoder(nn.Module): + def __init__(self, N, L, audio_channels): + super(Decoder, self).__init__() + # Hyper-parameter + self.N, self.L = N, L + self.audio_channels = audio_channels + # Components + self.basis_signals = nn.Linear(N, audio_channels * L, bias=False) + + def forward(self, mixture_w, est_mask): + """ + Args: + mixture_w: [M, N, K] + est_mask: [M, C, N, K] + Returns: + est_source: [M, C, T] + """ + # D = W * M + source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] + source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] + # S = DV + est_source = self.basis_signals(source_w) # [M, C, K, ac * L] + m, c, k, _ = est_source.size() + est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous() + est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T + return est_source + + +class TemporalConvNet(nn.Module): + def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'): + """ + Args: + N: Number of filters in autoencoder + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(TemporalConvNet, self).__init__() + # Hyper-parameter + self.C = C + self.mask_nonlinear = mask_nonlinear + # Components + # [M, N, K] -> [M, N, K] + layer_norm = ChannelwiseLayerNorm(N) + # [M, N, K] -> [M, B, K] + bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) + # [M, B, K] -> [M, B, K] + repeats = [] + for r in range(R): + blocks = [] + for x in range(X): + dilation = 2**x + padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 + blocks += [ + TemporalBlock(B, + H, + P, + stride=1, + padding=padding, + dilation=dilation, + norm_type=norm_type, + causal=causal) + ] + repeats += [nn.Sequential(*blocks)] + temporal_conv_net = nn.Sequential(*repeats) + # [M, B, K] -> [M, C*N, K] + mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) + # Put together + self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, + mask_conv1x1) + + def forward(self, mixture_w): + """ + Keep this API same with TasNet + Args: + mixture_w: [M, N, K], M is batch size + returns: + est_mask: [M, C, N, K] + """ + M, N, K = mixture_w.size() + score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] + score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] + if self.mask_nonlinear == 'softmax': + est_mask = F.softmax(score, dim=1) + elif self.mask_nonlinear == 'relu': + est_mask = F.relu(score) + else: + raise ValueError("Unsupported mask non-linear function") + return est_mask + + +class TemporalBlock(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(TemporalBlock, self).__init__() + # [M, B, K] -> [M, H, K] + conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) + prelu = nn.PReLU() + norm = chose_norm(norm_type, out_channels) + # [M, H, K] -> [M, B, K] + dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, + dilation, norm_type, causal) + # Put together + self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) + + def forward(self, x): + """ + Args: + x: [M, B, K] + Returns: + [M, B, K] + """ + residual = x + out = self.net(x) + # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? + return out + residual # look like w/o F.relu is better than w/ F.relu + # return F.relu(out + residual) + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(DepthwiseSeparableConv, self).__init__() + # Use `groups` option to implement depthwise convolution + # [M, H, K] -> [M, H, K] + depthwise_conv = nn.Conv1d(in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False) + if causal: + chomp = Chomp1d(padding) + prelu = nn.PReLU() + norm = chose_norm(norm_type, in_channels) + # [M, H, K] -> [M, B, K] + pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) + # Put together + if causal: + self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) + else: + self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) + + def forward(self, x): + """ + Args: + x: [M, H, K] + Returns: + result: [M, B, K] + """ + return self.net(x) + + +class Chomp1d(nn.Module): + """To ensure the output length is the same as the input. + """ + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + """ + Args: + x: [M, H, Kpad] + Returns: + [M, H, K] + """ + return x[:, :, :-self.chomp_size].contiguous() + + +def chose_norm(norm_type, channel_size): + """The input of normlization will be (M, C, K), where M is batch size, + C is channel size and K is sequence length. + """ + if norm_type == "gLN": + return GlobalLayerNorm(channel_size) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size) + elif norm_type == "id": + return nn.Identity() + else: # norm_type == "BN": + # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics + # along M and K, so this BN usage is right. + return nn.BatchNorm1d(channel_size) + + +# TODO: Use nn.LayerNorm to impl cLN to speed up +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN)""" + def __init__(self, channel_size): + super(ChannelwiseLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + cLN_y: [M, N, K] + """ + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return cLN_y + + +class GlobalLayerNorm(nn.Module): + """Global Layer Normalization (gLN)""" + def __init__(self, channel_size): + super(GlobalLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + gLN_y: [M, N, K] + """ + # TODO: in torch 1.0, torch.mean() support dim list + mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] + var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return gLN_y + + +if __name__ == "__main__": + torch.manual_seed(123) + M, N, L, T = 2, 3, 4, 12 + K = 2 * T // L - 1 + B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False + mixture = torch.randint(3, (M, T)) + # test Encoder + encoder = Encoder(L, N) + encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) + mixture_w = encoder(mixture) + print('mixture', mixture) + print('U', encoder.conv1d_U.weight) + print('mixture_w', mixture_w) + print('mixture_w size', mixture_w.size()) + + # test TemporalConvNet + separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) + est_mask = separator(mixture_w) + print('est_mask', est_mask) + + # test Decoder + decoder = Decoder(N, L) + est_mask = torch.randint(2, (B, K, C, N)) + est_source = decoder(mixture_w, est_mask) + print('est_source', est_source) + + # test Conv-TasNet + conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) + est_source = conv_tasnet(mixture) + print('est_source', est_source) + print('est_source size', est_source.size()) diff --git a/demucs/test.py b/demucs/test.py new file mode 100644 index 0000000..4140914 --- /dev/null +++ b/demucs/test.py @@ -0,0 +1,109 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import sys +from concurrent import futures + +import musdb +import museval +import torch as th +import tqdm +from scipy.io import wavfile +from torch import distributed + +from .audio import convert_audio +from .utils import apply_model + + +def evaluate(model, + musdb_path, + eval_folder, + workers=2, + device="cpu", + rank=0, + save=False, + shifts=0, + split=False, + overlap=0.25, + is_wav=False, + world_size=1): + """ + Evaluate model using museval. Run the model + on a single GPU, the bottleneck being the call to museval. + """ + + output_dir = eval_folder / "results" + output_dir.mkdir(exist_ok=True, parents=True) + json_folder = eval_folder / "results/test" + json_folder.mkdir(exist_ok=True, parents=True) + + # we load tracks from the original musdb set + test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav) + src_rate = 44100 # hardcoded for now... + + for p in model.parameters(): + p.requires_grad = False + p.grad = None + + pendings = [] + with futures.ProcessPoolExecutor(workers or 1) as pool: + for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout): + track = test_set.tracks[index] + + out = json_folder / f"{track.name}.json.gz" + if out.exists(): + continue + + mix = th.from_numpy(track.audio).t().float() + ref = mix.mean(dim=0) # mono mixture + mix = (mix - ref.mean()) / ref.std() + mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) + estimates = apply_model(model, mix.to(device), + shifts=shifts, split=split, overlap=overlap) + estimates = estimates * ref.std() + ref.mean() + + estimates = estimates.transpose(1, 2) + references = th.stack( + [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) + references = convert_audio(references, src_rate, + model.samplerate, model.audio_channels) + references = references.transpose(1, 2).numpy() + estimates = estimates.cpu().numpy() + win = int(1. * model.samplerate) + hop = int(1. * model.samplerate) + if save: + folder = eval_folder / "wav/test" / track.name + folder.mkdir(exist_ok=True, parents=True) + for name, estimate in zip(model.sources, estimates): + wavfile.write(str(folder / (name + ".wav")), 44100, estimate) + + if workers: + pendings.append((track.name, pool.submit( + museval.evaluate, references, estimates, win=win, hop=hop))) + else: + pendings.append((track.name, museval.evaluate( + references, estimates, win=win, hop=hop))) + del references, mix, estimates, track + + for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout): + if workers: + pending = pending.result() + sdr, isr, sir, sar = pending + track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name) + for idx, target in enumerate(model.sources): + values = { + "SDR": sdr[idx].tolist(), + "SIR": sir[idx].tolist(), + "ISR": isr[idx].tolist(), + "SAR": sar[idx].tolist() + } + + track_store.add_target(target_name=target, values=values) + json_path = json_folder / f"{track_name}.json.gz" + gzip.open(json_path, "w").write(track_store.json.encode('utf-8')) + if world_size > 1: + distributed.barrier() diff --git a/demucs/utils.py b/demucs/utils.py new file mode 100644 index 0000000..4364184 --- /dev/null +++ b/demucs/utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import errno +import functools +import hashlib +import inspect +import io +import os +import random +import socket +import tempfile +import warnings +import zlib +from contextlib import contextmanager + +from diffq import UniformQuantizer, DiffQuantizer +import torch as th +import tqdm +from torch import distributed +from torch.nn import functional as F + + +def center_trim(tensor, reference): + """ + Center trim `tensor` with respect to `reference`, along the last dimension. + `reference` can also be a number, representing the length to trim to. + If the size difference != 0 mod 2, the extra sample is removed on the right side. + """ + if hasattr(reference, "size"): + reference = reference.size(-1) + delta = tensor.size(-1) - reference + if delta < 0: + raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") + if delta: + tensor = tensor[..., delta // 2:-(delta - delta // 2)] + return tensor + + +def average_metric(metric, count=1.): + """ + Average `metric` which should be a float across all hosts. `count` should be + the weight for this particular host (i.e. number of examples). + """ + metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') + distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) + return metric[1].item() / metric[0].item() + + +def free_port(host='', low=20000, high=40000): + """ + Return a port number that is most likely free. + This could suffer from a race condition although + it should be quite rare. + """ + sock = socket.socket() + while True: + port = random.randint(low, high) + try: + sock.bind((host, port)) + except OSError as error: + if error.errno == errno.EADDRINUSE: + continue + raise + return port + + +def sizeof_fmt(num, suffix='B'): + """ + Given `num` bytes, return human readable size. + Taken from https://stackoverflow.com/a/1094933 + """ + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + +def human_seconds(seconds, display='.2f'): + """ + Given `seconds` seconds, return human readable duration. + """ + value = seconds * 1e6 + ratios = [1e3, 1e3, 60, 60, 24] + names = ['us', 'ms', 's', 'min', 'hrs', 'days'] + last = names.pop(0) + for name, ratio in zip(names, ratios): + if value / ratio < 0.3: + break + value /= ratio + last = name + return f"{format(value, display)} {last}" + + +class TensorChunk: + def __init__(self, tensor, offset=0, length=None): + total_length = tensor.shape[-1] + assert offset >= 0 + assert offset < total_length + + if length is None: + length = total_length - offset + else: + length = min(total_length - offset, length) + + self.tensor = tensor + self.offset = offset + self.length = length + self.device = tensor.device + + @property + def shape(self): + shape = list(self.tensor.shape) + shape[-1] = self.length + return shape + + def padded(self, target_length): + delta = target_length - self.length + total_length = self.tensor.shape[-1] + assert delta >= 0 + + start = self.offset - delta // 2 + end = start + target_length + + correct_start = max(0, start) + correct_end = min(total_length, end) + + pad_left = correct_start - start + pad_right = end - correct_end + + out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) + assert out.shape[-1] == target_length + return out + + +def tensor_chunk(tensor_or_chunk): + if isinstance(tensor_or_chunk, TensorChunk): + return tensor_or_chunk + else: + assert isinstance(tensor_or_chunk, th.Tensor) + return TensorChunk(tensor_or_chunk) + + +def apply_model(model, mix, shifts=None, split=False, + overlap=0.25, transition_power=1., progress=False): + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + """ + assert transition_power >= 1, "transition_power < 1 leads to weird behavior." + device = mix.device + channels, length = mix.shape + if split: + out = th.zeros(len(model.sources), channels, length, device=device) + sum_weight = th.zeros(length, device=device) + segment = model.segment_length + stride = int((1 - overlap) * segment) + offsets = range(0, length, stride) + scale = stride / model.samplerate + if progress: + offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') + # We start from a triangle shaped weight, with maximal weight in the middle + # of the segment. Then we normalize and take to the power `transition_power`. + # Large values of transition power will lead to sharper transitions. + weight = th.cat([th.arange(1, segment // 2 + 1), + th.arange(segment - segment // 2, 0, -1)]).to(device) + assert len(weight) == segment + # If the overlap < 50%, this will translate to linear transition when + # transition_power is 1. + weight = (weight / weight.max())**transition_power + for offset in offsets: + chunk = TensorChunk(mix, offset, segment) + chunk_out = apply_model(model, chunk, shifts=shifts) + chunk_length = chunk_out.shape[-1] + out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out + sum_weight[offset:offset + segment] += weight[:chunk_length] + offset += segment + assert sum_weight.min() > 0 + out /= sum_weight + return out + elif shifts: + max_shift = int(0.5 * model.samplerate) + mix = tensor_chunk(mix) + padded_mix = mix.padded(length + 2 * max_shift) + out = 0 + for _ in range(shifts): + offset = random.randint(0, max_shift) + shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) + shifted_out = apply_model(model, shifted) + out += shifted_out[..., max_shift - offset:] + out /= shifts + return out + else: + valid_length = model.valid_length(length) + mix = tensor_chunk(mix) + padded_mix = mix.padded(valid_length) + with th.no_grad(): + out = model(padded_mix.unsqueeze(0))[0] + return center_trim(out, length) + + +@contextmanager +def temp_filenames(count, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + + +def get_quantizer(model, args, optimizer=None): + quantizer = None + if args.diffq: + quantizer = DiffQuantizer( + model, min_size=args.q_min_size, group_size=8) + if optimizer is not None: + quantizer.setup_optimizer(optimizer) + elif args.qat: + quantizer = UniformQuantizer( + model, bits=args.qat, min_size=args.q_min_size) + return quantizer + + +def load_model(path, strict=False): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + load_from = path + package = th.load(load_from, 'cpu') + + klass = package["klass"] + args = package["args"] + kwargs = package["kwargs"] + + if strict: + model = klass(*args, **kwargs) + else: + sig = inspect.signature(klass) + for key in list(kwargs): + if key not in sig.parameters: + warnings.warn("Dropping inexistant parameter " + key) + del kwargs[key] + model = klass(*args, **kwargs) + + state = package["state"] + training_args = package["training_args"] + quantizer = get_quantizer(model, training_args) + + set_state(model, quantizer, state) + return model + + +def get_state(model, quantizer): + if quantizer is None: + state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} + else: + state = quantizer.get_quantized_state() + buf = io.BytesIO() + th.save(state, buf) + state = {'compressed': zlib.compress(buf.getvalue())} + return state + + +def set_state(model, quantizer, state): + if quantizer is None: + model.load_state_dict(state) + else: + buf = io.BytesIO(zlib.decompress(state["compressed"])) + state = th.load(buf, "cpu") + quantizer.restore_quantized_state(state) + + return state + + +def save_state(state, path): + buf = io.BytesIO() + th.save(state, buf) + sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] + + path = path.parent / (path.stem + "-" + sig + path.suffix) + path.write_bytes(buf.getvalue()) + + +def save_model(model, quantizer, training_args, path): + args, kwargs = model._init_args_kwargs + klass = model.__class__ + + state = get_state(model, quantizer) + + save_to = path + package = { + 'klass': klass, + 'args': args, + 'kwargs': kwargs, + 'state': state, + 'training_args': training_args, + } + th.save(package, save_to) + + +def capture_init(init): + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ diff --git a/demucs/wav.py b/demucs/wav.py new file mode 100644 index 0000000..a65c3b2 --- /dev/null +++ b/demucs/wav.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +import hashlib +import math +import json +from pathlib import Path + +import julius +import torch as th +from torch import distributed +import torchaudio as ta +from torch.nn import functional as F + +from .audio import convert_audio_channels +from .compressed import get_musdb_tracks + +MIXTURE = "mixture" +EXT = ".wav" + + +def _track_metadata(track, sources): + track_length = None + track_samplerate = None + for source in sources + [MIXTURE]: + file = track / f"{source}{EXT}" + info = ta.info(str(file)) + length = info.num_frames + if track_length is None: + track_length = length + track_samplerate = info.sample_rate + elif track_length != length: + raise ValueError( + f"Invalid length for file {file}: " + f"expecting {track_length} but got {length}.") + elif info.sample_rate != track_samplerate: + raise ValueError( + f"Invalid sample rate for file {file}: " + f"expecting {track_samplerate} but got {info.sample_rate}.") + if source == MIXTURE: + wav, _ = ta.load(str(file)) + wav = wav.mean(0) + mean = wav.mean().item() + std = wav.std().item() + + return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} + + +def _build_metadata(path, sources): + meta = {} + path = Path(path) + for file in path.iterdir(): + meta[file.name] = _track_metadata(file, sources) + return meta + + +class Wavset: + def __init__( + self, + root, metadata, sources, + length=None, stride=None, normalize=True, + samplerate=44100, channels=2): + """ + Waveset (or mp3 set for that matter). Can be used to train + with arbitrary sources. Each track should be one folder inside of `path`. + The folder should contain files named `{source}.{ext}`. + Files will be grouped according to `sources` (each source is a list of + filenames). + + Sample rate and channels will be converted on the fly. + + `length` is the sample size to extract (in samples, not duration). + `stride` is how many samples to move by between each example. + """ + self.root = Path(root) + self.metadata = OrderedDict(metadata) + self.length = length + self.stride = stride or length + self.normalize = normalize + self.sources = sources + self.channels = channels + self.samplerate = samplerate + self.num_examples = [] + for name, meta in self.metadata.items(): + track_length = int(self.samplerate * meta['length'] / meta['samplerate']) + if length is None or track_length < length: + examples = 1 + else: + examples = int(math.ceil((track_length - self.length) / self.stride) + 1) + self.num_examples.append(examples) + + def __len__(self): + return sum(self.num_examples) + + def get_file(self, name, source): + return self.root / name / f"{source}{EXT}" + + def __getitem__(self, index): + for name, examples in zip(self.metadata, self.num_examples): + if index >= examples: + index -= examples + continue + meta = self.metadata[name] + num_frames = -1 + offset = 0 + if self.length is not None: + offset = int(math.ceil( + meta['samplerate'] * self.stride * index / self.samplerate)) + num_frames = int(math.ceil( + meta['samplerate'] * self.length / self.samplerate)) + wavs = [] + for source in self.sources: + file = self.get_file(name, source) + wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) + wav = convert_audio_channels(wav, self.channels) + wavs.append(wav) + + example = th.stack(wavs) + example = julius.resample_frac(example, meta['samplerate'], self.samplerate) + if self.normalize: + example = (example - meta['mean']) / meta['std'] + if self.length: + example = example[..., :self.length] + example = F.pad(example, (0, self.length - example.shape[-1])) + return example + + +def get_wav_datasets(args, samples, sources): + sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8] + metadata_file = args.metadata / (sig + ".json") + train_path = args.wav / "train" + valid_path = args.wav / "valid" + if not metadata_file.is_file() and args.rank == 0: + train = _build_metadata(train_path, sources) + valid = _build_metadata(valid_path, sources) + json.dump([train, valid], open(metadata_file, "w")) + if args.world_size > 1: + distributed.barrier() + train, valid = json.load(open(metadata_file)) + train_set = Wavset(train_path, train, sources, + length=samples, stride=args.data_stride, + samplerate=args.samplerate, channels=args.audio_channels, + normalize=args.norm_wav) + valid_set = Wavset(valid_path, valid, [MIXTURE] + sources, + samplerate=args.samplerate, channels=args.audio_channels, + normalize=args.norm_wav) + return train_set, valid_set + + +def get_musdb_wav_datasets(args, samples, sources): + metadata_file = args.metadata / "musdb_wav.json" + root = args.musdb / "train" + if not metadata_file.is_file() and args.rank == 0: + metadata = _build_metadata(root, sources) + json.dump(metadata, open(metadata_file, "w")) + if args.world_size > 1: + distributed.barrier() + metadata = json.load(open(metadata_file)) + + train_tracks = get_musdb_tracks(args.musdb, is_wav=True, subsets=["train"], split="train") + metadata_train = {name: meta for name, meta in metadata.items() if name in train_tracks} + metadata_valid = {name: meta for name, meta in metadata.items() if name not in train_tracks} + train_set = Wavset(root, metadata_train, sources, + length=samples, stride=args.data_stride, + samplerate=args.samplerate, channels=args.audio_channels, + normalize=args.norm_wav) + valid_set = Wavset(root, metadata_valid, [MIXTURE] + sources, + samplerate=args.samplerate, channels=args.audio_channels, + normalize=args.norm_wav) + return train_set, valid_set diff --git a/diffq/__init__.py b/diffq/__init__.py new file mode 100644 index 0000000..2b997ee --- /dev/null +++ b/diffq/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +""" +This package implements different quantization strategies: + +- `diffq.uniform.UniformQuantizer`: classic uniform quantization over n bits. +- `diffq.diffq.DiffQuantizer`: differentiable quantizer based on scaled noise injection. + +Also, do check `diffq.base.BaseQuantizer` for the common methods of all Quantizers. +""" + +from .uniform import UniformQuantizer +from .diffq import DiffQuantizer diff --git a/diffq/base.py b/diffq/base.py new file mode 100644 index 0000000..9bd5276 --- /dev/null +++ b/diffq/base.py @@ -0,0 +1,262 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from concurrent import futures +from fnmatch import fnmatch +from functools import partial +import io +import math +from multiprocessing import cpu_count +import typing as tp +import zlib + +import torch + + +class BaseQuantizer: + @dataclass + class _QuantizedParam: + name: str + param: torch.nn.Parameter + module: torch.nn.Module + # If a Parameter is used multiple times, `other` can be used + # to share state between the different Quantizers + other: tp.Optional[tp.Any] + + def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False, + exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True): + self.model = model + self.min_size = min_size + self.float16 = float16 + self.exclude = exclude + self.detect_bound = detect_bound + self._quantized = False + self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook) + self._post_handle = self.model.register_forward_hook(self._forward_hook) + + self._quantized_state = None + self._qparams = [] + self._float16 = [] + self._others = [] + self._rnns = [] + + self._saved = [] + + self._find_params() + + def _find_params(self): + min_params = self.min_size * 2**20 // 4 + previous = {} + for module_name, module in self.model.named_modules(): + if isinstance(module, torch.nn.RNNBase): + self._rnns.append(module) + for name, param in list(module.named_parameters(recurse=False)): + full_name = f"{module_name}.{name}" + matched = False + for pattern in self.exclude: + if fnmatch(full_name, pattern) or fnmatch(name, pattern): + matched = True + break + + if param.numel() <= min_params or matched: + if id(param) in previous: + continue + if self.detect_bound: + previous[id(param)] = None + if self.float16: + self._float16.append(param) + else: + self._others.append(param) + else: + qparam = self._register_param(name, param, module, previous.get(id(param))) + if self.detect_bound: + previous[id(param)] = qparam + self._qparams.append(qparam) + + def _register_param(self, name, param, module, other): + return self.__class__._QuantizedParam(name, param, module, other) + + def _forward_pre_hook(self, module, input): + if self.model.training: + self._quantized_state = None + if self._quantized: + self.unquantize() + if self._pre_forward_train(): + self._fix_rnns() + else: + self.quantize() + + def _forward_hook(self, module, input, output): + if self.model.training: + if self._post_forward_train(): + self._fix_rnns(flatten=False) # Hacky, next forward will flatten + + def quantize(self, save=True): + """ + Immediately apply quantization to the model parameters. + If `save` is True, save a copy of the unquantized parameters, that can be + restored with `unquantize()`. + """ + if self._quantized: + return + if save: + self._saved = [qp.param.data.to('cpu', copy=True) + for qp in self._qparams if qp.other is None] + self.restore_quantized_state(self.get_quantized_state()) + self._quantized = True + self._fix_rnns() + + def unquantize(self): + """ + Revert a previous call to `quantize()`. + """ + if not self._quantized: + raise RuntimeError("Can only be called on a quantized model.") + if not self._saved: + raise RuntimeError("Nothing to restore.") + for qparam in self._qparams: + if qparam.other is None: + qparam.param.data[:] = self._saved.pop(0) + assert len(self._saved) == 0 + self._quantized = False + self._fix_rnns() + + def _pre_forward_train(self) -> bool: + """ + Called once before each forward for continuous quantization. + Should return True if parameters were changed. + """ + return False + + def _post_forward_train(self) -> bool: + """ + Called once after each forward (to restore state for instance). + Should return True if parameters were changed. + """ + return False + + def _fix_rnns(self, flatten=True): + """ + To be called after quantization happened to fix RNNs. + """ + for rnn in self._rnns: + rnn._flat_weights = [ + (lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn) + for wn in rnn._flat_weights_names] + if flatten: + rnn.flatten_parameters() + + def get_quantized_state(self): + """ + Returns sufficient quantized information to rebuild the model state. + + ..Note:: + To achieve maximum compression, you should compress this with + gzip or other, as quantized weights are not optimally coded! + """ + if self._quantized_state is None: + self._quantized_state = self._get_quantized_state() + return self._quantized_state + + def _get_quantized_state(self): + """ + Actual implementation for `get_quantized_state`. + """ + float16_params = [] + for p in self._float16: + q = p.data.half() + float16_params.append(q) + + return { + "quantized": [self._quantize_param(qparam) for qparam in self._qparams + if qparam.other is None], + "float16": float16_params, + "others": [p.data.clone() for p in self._others], + } + + def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: + """ + To be overriden. + """ + raise NotImplementedError() + + def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: + """ + To be overriden. + """ + raise NotImplementedError() + + def restore_quantized_state(self, state) -> None: + """ + Restore the state of the model from the quantized state. + """ + for p, q in zip(self._float16, state["float16"]): + p.data[:] = q.to(p) + + for p, q in zip(self._others, state["others"]): + p.data[:] = q + + remaining = list(state["quantized"]) + for qparam in self._qparams: + if qparam.other is not None: + # Only unquantize first appearance of nn.Parameter. + continue + quantized = remaining.pop(0) + qparam.param.data[:] = self._unquantize_param(qparam, quantized) + self._fix_rnns() + + def detach(self) -> None: + """ + Detach from the model, removes hooks and anything else. + """ + self._pre_handle.remove() + self._post_handle.remove() + + def model_size(self) -> torch.Tensor: + """ + Returns an estimate of the quantized model size. + """ + total = torch.tensor(0.) + for p in self._float16: + total += 16 * p.numel() + for p in self._others: + total += 32 * p.numel() + return total / 2**20 / 8 # bits to MegaBytes + + def true_model_size(self) -> float: + """ + Return the true quantized model size, in MB, without extra + compression. + """ + return self.model_size().item() + + def compressed_model_size(self, compress_level=-1, num_workers=8) -> float: + """ + Return the compressed quantized model size, in MB. + + Args: + compress_level (int): compression level used with zlib, + see `zlib.compress` for details. + num_workers (int): will split the final big byte representation in that + many chunks processed in parallels. + """ + out = io.BytesIO() + torch.save(self.get_quantized_state(), out) + ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers) + return ms / 2 ** 20 + + +def _compress_len(data, compress_level): + return len(zlib.compress(data, level=compress_level)) + + +def _parallel_compress_len(data, compress_level, num_workers): + num_workers = min(cpu_count(), num_workers) + chunk_size = int(math.ceil(len(data) / num_workers)) + chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)] + with futures.ProcessPoolExecutor(num_workers) as pool: + return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks)) diff --git a/diffq/diffq.py b/diffq/diffq.py new file mode 100644 index 0000000..b475ec7 --- /dev/null +++ b/diffq/diffq.py @@ -0,0 +1,286 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Differentiable quantizer based on scaled noise injection. +""" +from dataclasses import dataclass +import math +import typing as tp + +import torch + +from .base import BaseQuantizer +from .uniform import uniform_quantize, uniform_unquantize +from .utils import simple_repr + + +class DiffQuantizer(BaseQuantizer): + @dataclass + class _QuantizedParam(BaseQuantizer._QuantizedParam): + logit: torch.nn.Parameter + + def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False, + group_size: int = 1, min_bits: float = 2, max_bits: float = 15, + param="bits", noise="gaussian", + init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq", + exclude: tp.List[str] = [], detect_bound: bool = True): + """ + Differentiable quantizer based on scaled noise injection. + For every parameter `p` in the model, this introduces a number of bits parameter + `b` with the same dimensions (when group_size = 1). + Before each forward, `p` is replaced by `p + U` + with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization + step for `b` bits. + This noise approximates the quantization noise in a differentiable manner, both + with respect to the unquantized parameter `p` and the number of bits `b`. + + At eveluation (as detected with `model.eval()`), the model is replaced + by its true quantized version, and restored when going back to training. + + When doing actual quantization (for serialization, or evaluation), + the number of bits is rounded to the nearest integer, and needs to be stored along. + This will cost a few bits per dimension. To reduce this cost, one can use `group_size`, + which will use a single noise level for multiple weight entries. + + You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the + model size in MB. You can then use this estimate as a penalty in your training loss. + + Args: + model (torch.nn.Module): model to quantize + min_size (float): minimum size in MB of a parameter to be quantized. + float16 (bool): if a layer is smaller than min_size, should we still do float16? + group_size (int): weight entries are groupped together to reduce the number + of noise scales to store. This should divide the size of all parameters + bigger than min_size. + min_bits (float): minimal number of bits. + max_bits (float): maximal number of bits. + init_bits (float): initial number of bits. + extra_bits (float): extra bits to add for actual quantization (before roundoff). + suffix (str): suffix used for the name of the extra noise scale parameters. + exclude (list[str]): list of patterns used to match parameters to exclude. + For instance `['bias']` to exclude all bias terms. + detect_bound (bool): if True, will detect bound parameters and reuse + the same quantized tensor for both, as well as the same number of bits. + + ..Warning:: + You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly. + + """ + self.group_size = group_size + self.min_bits = min_bits + self.max_bits = max_bits + self.init_bits = init_bits + self.extra_bits = extra_bits + self.suffix = suffix + self.param = param + self.noise = noise + assert noise in ["gaussian", "uniform"] + self._optimizer_setup = False + + self._min_noise = 1 / (2 ** self.max_bits - 1) + self._max_noise = 1 / (2 ** self.min_bits - 1) + + assert group_size >= 0 + assert min_bits < init_bits < max_bits, \ + "init_bits must be between min_bits and max_bits excluded3" + + for name, _ in model.named_parameters(): + if name.endswith(suffix): + raise RuntimeError("The model already has some noise scales parameters, " + "maybe you used twice a DiffQuantizer on the same model?.") + + super().__init__(model, min_size, float16, exclude, detect_bound) + + def _get_bits(self, logit: torch.Tensor): + if self.param == "noise": + return torch.log2(1 + 1 / self._get_noise_scale(logit)) + else: + t = torch.sigmoid(logit) + return self.max_bits * t + (1 - t) * self.min_bits + + def _get_noise_scale(self, logit: torch.Tensor): + if self.param == "noise": + t = torch.sigmoid(logit) + return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise)) + else: + return 1 / (2 ** self._get_bits(logit) - 1) + + def _register_param(self, name, param, module, other): + if other is not None: + return self.__class__._QuantizedParam( + name=name, param=param, module=module, logit=other.logit, other=other) + assert self.group_size == 0 or param.numel() % self.group_size == 0 + # we want the initial number of bits to be init_bits. + if self.param == "noise": + noise_scale = 1 / (2 ** self.init_bits - 1) + t = (math.log(noise_scale) - math.log(self._max_noise)) / ( + math.log(self._min_noise) - math.log(self._max_noise)) + else: + t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits) + assert 0 < t < 1 + logit = torch.logit(torch.tensor(float(t))) + assert abs(self._get_bits(logit) - self.init_bits) < 1e-5 + if self.group_size > 0: + nparam = param.numel() // self.group_size + else: + nparam = 1 + logit = torch.nn.Parameter( + torch.full( + (nparam,), + logit, + device=param.device)) + module.register_parameter(name + self.suffix, logit) + return self.__class__._QuantizedParam( + name=name, param=param, module=module, logit=logit, other=None) + + def clear_optimizer(self, optimizer: torch.optim.Optimizer): + params = [qp.logit for qp in self._qparams] + + for group in optimizer.param_groups: + new_params = [] + for q in list(group["params"]): + matched = False + for p in params: + if p is q: + matched = True + if not matched: + new_params.append(q) + group["params"][:] = new_params + + def setup_optimizer(self, optimizer: torch.optim.Optimizer, + lr: float = 1e-3, **kwargs): + """ + Setup the optimizer to tune the number of bits. In particular, this will deactivate + weight decay for the bits parameters. + + Args: + optimizer (torch.Optimizer): optimizer to use. + lr (float): specific learning rate for the bits parameters. 1e-3 + is perfect for Adam.,w + kwargs (dict): overrides for other optimization parameters for the bits. + """ + assert not self._optimizer_setup + self._optimizer_setup = True + + params = [qp.logit for qp in self._qparams] + + for group in optimizer.param_groups: + for q in list(group["params"]): + for p in params: + if p is q: + raise RuntimeError("You should create the optimizer " + "before the quantizer!") + + group = {"params": params, "lr": lr, "weight_decay": 0} + group.update(kwargs) + optimizer.add_param_group(group) + + def no_optimizer(self): + """ + Call this if you do not want to use an optimizer. + """ + self._optimizer_setup = True + + def check_unused(self): + for qparam in self._qparams: + if qparam.other is not None: + continue + grad = qparam.param.grad + if grad is None or (grad == 0).all(): + if qparam.logit.grad is not None: + qparam.logit.grad.data.zero_() + + def model_size(self, exact=False): + """ + Differentiable estimate of the model size. + The size is returned in MB. + + If `exact` is True, then the output is no longer differentiable but + reflect exactly an achievable size, even without compression, + i.e.same as returned by `naive_model_size()`. + """ + total = super().model_size() + subtotal = 0 + for qparam in self._qparams: + # only count the first appearance of a Parameter + if qparam.other is not None: + continue + bits = self.extra_bits + self._get_bits(qparam.logit) + if exact: + bits = bits.round().clamp(1, 15) + if self.group_size == 0: + group_size = qparam.param.numel() + else: + group_size = self.group_size + subtotal += group_size * bits.sum() + subtotal += 2 * 32 # param scale + + # Number of bits to represent each number of bits + bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits))) + subtotal += 8 # 8 bits for bits_bits + subtotal += bits_bits * bits.numel() + + subtotal /= 2 ** 20 * 8 # bits -> MegaBytes + return total + subtotal + + def true_model_size(self): + """ + Naive model size without zlib compression. + """ + return self.model_size(exact=True).item() + + def _pre_forward_train(self): + if not self._optimizer_setup: + raise RuntimeError("You must call `setup_optimizer()` on your optimizer " + "before starting training.") + for qparam in self._qparams: + if qparam.other is not None: + noisy = qparam.other.module._parameters[qparam.other.name] + else: + bits = self._get_bits(qparam.logit)[:, None] + if self.group_size == 0: + p_flat = qparam.param.view(-1) + else: + p_flat = qparam.param.view(-1, self.group_size) + scale = p_flat.max() - p_flat.min() + unit = 1 / (2**bits - 1) + if self.noise == "uniform": + noise_source = (torch.rand_like(p_flat) - 0.5) + elif self.noise == "gaussian": + noise_source = torch.randn_like(p_flat) / 2 + noise = scale * unit * noise_source + noisy = p_flat + noise + # We bypass the checks by PyTorch on parameters being leafs + qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param) + return True + + def _post_forward_train(self): + for qparam in self._qparams: + qparam.module._parameters[qparam.name] = qparam.param + return True + + def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: + bits = self.extra_bits + self._get_bits(qparam.logit) + bits = bits.round().clamp(1, 15)[:, None].byte() + if self.group_size == 0: + p = qparam.param.data.view(-1) + else: + p = qparam.param.data.view(-1, self.group_size) + levels, scales = uniform_quantize(p, bits) + return levels, scales, bits + + def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: + levels, param_scale, bits = quantized + return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data) + + def detach(self): + super().detach() + for qparam in self._qparams: + delattr(qparam.module, qparam.name + self.suffix) + + def __repr__(self): + return simple_repr(self) diff --git a/diffq/uniform.py b/diffq/uniform.py new file mode 100644 index 0000000..f61e912 --- /dev/null +++ b/diffq/uniform.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Classic uniform quantization over n bits. +""" +from typing import Tuple +import torch + +from .base import BaseQuantizer +from .utils import simple_repr + + +def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)): + """ + Quantize the given weights over `bits` bits. + + Returns: + - quantized levels + - (min, max) range. + + """ + assert (bits >= 1).all() and (bits <= 15).all() + num_levels = (2 ** bits.float()).long() + mn = p.min().item() + mx = p.max().item() + p = (p - mn) / (mx - mn) # put p in [0, 1] + unit = 1 / (num_levels - 1) # quantization unit + levels = (p / unit).round() + if (bits <= 8).all(): + levels = levels.byte() + else: + levels = levels.short() + return levels, (mn, mx) + + +def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float], + bits: torch.Tensor = torch.tensor(8.)): + """ + Unquantize the weights from the levels and scale. Return a float32 tensor. + """ + mn, mx = scales + num_levels = 2 ** bits.float() + unit = 1 / (num_levels - 1) + levels = levels.float() + p = levels * unit # in [0, 1] + return p * (mx - mn) + mn + + +class UniformQuantizer(BaseQuantizer): + def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01, + float16: bool = False, qat: bool = False, exclude=[], detect_bound=True): + """ + Args: + model (torch.nn.Module): model to quantize + bits (float): number of bits to quantize over. + min_size (float): minimum size in MB of a parameter to be quantized. + float16 (bool): if a layer is smaller than min_size, should we still do float16? + qat (bool): perform quantized aware training. + exclude (list[str]): list of patterns used to match parameters to exclude. + For instance `['bias']` to exclude all bias terms. + detect_bound (bool): if True, will detect bound parameters and reuse + the same quantized tensor for both. + """ + self.bits = float(bits) + self.qat = qat + + super().__init__(model, min_size, float16, exclude, detect_bound) + + def __repr__(self): + return simple_repr(self, ) + + def _pre_forward_train(self): + if self.qat: + for qparam in self._qparams: + if qparam.other is not None: + new_param = qparam.other.module._parameters[qparam.other.name] + else: + quantized = self._quantize_param(qparam) + qvalue = self._unquantize_param(qparam, quantized) + new_param = qparam.param + (qvalue - qparam.param).detach() + qparam.module._parameters[qparam.name] = new_param + return True + return False + + def _post_forward_train(self): + if self.qat: + for qparam in self._qparams: + qparam.module._parameters[qparam.name] = qparam.param + return True + return False + + def _quantize_param(self, qparam): + levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits)) + return (levels, scales) + + def _unquantize_param(self, qparam, quantized): + levels, scales = quantized + return uniform_unquantize(levels, scales, torch.tensor(self.bits)) + + def model_size(self): + """ + Non differentiable model size in MB. + """ + total = super().model_size() + subtotal = 0 + for qparam in self._qparams: + if qparam.other is None: # if parameter is bound, count only one copy. + subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales + subtotal /= 2**20 * 8 # bits to MegaBytes + return total + subtotal + + def true_model_size(self): + """ + Return the true quantized model size, in MB, without extra + compression. + """ + return self.model_size().item() diff --git a/diffq/utils.py b/diffq/utils.py new file mode 100644 index 0000000..be6ab52 --- /dev/null +++ b/diffq/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from typing import Optional, List + + +def simple_repr(obj, attrs: Optional[List[str]] = None, overrides={}): + """ + Return a simple representation string for `obj`. + If `attrs` is not None, it should be a list of attributes to include. + """ + params = inspect.signature(obj.__class__).parameters + attrs_repr = [] + if attrs is None: + attrs = params.keys() + for attr in attrs: + display = False + if attr in overrides: + value = overrides[attr] + elif hasattr(obj, attr): + value = getattr(obj, attr) + else: + continue + if attr in params: + param = params[attr] + if param.default is inspect._empty or value != param.default: + display = True + else: + display = True + + if display: + attrs_repr.append(f"{attr}={value}") + return f"{obj.__class__.__name__}({','.join(attrs_repr)})" diff --git a/models.py b/models.py new file mode 100644 index 0000000..15e5fb3 --- /dev/null +++ b/models.py @@ -0,0 +1,244 @@ +import torch +from torch._C import has_mkl +import torch.nn as nn +import numpy as np +import librosa + +dim_c = 4 +k = 3 +model_path = 'model' +n_fft_scale = {'bass': 8, 'drums':2, 'other':4, 'vocals':3, '*':2} + + +class Conv_TDF(nn.Module): + def __init__(self, c, l, f, k, bn, bias=True): + + super(Conv_TDF, self).__init__() + + self.use_tdf = bn is not None + + self.H = nn.ModuleList() + for i in range(l): + self.H.append( + nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k//2), + nn.BatchNorm2d(c), + nn.ReLU(), + ) + ) + + if self.use_tdf: + if bn==0: + self.tdf = nn.Sequential( + nn.Linear(f,f, bias=bias), + nn.BatchNorm2d(c), + nn.ReLU() + ) + else: + self.tdf = nn.Sequential( + nn.Linear(f,f//bn, bias=bias), + nn.BatchNorm2d(c), + nn.ReLU(), + nn.Linear(f//bn,f, bias=bias), + nn.BatchNorm2d(c), + nn.ReLU() + ) + + def forward(self, x): + for h in self.H: + x = h(x) + + return x + self.tdf(x) if self.use_tdf else x + + +class Conv_TDF_net_trim(nn.Module): + def __init__(self, device, load, model_name, target_name, lr, epoch, + L, l, g, dim_f, dim_t, k=3, hop=1024, bn=None, bias=True): + + super(Conv_TDF_net_trim, self).__init__() + + self.dim_f, self.dim_t = 2**dim_f, 2**dim_t + self.n_fft = self.dim_f * n_fft_scale[target_name] + self.hop = hop + self.n_bins = self.n_fft//2+1 + self.chunk_size = hop * (self.dim_t-1) + self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device) + self.target_name = target_name + self.blender = 'blender' in model_name + + out_c = dim_c*4 if target_name=='*' else dim_c + in_c = dim_c*2 if self.blender else dim_c + #out_c = dim_c*2 if self.blender else dim_c + self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(device) + + self.n = L//2 + if load: + + self.first_conv = nn.Sequential( + nn.Conv2d(in_channels=in_c, out_channels=g, kernel_size=1, stride=1), + nn.BatchNorm2d(g), + nn.ReLU(), + ) + + f = self.dim_f + c = g + self.ds_dense = nn.ModuleList() + self.ds = nn.ModuleList() + for i in range(self.n): + self.ds_dense.append(Conv_TDF(c, l, f, k, bn, bias=bias)) + + scale = (2,2) + self.ds.append( + nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=c+g, kernel_size=scale, stride=scale), + nn.BatchNorm2d(c+g), + nn.ReLU() + ) + ) + f = f//2 + c += g + + self.mid_dense = Conv_TDF(c, l, f, k, bn, bias=bias) + #if bn is None and mid_tdf: + # self.mid_dense = Conv_TDF(c, l, f, k, bn=0, bias=False) + + self.us_dense = nn.ModuleList() + self.us = nn.ModuleList() + for i in range(self.n): + scale = (2,2) + self.us.append( + nn.Sequential( + nn.ConvTranspose2d(in_channels=c, out_channels=c-g, kernel_size=scale, stride=scale), + nn.BatchNorm2d(c-g), + nn.ReLU() + ) + ) + f = f*2 + c -= g + + self.us_dense.append(Conv_TDF(c, l, f, k, bn, bias=bias)) + + + self.final_conv = nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=out_c, kernel_size=1, stride=1), + ) + + + model_cfg = f'L{L}l{l}g{g}' + model_cfg += ', ' if (bn is None or bn==0) else f'bn{bn}, ' + + stft_cfg = f'f{dim_f}t{dim_t}, ' + + model_name = model_name[:model_name.index('(')+1] + model_cfg + stft_cfg + model_name[model_name.index('(')+1:] + try: + self.load_state_dict( + torch.load('{0}/{1}/{2}_lr{3}_e{4:05}.ckpt'.format(model_path, model_name, target_name, lr, epoch), map_location=device) + ) + print(f'Loading model ({target_name})') + except FileNotFoundError: + print(f'Random init ({target_name})') + + + def stft(self, x): + x = x.reshape([-1, self.chunk_size]) + x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) + x = x.permute([0,3,1,2]) + x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,dim_c,self.n_bins,self.dim_t]) + return x[:,:,:self.dim_f] + + def istft(self, x, freq_pad=None): + freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad + x = torch.cat([x, freq_pad], -2) + c = 4*2 if self.target_name=='*' else 2 + x = x.reshape([-1,c,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) + x = x.permute([0,2,3,1]) + x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) + return x.reshape([-1,c,self.chunk_size]) + + + def forward(self, x): + + x = self.first_conv(x) + + x = x.transpose(-1,-2) + + ds_outputs = [] + for i in range(self.n): + x = self.ds_dense[i](x) + ds_outputs.append(x) + x = self.ds[i](x) + + x = self.mid_dense(x) + + for i in range(self.n): + x = self.us[i](x) + x *= ds_outputs[-i-1] + x = self.us_dense[i](x) + + x = x.transpose(-1,-2) + + x = self.final_conv(x) + + return x + +def stft(wave, nfft, hl): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + spec_left = librosa.stft(wave_left, nfft, hop_length=hl) + spec_right = librosa.stft(wave_right, nfft, hop_length=hl) + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + +def istft(spec, hl): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hl) + wave_right = librosa.istft(spec_right, hop_length=hl) + wave = np.asfortranarray([wave_left, wave_right]) + + return wave + +def spec_effects(wave, algorithm='default', value=None): + spec = [stft(wave[0],2048,1024),stft(wave[1],2048,1024)] + if algorithm == 'min_mag': + v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0]) + wave = istft(v_spec_m,1024) + elif algorithm == 'max_mag': + v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0]) + wave = istft(v_spec_m,1024) + elif algorithm == 'default': + #wave = [istft(spec[0],1024),istft(spec[1],1024)] + wave = (wave[1] * value) + (wave[0] * (1-value)) + elif algorithm == 'invert_p': + X_mag = np.abs(spec[0]) + y_mag = np.abs(spec[1]) + max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) + v_spec = spec[1] - max_mag * np.exp(1.j * np.angle(spec[0])) + wave = istft(v_spec,1024) + return wave + + +def get_models(name, device, load=True, stems='vocals'): + + if name=='tdf_extra': + models = [] + if 'vocals' in stems: + models.append( + Conv_TDF_net_trim( + device=device, load=load, + model_name='Conv-TDF', target_name='vocals', + lr=0.0001, epoch=0, + L=11, l=3, g=32, bn=8, bias=False, + dim_f=11, dim_t=8 + ) + ) + return models + + else: + print('Model undefined') + return None + + +