diff --git a/demucs/__init__.py b/demucs/__init__.py deleted file mode 100644 index d4182e3..0000000 --- a/demucs/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -__version__ = "2.0.3" diff --git a/demucs/__main__.py b/demucs/__main__.py deleted file mode 100644 index 5148f20..0000000 --- a/demucs/__main__.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import json -import math -import os -import sys -import time -from dataclasses import dataclass, field - -import torch as th -from torch import distributed, nn -from torch.nn.parallel.distributed import DistributedDataParallel - -from .augment import FlipChannels, FlipSign, Remix, Scale, Shift -from .compressed import get_compressed_datasets -from .model import Demucs -from .parser import get_name, get_parser -from .raw import Rawset -from .repitch import RepitchedWrapper -from .pretrained import load_pretrained, SOURCES -from .tasnet import ConvTasNet -from .test import evaluate -from .train import train_model, validate_model -from .utils import (human_seconds, load_model, save_model, get_state, - save_state, sizeof_fmt, get_quantizer) -from .wav import get_wav_datasets, get_musdb_wav_datasets - - -@dataclass -class SavedState: - metrics: list = field(default_factory=list) - last_state: dict = None - best_state: dict = None - optimizer: dict = None - - -def main(): - parser = get_parser() - args = parser.parse_args() - name = get_name(parser, args) - print(f"Experiment {name}") - - if args.musdb is None and args.rank == 0: - print( - "You must provide the path to the MusDB dataset with the --musdb flag. " - "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", - file=sys.stderr) - sys.exit(1) - - eval_folder = args.evals / name - eval_folder.mkdir(exist_ok=True, parents=True) - args.logs.mkdir(exist_ok=True) - metrics_path = args.logs / f"{name}.json" - eval_folder.mkdir(exist_ok=True, parents=True) - args.checkpoints.mkdir(exist_ok=True, parents=True) - args.models.mkdir(exist_ok=True, parents=True) - - if args.device is None: - device = "cpu" - if th.cuda.is_available(): - device = "cuda" - else: - device = args.device - - th.manual_seed(args.seed) - # Prevents too many threads to be started when running `museval` as it can be quite - # inefficient on NUMA architectures. - os.environ["OMP_NUM_THREADS"] = "1" - os.environ["MKL_NUM_THREADS"] = "1" - - if args.world_size > 1: - if device != "cuda" and args.rank == 0: - print("Error: distributed training is only available with cuda device", file=sys.stderr) - sys.exit(1) - th.cuda.set_device(args.rank % th.cuda.device_count()) - distributed.init_process_group(backend="nccl", - init_method="tcp://" + args.master, - rank=args.rank, - world_size=args.world_size) - - checkpoint = args.checkpoints / f"{name}.th" - checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" - if args.restart and checkpoint.exists() and args.rank == 0: - checkpoint.unlink() - - if args.test or args.test_pretrained: - args.epochs = 1 - args.repeat = 0 - if args.test: - model = load_model(args.models / args.test) - else: - model = load_pretrained(args.test_pretrained) - elif args.tasnet: - model = ConvTasNet(audio_channels=args.audio_channels, - samplerate=args.samplerate, X=args.X, - segment_length=4 * args.samples, - sources=SOURCES) - else: - model = Demucs( - audio_channels=args.audio_channels, - channels=args.channels, - context=args.context, - depth=args.depth, - glu=args.glu, - growth=args.growth, - kernel_size=args.kernel_size, - lstm_layers=args.lstm_layers, - rescale=args.rescale, - rewrite=args.rewrite, - stride=args.conv_stride, - resample=args.resample, - normalize=args.normalize, - samplerate=args.samplerate, - segment_length=4 * args.samples, - sources=SOURCES, - ) - model.to(device) - if args.init: - model.load_state_dict(load_pretrained(args.init).state_dict()) - - if args.show: - print(model) - size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) - print(f"Model size {size}") - return - - try: - saved = th.load(checkpoint, map_location='cpu') - except IOError: - saved = SavedState() - - optimizer = th.optim.Adam(model.parameters(), lr=args.lr) - - quantizer = None - quantizer = get_quantizer(model, args, optimizer) - - if saved.last_state is not None: - model.load_state_dict(saved.last_state, strict=False) - if saved.optimizer is not None: - optimizer.load_state_dict(saved.optimizer) - - model_name = f"{name}.th" - if args.save_model: - if args.rank == 0: - model.to("cpu") - model.load_state_dict(saved.best_state) - save_model(model, quantizer, args, args.models / model_name) - return - elif args.save_state: - model_name = f"{args.save_state}.th" - if args.rank == 0: - model.to("cpu") - model.load_state_dict(saved.best_state) - state = get_state(model, quantizer) - save_state(state, args.models / model_name) - return - - if args.rank == 0: - done = args.logs / f"{name}.done" - if done.exists(): - done.unlink() - - augment = [Shift(args.data_stride)] - if args.augment: - augment += [FlipSign(), FlipChannels(), Scale(), - Remix(group_size=args.remix_group_size)] - augment = nn.Sequential(*augment).to(device) - print("Agumentation pipeline:", augment) - - if args.mse: - criterion = nn.MSELoss() - else: - criterion = nn.L1Loss() - - # Setting number of samples so that all convolution windows are full. - # Prevents hard to debug mistake with the prediction being shifted compared - # to the input mixture. - samples = model.valid_length(args.samples) - print(f"Number of training samples adjusted to {samples}") - samples = samples + args.data_stride - if args.repitch: - # We need a bit more audio samples, to account for potential - # tempo change. - samples = math.ceil(samples / (1 - 0.01 * args.max_tempo)) - - args.metadata.mkdir(exist_ok=True, parents=True) - if args.raw: - train_set = Rawset(args.raw / "train", - samples=samples, - channels=args.audio_channels, - streams=range(1, len(model.sources) + 1), - stride=args.data_stride) - - valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) - elif args.wav: - train_set, valid_set = get_wav_datasets(args, samples, model.sources) - elif args.is_wav: - train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources) - else: - train_set, valid_set = get_compressed_datasets(args, samples) - - if args.repitch: - train_set = RepitchedWrapper( - train_set, - proba=args.repitch, - max_tempo=args.max_tempo) - - best_loss = float("inf") - for epoch, metrics in enumerate(saved.metrics): - print(f"Epoch {epoch:03d}: " - f"train={metrics['train']:.8f} " - f"valid={metrics['valid']:.8f} " - f"best={metrics['best']:.4f} " - f"ms={metrics.get('true_model_size', 0):.2f}MB " - f"cms={metrics.get('compressed_model_size', 0):.2f}MB " - f"duration={human_seconds(metrics['duration'])}") - best_loss = metrics['best'] - - if args.world_size > 1: - dmodel = DistributedDataParallel(model, - device_ids=[th.cuda.current_device()], - output_device=th.cuda.current_device()) - else: - dmodel = model - - for epoch in range(len(saved.metrics), args.epochs): - begin = time.time() - model.train() - train_loss, model_size = train_model( - epoch, train_set, dmodel, criterion, optimizer, augment, - quantizer=quantizer, - batch_size=args.batch_size, - device=device, - repeat=args.repeat, - seed=args.seed, - diffq=args.diffq, - workers=args.workers, - world_size=args.world_size) - model.eval() - valid_loss = validate_model( - epoch, valid_set, model, criterion, - device=device, - rank=args.rank, - split=args.split_valid, - overlap=args.overlap, - world_size=args.world_size) - - ms = 0 - cms = 0 - if quantizer and args.rank == 0: - ms = quantizer.true_model_size() - cms = quantizer.compressed_model_size(num_workers=min(40, args.world_size * 10)) - - duration = time.time() - begin - if valid_loss < best_loss and ms <= args.ms_target: - best_loss = valid_loss - saved.best_state = { - key: value.to("cpu").clone() - for key, value in model.state_dict().items() - } - - saved.metrics.append({ - "train": train_loss, - "valid": valid_loss, - "best": best_loss, - "duration": duration, - "model_size": model_size, - "true_model_size": ms, - "compressed_model_size": cms, - }) - if args.rank == 0: - json.dump(saved.metrics, open(metrics_path, "w")) - - saved.last_state = model.state_dict() - saved.optimizer = optimizer.state_dict() - if args.rank == 0 and not args.test: - th.save(saved, checkpoint_tmp) - checkpoint_tmp.rename(checkpoint) - - print(f"Epoch {epoch:03d}: " - f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB " - f"cms={cms:.2f}MB " - f"duration={human_seconds(duration)}") - - if args.world_size > 1: - distributed.barrier() - - del dmodel - model.load_state_dict(saved.best_state) - if args.eval_cpu: - device = "cpu" - model.to(device) - model.eval() - evaluate(model, args.musdb, eval_folder, - is_wav=args.is_wav, - rank=args.rank, - world_size=args.world_size, - device=device, - save=args.save, - split=args.split_valid, - shifts=args.shifts, - overlap=args.overlap, - workers=args.eval_workers) - model.to("cpu") - if args.rank == 0: - if not (args.test or args.test_pretrained): - save_model(model, quantizer, args, args.models / model_name) - print("done") - done.write_text("done") - - -if __name__ == "__main__": - main() diff --git a/demucs/audio.py b/demucs/audio.py deleted file mode 100644 index b29f156..0000000 --- a/demucs/audio.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import json -import subprocess as sp -from pathlib import Path - -import julius -import numpy as np -import torch - -from .utils import temp_filenames - - -def _read_info(path): - stdout_data = sp.check_output([ - 'ffprobe', "-loglevel", "panic", - str(path), '-print_format', 'json', '-show_format', '-show_streams' - ]) - return json.loads(stdout_data.decode('utf-8')) - - -class AudioFile: - """ - Allows to read audio from any format supported by ffmpeg, as well as resampling or - converting to mono on the fly. See :method:`read` for more details. - """ - def __init__(self, path: Path): - self.path = Path(path) - self._info = None - - def __repr__(self): - features = [("path", self.path)] - features.append(("samplerate", self.samplerate())) - features.append(("channels", self.channels())) - features.append(("streams", len(self))) - features_str = ", ".join(f"{name}={value}" for name, value in features) - return f"AudioFile({features_str})" - - @property - def info(self): - if self._info is None: - self._info = _read_info(self.path) - return self._info - - @property - def duration(self): - return float(self.info['format']['duration']) - - @property - def _audio_streams(self): - return [ - index for index, stream in enumerate(self.info["streams"]) - if stream["codec_type"] == "audio" - ] - - def __len__(self): - return len(self._audio_streams) - - def channels(self, stream=0): - return int(self.info['streams'][self._audio_streams[stream]]['channels']) - - def samplerate(self, stream=0): - return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) - - def read(self, - seek_time=None, - duration=None, - streams=slice(None), - samplerate=None, - channels=None, - temp_folder=None): - """ - Slightly more efficient implementation than stempeg, - in particular, this will extract all stems at once - rather than having to loop over one file multiple times - for each stream. - - Args: - seek_time (float): seek time in seconds or None if no seeking is needed. - duration (float): duration in seconds to extract or None to extract until the end. - streams (slice, int or list): streams to extract, can be a single int, a list or - a slice. If it is a slice or list, the output will be of size [S, C, T] - with S the number of streams, C the number of channels and T the number of samples. - If it is an int, the output will be [C, T]. - samplerate (int): if provided, will resample on the fly. If None, no resampling will - be done. Original sampling rate can be obtained with :method:`samplerate`. - channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that - as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. - See https://sound.stackexchange.com/a/42710. - Our definition of mono is simply the average of the two channels. Any other - value will be ignored. - temp_folder (str or Path or None): temporary folder to use for decoding. - - - """ - streams = np.array(range(len(self)))[streams] - single = not isinstance(streams, np.ndarray) - if single: - streams = [streams] - - if duration is None: - target_size = None - query_duration = None - else: - target_size = int((samplerate or self.samplerate()) * duration) - query_duration = float((target_size + 1) / (samplerate or self.samplerate())) - - with temp_filenames(len(streams)) as filenames: - command = ['ffmpeg', '-y'] - command += ['-loglevel', 'panic'] - if seek_time: - command += ['-ss', str(seek_time)] - command += ['-i', str(self.path)] - for stream, filename in zip(streams, filenames): - command += ['-map', f'0:{self._audio_streams[stream]}'] - if query_duration is not None: - command += ['-t', str(query_duration)] - command += ['-threads', '1'] - command += ['-f', 'f32le'] - if samplerate is not None: - command += ['-ar', str(samplerate)] - command += [filename] - - sp.run(command, check=True) - wavs = [] - for filename in filenames: - wav = np.fromfile(filename, dtype=np.float32) - wav = torch.from_numpy(wav) - wav = wav.view(-1, self.channels()).t() - if channels is not None: - wav = convert_audio_channels(wav, channels) - if target_size is not None: - wav = wav[..., :target_size] - wavs.append(wav) - wav = torch.stack(wavs, dim=0) - if single: - wav = wav[0] - return wav - - -def convert_audio_channels(wav, channels=2): - """Convert audio to the given number of channels.""" - *shape, src_channels, length = wav.shape - if src_channels == channels: - pass - elif channels == 1: - # Case 1: - # The caller asked 1-channel audio, but the stream have multiple - # channels, downmix all channels. - wav = wav.mean(dim=-2, keepdim=True) - elif src_channels == 1: - # Case 2: - # The caller asked for multiple channels, but the input file have - # one single channel, replicate the audio over all channels. - wav = wav.expand(*shape, channels, length) - elif src_channels >= channels: - # Case 3: - # The caller asked for multiple channels, and the input file have - # more channels than requested. In that case return the first channels. - wav = wav[..., :channels, :] - else: - # Case 4: What is a reasonable choice here? - raise ValueError('The audio file has less channels than requested but is not mono.') - return wav - - -def convert_audio(wav, from_samplerate, to_samplerate, channels): - wav = convert_audio_channels(wav, channels) - return julius.resample_frac(wav, from_samplerate, to_samplerate) diff --git a/demucs/compressed.py b/demucs/compressed.py deleted file mode 100644 index eb8fbb7..0000000 --- a/demucs/compressed.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import json -from fractions import Fraction -from concurrent import futures - -import musdb -from torch import distributed - -from .audio import AudioFile - - -def get_musdb_tracks(root, *args, **kwargs): - mus = musdb.DB(root, *args, **kwargs) - return {track.name: track.path for track in mus} - - -class StemsSet: - def __init__(self, tracks, metadata, duration=None, stride=1, - samplerate=44100, channels=2, streams=slice(None)): - - self.metadata = [] - for name, path in tracks.items(): - meta = dict(metadata[name]) - meta["path"] = path - meta["name"] = name - self.metadata.append(meta) - if duration is not None and meta["duration"] < duration: - raise ValueError(f"Track {name} duration is too small {meta['duration']}") - self.metadata.sort(key=lambda x: x["name"]) - self.duration = duration - self.stride = stride - self.channels = channels - self.samplerate = samplerate - self.streams = streams - - def __len__(self): - return sum(self._examples_count(m) for m in self.metadata) - - def _examples_count(self, meta): - if self.duration is None: - return 1 - else: - return int((meta["duration"] - self.duration) // self.stride + 1) - - def track_metadata(self, index): - for meta in self.metadata: - examples = self._examples_count(meta) - if index >= examples: - index -= examples - continue - return meta - - def __getitem__(self, index): - for meta in self.metadata: - examples = self._examples_count(meta) - if index >= examples: - index -= examples - continue - streams = AudioFile(meta["path"]).read(seek_time=index * self.stride, - duration=self.duration, - channels=self.channels, - samplerate=self.samplerate, - streams=self.streams) - return (streams - meta["mean"]) / meta["std"] - - -def _get_track_metadata(path): - # use mono at 44kHz as reference. For any other settings data won't be perfectly - # normalized but it should be good enough. - audio = AudioFile(path) - mix = audio.read(streams=0, channels=1, samplerate=44100) - return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()} - - -def _build_metadata(tracks, workers=10): - pendings = [] - with futures.ProcessPoolExecutor(workers) as pool: - for name, path in tracks.items(): - pendings.append((name, pool.submit(_get_track_metadata, path))) - return {name: p.result() for name, p in pendings} - - -def _build_musdb_metadata(path, musdb, workers): - tracks = get_musdb_tracks(musdb) - metadata = _build_metadata(tracks, workers) - path.parent.mkdir(exist_ok=True, parents=True) - json.dump(metadata, open(path, "w")) - - -def get_compressed_datasets(args, samples): - metadata_file = args.metadata / "musdb.json" - if not metadata_file.is_file() and args.rank == 0: - _build_musdb_metadata(metadata_file, args.musdb, args.workers) - if args.world_size > 1: - distributed.barrier() - metadata = json.load(open(metadata_file)) - duration = Fraction(samples, args.samplerate) - stride = Fraction(args.data_stride, args.samplerate) - train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), - metadata, - duration=duration, - stride=stride, - streams=slice(1, None), - samplerate=args.samplerate, - channels=args.audio_channels) - valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), - metadata, - samplerate=args.samplerate, - channels=args.audio_channels) - return train_set, valid_set diff --git a/demucs/model.py b/demucs/model.py deleted file mode 100644 index e9d932f..0000000 --- a/demucs/model.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math - -import julius -from torch import nn - -from .utils import capture_init, center_trim - - -class BLSTM(nn.Module): - def __init__(self, dim, layers=1): - super().__init__() - self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) - self.linear = nn.Linear(2 * dim, dim) - - def forward(self, x): - x = x.permute(2, 0, 1) - x = self.lstm(x)[0] - x = self.linear(x) - x = x.permute(1, 2, 0) - return x - - -def rescale_conv(conv, reference): - std = conv.weight.std().detach() - scale = (std / reference)**0.5 - conv.weight.data /= scale - if conv.bias is not None: - conv.bias.data /= scale - - -def rescale_module(module, reference): - for sub in module.modules(): - if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): - rescale_conv(sub, reference) - - -class Demucs(nn.Module): - @capture_init - def __init__(self, - sources, - audio_channels=2, - channels=64, - depth=6, - rewrite=True, - glu=True, - rescale=0.1, - resample=True, - kernel_size=8, - stride=4, - growth=2., - lstm_layers=2, - context=3, - normalize=False, - samplerate=44100, - segment_length=4 * 10 * 44100): - """ - Args: - sources (list[str]): list of source names - audio_channels (int): stereo or mono - channels (int): first convolution channels - depth (int): number of encoder/decoder layers - rewrite (bool): add 1x1 convolution to each encoder layer - and a convolution to each decoder layer. - For the decoder layer, `context` gives the kernel size. - glu (bool): use glu instead of ReLU - resample_input (bool): upsample x2 the input and downsample /2 the output. - rescale (int): rescale initial weights of convolutions - to get their standard deviation closer to `rescale` - kernel_size (int): kernel size for convolutions - stride (int): stride for convolutions - growth (float): multiply (resp divide) number of channels by that - for each layer of the encoder (resp decoder) - lstm_layers (int): number of lstm layers, 0 = no lstm - context (int): kernel size of the convolution in the - decoder before the transposed convolution. If > 1, - will provide some context from neighboring time - steps. - samplerate (int): stored as meta information for easing - future evaluations of the model. - segment_length (int): stored as meta information for easing - future evaluations of the model. Length of the segments on which - the model was trained. - """ - - super().__init__() - self.audio_channels = audio_channels - self.sources = sources - self.kernel_size = kernel_size - self.context = context - self.stride = stride - self.depth = depth - self.resample = resample - self.channels = channels - self.normalize = normalize - self.samplerate = samplerate - self.segment_length = segment_length - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - if glu: - activation = nn.GLU(dim=1) - ch_scale = 2 - else: - activation = nn.ReLU() - ch_scale = 1 - in_channels = audio_channels - for index in range(depth): - encode = [] - encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] - if rewrite: - encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] - self.encoder.append(nn.Sequential(*encode)) - - decode = [] - if index > 0: - out_channels = in_channels - else: - out_channels = len(self.sources) * audio_channels - if rewrite: - decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] - decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] - if index > 0: - decode.append(nn.ReLU()) - self.decoder.insert(0, nn.Sequential(*decode)) - in_channels = channels - channels = int(growth * channels) - - channels = in_channels - - if lstm_layers: - self.lstm = BLSTM(channels, lstm_layers) - else: - self.lstm = None - - if rescale: - rescale_module(self, reference=rescale) - - def valid_length(self, length): - """ - Return the nearest valid length to use with the model so that - there is no time steps left over in a convolutions, e.g. for all - layers, size of the input - kernel_size % stride = 0. - - If the mixture has a valid length, the estimated sources - will have exactly the same length when context = 1. If context > 1, - the two signals can be center trimmed to match. - - For training, extracts should have a valid length.For evaluation - on full tracks we recommend passing `pad = True` to :method:`forward`. - """ - if self.resample: - length *= 2 - for _ in range(self.depth): - length = math.ceil((length - self.kernel_size) / self.stride) + 1 - length = max(1, length) - length += self.context - 1 - for _ in range(self.depth): - length = (length - 1) * self.stride + self.kernel_size - - if self.resample: - length = math.ceil(length / 2) - return int(length) - - def forward(self, mix): - x = mix - - if self.normalize: - mono = mix.mean(dim=1, keepdim=True) - mean = mono.mean(dim=-1, keepdim=True) - std = mono.std(dim=-1, keepdim=True) - else: - mean = 0 - std = 1 - - x = (x - mean) / (1e-5 + std) - - if self.resample: - x = julius.resample_frac(x, 1, 2) - - saved = [] - for encode in self.encoder: - x = encode(x) - saved.append(x) - if self.lstm: - x = self.lstm(x) - for decode in self.decoder: - skip = center_trim(saved.pop(-1), x) - x = x + skip - x = decode(x) - - if self.resample: - x = julius.resample_frac(x, 2, 1) - x = x * std + mean - x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) - return x diff --git a/demucs/parser.py b/demucs/parser.py deleted file mode 100644 index 4e8a19c..0000000 --- a/demucs/parser.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import os -from pathlib import Path - - -def get_parser(): - parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.") - default_raw = None - default_musdb = None - if 'DEMUCS_RAW' in os.environ: - default_raw = Path(os.environ['DEMUCS_RAW']) - if 'DEMUCS_MUSDB' in os.environ: - default_musdb = Path(os.environ['DEMUCS_MUSDB']) - parser.add_argument( - "--raw", - type=Path, - default=default_raw, - help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.") - parser.add_argument("--no_raw", action="store_const", const=None, dest="raw") - parser.add_argument("-m", - "--musdb", - type=Path, - default=default_musdb, - help="Path to musdb root") - parser.add_argument("--is_wav", action="store_true", - help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).") - parser.add_argument("--metadata", type=Path, default=Path("metadata/"), - help="Folder where metadata information is stored.") - parser.add_argument("--wav", type=Path, - help="Path to a wav dataset. This should contain a 'train' and a 'valid' " - "subfolder.") - parser.add_argument("--samplerate", type=int, default=44100) - parser.add_argument("--audio_channels", type=int, default=2) - parser.add_argument("--samples", - default=44100 * 10, - type=int, - help="number of samples to feed in") - parser.add_argument("--data_stride", - default=44100, - type=int, - help="Stride for chunks, shorter = longer epochs") - parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers") - parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers") - parser.add_argument("-d", - "--device", - help="Device to train on, default is cuda if available else cpu") - parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.") - parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file") - parser.add_argument("--test", help="Just run the test pipeline + one validation. " - "This should be a filename relative to the models/ folder.") - parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, " - "on a pretrained model. ") - - parser.add_argument("--rank", default=0, type=int) - parser.add_argument("--world_size", default=1, type=int) - parser.add_argument("--master") - - parser.add_argument("--checkpoints", - type=Path, - default=Path("checkpoints"), - help="Folder where to store checkpoints etc") - parser.add_argument("--evals", - type=Path, - default=Path("evals"), - help="Folder where to store evals and waveforms") - parser.add_argument("--save", - action="store_true", - help="Save estimated for the test set waveforms") - parser.add_argument("--logs", - type=Path, - default=Path("logs"), - help="Folder where to store logs") - parser.add_argument("--models", - type=Path, - default=Path("models"), - help="Folder where to store trained models") - parser.add_argument("-R", - "--restart", - action='store_true', - help='Restart training, ignoring previous run') - - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs") - parser.add_argument("-r", - "--repeat", - type=int, - default=2, - help="Repeat the train set, longer epochs") - parser.add_argument("-b", "--batch_size", type=int, default=64) - parser.add_argument("--lr", type=float, default=3e-4) - parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1") - parser.add_argument("--init", help="Initialize from a pre-trained model.") - - # Augmentation options - parser.add_argument("--no_augment", - action="store_false", - dest="augment", - default=True, - help="No basic data augmentation.") - parser.add_argument("--repitch", type=float, default=0.2, - help="Probability to do tempo/pitch change") - parser.add_argument("--max_tempo", type=float, default=12, - help="Maximum relative tempo change in %% when using repitch.") - - parser.add_argument("--remix_group_size", - type=int, - default=4, - help="Shuffle sources using group of this size. Useful to somewhat " - "replicate multi-gpu training " - "on less GPUs.") - parser.add_argument("--shifts", - type=int, - default=10, - help="Number of random shifts used for the shift trick.") - parser.add_argument("--overlap", - type=float, - default=0.25, - help="Overlap when --split_valid is passed.") - - # See model.py for doc - parser.add_argument("--growth", - type=float, - default=2., - help="Number of channels between two layers will increase by this factor") - parser.add_argument("--depth", - type=int, - default=6, - help="Number of layers for the encoder and decoder") - parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM") - parser.add_argument("--channels", - type=int, - default=64, - help="Number of channels for the first encoder layer") - parser.add_argument("--kernel_size", - type=int, - default=8, - help="Kernel size for the (transposed) convolutions") - parser.add_argument("--conv_stride", - type=int, - default=4, - help="Stride for the (transposed) convolutions") - parser.add_argument("--context", - type=int, - default=3, - help="Context size for the decoder convolutions " - "before the transposed convolutions") - parser.add_argument("--rescale", - type=float, - default=0.1, - help="Initial weight rescale reference") - parser.add_argument("--no_resample", action="store_false", - default=True, dest="resample", - help="No Resampling of the input/output x2") - parser.add_argument("--no_glu", - action="store_false", - default=True, - dest="glu", - help="Replace all GLUs by ReLUs") - parser.add_argument("--no_rewrite", - action="store_false", - default=True, - dest="rewrite", - help="No 1x1 rewrite convolutions") - parser.add_argument("--normalize", action="store_true") - parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True) - - # Tasnet options - parser.add_argument("--tasnet", action="store_true") - parser.add_argument("--split_valid", - action="store_true", - help="Predict chunks by chunks for valid and test. Required for tasnet") - parser.add_argument("--X", type=int, default=8) - - # Other options - parser.add_argument("--show", - action="store_true", - help="Show model architecture, size and exit") - parser.add_argument("--save_model", action="store_true", - help="Skip traning, just save final model " - "for the current checkpoint value.") - parser.add_argument("--save_state", - help="Skip training, just save state " - "for the current checkpoint value. You should " - "provide a model name as argument.") - - # Quantization options - parser.add_argument("--q-min-size", type=float, default=1, - help="Only quantize layers over this size (in MB)") - parser.add_argument( - "--qat", type=int, help="If provided, use QAT training with that many bits.") - - parser.add_argument("--diffq", type=float, default=0) - parser.add_argument( - "--ms-target", type=float, default=162, - help="Model size target in MB, when using DiffQ. Best model will be kept " - "only if it is smaller than this target.") - - return parser - - -def get_name(parser, args): - """ - Return the name of an experiment given the args. Some parameters are ignored, - for instance --workers, as they do not impact the final result. - """ - ignore_args = set([ - "checkpoints", - "deterministic", - "eval", - "evals", - "eval_cpu", - "eval_workers", - "logs", - "master", - "rank", - "restart", - "save", - "save_model", - "save_state", - "show", - "workers", - "world_size", - ]) - parts = [] - name_args = dict(args.__dict__) - for name, value in name_args.items(): - if name in ignore_args: - continue - if value != parser.get_default(name): - if isinstance(value, Path): - parts.append(f"{name}={value.name}") - else: - parts.append(f"{name}={value}") - if parts: - name = " ".join(parts) - else: - name = "default" - return name diff --git a/demucs/pretrained.py b/demucs/pretrained.py deleted file mode 100644 index 6aac5db..0000000 --- a/demucs/pretrained.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# author: adefossez - -import logging - -from diffq import DiffQuantizer -import torch.hub - -from .model import Demucs -from .tasnet import ConvTasNet -from .utils import set_state - -logger = logging.getLogger(__name__) -ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/" - -PRETRAINED_MODELS = { - 'demucs': 'e07c671f', - 'demucs48_hq': '28a1282c', - 'demucs_extra': '3646af93', - 'demucs_quantized': '07afea75', - 'tasnet': 'beb46fac', - 'tasnet_extra': 'df3777b2', - 'demucs_unittest': '09ebc15f', -} - -SOURCES = ["drums", "bass", "other", "vocals"] - - -def get_url(name): - sig = PRETRAINED_MODELS[name] - return ROOT + name + "-" + sig[:8] + ".th" - - -def is_pretrained(name): - return name in PRETRAINED_MODELS - - -def load_pretrained(name): - if name == "demucs": - return demucs(pretrained=True) - elif name == "demucs48_hq": - return demucs(pretrained=True, hq=True, channels=48) - elif name == "demucs_extra": - return demucs(pretrained=True, extra=True) - elif name == "demucs_quantized": - return demucs(pretrained=True, quantized=True) - elif name == "demucs_unittest": - return demucs_unittest(pretrained=True) - elif name == "tasnet": - return tasnet(pretrained=True) - elif name == "tasnet_extra": - return tasnet(pretrained=True, extra=True) - else: - raise ValueError(f"Invalid pretrained name {name}") - - -def _load_state(name, model, quantizer=None): - url = get_url(name) - state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) - set_state(model, quantizer, state) - if quantizer: - quantizer.detach() - - -def demucs_unittest(pretrained=True): - model = Demucs(channels=4, sources=SOURCES) - if pretrained: - _load_state('demucs_unittest', model) - return model - - -def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64): - if not pretrained and (extra or quantized or hq): - raise ValueError("if extra or quantized is True, pretrained must be True.") - model = Demucs(sources=SOURCES, channels=channels) - if pretrained: - name = 'demucs' - if channels != 64: - name += str(channels) - quantizer = None - if sum([extra, quantized, hq]) > 1: - raise ValueError("Only one of extra, quantized, hq, can be True.") - if quantized: - quantizer = DiffQuantizer(model, group_size=8, min_size=1) - name += '_quantized' - if extra: - name += '_extra' - if hq: - name += '_hq' - _load_state(name, model, quantizer) - return model - - -def tasnet(pretrained=True, extra=False): - if not pretrained and extra: - raise ValueError("if extra is True, pretrained must be True.") - model = ConvTasNet(X=10, sources=SOURCES) - if pretrained: - name = 'tasnet' - if extra: - name = 'tasnet_extra' - _load_state(name, model) - return model diff --git a/demucs/raw.py b/demucs/raw.py deleted file mode 100644 index d4941ad..0000000 --- a/demucs/raw.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import os -from collections import defaultdict, namedtuple -from pathlib import Path - -import musdb -import numpy as np -import torch as th -import tqdm -from torch.utils.data import DataLoader - -from .audio import AudioFile - -ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"]) - - -class Rawset: - """ - Dataset of raw, normalized, float32 audio files - """ - def __init__(self, path, samples=None, stride=None, channels=2, streams=None): - self.path = Path(path) - self.channels = channels - self.samples = samples - if stride is None: - stride = samples if samples is not None else 0 - self.stride = stride - entries = defaultdict(list) - for root, folders, files in os.walk(self.path, followlinks=True): - folders.sort() - files.sort() - for file in files: - if file.endswith(".raw"): - path = Path(root) / file - name, stream = path.stem.rsplit('.', 1) - entries[(path.parent.relative_to(self.path), name)].append(int(stream)) - - self._entries = list(entries.keys()) - - sizes = [] - self._lengths = [] - ref_streams = sorted(entries[self._entries[0]]) - assert ref_streams == list(range(len(ref_streams))) - if streams is None: - self.streams = ref_streams - else: - self.streams = streams - for entry in sorted(entries.keys()): - streams = entries[entry] - assert sorted(streams) == ref_streams - file = self._path(*entry) - length = file.stat().st_size // (4 * channels) - if samples is None: - sizes.append(1) - else: - if length < samples: - self._entries.remove(entry) - continue - sizes.append((length - samples) // stride + 1) - self._lengths.append(length) - if not sizes: - raise ValueError(f"Empty dataset {self.path}") - self._cumulative_sizes = np.cumsum(sizes) - self._sizes = sizes - - def __len__(self): - return self._cumulative_sizes[-1] - - @property - def total_length(self): - return sum(self._lengths) - - def chunk_info(self, index): - file_index = np.searchsorted(self._cumulative_sizes, index, side='right') - if file_index == 0: - local_index = index - else: - local_index = index - self._cumulative_sizes[file_index - 1] - return ChunkInfo(offset=local_index * self.stride, - file_index=file_index, - local_index=local_index) - - def _path(self, folder, name, stream=0): - return self.path / folder / (name + f'.{stream}.raw') - - def __getitem__(self, index): - chunk = self.chunk_info(index) - entry = self._entries[chunk.file_index] - - length = self.samples or self._lengths[chunk.file_index] - streams = [] - to_read = length * self.channels * 4 - for stream_index, stream in enumerate(self.streams): - offset = chunk.offset * 4 * self.channels - file = open(self._path(*entry, stream=stream), 'rb') - file.seek(offset) - content = file.read(to_read) - assert len(content) == to_read - content = np.frombuffer(content, dtype=np.float32) - content = content.copy() # make writable - streams.append(th.from_numpy(content).view(length, self.channels).t()) - return th.stack(streams, dim=0) - - def name(self, index): - chunk = self.chunk_info(index) - folder, name = self._entries[chunk.file_index] - return folder / name - - -class MusDBSet: - def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2): - self.mus = mus - self.streams = streams - self.samplerate = samplerate - self.channels = channels - - def __len__(self): - return len(self.mus.tracks) - - def __getitem__(self, index): - track = self.mus.tracks[index] - return (track.name, AudioFile(track.path).read(channels=self.channels, - seek_time=0, - streams=self.streams, - samplerate=self.samplerate)) - - -def build_raw(mus, destination, normalize, workers, samplerate, channels): - destination.mkdir(parents=True, exist_ok=True) - loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate), - batch_size=1, - num_workers=workers, - collate_fn=lambda x: x[0]) - for name, streams in tqdm.tqdm(loader): - if normalize: - ref = streams[0].mean(dim=0) # use mono mixture as reference - streams = (streams - ref.mean()) / ref.std() - for index, stream in enumerate(streams): - open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes()) - - -def main(): - parser = argparse.ArgumentParser('rawset') - parser.add_argument('--workers', type=int, default=10) - parser.add_argument('--samplerate', type=int, default=44100) - parser.add_argument('--channels', type=int, default=2) - parser.add_argument('musdb', type=Path) - parser.add_argument('destination', type=Path) - - args = parser.parse_args() - - build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"), - args.destination / "train", - normalize=True, - channels=args.channels, - samplerate=args.samplerate, - workers=args.workers) - build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"), - args.destination / "valid", - normalize=True, - samplerate=args.samplerate, - channels=args.channels, - workers=args.workers) - - -if __name__ == "__main__": - main() diff --git a/demucs/separate.py b/demucs/separate.py deleted file mode 100644 index 3fc7af9..0000000 --- a/demucs/separate.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import sys -from pathlib import Path -import subprocess - -import julius -import torch as th -import torchaudio as ta - -from .audio import AudioFile, convert_audio_channels -from .pretrained import is_pretrained, load_pretrained -from .utils import apply_model, load_model - - -def load_track(track, device, audio_channels, samplerate): - errors = {} - wav = None - - try: - wav = AudioFile(track).read( - streams=0, - samplerate=samplerate, - channels=audio_channels).to(device) - except FileNotFoundError: - errors['ffmpeg'] = 'Ffmpeg is not installed.' - except subprocess.CalledProcessError: - errors['ffmpeg'] = 'FFmpeg could not read the file.' - - if wav is None: - try: - wav, sr = ta.load(str(track)) - except RuntimeError as err: - errors['torchaudio'] = err.args[0] - else: - wav = convert_audio_channels(wav, audio_channels) - wav = wav.to(device) - wav = julius.resample_frac(wav, sr, samplerate) - - if wav is None: - print(f"Could not load file {track}. " - "Maybe it is not a supported file format? ") - for backend, error in errors.items(): - print(f"When trying to load using {backend}, got the following error: {error}") - sys.exit(1) - return wav - - -def encode_mp3(wav, path, bitrate=320, samplerate=44100, channels=2, verbose=False): - try: - import lameenc - except ImportError: - print("Failed to call lame encoder. Maybe it is not installed? " - "On windows, run `python.exe -m pip install -U lameenc`, " - "on OSX/Linux, run `python3 -m pip install -U lameenc`, " - "then try again.", file=sys.stderr) - sys.exit(1) - encoder = lameenc.Encoder() - encoder.set_bit_rate(bitrate) - encoder.set_in_sample_rate(samplerate) - encoder.set_channels(channels) - encoder.set_quality(2) # 2-highest, 7-fastest - if not verbose: - encoder.silence() - wav = wav.transpose(0, 1).numpy() - mp3_data = encoder.encode(wav.tobytes()) - mp3_data += encoder.flush() - with open(path, "wb") as f: - f.write(mp3_data) - - -def main(): - parser = argparse.ArgumentParser("demucs.separate", - description="Separate the sources for the given tracks") - parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks') - parser.add_argument("-n", - "--name", - default="demucs_quantized", - help="Model name. See README.md for the list of pretrained models. " - "Default is demucs_quantized.") - parser.add_argument("-v", "--verbose", action="store_true") - parser.add_argument("-o", - "--out", - type=Path, - default=Path("separated"), - help="Folder where to put extracted tracks. A subfolder " - "with the model name will be created.") - parser.add_argument("--models", - type=Path, - default=Path("models"), - help="Path to trained models. " - "Also used to store downloaded pretrained models") - parser.add_argument("-d", - "--device", - default="cuda" if th.cuda.is_available() else "cpu", - help="Device to use, default is cuda if available else cpu") - parser.add_argument("--shifts", - default=0, - type=int, - help="Number of random shifts for equivariant stabilization." - "Increase separation time but improves quality for Demucs. 10 was used " - "in the original paper.") - parser.add_argument("--overlap", - default=0.25, - type=float, - help="Overlap between the splits.") - parser.add_argument("--no-split", - action="store_false", - dest="split", - default=True, - help="Doesn't split audio in chunks. This can use large amounts of memory.") - parser.add_argument("--float32", - action="store_true", - help="Convert the output wavefile to use pcm f32 format instead of s16. " - "This should not make a difference if you just plan on listening to the " - "audio but might be needed to compute exactly metrics like SDR etc.") - parser.add_argument("--int16", - action="store_false", - dest="float32", - help="Opposite of --float32, here for compatibility.") - parser.add_argument("--mp3", action="store_true", - help="Convert the output wavs to mp3.") - parser.add_argument("--mp3-bitrate", - default=320, - type=int, - help="Bitrate of converted mp3.") - - args = parser.parse_args() - name = args.name + ".th" - model_path = args.models / name - if model_path.is_file(): - model = load_model(model_path) - else: - if is_pretrained(args.name): - model = load_pretrained(args.name) - else: - print(f"No pre-trained model {args.name}", file=sys.stderr) - sys.exit(1) - model.to(args.device) - - out = args.out / args.name - out.mkdir(parents=True, exist_ok=True) - print(f"Separated tracks will be stored in {out.resolve()}") - for track in args.tracks: - if not track.exists(): - print( - f"File {track} does not exist. If the path contains spaces, " - "please try again after surrounding the entire path with quotes \"\".", - file=sys.stderr) - continue - print(f"Separating track {track}") - wav = load_track(track, args.device, model.audio_channels, model.samplerate) - - ref = wav.mean(0) - wav = (wav - ref.mean()) / ref.std() - sources = apply_model(model, wav, shifts=args.shifts, split=args.split, - overlap=args.overlap, progress=True) - sources = sources * ref.std() + ref.mean() - - track_folder = out / track.name.rsplit(".", 1)[0] - track_folder.mkdir(exist_ok=True) - for source, name in zip(sources, model.sources): - source = source / max(1.01 * source.abs().max(), 1) - if args.mp3 or not args.float32: - source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short() - source = source.cpu() - stem = str(track_folder / name) - if args.mp3: - encode_mp3(source, stem + ".mp3", - bitrate=args.mp3_bitrate, - samplerate=model.samplerate, - channels=model.audio_channels, - verbose=args.verbose) - else: - wavname = str(track_folder / f"{name}.wav") - ta.save(wavname, source, sample_rate=model.samplerate) - - -if __name__ == "__main__": - main() diff --git a/demucs/tasnet.py b/demucs/tasnet.py deleted file mode 100644 index ecc1257..0000000 --- a/demucs/tasnet.py +++ /dev/null @@ -1,452 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# Created on 2018/12 -# Author: Kaituo XU -# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels -# Here is the original license: -# The MIT License (MIT) -# -# Copyright (c) 2018 Kaituo XU -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .utils import capture_init - -EPS = 1e-8 - - -def overlap_and_add(signal, frame_step): - outer_dimensions = signal.size()[:-2] - frames, frame_length = signal.size()[-2:] - - subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor - subframe_step = frame_step // subframe_length - subframes_per_frame = frame_length // subframe_length - output_size = frame_step * (frames - 1) + frame_length - output_subframes = output_size // subframe_length - - subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) - - frame = torch.arange(0, output_subframes, - device=signal.device).unfold(0, subframes_per_frame, subframe_step) - frame = frame.long() # signal may in GPU or CPU - frame = frame.contiguous().view(-1) - - result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) - result.index_add_(-2, frame, subframe_signal) - result = result.view(*outer_dimensions, -1) - return result - - -class ConvTasNet(nn.Module): - @capture_init - def __init__(self, - sources, - N=256, - L=20, - B=256, - H=512, - P=3, - X=8, - R=4, - audio_channels=2, - norm_type="gLN", - causal=False, - mask_nonlinear='relu', - samplerate=44100, - segment_length=44100 * 2 * 4): - """ - Args: - sources: list of sources - N: Number of filters in autoencoder - L: Length of the filters (in samples) - B: Number of channels in bottleneck 1 × 1-conv block - H: Number of channels in convolutional blocks - P: Kernel size in convolutional blocks - X: Number of convolutional blocks in each repeat - R: Number of repeats - norm_type: BN, gLN, cLN - causal: causal or non-causal - mask_nonlinear: use which non-linear function to generate mask - """ - super(ConvTasNet, self).__init__() - # Hyper-parameter - self.sources = sources - self.C = len(sources) - self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R - self.norm_type = norm_type - self.causal = causal - self.mask_nonlinear = mask_nonlinear - self.audio_channels = audio_channels - self.samplerate = samplerate - self.segment_length = segment_length - # Components - self.encoder = Encoder(L, N, audio_channels) - self.separator = TemporalConvNet( - N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear) - self.decoder = Decoder(N, L, audio_channels) - # init - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_normal_(p) - - def valid_length(self, length): - return length - - def forward(self, mixture): - """ - Args: - mixture: [M, T], M is batch size, T is #samples - Returns: - est_source: [M, C, T] - """ - mixture_w = self.encoder(mixture) - est_mask = self.separator(mixture_w) - est_source = self.decoder(mixture_w, est_mask) - - # T changed after conv1d in encoder, fix it here - T_origin = mixture.size(-1) - T_conv = est_source.size(-1) - est_source = F.pad(est_source, (0, T_origin - T_conv)) - return est_source - - -class Encoder(nn.Module): - """Estimation of the nonnegative mixture weight by a 1-D conv layer. - """ - def __init__(self, L, N, audio_channels): - super(Encoder, self).__init__() - # Hyper-parameter - self.L, self.N = L, N - # Components - # 50% overlap - self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False) - - def forward(self, mixture): - """ - Args: - mixture: [M, T], M is batch size, T is #samples - Returns: - mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 - """ - mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] - return mixture_w - - -class Decoder(nn.Module): - def __init__(self, N, L, audio_channels): - super(Decoder, self).__init__() - # Hyper-parameter - self.N, self.L = N, L - self.audio_channels = audio_channels - # Components - self.basis_signals = nn.Linear(N, audio_channels * L, bias=False) - - def forward(self, mixture_w, est_mask): - """ - Args: - mixture_w: [M, N, K] - est_mask: [M, C, N, K] - Returns: - est_source: [M, C, T] - """ - # D = W * M - source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] - source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] - # S = DV - est_source = self.basis_signals(source_w) # [M, C, K, ac * L] - m, c, k, _ = est_source.size() - est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous() - est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T - return est_source - - -class TemporalConvNet(nn.Module): - def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'): - """ - Args: - N: Number of filters in autoencoder - B: Number of channels in bottleneck 1 × 1-conv block - H: Number of channels in convolutional blocks - P: Kernel size in convolutional blocks - X: Number of convolutional blocks in each repeat - R: Number of repeats - C: Number of speakers - norm_type: BN, gLN, cLN - causal: causal or non-causal - mask_nonlinear: use which non-linear function to generate mask - """ - super(TemporalConvNet, self).__init__() - # Hyper-parameter - self.C = C - self.mask_nonlinear = mask_nonlinear - # Components - # [M, N, K] -> [M, N, K] - layer_norm = ChannelwiseLayerNorm(N) - # [M, N, K] -> [M, B, K] - bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) - # [M, B, K] -> [M, B, K] - repeats = [] - for r in range(R): - blocks = [] - for x in range(X): - dilation = 2**x - padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 - blocks += [ - TemporalBlock(B, - H, - P, - stride=1, - padding=padding, - dilation=dilation, - norm_type=norm_type, - causal=causal) - ] - repeats += [nn.Sequential(*blocks)] - temporal_conv_net = nn.Sequential(*repeats) - # [M, B, K] -> [M, C*N, K] - mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) - # Put together - self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, - mask_conv1x1) - - def forward(self, mixture_w): - """ - Keep this API same with TasNet - Args: - mixture_w: [M, N, K], M is batch size - returns: - est_mask: [M, C, N, K] - """ - M, N, K = mixture_w.size() - score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] - score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] - if self.mask_nonlinear == 'softmax': - est_mask = F.softmax(score, dim=1) - elif self.mask_nonlinear == 'relu': - est_mask = F.relu(score) - else: - raise ValueError("Unsupported mask non-linear function") - return est_mask - - -class TemporalBlock(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - norm_type="gLN", - causal=False): - super(TemporalBlock, self).__init__() - # [M, B, K] -> [M, H, K] - conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) - prelu = nn.PReLU() - norm = chose_norm(norm_type, out_channels) - # [M, H, K] -> [M, B, K] - dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, - dilation, norm_type, causal) - # Put together - self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) - - def forward(self, x): - """ - Args: - x: [M, B, K] - Returns: - [M, B, K] - """ - residual = x - out = self.net(x) - # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? - return out + residual # look like w/o F.relu is better than w/ F.relu - # return F.relu(out + residual) - - -class DepthwiseSeparableConv(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - norm_type="gLN", - causal=False): - super(DepthwiseSeparableConv, self).__init__() - # Use `groups` option to implement depthwise convolution - # [M, H, K] -> [M, H, K] - depthwise_conv = nn.Conv1d(in_channels, - in_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=in_channels, - bias=False) - if causal: - chomp = Chomp1d(padding) - prelu = nn.PReLU() - norm = chose_norm(norm_type, in_channels) - # [M, H, K] -> [M, B, K] - pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) - # Put together - if causal: - self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) - else: - self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) - - def forward(self, x): - """ - Args: - x: [M, H, K] - Returns: - result: [M, B, K] - """ - return self.net(x) - - -class Chomp1d(nn.Module): - """To ensure the output length is the same as the input. - """ - def __init__(self, chomp_size): - super(Chomp1d, self).__init__() - self.chomp_size = chomp_size - - def forward(self, x): - """ - Args: - x: [M, H, Kpad] - Returns: - [M, H, K] - """ - return x[:, :, :-self.chomp_size].contiguous() - - -def chose_norm(norm_type, channel_size): - """The input of normlization will be (M, C, K), where M is batch size, - C is channel size and K is sequence length. - """ - if norm_type == "gLN": - return GlobalLayerNorm(channel_size) - elif norm_type == "cLN": - return ChannelwiseLayerNorm(channel_size) - elif norm_type == "id": - return nn.Identity() - else: # norm_type == "BN": - # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics - # along M and K, so this BN usage is right. - return nn.BatchNorm1d(channel_size) - - -# TODO: Use nn.LayerNorm to impl cLN to speed up -class ChannelwiseLayerNorm(nn.Module): - """Channel-wise Layer Normalization (cLN)""" - def __init__(self, channel_size): - super(ChannelwiseLayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.reset_parameters() - - def reset_parameters(self): - self.gamma.data.fill_(1) - self.beta.data.zero_() - - def forward(self, y): - """ - Args: - y: [M, N, K], M is batch size, N is channel size, K is length - Returns: - cLN_y: [M, N, K] - """ - mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] - var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] - cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta - return cLN_y - - -class GlobalLayerNorm(nn.Module): - """Global Layer Normalization (gLN)""" - def __init__(self, channel_size): - super(GlobalLayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] - self.reset_parameters() - - def reset_parameters(self): - self.gamma.data.fill_(1) - self.beta.data.zero_() - - def forward(self, y): - """ - Args: - y: [M, N, K], M is batch size, N is channel size, K is length - Returns: - gLN_y: [M, N, K] - """ - # TODO: in torch 1.0, torch.mean() support dim list - mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] - var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) - gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta - return gLN_y - - -if __name__ == "__main__": - torch.manual_seed(123) - M, N, L, T = 2, 3, 4, 12 - K = 2 * T // L - 1 - B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False - mixture = torch.randint(3, (M, T)) - # test Encoder - encoder = Encoder(L, N) - encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) - mixture_w = encoder(mixture) - print('mixture', mixture) - print('U', encoder.conv1d_U.weight) - print('mixture_w', mixture_w) - print('mixture_w size', mixture_w.size()) - - # test TemporalConvNet - separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) - est_mask = separator(mixture_w) - print('est_mask', est_mask) - - # test Decoder - decoder = Decoder(N, L) - est_mask = torch.randint(2, (B, K, C, N)) - est_source = decoder(mixture_w, est_mask) - print('est_source', est_source) - - # test Conv-TasNet - conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) - est_source = conv_tasnet(mixture) - print('est_source', est_source) - print('est_source size', est_source.size()) diff --git a/demucs/test.py b/demucs/test.py deleted file mode 100644 index 4140914..0000000 --- a/demucs/test.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import gzip -import sys -from concurrent import futures - -import musdb -import museval -import torch as th -import tqdm -from scipy.io import wavfile -from torch import distributed - -from .audio import convert_audio -from .utils import apply_model - - -def evaluate(model, - musdb_path, - eval_folder, - workers=2, - device="cpu", - rank=0, - save=False, - shifts=0, - split=False, - overlap=0.25, - is_wav=False, - world_size=1): - """ - Evaluate model using museval. Run the model - on a single GPU, the bottleneck being the call to museval. - """ - - output_dir = eval_folder / "results" - output_dir.mkdir(exist_ok=True, parents=True) - json_folder = eval_folder / "results/test" - json_folder.mkdir(exist_ok=True, parents=True) - - # we load tracks from the original musdb set - test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav) - src_rate = 44100 # hardcoded for now... - - for p in model.parameters(): - p.requires_grad = False - p.grad = None - - pendings = [] - with futures.ProcessPoolExecutor(workers or 1) as pool: - for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout): - track = test_set.tracks[index] - - out = json_folder / f"{track.name}.json.gz" - if out.exists(): - continue - - mix = th.from_numpy(track.audio).t().float() - ref = mix.mean(dim=0) # mono mixture - mix = (mix - ref.mean()) / ref.std() - mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) - estimates = apply_model(model, mix.to(device), - shifts=shifts, split=split, overlap=overlap) - estimates = estimates * ref.std() + ref.mean() - - estimates = estimates.transpose(1, 2) - references = th.stack( - [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) - references = convert_audio(references, src_rate, - model.samplerate, model.audio_channels) - references = references.transpose(1, 2).numpy() - estimates = estimates.cpu().numpy() - win = int(1. * model.samplerate) - hop = int(1. * model.samplerate) - if save: - folder = eval_folder / "wav/test" / track.name - folder.mkdir(exist_ok=True, parents=True) - for name, estimate in zip(model.sources, estimates): - wavfile.write(str(folder / (name + ".wav")), 44100, estimate) - - if workers: - pendings.append((track.name, pool.submit( - museval.evaluate, references, estimates, win=win, hop=hop))) - else: - pendings.append((track.name, museval.evaluate( - references, estimates, win=win, hop=hop))) - del references, mix, estimates, track - - for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout): - if workers: - pending = pending.result() - sdr, isr, sir, sar = pending - track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name) - for idx, target in enumerate(model.sources): - values = { - "SDR": sdr[idx].tolist(), - "SIR": sir[idx].tolist(), - "ISR": isr[idx].tolist(), - "SAR": sar[idx].tolist() - } - - track_store.add_target(target_name=target, values=values) - json_path = json_folder / f"{track_name}.json.gz" - gzip.open(json_path, "w").write(track_store.json.encode('utf-8')) - if world_size > 1: - distributed.barrier() diff --git a/demucs/utils.py b/demucs/utils.py deleted file mode 100644 index 4364184..0000000 --- a/demucs/utils.py +++ /dev/null @@ -1,323 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import errno -import functools -import hashlib -import inspect -import io -import os -import random -import socket -import tempfile -import warnings -import zlib -from contextlib import contextmanager - -from diffq import UniformQuantizer, DiffQuantizer -import torch as th -import tqdm -from torch import distributed -from torch.nn import functional as F - - -def center_trim(tensor, reference): - """ - Center trim `tensor` with respect to `reference`, along the last dimension. - `reference` can also be a number, representing the length to trim to. - If the size difference != 0 mod 2, the extra sample is removed on the right side. - """ - if hasattr(reference, "size"): - reference = reference.size(-1) - delta = tensor.size(-1) - reference - if delta < 0: - raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") - if delta: - tensor = tensor[..., delta // 2:-(delta - delta // 2)] - return tensor - - -def average_metric(metric, count=1.): - """ - Average `metric` which should be a float across all hosts. `count` should be - the weight for this particular host (i.e. number of examples). - """ - metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') - distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) - return metric[1].item() / metric[0].item() - - -def free_port(host='', low=20000, high=40000): - """ - Return a port number that is most likely free. - This could suffer from a race condition although - it should be quite rare. - """ - sock = socket.socket() - while True: - port = random.randint(low, high) - try: - sock.bind((host, port)) - except OSError as error: - if error.errno == errno.EADDRINUSE: - continue - raise - return port - - -def sizeof_fmt(num, suffix='B'): - """ - Given `num` bytes, return human readable size. - Taken from https://stackoverflow.com/a/1094933 - """ - for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: - if abs(num) < 1024.0: - return "%3.1f%s%s" % (num, unit, suffix) - num /= 1024.0 - return "%.1f%s%s" % (num, 'Yi', suffix) - - -def human_seconds(seconds, display='.2f'): - """ - Given `seconds` seconds, return human readable duration. - """ - value = seconds * 1e6 - ratios = [1e3, 1e3, 60, 60, 24] - names = ['us', 'ms', 's', 'min', 'hrs', 'days'] - last = names.pop(0) - for name, ratio in zip(names, ratios): - if value / ratio < 0.3: - break - value /= ratio - last = name - return f"{format(value, display)} {last}" - - -class TensorChunk: - def __init__(self, tensor, offset=0, length=None): - total_length = tensor.shape[-1] - assert offset >= 0 - assert offset < total_length - - if length is None: - length = total_length - offset - else: - length = min(total_length - offset, length) - - self.tensor = tensor - self.offset = offset - self.length = length - self.device = tensor.device - - @property - def shape(self): - shape = list(self.tensor.shape) - shape[-1] = self.length - return shape - - def padded(self, target_length): - delta = target_length - self.length - total_length = self.tensor.shape[-1] - assert delta >= 0 - - start = self.offset - delta // 2 - end = start + target_length - - correct_start = max(0, start) - correct_end = min(total_length, end) - - pad_left = correct_start - start - pad_right = end - correct_end - - out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) - assert out.shape[-1] == target_length - return out - - -def tensor_chunk(tensor_or_chunk): - if isinstance(tensor_or_chunk, TensorChunk): - return tensor_or_chunk - else: - assert isinstance(tensor_or_chunk, th.Tensor) - return TensorChunk(tensor_or_chunk) - - -def apply_model(model, mix, shifts=None, split=False, - overlap=0.25, transition_power=1., progress=False): - """ - Apply model to a given mixture. - - Args: - shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec - and apply the oppositve shift to the output. This is repeated `shifts` time and - all predictions are averaged. This effectively makes the model time equivariant - and improves SDR by up to 0.2 points. - split (bool): if True, the input will be broken down in 8 seconds extracts - and predictions will be performed individually on each and concatenated. - Useful for model with large memory footprint like Tasnet. - progress (bool): if True, show a progress bar (requires split=True) - """ - assert transition_power >= 1, "transition_power < 1 leads to weird behavior." - device = mix.device - channels, length = mix.shape - if split: - out = th.zeros(len(model.sources), channels, length, device=device) - sum_weight = th.zeros(length, device=device) - segment = model.segment_length - stride = int((1 - overlap) * segment) - offsets = range(0, length, stride) - scale = stride / model.samplerate - if progress: - offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') - # We start from a triangle shaped weight, with maximal weight in the middle - # of the segment. Then we normalize and take to the power `transition_power`. - # Large values of transition power will lead to sharper transitions. - weight = th.cat([th.arange(1, segment // 2 + 1), - th.arange(segment - segment // 2, 0, -1)]).to(device) - assert len(weight) == segment - # If the overlap < 50%, this will translate to linear transition when - # transition_power is 1. - weight = (weight / weight.max())**transition_power - for offset in offsets: - chunk = TensorChunk(mix, offset, segment) - chunk_out = apply_model(model, chunk, shifts=shifts) - chunk_length = chunk_out.shape[-1] - out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out - sum_weight[offset:offset + segment] += weight[:chunk_length] - offset += segment - assert sum_weight.min() > 0 - out /= sum_weight - return out - elif shifts: - max_shift = int(0.5 * model.samplerate) - mix = tensor_chunk(mix) - padded_mix = mix.padded(length + 2 * max_shift) - out = 0 - for _ in range(shifts): - offset = random.randint(0, max_shift) - shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) - shifted_out = apply_model(model, shifted) - out += shifted_out[..., max_shift - offset:] - out /= shifts - return out - else: - valid_length = model.valid_length(length) - mix = tensor_chunk(mix) - padded_mix = mix.padded(valid_length) - with th.no_grad(): - out = model(padded_mix.unsqueeze(0))[0] - return center_trim(out, length) - - -@contextmanager -def temp_filenames(count, delete=True): - names = [] - try: - for _ in range(count): - names.append(tempfile.NamedTemporaryFile(delete=False).name) - yield names - finally: - if delete: - for name in names: - os.unlink(name) - - -def get_quantizer(model, args, optimizer=None): - quantizer = None - if args.diffq: - quantizer = DiffQuantizer( - model, min_size=args.q_min_size, group_size=8) - if optimizer is not None: - quantizer.setup_optimizer(optimizer) - elif args.qat: - quantizer = UniformQuantizer( - model, bits=args.qat, min_size=args.q_min_size) - return quantizer - - -def load_model(path, strict=False): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - load_from = path - package = th.load(load_from, 'cpu') - - klass = package["klass"] - args = package["args"] - kwargs = package["kwargs"] - - if strict: - model = klass(*args, **kwargs) - else: - sig = inspect.signature(klass) - for key in list(kwargs): - if key not in sig.parameters: - warnings.warn("Dropping inexistant parameter " + key) - del kwargs[key] - model = klass(*args, **kwargs) - - state = package["state"] - training_args = package["training_args"] - quantizer = get_quantizer(model, training_args) - - set_state(model, quantizer, state) - return model - - -def get_state(model, quantizer): - if quantizer is None: - state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} - else: - state = quantizer.get_quantized_state() - buf = io.BytesIO() - th.save(state, buf) - state = {'compressed': zlib.compress(buf.getvalue())} - return state - - -def set_state(model, quantizer, state): - if quantizer is None: - model.load_state_dict(state) - else: - buf = io.BytesIO(zlib.decompress(state["compressed"])) - state = th.load(buf, "cpu") - quantizer.restore_quantized_state(state) - - return state - - -def save_state(state, path): - buf = io.BytesIO() - th.save(state, buf) - sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] - - path = path.parent / (path.stem + "-" + sig + path.suffix) - path.write_bytes(buf.getvalue()) - - -def save_model(model, quantizer, training_args, path): - args, kwargs = model._init_args_kwargs - klass = model.__class__ - - state = get_state(model, quantizer) - - save_to = path - package = { - 'klass': klass, - 'args': args, - 'kwargs': kwargs, - 'state': state, - 'training_args': training_args, - } - th.save(package, save_to) - - -def capture_init(init): - @functools.wraps(init) - def __init__(self, *args, **kwargs): - self._init_args_kwargs = (args, kwargs) - init(self, *args, **kwargs) - - return __init__ diff --git a/demucs/wav.py b/demucs/wav.py deleted file mode 100644 index a65c3b2..0000000 --- a/demucs/wav.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from collections import OrderedDict -import hashlib -import math -import json -from pathlib import Path - -import julius -import torch as th -from torch import distributed -import torchaudio as ta -from torch.nn import functional as F - -from .audio import convert_audio_channels -from .compressed import get_musdb_tracks - -MIXTURE = "mixture" -EXT = ".wav" - - -def _track_metadata(track, sources): - track_length = None - track_samplerate = None - for source in sources + [MIXTURE]: - file = track / f"{source}{EXT}" - info = ta.info(str(file)) - length = info.num_frames - if track_length is None: - track_length = length - track_samplerate = info.sample_rate - elif track_length != length: - raise ValueError( - f"Invalid length for file {file}: " - f"expecting {track_length} but got {length}.") - elif info.sample_rate != track_samplerate: - raise ValueError( - f"Invalid sample rate for file {file}: " - f"expecting {track_samplerate} but got {info.sample_rate}.") - if source == MIXTURE: - wav, _ = ta.load(str(file)) - wav = wav.mean(0) - mean = wav.mean().item() - std = wav.std().item() - - return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} - - -def _build_metadata(path, sources): - meta = {} - path = Path(path) - for file in path.iterdir(): - meta[file.name] = _track_metadata(file, sources) - return meta - - -class Wavset: - def __init__( - self, - root, metadata, sources, - length=None, stride=None, normalize=True, - samplerate=44100, channels=2): - """ - Waveset (or mp3 set for that matter). Can be used to train - with arbitrary sources. Each track should be one folder inside of `path`. - The folder should contain files named `{source}.{ext}`. - Files will be grouped according to `sources` (each source is a list of - filenames). - - Sample rate and channels will be converted on the fly. - - `length` is the sample size to extract (in samples, not duration). - `stride` is how many samples to move by between each example. - """ - self.root = Path(root) - self.metadata = OrderedDict(metadata) - self.length = length - self.stride = stride or length - self.normalize = normalize - self.sources = sources - self.channels = channels - self.samplerate = samplerate - self.num_examples = [] - for name, meta in self.metadata.items(): - track_length = int(self.samplerate * meta['length'] / meta['samplerate']) - if length is None or track_length < length: - examples = 1 - else: - examples = int(math.ceil((track_length - self.length) / self.stride) + 1) - self.num_examples.append(examples) - - def __len__(self): - return sum(self.num_examples) - - def get_file(self, name, source): - return self.root / name / f"{source}{EXT}" - - def __getitem__(self, index): - for name, examples in zip(self.metadata, self.num_examples): - if index >= examples: - index -= examples - continue - meta = self.metadata[name] - num_frames = -1 - offset = 0 - if self.length is not None: - offset = int(math.ceil( - meta['samplerate'] * self.stride * index / self.samplerate)) - num_frames = int(math.ceil( - meta['samplerate'] * self.length / self.samplerate)) - wavs = [] - for source in self.sources: - file = self.get_file(name, source) - wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) - wav = convert_audio_channels(wav, self.channels) - wavs.append(wav) - - example = th.stack(wavs) - example = julius.resample_frac(example, meta['samplerate'], self.samplerate) - if self.normalize: - example = (example - meta['mean']) / meta['std'] - if self.length: - example = example[..., :self.length] - example = F.pad(example, (0, self.length - example.shape[-1])) - return example - - -def get_wav_datasets(args, samples, sources): - sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8] - metadata_file = args.metadata / (sig + ".json") - train_path = args.wav / "train" - valid_path = args.wav / "valid" - if not metadata_file.is_file() and args.rank == 0: - train = _build_metadata(train_path, sources) - valid = _build_metadata(valid_path, sources) - json.dump([train, valid], open(metadata_file, "w")) - if args.world_size > 1: - distributed.barrier() - train, valid = json.load(open(metadata_file)) - train_set = Wavset(train_path, train, sources, - length=samples, stride=args.data_stride, - samplerate=args.samplerate, channels=args.audio_channels, - normalize=args.norm_wav) - valid_set = Wavset(valid_path, valid, [MIXTURE] + sources, - samplerate=args.samplerate, channels=args.audio_channels, - normalize=args.norm_wav) - return train_set, valid_set - - -def get_musdb_wav_datasets(args, samples, sources): - metadata_file = args.metadata / "musdb_wav.json" - root = args.musdb / "train" - if not metadata_file.is_file() and args.rank == 0: - metadata = _build_metadata(root, sources) - json.dump(metadata, open(metadata_file, "w")) - if args.world_size > 1: - distributed.barrier() - metadata = json.load(open(metadata_file)) - - train_tracks = get_musdb_tracks(args.musdb, is_wav=True, subsets=["train"], split="train") - metadata_train = {name: meta for name, meta in metadata.items() if name in train_tracks} - metadata_valid = {name: meta for name, meta in metadata.items() if name not in train_tracks} - train_set = Wavset(root, metadata_train, sources, - length=samples, stride=args.data_stride, - samplerate=args.samplerate, channels=args.audio_channels, - normalize=args.norm_wav) - valid_set = Wavset(root, metadata_valid, [MIXTURE] + sources, - samplerate=args.samplerate, channels=args.audio_channels, - normalize=args.norm_wav) - return train_set, valid_set