diff --git a/demucs/__pycache__/apply.cpython-39.pyc b/demucs/__pycache__/apply.cpython-39.pyc new file mode 100644 index 0000000..cbd1c41 Binary files /dev/null and b/demucs/__pycache__/apply.cpython-39.pyc differ diff --git a/demucs/__pycache__/audio.cpython-39.pyc b/demucs/__pycache__/audio.cpython-39.pyc new file mode 100644 index 0000000..d257ebb Binary files /dev/null and b/demucs/__pycache__/audio.cpython-39.pyc differ diff --git a/demucs/__pycache__/demucs.cpython-39.pyc b/demucs/__pycache__/demucs.cpython-39.pyc new file mode 100644 index 0000000..f608b5a Binary files /dev/null and b/demucs/__pycache__/demucs.cpython-39.pyc differ diff --git a/demucs/__pycache__/hdemucs.cpython-39.pyc b/demucs/__pycache__/hdemucs.cpython-39.pyc new file mode 100644 index 0000000..c9d8ff4 Binary files /dev/null and b/demucs/__pycache__/hdemucs.cpython-39.pyc differ diff --git a/demucs/__pycache__/pretrained.cpython-39.pyc b/demucs/__pycache__/pretrained.cpython-39.pyc new file mode 100644 index 0000000..d3cc165 Binary files /dev/null and b/demucs/__pycache__/pretrained.cpython-39.pyc differ diff --git a/demucs/__pycache__/repo.cpython-39.pyc b/demucs/__pycache__/repo.cpython-39.pyc new file mode 100644 index 0000000..18ef7a3 Binary files /dev/null and b/demucs/__pycache__/repo.cpython-39.pyc differ diff --git a/demucs/__pycache__/spec.cpython-39.pyc b/demucs/__pycache__/spec.cpython-39.pyc new file mode 100644 index 0000000..1c75010 Binary files /dev/null and b/demucs/__pycache__/spec.cpython-39.pyc differ diff --git a/demucs/__pycache__/states.cpython-39.pyc b/demucs/__pycache__/states.cpython-39.pyc new file mode 100644 index 0000000..d013690 Binary files /dev/null and b/demucs/__pycache__/states.cpython-39.pyc differ diff --git a/demucs/__pycache__/utils.cpython-39.pyc b/demucs/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000..0ad9279 Binary files /dev/null and b/demucs/__pycache__/utils.cpython-39.pyc differ diff --git a/demucs/apply.py b/demucs/apply.py new file mode 100644 index 0000000..10ebd66 --- /dev/null +++ b/demucs/apply.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Code to apply a model to a mix. It will handle chunking with overlaps and +inteprolation between chunks, as well as the "shift trick". +""" +from concurrent.futures import ThreadPoolExecutor +import random +import typing as tp + +import torch as th +from torch import nn +from torch.nn import functional as F +import tqdm + +from .demucs import Demucs +from .hdemucs import HDemucs +from .utils import center_trim, DummyPoolExecutor + +Model = tp.Union[Demucs, HDemucs] + + +class BagOfModels(nn.Module): + def __init__(self, models: tp.List[Model], + weights: tp.Optional[tp.List[tp.List[float]]] = None, + segment: tp.Optional[float] = None): + """ + Represents a bag of models with specific weights. + You should call `apply_model` rather than calling directly the forward here for + optimal performance. + + Args: + models (list[nn.Module]): list of Demucs/HDemucs models. + weights (list[list[float]]): list of weights. If None, assumed to + be all ones, otherwise it should be a list of N list (N number of models), + each containing S floats (S number of sources). + segment (None or float): overrides the `segment` attribute of each model + (this is performed inplace, be careful is you reuse the models passed). + """ + super().__init__() + assert len(models) > 0 + first = models[0] + for other in models: + assert other.sources == first.sources + assert other.samplerate == first.samplerate + assert other.audio_channels == first.audio_channels + if segment is not None: + other.segment = segment + + self.audio_channels = first.audio_channels + self.samplerate = first.samplerate + self.sources = first.sources + self.models = nn.ModuleList(models) + + if weights is None: + weights = [[1. for _ in first.sources] for _ in models] + else: + assert len(weights) == len(models) + for weight in weights: + assert len(weight) == len(first.sources) + self.weights = weights + + def forward(self, x): + raise NotImplementedError("Call `apply_model` on this.") + + +class TensorChunk: + def __init__(self, tensor, offset=0, length=None): + total_length = tensor.shape[-1] + assert offset >= 0 + assert offset < total_length + + if length is None: + length = total_length - offset + else: + length = min(total_length - offset, length) + + self.tensor = tensor + self.offset = offset + self.length = length + self.device = tensor.device + + @property + def shape(self): + shape = list(self.tensor.shape) + shape[-1] = self.length + return shape + + def padded(self, target_length): + delta = target_length - self.length + total_length = self.tensor.shape[-1] + assert delta >= 0 + + start = self.offset - delta // 2 + end = start + target_length + + correct_start = max(0, start) + correct_end = min(total_length, end) + + pad_left = correct_start - start + pad_right = end - correct_end + + out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) + assert out.shape[-1] == target_length + return out + + +def tensor_chunk(tensor_or_chunk): + if isinstance(tensor_or_chunk, TensorChunk): + return tensor_or_chunk + else: + assert isinstance(tensor_or_chunk, th.Tensor) + return TensorChunk(tensor_or_chunk) + + +def apply_model(model, mix, shifts=1, split=True, + overlap=0.25, transition_power=1., progress=False, device=None, + num_workers=0, pool=None): + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + device (torch.device, str, or None): if provided, device on which to + execute the computation, otherwise `mix.device` is assumed. + When `device` is different from `mix.device`, only local computations will + be on `device`, while the entire tracks will be stored on `mix.device`. + """ + if device is None: + device = mix.device + else: + device = th.device(device) + if pool is None: + if num_workers > 0 and device.type == 'cpu': + pool = ThreadPoolExecutor(num_workers) + else: + pool = DummyPoolExecutor() + kwargs = { + 'shifts': shifts, + 'split': split, + 'overlap': overlap, + 'transition_power': transition_power, + 'progress': progress, + 'device': device, + 'pool': pool, + } + if isinstance(model, BagOfModels): + # Special treatment for bag of model. + # We explicitely apply multiple times `apply_model` so that the random shifts + # are different for each model. + estimates = 0 + totals = [0] * len(model.sources) + for sub_model, weight in zip(model.models, model.weights): + original_model_device = next(iter(sub_model.parameters())).device + sub_model.to(device) + + out = apply_model(sub_model, mix, **kwargs) + sub_model.to(original_model_device) + for k, inst_weight in enumerate(weight): + out[:, k, :, :] *= inst_weight + totals[k] += inst_weight + estimates += out + del out + + for k in range(estimates.shape[1]): + estimates[:, k, :, :] /= totals[k] + return estimates + + model.to(device) + assert transition_power >= 1, "transition_power < 1 leads to weird behavior." + batch, channels, length = mix.shape + if split: + kwargs['split'] = False + out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) + sum_weight = th.zeros(length, device=mix.device) + segment = int(model.samplerate * model.segment) + stride = int((1 - overlap) * segment) + offsets = range(0, length, stride) + scale = stride / model.samplerate + # We start from a triangle shaped weight, with maximal weight in the middle + # of the segment. Then we normalize and take to the power `transition_power`. + # Large values of transition power will lead to sharper transitions. + weight = th.cat([th.arange(1, segment // 2 + 1, device=device), + th.arange(segment - segment // 2, 0, -1, device=device)]) + assert len(weight) == segment + # If the overlap < 50%, this will translate to linear transition when + # transition_power is 1. + weight = (weight / weight.max())**transition_power + futures = [] + for offset in offsets: + chunk = TensorChunk(mix, offset, segment) + future = pool.submit(apply_model, model, chunk, **kwargs) + futures.append((future, offset)) + offset += segment + if progress: + futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') + for future, offset in futures: + chunk_out = future.result() + chunk_length = chunk_out.shape[-1] + out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device) + sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device) + assert sum_weight.min() > 0 + out /= sum_weight + return out + elif shifts: + kwargs['shifts'] = 0 + max_shift = int(0.5 * model.samplerate) + mix = tensor_chunk(mix) + padded_mix = mix.padded(length + 2 * max_shift) + out = 0 + for _ in range(shifts): + offset = random.randint(0, max_shift) + shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) + shifted_out = apply_model(model, shifted, **kwargs) + out += shifted_out[..., max_shift - offset:] + out /= shifts + return out + else: + if hasattr(model, 'valid_length'): + valid_length = model.valid_length(length) + else: + valid_length = length + mix = tensor_chunk(mix) + padded_mix = mix.padded(valid_length).to(device) + with th.no_grad(): + out = model(padded_mix) + return center_trim(out, length) diff --git a/demucs/audio.py b/demucs/audio.py new file mode 100644 index 0000000..d1ba194 --- /dev/null +++ b/demucs/audio.py @@ -0,0 +1,256 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import json +import subprocess as sp +from pathlib import Path + +import lameenc +import julius +import numpy as np +import torch +import torchaudio as ta + +from .utils import temp_filenames + + +def _read_info(path): + stdout_data = sp.check_output([ + 'ffprobe', "-loglevel", "panic", + str(path), '-print_format', 'json', '-show_format', '-show_streams' + ]) + return json.loads(stdout_data.decode('utf-8')) + + +class AudioFile: + """ + Allows to read audio from any format supported by ffmpeg, as well as resampling or + converting to mono on the fly. See :method:`read` for more details. + """ + def __init__(self, path: Path): + self.path = Path(path) + self._info = None + + def __repr__(self): + features = [("path", self.path)] + features.append(("samplerate", self.samplerate())) + features.append(("channels", self.channels())) + features.append(("streams", len(self))) + features_str = ", ".join(f"{name}={value}" for name, value in features) + return f"AudioFile({features_str})" + + @property + def info(self): + if self._info is None: + self._info = _read_info(self.path) + return self._info + + @property + def duration(self): + return float(self.info['format']['duration']) + + @property + def _audio_streams(self): + return [ + index for index, stream in enumerate(self.info["streams"]) + if stream["codec_type"] == "audio" + ] + + def __len__(self): + return len(self._audio_streams) + + def channels(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['channels']) + + def samplerate(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) + + def read(self, + seek_time=None, + duration=None, + streams=slice(None), + samplerate=None, + channels=None, + temp_folder=None): + """ + Slightly more efficient implementation than stempeg, + in particular, this will extract all stems at once + rather than having to loop over one file multiple times + for each stream. + + Args: + seek_time (float): seek time in seconds or None if no seeking is needed. + duration (float): duration in seconds to extract or None to extract until the end. + streams (slice, int or list): streams to extract, can be a single int, a list or + a slice. If it is a slice or list, the output will be of size [S, C, T] + with S the number of streams, C the number of channels and T the number of samples. + If it is an int, the output will be [C, T]. + samplerate (int): if provided, will resample on the fly. If None, no resampling will + be done. Original sampling rate can be obtained with :method:`samplerate`. + channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that + as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. + See https://sound.stackexchange.com/a/42710. + Our definition of mono is simply the average of the two channels. Any other + value will be ignored. + temp_folder (str or Path or None): temporary folder to use for decoding. + + + """ + streams = np.array(range(len(self)))[streams] + single = not isinstance(streams, np.ndarray) + if single: + streams = [streams] + + if duration is None: + target_size = None + query_duration = None + else: + target_size = int((samplerate or self.samplerate()) * duration) + query_duration = float((target_size + 1) / (samplerate or self.samplerate())) + + with temp_filenames(len(streams)) as filenames: + command = ['ffmpeg', '-y'] + command += ['-loglevel', 'panic'] + if seek_time: + command += ['-ss', str(seek_time)] + command += ['-i', str(self.path)] + for stream, filename in zip(streams, filenames): + command += ['-map', f'0:{self._audio_streams[stream]}'] + if query_duration is not None: + command += ['-t', str(query_duration)] + command += ['-threads', '1'] + command += ['-f', 'f32le'] + if samplerate is not None: + command += ['-ar', str(samplerate)] + command += [filename] + + sp.run(command, check=True) + wavs = [] + for filename in filenames: + wav = np.fromfile(filename, dtype=np.float32) + wav = torch.from_numpy(wav) + wav = wav.view(-1, self.channels()).t() + if channels is not None: + wav = convert_audio_channels(wav, channels) + if target_size is not None: + wav = wav[..., :target_size] + wavs.append(wav) + wav = torch.stack(wavs, dim=0) + if single: + wav = wav[0] + return wav + + +def convert_audio_channels(wav, channels=2): + """Convert audio to the given number of channels.""" + *shape, src_channels, length = wav.shape + if src_channels == channels: + pass + elif channels == 1: + # Case 1: + # The caller asked 1-channel audio, but the stream have multiple + # channels, downmix all channels. + wav = wav.mean(dim=-2, keepdim=True) + elif src_channels == 1: + # Case 2: + # The caller asked for multiple channels, but the input file have + # one single channel, replicate the audio over all channels. + wav = wav.expand(*shape, channels, length) + elif src_channels >= channels: + # Case 3: + # The caller asked for multiple channels, and the input file have + # more channels than requested. In that case return the first channels. + wav = wav[..., :channels, :] + else: + # Case 4: What is a reasonable choice here? + raise ValueError('The audio file has less channels than requested but is not mono.') + return wav + + +def convert_audio(wav, from_samplerate, to_samplerate, channels): + """Convert audio from a given samplerate to a target one and target number of channels.""" + wav = convert_audio_channels(wav, channels) + return julius.resample_frac(wav, from_samplerate, to_samplerate) + + +def i16_pcm(wav): + """Convert audio to 16 bits integer PCM format.""" + if wav.dtype.is_floating_point: + return (wav.clamp_(-1, 1) * (2**15 - 1)).short() + else: + return wav + + +def f32_pcm(wav): + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + else: + return wav.float() / (2**15 - 1) + + +def as_dtype_pcm(wav, dtype): + """Convert audio to either f32 pcm or i16 pcm depending on the given dtype.""" + if wav.dtype.is_floating_point: + return f32_pcm(wav) + else: + return i16_pcm(wav) + + +def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False): + """Save given audio as mp3. This should work on all OSes.""" + C, T = wav.shape + wav = i16_pcm(wav) + encoder = lameenc.Encoder() + encoder.set_bit_rate(bitrate) + encoder.set_in_sample_rate(samplerate) + encoder.set_channels(C) + encoder.set_quality(2) # 2-highest, 7-fastest + if not verbose: + encoder.silence() + wav = wav.transpose(0, 1).numpy() + mp3_data = encoder.encode(wav.tobytes()) + mp3_data += encoder.flush() + with open(path, "wb") as f: + f.write(mp3_data) + + +def prevent_clip(wav, mode='rescale'): + """ + different strategies for avoiding raw clipping. + """ + assert wav.dtype.is_floating_point, "too late for clipping" + if mode == 'rescale': + wav = wav / max(1.01 * wav.abs().max(), 1) + elif mode == 'clamp': + wav = wav.clamp(-0.99, 0.99) + elif mode == 'tanh': + wav = torch.tanh(wav) + else: + raise ValueError(f"Invalid mode {mode}") + return wav + + +def save_audio(wav, path, samplerate, bitrate=320, clip='rescale', + bits_per_sample=16, as_float=False): + """Save audio file, automatically preventing clipping if necessary + based on the given `clip` strategy. If the path ends in `.mp3`, this + will save as mp3 with the given `bitrate`. + """ + wav = prevent_clip(wav, mode=clip) + path = Path(path) + suffix = path.suffix.lower() + if suffix == ".mp3": + encode_mp3(wav, path, samplerate, bitrate) + elif suffix == ".wav": + if as_float: + bits_per_sample = 32 + encoding = 'PCM_F' + else: + encoding = 'PCM_S' + ta.save(str(path), wav, sample_rate=samplerate, + encoding=encoding, bits_per_sample=bits_per_sample) + else: + raise ValueError(f"Invalid suffix for path: {suffix}") diff --git a/demucs/demucs.py b/demucs/demucs.py new file mode 100644 index 0000000..d2c08e7 --- /dev/null +++ b/demucs/demucs.py @@ -0,0 +1,459 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import julius +import torch +from torch import nn +from torch.nn import functional as F + +from .states import capture_init +from .utils import center_trim, unfold + + +class BLSTM(nn.Module): + """ + BiLSTM with same hidden units as input dim. + If `max_steps` is not None, input will be splitting in overlapping + chunks and the LSTM applied separately on each chunk. + """ + def __init__(self, dim, layers=1, max_steps=None, skip=False): + super().__init__() + assert max_steps is None or max_steps % 4 == 0 + self.max_steps = max_steps + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + self.skip = skip + + def forward(self, x): + B, C, T = x.shape + y = x + framed = False + if self.max_steps is not None and T > self.max_steps: + width = self.max_steps + stride = width // 2 + frames = unfold(x, width, stride) + nframes = frames.shape[2] + framed = True + x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) + + x = x.permute(2, 0, 1) + + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + if framed: + out = [] + frames = x.reshape(B, -1, C, width) + limit = stride // 2 + for k in range(nframes): + if k == 0: + out.append(frames[:, k, :, :-limit]) + elif k == nframes - 1: + out.append(frames[:, k, :, limit:]) + else: + out.append(frames[:, k, :, limit:-limit]) + out = torch.cat(out, -1) + out = out[..., :T] + x = out + if self.skip: + x = x + y + return x + + +def rescale_conv(conv, reference): + """Rescale initial weight scale. It is unclear why it helps but it certainly does. + """ + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): + rescale_conv(sub, reference) + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonaly residual outputs close to 0 initially, then learnt. + """ + def __init__(self, channels: int, init: float = 0): + super().__init__() + self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) + self.scale.data[:] = init + + def forward(self, x): + return self.scale[:, None] * x + + +class DConv(nn.Module): + """ + New residual branches in each encoder layer. + This alternates dilated convolutions, potentially with LSTMs and attention. + Also before entering each residual branch, dimension is projected on a smaller subspace, + e.g. of dim `channels // compress`. + """ + def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, + norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, + kernel=3, dilate=True): + """ + Args: + channels: input/output channels for residual branch. + compress: amount of channel compression inside the branch. + depth: number of layers in the residual branch. Each layer has its own + projection, and potentially LSTM and attention. + init: initial scale for LayerNorm. + norm: use GroupNorm. + attn: use LocalAttention. + heads: number of heads for the LocalAttention. + ndecay: number of decay controls in the LocalAttention. + lstm: use LSTM. + gelu: Use GELU activation. + kernel: kernel size for the (dilated) convolutions. + dilate: if true, use dilation, increasing with the depth. + """ + + super().__init__() + assert kernel % 2 == 1 + self.channels = channels + self.compress = compress + self.depth = abs(depth) + dilate = depth > 0 + + norm_fn: tp.Callable[[int], nn.Module] + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(1, d) # noqa + + hidden = int(channels / compress) + + act: tp.Type[nn.Module] + if gelu: + act = nn.GELU + else: + act = nn.ReLU + + self.layers = nn.ModuleList([]) + for d in range(self.depth): + dilation = 2 ** d if dilate else 1 + padding = dilation * (kernel // 2) + mods = [ + nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), + norm_fn(hidden), act(), + nn.Conv1d(hidden, 2 * channels, 1), + norm_fn(2 * channels), nn.GLU(1), + LayerScale(channels, init), + ] + if attn: + mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) + if lstm: + mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class LocalState(nn.Module): + """Local state allows to have attention based only on data (no positional embedding), + but while setting a constraint on the time window (e.g. decaying penalty term). + + Also a failed experiments with trying to provide some frequency based attention. + """ + def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): + super().__init__() + assert channels % heads == 0, (channels, heads) + self.heads = heads + self.nfreqs = nfreqs + self.ndecay = ndecay + self.content = nn.Conv1d(channels, channels, 1) + self.query = nn.Conv1d(channels, channels, 1) + self.key = nn.Conv1d(channels, channels, 1) + if nfreqs: + self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) + if ndecay: + self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) + # Initialize decay close to zero (there is a sigmoid), for maximum initial window. + self.query_decay.weight.data *= 0.01 + assert self.query_decay.bias is not None # stupid type checker + self.query_decay.bias.data[:] = -2 + self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) + + def forward(self, x): + B, C, T = x.shape + heads = self.heads + indexes = torch.arange(T, device=x.device, dtype=x.dtype) + # left index are keys, right index are queries + delta = indexes[:, None] - indexes[None, :] + + queries = self.query(x).view(B, heads, -1, T) + keys = self.key(x).view(B, heads, -1, T) + # t are keys, s are queries + dots = torch.einsum("bhct,bhcs->bhts", keys, queries) + dots /= keys.shape[2]**0.5 + if self.nfreqs: + periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) + freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) + freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 + dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) + if self.ndecay: + decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) + decay_q = self.query_decay(x).view(B, heads, -1, T) + decay_q = torch.sigmoid(decay_q) / 2 + decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 + dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) + + # Kill self reference. + dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) + weights = torch.softmax(dots, dim=2) + + content = self.content(x).view(B, heads, -1, T) + result = torch.einsum("bhts,bhct->bhcs", weights, content) + if self.nfreqs: + time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) + result = torch.cat([result, time_sig], 2) + result = result.reshape(B, -1, T) + return x + self.proj(result) + + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=64, + growth=2., + # Main structure + depth=6, + rewrite=True, + lstm_layers=0, + # Convolutions + kernel_size=8, + stride=4, + context=1, + # Activations + gelu=True, + glu=True, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Pre/post processing + normalize=True, + resample=True, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + """ + Args: + sources (list[str]): list of source names + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated + by default, as this is now replaced by the smaller and faster small LSTMs + in the DConv branches. + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time steps. + gelu: use GELU activation function. + glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + normalize (bool): normalizes the input audio on the fly, and scales back + the output by the same amount. + resample (bool): upsample x2 the input and downsample /2 the output. + rescale (int): rescale initial weights of convolutions + to get their standard deviation closer to `rescale`. + samplerate (int): stored as meta information for easing + future evaluations of the model. + segment (float): duration of the chunks of audio to ideally evaluate the model on. + This is used by `demucs.apply.apply_model`. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.resample = resample + self.channels = channels + self.normalize = normalize + self.samplerate = samplerate + self.segment = segment + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + self.skip_scales = nn.ModuleList() + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + if gelu: + act2 = nn.GELU + else: + act2 = nn.ReLU + + in_channels = audio_channels + padding = 0 + for index in range(depth): + norm_fn = lambda d: nn.Identity() # noqa + if index >= norm_starts: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + + encode = [] + encode += [ + nn.Conv1d(in_channels, channels, kernel_size, stride), + norm_fn(channels), + act2(), + ] + attn = index >= dconv_attn + lstm = index >= dconv_lstm + if dconv_mode & 1: + encode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + if rewrite: + encode += [ + nn.Conv1d(channels, ch_scale * channels, 1), + norm_fn(ch_scale * channels), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + out_channels = len(self.sources) * audio_channels + if rewrite: + decode += [ + nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), + norm_fn(ch_scale * channels), activation] + if dconv_mode & 2: + decode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + decode += [nn.ConvTranspose1d(channels, out_channels, + kernel_size, stride, padding=padding)] + if index > 0: + decode += [norm_fn(out_channels), act2()] + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolution, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + Note that input are automatically padded if necessary to ensure that the output + has the same length as the input. + """ + if self.resample: + length *= 2 + + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + + for idx in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + + if self.resample: + length = math.ceil(length / 2) + return int(length) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + mean = mono.mean(dim=-1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + x = (x - mean) / (1e-5 + std) + else: + mean = 0 + std = 1 + + delta = self.valid_length(length) - length + x = F.pad(x, (delta // 2, delta - delta // 2)) + + if self.resample: + x = julius.resample_frac(x, 1, 2) + + saved = [] + for encode in self.encoder: + x = encode(x) + saved.append(x) + + if self.lstm: + x = self.lstm(x) + + for decode in self.decoder: + skip = saved.pop(-1) + skip = center_trim(skip, x) + x = decode(x + skip) + + if self.resample: + x = julius.resample_frac(x, 2, 1) + x = x * std + mean + x = center_trim(x, length) + x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) + return x + + def load_state_dict(self, state, strict=True): + # fix a mismatch with previous generation Demucs models. + for idx in range(self.depth): + for a in ['encoder', 'decoder']: + for b in ['bias', 'weight']: + new = f'{a}.{idx}.3.{b}' + old = f'{a}.{idx}.2.{b}' + if old in state and new not in state: + state[new] = state.pop(old) + super().load_state_dict(state, strict=strict) diff --git a/demucs/distrib.py b/demucs/distrib.py new file mode 100644 index 0000000..b73011a --- /dev/null +++ b/demucs/distrib.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Distributed training utilities. +""" +import logging +import pickle + +import numpy as np +import torch +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader, Subset +from torch.nn.parallel.distributed import DistributedDataParallel + +from dora import distrib as dora_distrib + +logger = logging.getLogger(__name__) +rank = 0 +world_size = 1 + + +def init(): + global rank, world_size + if not torch.distributed.is_initialized(): + dora_distrib.init() + rank = dora_distrib.rank() + world_size = dora_distrib.world_size() + + +def average(metrics, count=1.): + if isinstance(metrics, dict): + keys, values = zip(*sorted(metrics.items())) + values = average(values, count) + return dict(zip(keys, values)) + if world_size == 1: + return metrics + tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) + tensor *= count + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) + return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() + + +def wrap(model): + if world_size == 1: + return model + else: + return DistributedDataParallel( + model, + # find_unused_parameters=True, + device_ids=[torch.cuda.current_device()], + output_device=torch.cuda.current_device()) + + +def barrier(): + if world_size > 1: + torch.distributed.barrier() + + +def share(obj=None, src=0): + if world_size == 1: + return obj + size = torch.empty(1, device='cuda', dtype=torch.long) + if rank == src: + dump = pickle.dumps(obj) + size[0] = len(dump) + torch.distributed.broadcast(size, src=src) + # size variable is now set to the length of pickled obj in all processes + + if rank == src: + buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda() + else: + buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8) + torch.distributed.broadcast(buffer, src=src) + # buffer variable is now set to pickled obj in all processes + + if rank != src: + obj = pickle.loads(buffer.cpu().numpy().tobytes()) + logger.debug(f"Shared object of size {len(buffer)}") + return obj + + +def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): + """ + Create a dataloader properly in case of distributed training. + If a gradient is going to be computed you must set `shuffle=True`. + """ + if world_size == 1: + return klass(dataset, *args, shuffle=shuffle, **kwargs) + + if shuffle: + # train means we will compute backward, we use DistributedSampler + sampler = DistributedSampler(dataset) + # We ignore shuffle, DistributedSampler already shuffles + return klass(dataset, *args, **kwargs, sampler=sampler) + else: + # We make a manual shard, as DistributedSampler otherwise replicate some examples + dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) + return klass(dataset, *args, shuffle=shuffle, **kwargs) diff --git a/demucs/ema.py b/demucs/ema.py new file mode 100644 index 0000000..958c595 --- /dev/null +++ b/demucs/ema.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Inspired from https://github.com/rwightman/pytorch-image-models +from contextlib import contextmanager + +import torch + +from .states import swap_state + + +class ModelEMA: + """ + Perform EMA on a model. You can switch to the EMA weights temporarily + with the `swap` method. + + ema = ModelEMA(model) + with ema.swap(): + # compute valid metrics with averaged model. + """ + def __init__(self, model, decay=0.9999, unbias=True, device='cpu'): + self.decay = decay + self.model = model + self.state = {} + self.count = 0 + self.device = device + self.unbias = unbias + + self._init() + + def _init(self): + for key, val in self.model.state_dict().items(): + if val.dtype != torch.float32: + continue + device = self.device or val.device + if key not in self.state: + self.state[key] = val.detach().to(device, copy=True) + + def update(self): + if self.unbias: + self.count = self.count * self.decay + 1 + w = 1 / self.count + else: + w = 1 - self.decay + for key, val in self.model.state_dict().items(): + if val.dtype != torch.float32: + continue + device = self.device or val.device + self.state[key].mul_(1 - w) + self.state[key].add_(val.detach().to(device), alpha=w) + + @contextmanager + def swap(self): + with swap_state(self.model, self.state): + yield + + def state_dict(self): + return {'state': self.state, 'count': self.count} + + def load_state_dict(self, state): + self.count = state['count'] + for k, v in state['state'].items(): + self.state[k].copy_(v) diff --git a/demucs/evaluate.py b/demucs/evaluate.py new file mode 100644 index 0000000..badb35e --- /dev/null +++ b/demucs/evaluate.py @@ -0,0 +1,173 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Test time evaluation, either using the original SDR from [Vincent et al. 2006] +or the newest SDR definition from the MDX 2021 competition (this one will +be reported as `nsdr` for `new sdr`). +""" + +from concurrent import futures +import logging + +from dora.log import LogProgress +import numpy as np +import musdb +import museval +import torch as th + +from .apply import apply_model +from .audio import convert_audio, save_audio +from . import distrib +from .utils import DummyPoolExecutor + + +logger = logging.getLogger(__name__) + + +def new_sdr(references, estimates): + """ + Compute the SDR according to the MDX challenge definition. + Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license) + """ + assert references.dim() == 4 + assert estimates.dim() == 4 + delta = 1e-7 # avoid numerical errors + num = th.sum(th.square(references), dim=(2, 3)) + den = th.sum(th.square(references - estimates), dim=(2, 3)) + num += delta + den += delta + scores = 10 * th.log10(num / den) + return scores + + +def eval_track(references, estimates, win, hop, compute_sdr=True): + references = references.transpose(1, 2).double() + estimates = estimates.transpose(1, 2).double() + + new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0] + + if not compute_sdr: + return None, new_scores + else: + references = references.numpy() + estimates = estimates.numpy() + scores = museval.metrics.bss_eval( + references, estimates, + compute_permutation=False, + window=win, + hop=hop, + framewise_filters=False, + bsseval_sources_version=False)[:-1] + return scores, new_scores + + +def evaluate(solver, compute_sdr=False): + """ + Evaluate model using museval. + `new_only` means using only the MDX definition of the SDR, which is much faster to evaluate. + """ + + args = solver.args + + output_dir = solver.folder / "results" + output_dir.mkdir(exist_ok=True, parents=True) + json_folder = solver.folder / "results/test" + json_folder.mkdir(exist_ok=True, parents=True) + + # we load tracks from the original musdb set + if args.test.nonhq is None: + test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True) + else: + test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False) + src_rate = args.dset.musdb_samplerate + + eval_device = 'cpu' + + model = solver.model + win = int(1. * model.samplerate) + hop = int(1. * model.samplerate) + + indexes = range(distrib.rank, len(test_set), distrib.world_size) + indexes = LogProgress(logger, indexes, updates=args.misc.num_prints, + name='Eval') + pendings = [] + + pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor + with pool(args.test.workers) as pool: + for index in indexes: + track = test_set.tracks[index] + + mix = th.from_numpy(track.audio).t().float() + if mix.dim() == 1: + mix = mix[None] + mix = mix.to(solver.device) + ref = mix.mean(dim=0) # mono mixture + mix = (mix - ref.mean()) / ref.std() + mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) + estimates = apply_model(model, mix[None], + shifts=args.test.shifts, split=args.test.split, + overlap=args.test.overlap)[0] + estimates = estimates * ref.std() + ref.mean() + estimates = estimates.to(eval_device) + + references = th.stack( + [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) + if references.dim() == 2: + references = references[:, None] + references = references.to(eval_device) + references = convert_audio(references, src_rate, + model.samplerate, model.audio_channels) + if args.test.save: + folder = solver.folder / "wav" / track.name + folder.mkdir(exist_ok=True, parents=True) + for name, estimate in zip(model.sources, estimates): + save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate) + + pendings.append((track.name, pool.submit( + eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr))) + + pendings = LogProgress(logger, pendings, updates=args.misc.num_prints, + name='Eval (BSS)') + tracks = {} + for track_name, pending in pendings: + pending = pending.result() + scores, nsdrs = pending + tracks[track_name] = {} + for idx, target in enumerate(model.sources): + tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]} + if scores is not None: + (sdr, isr, sir, sar) = scores + for idx, target in enumerate(model.sources): + values = { + "SDR": sdr[idx].tolist(), + "SIR": sir[idx].tolist(), + "ISR": isr[idx].tolist(), + "SAR": sar[idx].tolist() + } + tracks[track_name][target].update(values) + + all_tracks = {} + for src in range(distrib.world_size): + all_tracks.update(distrib.share(tracks, src)) + + result = {} + metric_names = next(iter(all_tracks.values()))[model.sources[0]] + for metric_name in metric_names: + avg = 0 + avg_of_medians = 0 + for source in model.sources: + medians = [ + np.nanmedian(all_tracks[track][source][metric_name]) + for track in all_tracks.keys()] + mean = np.mean(medians) + median = np.median(medians) + result[metric_name.lower() + "_" + source] = mean + result[metric_name.lower() + "_med" + "_" + source] = median + avg += mean / len(model.sources) + avg_of_medians += median / len(model.sources) + result[metric_name.lower()] = avg + result[metric_name.lower() + "_med"] = avg_of_medians + return result diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py new file mode 100644 index 0000000..864fd3f --- /dev/null +++ b/demucs/hdemucs.py @@ -0,0 +1,761 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +from copy import deepcopy +import math + +from openunmix.filtering import wiener +import torch +from torch import nn +from torch.nn import functional as F + +from .demucs import DConv, rescale_module +from .states import capture_init +from .spec import spectro, ispectro + + +class ScaledEmbedding(nn.Module): + """ + Boost learning rate for embeddings (with `scale`). + Also, can make embeddings continuous with `smooth`. + """ + def __init__(self, num_embeddings: int, embedding_dim: int, + scale: float = 10., smooth=False): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + if smooth: + weight = torch.cumsum(self.embedding.weight.data, dim=0) + # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. + weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] + self.embedding.weight.data[:] = weight + self.embedding.weight.data /= scale + self.scale = scale + + @property + def weight(self): + return self.embedding.weight * self.scale + + def forward(self, x): + out = self.embedding(x) * self.scale + return out + + +class HEncLayer(nn.Module): + def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, + rewrite=True): + """Encoder layer. This used both by the time and the frequency branch. + + Args: + chin: number of input channels. + chout: number of output channels. + norm_groups: number of groups for group norm. + empty: used to make a layer with just the first conv. this is used + before merging the time and freq. branches. + freq: this is acting on frequencies. + dconv: insert DConv residual branches. + norm: use GroupNorm. + context: context size for the 1x1 conv. + dconv_kw: list of kwargs for the DConv class. + pad: pad the input. Padding is done so that the output size is + always the input size / stride. + rewrite: add 1x1 conv at the end of the layer. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + klass = nn.Conv1d + self.freq = freq + self.kernel_size = kernel_size + self.stride = stride + self.empty = empty + self.norm = norm + self.pad = pad + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + pad = [pad, 0] + klass = nn.Conv2d + self.conv = klass(chin, chout, kernel_size, stride, pad) + if self.empty: + return + self.norm1 = norm_fn(chout) + self.rewrite = None + if rewrite: + self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) + self.norm2 = norm_fn(2 * chout) + + self.dconv = None + if dconv: + self.dconv = DConv(chout, **dconv_kw) + + def forward(self, x, inject=None): + """ + `inject` is used to inject the result from the time branch into the frequency branch, + when both have the same stride. + """ + if not self.freq and x.dim() == 4: + B, C, Fr, T = x.shape + x = x.view(B, -1, T) + + if not self.freq: + le = x.shape[-1] + if not le % self.stride == 0: + x = F.pad(x, (0, self.stride - (le % self.stride))) + y = self.conv(x) + if self.empty: + return y + if inject is not None: + assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) + if inject.dim() == 3 and y.dim() == 4: + inject = inject[:, :, None] + y = y + inject + y = F.gelu(self.norm1(y)) + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + if self.rewrite: + z = self.norm2(self.rewrite(y)) + z = F.glu(z, dim=1) + else: + z = y + return z + + +class MultiWrap(nn.Module): + """ + Takes one layer and replicate it N times. each replica will act + on a frequency band. All is done so that if the N replica have the same weights, + then this is exactly equivalent to applying the original module on all frequencies. + + This is a bit over-engineered to avoid edge artifacts when splitting + the frequency bands, but it is possible the naive implementation would work as well... + """ + def __init__(self, layer, split_ratios): + """ + Args: + layer: module to clone, must be either HEncLayer or HDecLayer. + split_ratios: list of float indicating which ratio to keep for each band. + """ + super().__init__() + self.split_ratios = split_ratios + self.layers = nn.ModuleList() + self.conv = isinstance(layer, HEncLayer) + assert not layer.norm + assert layer.freq + assert layer.pad + if not self.conv: + assert not layer.context_freq + for k in range(len(split_ratios) + 1): + lay = deepcopy(layer) + if self.conv: + lay.conv.padding = (0, 0) + else: + lay.pad = False + for m in lay.modules(): + if hasattr(m, 'reset_parameters'): + m.reset_parameters() + self.layers.append(lay) + + def forward(self, x, skip=None, length=None): + B, C, Fr, T = x.shape + + ratios = list(self.split_ratios) + [1] + start = 0 + outs = [] + for ratio, layer in zip(ratios, self.layers): + if self.conv: + pad = layer.kernel_size // 4 + if ratio == 1: + limit = Fr + frames = -1 + else: + limit = int(round(Fr * ratio)) + le = limit - start + if start == 0: + le += pad + frames = round((le - layer.kernel_size) / layer.stride + 1) + limit = start + (frames - 1) * layer.stride + layer.kernel_size + if start == 0: + limit -= pad + assert limit - start > 0, (limit, start) + assert limit <= Fr, (limit, Fr) + y = x[:, :, start:limit, :] + if start == 0: + y = F.pad(y, (0, 0, pad, 0)) + if ratio == 1: + y = F.pad(y, (0, 0, 0, pad)) + outs.append(layer(y)) + start = limit - layer.kernel_size + layer.stride + else: + if ratio == 1: + limit = Fr + else: + limit = int(round(Fr * ratio)) + last = layer.last + layer.last = True + + y = x[:, :, start:limit] + s = skip[:, :, start:limit] + out, _ = layer(y, s, None) + if outs: + outs[-1][:, :, -layer.stride:] += ( + out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) + out = out[:, :, layer.stride:] + if ratio == 1: + out = out[:, :, :-layer.stride // 2, :] + if start == 0: + out = out[:, :, layer.stride // 2:, :] + outs.append(out) + layer.last = last + start = limit + out = torch.cat(outs, dim=2) + if not self.conv and not last: + out = F.gelu(out) + if self.conv: + return out + else: + return out, None + + +class HDecLayer(nn.Module): + def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, + context_freq=True, rewrite=True): + """ + Same as HEncLayer but for decoder. See `HEncLayer` for documentation. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + self.pad = pad + self.last = last + self.freq = freq + self.chin = chin + self.empty = empty + self.stride = stride + self.kernel_size = kernel_size + self.norm = norm + self.context_freq = context_freq + klass = nn.Conv1d + klass_tr = nn.ConvTranspose1d + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + klass = nn.Conv2d + klass_tr = nn.ConvTranspose2d + self.conv_tr = klass_tr(chin, chout, kernel_size, stride) + self.norm2 = norm_fn(chout) + if self.empty: + return + self.rewrite = None + if rewrite: + if context_freq: + self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) + else: + self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, + [0, context]) + self.norm1 = norm_fn(2 * chin) + + self.dconv = None + if dconv: + self.dconv = DConv(chin, **dconv_kw) + + def forward(self, x, skip, length): + if self.freq and x.dim() == 3: + B, C, T = x.shape + x = x.view(B, self.chin, -1, T) + + if not self.empty: + x = x + skip + + if self.rewrite: + y = F.glu(self.norm1(self.rewrite(x)), dim=1) + else: + y = x + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + else: + y = x + assert skip is None + z = self.norm2(self.conv_tr(y)) + if self.freq: + if self.pad: + z = z[..., self.pad:-self.pad, :] + else: + z = z[..., self.pad:self.pad + length] + assert z.shape[-1] == length, (z.shape[-1], length) + if not self.last: + z = F.gelu(z) + return z, y + + +class HDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=6, + rewrite=True, + hybrid=True, + hybrid_old=False, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=2, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. + hybrid_old: some models trained for MDX had a padding bug. This replicates + this bug to avoid retraining them. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + rescale: weight recaling trick + + """ + super().__init__() + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.channels = channels + self.samplerate = samplerate + self.segment = segment + + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + self.hybrid = hybrid + self.hybrid_old = hybrid_old + if hybrid_old: + assert hybrid, "hybrid_old must come with hybrid=True" + if hybrid: + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + if hybrid: + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + lstm = index >= dconv_lstm + attn = index >= dconv_attn + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + 'kernel_size': ker, + 'stride': stri, + 'freq': freq, + 'pad': pad, + 'norm': norm, + 'rewrite': rewrite, + 'norm_groups': norm_groups, + 'dconv_kw': { + 'lstm': lstm, + 'attn': attn, + 'depth': dconv_depth, + 'compress': dconv_comp, + 'init': dconv_init, + 'gelu': True, + } + } + kwt = dict(kw) + kwt['freq'] = 0 + kwt['kernel_size'] = kernel_size + kwt['stride'] = stride + kwt['pad'] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec['context_freq'] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer(chin_z, chout_z, + dconv=dconv_mode & 1, context=context_enc, **kw) + if hybrid and freq: + tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, + empty=last_freq, **kwt) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, + last=index == 0, context=context, **kw_dec) + if multi: + dec = MultiWrap(dec, multi_freqs) + if hybrid and freq: + tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, + last=index == 0, context=context, **kwt) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + if self.hybrid: + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + if not self.hybrid_old: + x = F.pad(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect') + else: + x = F.pad(x, (pad, pad + le * hl - x.shape[-1])) + + z = spectro(x, nfft, hl)[..., :-1, :] + if self.hybrid: + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2:2+le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4 ** scale) + z = F.pad(z, (0, 0, 0, 1)) + if self.hybrid: + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + if not self.hybrid_old: + le = hl * int(math.ceil(length / hl)) + 2 * pad + else: + le = hl * int(math.ceil(length / hl)) + x = ispectro(z, hl, length=le) + if not self.hybrid_old: + x = x[..., pad:pad + length] + else: + x = x[..., :length] + else: + x = ispectro(z, hl, length) + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], mix_stft[sample, frame], niters, + residual=residual) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + z = self._spec(mix) + mag = self._magnitude(z) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + if self.hybrid: + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if self.hybrid and idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + + x = torch.zeros_like(x) + if self.hybrid: + xt = torch.zeros_like(x) + # initialize everything to zero (signal will go through u-net skips). + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + if self.hybrid: + offset = self.depth - len(self.tdecoder) + if self.hybrid and idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + zout = self._mask(z, x) + x = self._ispec(zout, length) + + if self.hybrid: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + return x diff --git a/demucs/pretrained.py b/demucs/pretrained.py new file mode 100644 index 0000000..1c976c6 --- /dev/null +++ b/demucs/pretrained.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loading pretrained models. +""" + +import logging +from pathlib import Path +import typing as tp + +from dora.log import fatal + +from .hdemucs import HDemucs +from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa + +logger = logging.getLogger(__name__) +ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" +REMOTE_ROOT = Path(__file__).parent / 'remote' + +SOURCES = ["drums", "bass", "other", "vocals"] + + +def demucs_unittest(): + model = HDemucs(channels=4, sources=SOURCES) + return model + + +def add_model_flags(parser): + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("-s", "--sig", help="Locally trained XP signature.") + group.add_argument("-n", "--name", default="mdx_extra_q", + help="Pretrained model name or signature. Default is mdx_extra_q.") + parser.add_argument("--repo", type=Path, + help="Folder containing all pre-trained models for use with -n.") + + +def get_model(name: str, + repo: tp.Optional[Path] = None): + """`name` must be a bag of models name or a pretrained signature + from the remote AWS model repo or the specified local repo if `repo` is not None. + """ + if name == 'demucs_unittest': + return demucs_unittest() + model_repo: ModelOnlyRepo + if repo is None: + remote_files = [line.strip() + for line in (REMOTE_ROOT / 'files.txt').read_text().split('\n') + if line.strip()] + model_repo = RemoteRepo(ROOT_URL, remote_files) + bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) + else: + if not repo.is_dir(): + fatal(f"{repo} must exist and be a directory.") + model_repo = LocalRepo(repo) + bag_repo = BagOnlyRepo(repo, model_repo) + any_repo = AnyModelRepo(model_repo, bag_repo) + return any_repo.get_model(name) + + +def get_model_from_args(args): + """ + Load local model package or pre-trained model. + """ + return get_model(name=args.name, repo=args.repo) diff --git a/demucs/repo.py b/demucs/repo.py new file mode 100644 index 0000000..f79c532 --- /dev/null +++ b/demucs/repo.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Represents a model repository, including pre-trained models and bags of models. +A repo can either be the main remote repository stored in AWS, or a local repository +with your own models. +""" + +from hashlib import sha256 +from pathlib import Path +import typing as tp + +import torch +import yaml + +from .apply import BagOfModels, Model +from .states import load_model + + +AnyModel = tp.Union[Model, BagOfModels] + + +class ModelLoadingError(RuntimeError): + pass + + +def check_checksum(path: Path, checksum: str): + sha = sha256() + with open(path, 'rb') as file: + while True: + buf = file.read(2**20) + if not buf: + break + sha.update(buf) + actual_checksum = sha.hexdigest()[:len(checksum)] + if actual_checksum != checksum: + raise ModelLoadingError(f'Invalid checksum for file {path}, ' + f'expected {checksum} but got {actual_checksum}') + + +class ModelOnlyRepo: + """Base class for all model only repos. + """ + def has_model(self, sig: str) -> bool: + raise NotImplementedError() + + def get_model(self, sig: str) -> Model: + raise NotImplementedError() + + +class RemoteRepo(ModelOnlyRepo): + def __init__(self, root_url: str, remote_files: tp.List[str]): + if not root_url.endswith('/'): + root_url += '/' + self._models: tp.Dict[str, str] = {} + for file in remote_files: + sig, checksum = file.split('.')[0].split('-') + assert sig not in self._models + self._models[sig] = root_url + file + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + url = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') + pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) + return load_model(pkg) + + +class LocalRepo(ModelOnlyRepo): + def __init__(self, root: Path): + self.root = root + self.scan() + + def scan(self): + self._models = {} + self._checksums = {} + for file in self.root.iterdir(): + if file.suffix == '.th': + if '-' in file.stem: + xp_sig, checksum = file.stem.split('-') + self._checksums[xp_sig] = checksum + else: + xp_sig = file.stem + if xp_sig in self._models: + raise ModelLoadingError( + f'Duplicate pre-trained model exist for signature {xp_sig}. ' + 'Please delete all but one.') + self._models[xp_sig] = file + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + file = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') + if sig in self._checksums: + check_checksum(file, self._checksums[sig]) + return load_model(file) + + +class BagOnlyRepo: + """Handles only YAML files containing bag of models, leaving the actual + model loading to some Repo. + """ + def __init__(self, root: Path, model_repo: ModelOnlyRepo): + self.root = root + self.model_repo = model_repo + self.scan() + + def scan(self): + self._bags = {} + for file in self.root.iterdir(): + if file.suffix == '.yaml': + self._bags[file.stem] = file + + def has_model(self, name: str) -> bool: + return name in self._bags + + def get_model(self, name: str) -> BagOfModels: + try: + yaml_file = self._bags[name] + except KeyError: + raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' + 'a bag of models.') + bag = yaml.safe_load(open(yaml_file)) + signatures = bag['models'] + models = [self.model_repo.get_model(sig) for sig in signatures] + weights = bag.get('weights') + segment = bag.get('segment') + return BagOfModels(models, weights, segment) + + +class AnyModelRepo: + def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): + self.model_repo = model_repo + self.bag_repo = bag_repo + + def has_model(self, name_or_sig: str) -> bool: + return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) + + def get_model(self, name_or_sig: str) -> AnyModel: + if self.model_repo.has_model(name_or_sig): + return self.model_repo.get_model(name_or_sig) + else: + return self.bag_repo.get_model(name_or_sig) diff --git a/demucs/separate.py b/demucs/separate.py new file mode 100644 index 0000000..1554ce3 --- /dev/null +++ b/demucs/separate.py @@ -0,0 +1,186 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys +from pathlib import Path +import subprocess + +from dora.log import fatal +import torch as th +import torchaudio as ta + +from .apply import apply_model, BagOfModels +from .audio import AudioFile, convert_audio, save_audio +from .pretrained import get_model_from_args, add_model_flags, ModelLoadingError + + +def load_track(track, audio_channels, samplerate): + errors = {} + wav = None + + try: + wav = AudioFile(track).read( + streams=0, + samplerate=samplerate, + channels=audio_channels) + except FileNotFoundError: + errors['ffmpeg'] = 'Ffmpeg is not installed.' + except subprocess.CalledProcessError: + errors['ffmpeg'] = 'FFmpeg could not read the file.' + + if wav is None: + try: + wav, sr = ta.load(str(track)) + except RuntimeError as err: + errors['torchaudio'] = err.args[0] + else: + wav = convert_audio(wav, sr, samplerate, audio_channels) + + if wav is None: + print(f"Could not load file {track}. " + "Maybe it is not a supported file format? ") + for backend, error in errors.items(): + print(f"When trying to load using {backend}, got the following error: {error}") + sys.exit(1) + return wav + + +def main(): + parser = argparse.ArgumentParser("demucs.separate", + description="Separate the sources for the given tracks") + parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks') + add_model_flags(parser) + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("-o", + "--out", + type=Path, + default=Path("separated"), + help="Folder where to put extracted tracks. A subfolder " + "with the model name will be created.") + parser.add_argument("-d", + "--device", + default="cuda" if th.cuda.is_available() else "cpu", + help="Device to use, default is cuda if available else cpu") + parser.add_argument("--shifts", + default=1, + type=int, + help="Number of random shifts for equivariant stabilization." + "Increase separation time but improves quality for Demucs. 10 was used " + "in the original paper.") + parser.add_argument("--overlap", + default=0.25, + type=float, + help="Overlap between the splits.") + split_group = parser.add_mutually_exclusive_group() + split_group.add_argument("--no-split", + action="store_false", + dest="split", + default=True, + help="Doesn't split audio in chunks. " + "This can use large amounts of memory.") + split_group.add_argument("--segment", type=int, + help="Set split size of each chunk. " + "This can help save memory of graphic card. ") + parser.add_argument("--two-stems", + dest="stem", metavar="STEM", + help="Only separate audio into {STEM} and no_{STEM}. ") + group = parser.add_mutually_exclusive_group() + group.add_argument("--int24", action="store_true", + help="Save wav output as 24 bits wav.") + group.add_argument("--float32", action="store_true", + help="Save wav output as float32 (2x bigger).") + parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp"], + help="Strategy for avoiding clipping: rescaling entire signal " + "if necessary (rescale) or hard clipping (clamp).") + parser.add_argument("--mp3", action="store_true", + help="Convert the output wavs to mp3.") + parser.add_argument("--mp3-bitrate", + default=320, + type=int, + help="Bitrate of converted mp3.") + parser.add_argument("-j", "--jobs", + default=0, + type=int, + help="Number of jobs. This can increase memory usage but will " + "be much faster when multiple cores are available.") + + args = parser.parse_args() + + try: + model = get_model_from_args(args) + except ModelLoadingError as error: + fatal(error.args[0]) + + if args.segment is not None and args.segment < 8: + fatal('Segment must greater than 8. ') + + if isinstance(model, BagOfModels): + if args.segment is not None: + for sub in model.models: + sub.segment = args.segment + else: + if args.segment is not None: + sub.segment = args.segment + + model.cpu() + model.eval() + + if args.stem is not None and args.stem not in model.sources: + fatal( + 'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( + stem=args.stem, sources=', '.join(model.sources))) + out = args.out / args.name + out.mkdir(parents=True, exist_ok=True) + print(f"Separated tracks will be stored in {out.resolve()}") + for track in args.tracks: + if not track.exists(): + print( + f"File {track} does not exist. If the path contains spaces, " + "please try again after surrounding the entire path with quotes \"\".", + file=sys.stderr) + continue + print(f"Separating track {track}") + wav = load_track(track, model.audio_channels, model.samplerate) + + ref = wav.mean(0) + wav = (wav - ref.mean()) / ref.std() + sources = apply_model(model, wav[None], device=args.device, shifts=args.shifts, + split=args.split, overlap=args.overlap, progress=True, + num_workers=args.jobs)[0] + sources = sources * ref.std() + ref.mean() + + track_folder = out / track.name.rsplit(".", 1)[0] + track_folder.mkdir(exist_ok=True) + if args.mp3: + ext = ".mp3" + else: + ext = ".wav" + kwargs = { + 'samplerate': model.samplerate, + 'bitrate': args.mp3_bitrate, + 'clip': args.clip_mode, + 'as_float': args.float32, + 'bits_per_sample': 24 if args.int24 else 16, + } + if args.stem is None: + for source, name in zip(sources, model.sources): + stem = str(track_folder / (name + ext)) + save_audio(source, stem, **kwargs) + else: + sources = list(sources) + stem = str(track_folder / (args.stem + ext)) + save_audio(sources.pop(model.sources.index(args.stem)), stem, **kwargs) + # Warning : after poping the stem, selected stem is no longer in the list 'sources' + other_stem = th.zeros_like(sources[0]) + for i in sources: + other_stem += i + stem = str(track_folder / ("no_" + args.stem + ext)) + save_audio(other_stem, stem, **kwargs) + + +if __name__ == "__main__": + main() diff --git a/demucs/solver.py b/demucs/solver.py new file mode 100644 index 0000000..9970615 --- /dev/null +++ b/demucs/solver.py @@ -0,0 +1,404 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Main training loop.""" + +import logging + +from dora import get_xp +from dora.utils import write_and_rename +from dora.log import LogProgress, bold +import torch +import torch.nn.functional as F + +from . import augment, distrib, states, pretrained +from .apply import apply_model +from .ema import ModelEMA +from .evaluate import evaluate, new_sdr +from .svd import svd_penalty +from .utils import pull_metric, EMA + +logger = logging.getLogger(__name__) + + +def _summary(metrics): + return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items()) + + +class Solver(object): + def __init__(self, loaders, model, optimizer, args): + self.args = args + self.loaders = loaders + + self.model = model + self.optimizer = optimizer + self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer) + self.dmodel = distrib.wrap(model) + self.device = next(iter(self.model.parameters())).device + + # Exponential moving average of the model, either updated every batch or epoch. + # The best model from all the EMAs and the original one is kept based on the valid + # loss for the final best model. + self.emas = {'batch': [], 'epoch': []} + for kind in self.emas.keys(): + decays = getattr(args.ema, kind) + device = self.device if kind == 'batch' else 'cpu' + if decays: + for decay in decays: + self.emas[kind].append(ModelEMA(self.model, decay, device=device)) + + # data augment + augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift), + same=args.augment.shift_same)] + if args.augment.flip: + augments += [augment.FlipChannels(), augment.FlipSign()] + for aug in ['scale', 'remix']: + kw = getattr(args.augment, aug) + if kw.proba: + augments.append(getattr(augment, aug.capitalize())(**kw)) + self.augment = torch.nn.Sequential(*augments) + + xp = get_xp() + self.folder = xp.folder + # Checkpoints + self.checkpoint_file = xp.folder / 'checkpoint.th' + self.best_file = xp.folder / 'best.th' + logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve()) + self.best_state = None + self.best_changed = False + + self.link = xp.link + self.history = self.link.history + + self._reset() + + def _serialize(self, epoch): + package = {} + package['state'] = self.model.state_dict() + package['optimizer'] = self.optimizer.state_dict() + package['history'] = self.history + package['best_state'] = self.best_state + package['args'] = self.args + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + package[f'ema_{kind}_{k}'] = ema.state_dict() + with write_and_rename(self.checkpoint_file) as tmp: + torch.save(package, tmp) + + save_every = self.args.save_every + if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs: + with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp: + torch.save(package, tmp) + + if self.best_changed: + # Saving only the latest best model. + with write_and_rename(self.best_file) as tmp: + package = states.serialize_model(self.model, self.args) + package['state'] = self.best_state + torch.save(package, tmp) + self.best_changed = False + + def _reset(self): + """Reset state of the solver, potentially using checkpoint.""" + if self.checkpoint_file.exists(): + logger.info(f'Loading checkpoint model: {self.checkpoint_file}') + package = torch.load(self.checkpoint_file, 'cpu') + self.model.load_state_dict(package['state']) + self.optimizer.load_state_dict(package['optimizer']) + self.history[:] = package['history'] + self.best_state = package['best_state'] + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + ema.load_state_dict(package[f'ema_{kind}_{k}']) + elif self.args.continue_pretrained: + model = pretrained.get_model( + name=self.args.continue_pretrained, + repo=self.args.pretrained_repo) + self.model.load_state_dict(model.state_dict()) + elif self.args.continue_from: + name = 'checkpoint.th' + root = self.folder.parent + cf = root / str(self.args.continue_from) / name + logger.info("Loading from %s", cf) + package = torch.load(cf, 'cpu') + self.best_state = package['best_state'] + if self.args.continue_best: + self.model.load_state_dict(package['best_state'], strict=False) + else: + self.model.load_state_dict(package['state'], strict=False) + if self.args.continue_opt: + self.optimizer.load_state_dict(package['optimizer']) + + def _format_train(self, metrics: dict) -> dict: + """Formatting for train/valid metrics.""" + losses = { + 'loss': format(metrics['loss'], ".4f"), + 'reco': format(metrics['reco'], ".4f"), + } + if 'nsdr' in metrics: + losses['nsdr'] = format(metrics['nsdr'], ".3f") + if self.quantizer is not None: + losses['ms'] = format(metrics['ms'], ".2f") + if 'grad' in metrics: + losses['grad'] = format(metrics['grad'], ".4f") + if 'best' in metrics: + losses['best'] = format(metrics['best'], '.4f') + if 'bname' in metrics: + losses['bname'] = metrics['bname'] + if 'penalty' in metrics: + losses['penalty'] = format(metrics['penalty'], ".4f") + if 'hloss' in metrics: + losses['hloss'] = format(metrics['hloss'], ".4f") + return losses + + def _format_test(self, metrics: dict) -> dict: + """Formatting for test metrics.""" + losses = {} + if 'sdr' in metrics: + losses['sdr'] = format(metrics['sdr'], '.3f') + if 'nsdr' in metrics: + losses['nsdr'] = format(metrics['nsdr'], '.3f') + for source in self.model.sources: + key = f'sdr_{source}' + if key in metrics: + losses[key] = format(metrics[key], '.3f') + key = f'nsdr_{source}' + if key in metrics: + losses[key] = format(metrics[key], '.3f') + return losses + + def train(self): + # Optimizing the model + if self.history: + logger.info("Replaying metrics from previous run") + for epoch, metrics in enumerate(self.history): + formatted = self._format_train(metrics['train']) + logger.info( + bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + formatted = self._format_train(metrics['valid']) + logger.info( + bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + if 'test' in metrics: + formatted = self._format_test(metrics['test']) + if formatted: + logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) + + epoch = 0 + for epoch in range(len(self.history), self.args.epochs): + # Train one epoch + self.model.train() # Turn on BatchNorm & Dropout + metrics = {} + logger.info('-' * 70) + logger.info("Training...") + metrics['train'] = self._run_one_epoch(epoch) + formatted = self._format_train(metrics['train']) + logger.info( + bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + + # Cross validation + logger.info('-' * 70) + logger.info('Cross validation...') + self.model.eval() # Turn off Batchnorm & Dropout + with torch.no_grad(): + valid = self._run_one_epoch(epoch, train=False) + bvalid = valid + bname = 'main' + state = states.copy_state(self.model.state_dict()) + metrics['valid'] = {} + metrics['valid']['main'] = valid + key = self.args.test.metric + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + with ema.swap(): + valid = self._run_one_epoch(epoch, train=False) + name = f'ema_{kind}_{k}' + metrics['valid'][name] = valid + a = valid[key] + b = bvalid[key] + if key.startswith('nsdr'): + a = -a + b = -b + if a < b: + bvalid = valid + state = ema.state + bname = name + metrics['valid'].update(bvalid) + metrics['valid']['bname'] = bname + + valid_loss = metrics['valid'][key] + mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss] + if key.startswith('nsdr'): + best_loss = max(mets) + else: + best_loss = min(mets) + metrics['valid']['best'] = best_loss + if self.args.svd.penalty > 0: + kw = dict(self.args.svd) + kw.pop('penalty') + with torch.no_grad(): + penalty = svd_penalty(self.model, exact=True, **kw) + metrics['valid']['penalty'] = penalty + + formatted = self._format_train(metrics['valid']) + logger.info( + bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + + # Save the best model + if valid_loss == best_loss or self.args.dset.train_valid: + logger.info(bold('New best valid loss %.4f'), valid_loss) + self.best_state = states.copy_state(state) + self.best_changed = True + + # Eval model every `test.every` epoch or on last epoch + should_eval = (epoch + 1) % self.args.test.every == 0 + is_last = epoch == self.args.epochs - 1 + reco = metrics['valid']['main']['reco'] + # Tries to detect divergence in a reliable way and finish job + # not to waste compute. + div = epoch >= 180 and reco > 0.18 + div = div or epoch >= 100 and reco > 0.25 + div = div and self.args.optim.loss == 'l1' + if div: + logger.warning("Finishing training early because valid loss is too high.") + is_last = True + if should_eval or is_last: + # Evaluate on the testset + logger.info('-' * 70) + logger.info('Evaluating on the test set...') + # We switch to the best known model for testing + if self.args.test.best: + state = self.best_state + else: + state = states.copy_state(self.model.state_dict()) + compute_sdr = self.args.test.sdr and is_last + with states.swap_state(self.model, state): + with torch.no_grad(): + metrics['test'] = evaluate(self, compute_sdr=compute_sdr) + formatted = self._format_test(metrics['test']) + logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) + self.link.push_metrics(metrics) + + if distrib.rank == 0: + # Save model each epoch + self._serialize(epoch) + logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve()) + if is_last: + break + + def _run_one_epoch(self, epoch, train=True): + args = self.args + data_loader = self.loaders['train'] if train else self.loaders['valid'] + # get a different order for distributed training, otherwise this will get ignored + data_loader.sampler.epoch = epoch + + label = ["Valid", "Train"][train] + name = label + f" | Epoch {epoch + 1}" + total = len(data_loader) + if args.max_batches: + total = min(total, args.max_batches) + logprog = LogProgress(logger, data_loader, total=total, + updates=self.args.misc.num_prints, name=name) + averager = EMA() + + for idx, sources in enumerate(logprog): + sources = sources.to(self.device) + if train: + sources = self.augment(sources) + mix = sources.sum(dim=1) + else: + mix = sources[:, 0] + sources = sources[:, 1:] + + if not train and self.args.valid_apply: + estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0) + else: + estimate = self.dmodel(mix) + if train and hasattr(self.model, 'transform_target'): + sources = self.model.transform_target(mix, sources) + assert estimate.shape == sources.shape, (estimate.shape, sources.shape) + dims = tuple(range(2, sources.dim())) + + if args.optim.loss == 'l1': + loss = F.l1_loss(estimate, sources, reduction='none') + loss = loss.mean(dims).mean(0) + reco = loss + elif args.optim.loss == 'mse': + loss = F.mse_loss(estimate, sources, reduction='none') + loss = loss.mean(dims) + reco = loss**0.5 + reco = reco.mean(0) + else: + raise ValueError(f"Invalid loss {self.args.loss}") + weights = torch.tensor(args.weights).to(sources) + loss = (loss * weights).sum() / weights.sum() + + ms = 0 + if self.quantizer is not None: + ms = self.quantizer.model_size() + if args.quant.diffq: + loss += args.quant.diffq * ms + + losses = {} + losses['reco'] = (reco * weights).sum() / weights.sum() + losses['ms'] = ms + + if not train: + nsdrs = new_sdr(sources, estimate.detach()).mean(0) + total = 0 + for source, nsdr, w in zip(self.model.sources, nsdrs, weights): + losses[f'nsdr_{source}'] = nsdr + total += w * nsdr + losses['nsdr'] = total / weights.sum() + + if train and args.svd.penalty > 0: + kw = dict(args.svd) + kw.pop('penalty') + penalty = svd_penalty(self.model, **kw) + losses['penalty'] = penalty + loss += args.svd.penalty * penalty + + losses['loss'] = loss + + for k, source in enumerate(self.model.sources): + losses[f'reco_{source}'] = reco[k] + + # optimize model in training mode + if train: + loss.backward() + grad_norm = 0 + grads = [] + for p in self.model.parameters(): + if p.grad is not None: + grad_norm += p.grad.data.norm()**2 + grads.append(p.grad.data) + losses['grad'] = grad_norm ** 0.5 + if args.optim.clip_grad: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + args.optim.clip_grad) + + if self.args.flag == 'uns': + for n, p in self.model.named_parameters(): + if p.grad is None: + print('no grad', n) + self.optimizer.step() + self.optimizer.zero_grad() + for ema in self.emas['batch']: + ema.update() + losses = averager(losses) + logs = self._format_train(losses) + logprog.update(**logs) + # Just in case, clear some memory + del loss, estimate, reco, ms + if args.max_batches == idx: + break + if self.args.debug and train: + break + if self.args.flag == 'debug': + break + if train: + for ema in self.emas['epoch']: + ema.update() + return distrib.average(losses, idx + 1) diff --git a/demucs/spec.py b/demucs/spec.py new file mode 100644 index 0000000..85e5dc9 --- /dev/null +++ b/demucs/spec.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Conveniance wrapper to perform STFT and iSTFT""" + +import torch as th + + +def spectro(x, n_fft=512, hop_length=None, pad=0): + *other, length = x.shape + x = x.reshape(-1, length) + z = th.stft(x, + n_fft * (1 + pad), + hop_length or n_fft // 4, + window=th.hann_window(n_fft).to(x), + win_length=n_fft, + normalized=True, + center=True, + return_complex=True, + pad_mode='reflect') + _, freqs, frame = z.shape + return z.view(*other, freqs, frame) + + +def ispectro(z, hop_length=None, length=None, pad=0): + *other, freqs, frames = z.shape + n_fft = 2 * freqs - 2 + z = z.view(-1, freqs, frames) + win_length = n_fft // (1 + pad) + x = th.istft(z, + n_fft, + hop_length, + window=th.hann_window(win_length).to(z.real), + win_length=win_length, + normalized=True, + length=length, + center=True) + _, length = x.shape + return x.view(*other, length) diff --git a/demucs/states.py b/demucs/states.py new file mode 100644 index 0000000..db17a18 --- /dev/null +++ b/demucs/states.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Utilities to save and load models. +""" +from contextlib import contextmanager + +import functools +import hashlib +import inspect +import io +from pathlib import Path +import warnings + +from omegaconf import OmegaConf +from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state +import torch + + +def get_quantizer(model, args, optimizer=None): + """Return the quantizer given the XP quantization args.""" + quantizer = None + if args.diffq: + quantizer = DiffQuantizer( + model, min_size=args.min_size, group_size=args.group_size) + if optimizer is not None: + quantizer.setup_optimizer(optimizer) + elif args.qat: + quantizer = UniformQuantizer( + model, bits=args.qat, min_size=args.min_size) + return quantizer + + +def load_model(path_or_package, strict=False): + """Load a model from the given serialized model, either given as a dict (already loaded) + or a path to a file on disk.""" + if isinstance(path_or_package, dict): + package = path_or_package + elif isinstance(path_or_package, (str, Path)): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + path = path_or_package + package = torch.load(path, 'cpu') + else: + raise ValueError(f"Invalid type for {path_or_package}.") + + klass = package["klass"] + args = package["args"] + kwargs = package["kwargs"] + + if strict: + model = klass(*args, **kwargs) + else: + sig = inspect.signature(klass) + for key in list(kwargs): + if key not in sig.parameters: + warnings.warn("Dropping inexistant parameter " + key) + del kwargs[key] + model = klass(*args, **kwargs) + + state = package["state"] + + set_state(model, state) + return model + + +def get_state(model, quantizer, half=False): + """Get the state from a model, potentially with quantization applied. + If `half` is True, model are stored as half precision, which shouldn't impact performance + but half the state size.""" + if quantizer is None: + dtype = torch.half if half else None + state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} + else: + state = quantizer.get_quantized_state() + state['__quantized'] = True + return state + + +def set_state(model, state, quantizer=None): + """Set the state on a given model.""" + if state.get('__quantized'): + if quantizer is not None: + quantizer.restore_quantized_state(model, state['quantized']) + else: + restore_quantized_state(model, state) + else: + model.load_state_dict(state) + return state + + +def save_with_checksum(content, path): + """Save the given value on disk, along with a sha256 hash. + Should be used with the output of either `serialize_model` or `get_state`.""" + buf = io.BytesIO() + torch.save(content, buf) + sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] + + path = path.parent / (path.stem + "-" + sig + path.suffix) + path.write_bytes(buf.getvalue()) + + +def serialize_model(model, training_args, quantizer=None, half=True): + args, kwargs = model._init_args_kwargs + klass = model.__class__ + + state = get_state(model, quantizer, half) + return { + 'klass': klass, + 'args': args, + 'kwargs': kwargs, + 'state': state, + 'training_args': OmegaConf.to_container(training_args, resolve=True), + } + + +def copy_state(state): + return {k: v.cpu().clone() for k, v in state.items()} + + +@contextmanager +def swap_state(model, state): + """ + Context manager that swaps the state of a model, e.g: + + # model is in old state + with swap_state(model, new_state): + # model in new state + # model back to old state + """ + old_state = copy_state(model.state_dict()) + model.load_state_dict(state, strict=False) + try: + yield + finally: + model.load_state_dict(old_state) + + +def capture_init(init): + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ diff --git a/demucs/svd.py b/demucs/svd.py new file mode 100644 index 0000000..96a74e2 --- /dev/null +++ b/demucs/svd.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Ways to make the model stronger.""" +import random +import torch + + +def power_iteration(m, niters=1, bs=1): + """This is the power method. batch size is used to try multiple starting point in parallel.""" + assert m.dim() == 2 + assert m.shape[0] == m.shape[1] + dim = m.shape[0] + b = torch.randn(dim, bs, device=m.device, dtype=m.dtype) + + for _ in range(niters): + n = m.mm(b) + norm = n.norm(dim=0, keepdim=True) + b = n / (1e-10 + norm) + + return norm.mean() + + +# We need a shared RNG to make sure all the distributed worker will skip the penalty together, +# as otherwise we wouldn't get any speed up. +penalty_rng = random.Random(1234) + + +def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True, + proba=1, conv_only=False, exact=False, bs=1): + """ + Penalty on the largest singular value for a layer. + Args: + - model: model to penalize + - min_size: minimum size in MB of a layer to penalize. + - dim: projection dimension for the svd_lowrank. Higher is better but slower. + - niters: number of iterations in the algorithm used by svd_lowrank. + - powm: use power method instead of lowrank SVD, my own experience + is that it is both slower and less stable. + - convtr: when True, differentiate between Conv and Transposed Conv. + this is kept for compatibility with older experiments. + - proba: probability to apply the penalty. + - conv_only: only apply to conv and conv transposed, not LSTM + (might not be reliable for other models than Demucs). + - exact: use exact SVD (slow but useful at validation). + - bs: batch_size for power method. + """ + total = 0 + if penalty_rng.random() > proba: + return 0. + + for m in model.modules(): + for name, p in m.named_parameters(recurse=False): + if p.numel() / 2**18 < min_size: + continue + if convtr: + if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): + if p.dim() in [3, 4]: + p = p.transpose(0, 1).contiguous() + if p.dim() == 3: + p = p.view(len(p), -1) + elif p.dim() == 4: + p = p.view(len(p), -1) + elif p.dim() == 1: + continue + elif conv_only: + continue + assert p.dim() == 2, (name, p.shape) + if exact: + estimate = torch.svd(p, compute_uv=False)[1].pow(2).max() + elif powm: + a, b = p.shape + if a < b: + n = p.mm(p.t()) + else: + n = p.t().mm(p) + estimate = power_iteration(n, niters, bs) + else: + estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2) + total += estimate + return total / proba diff --git a/demucs/utils.py b/demucs/utils.py new file mode 100644 index 0000000..3f2afaa --- /dev/null +++ b/demucs/utils.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from contextlib import contextmanager +import math +import os +import tempfile +import typing as tp + +import torch +from torch.nn import functional as F + + +def unfold(a, kernel_size, stride): + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + + This will pad the input so that `F = ceil(T / K)`. + + see https://github.com/pytorch/pytorch/issues/60466 + """ + *shape, length = a.shape + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(a, (0, tgt_length - length)) + strides = list(a.stride()) + assert strides[-1] == 1, 'data should be contiguous' + strides = strides[:-1] + [stride, 1] + return a.as_strided([*shape, n_frames, kernel_size], strides) + + +def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): + """ + Center trim `tensor` with respect to `reference`, along the last dimension. + `reference` can also be a number, representing the length to trim to. + If the size difference != 0 mod 2, the extra sample is removed on the right side. + """ + ref_size: int + if isinstance(reference, torch.Tensor): + ref_size = reference.size(-1) + else: + ref_size = reference + delta = tensor.size(-1) - ref_size + if delta < 0: + raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") + if delta: + tensor = tensor[..., delta // 2:-(delta - delta // 2)] + return tensor + + +def pull_metric(history: tp.List[dict], name: str): + out = [] + for metrics in history: + metric = metrics + for part in name.split("."): + metric = metric[part] + out.append(metric) + return out + + +def EMA(beta: float = 1): + """ + Exponential Moving Average callback. + Returns a single function that can be called to repeatidly update the EMA + with a dict of metrics. The callback will return + the new averaged dict of metrics. + + Note that for `beta=1`, this is just plain averaging. + """ + fix: tp.Dict[str, float] = defaultdict(float) + total: tp.Dict[str, float] = defaultdict(float) + + def _update(metrics: dict, weight: float = 1) -> dict: + nonlocal total, fix + for key, value in metrics.items(): + total[key] = total[key] * beta + weight * float(value) + fix[key] = fix[key] * beta + weight + return {key: tot / fix[key] for key, tot in total.items()} + return _update + + +def sizeof_fmt(num: float, suffix: str = 'B'): + """ + Given `num` bytes, return human readable size. + Taken from https://stackoverflow.com/a/1094933 + """ + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + +@contextmanager +def temp_filenames(count: int, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + + +class DummyPoolExecutor: + class DummyResult: + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + def result(self): + return self.func(*self.args, **self.kwargs) + + def __init__(self, workers=0): + pass + + def submit(self, func, *args, **kwargs): + return DummyPoolExecutor.DummyResult(func, *args, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return diff --git a/demucs/wav.py b/demucs/wav.py new file mode 100644 index 0000000..1c023a7 --- /dev/null +++ b/demucs/wav.py @@ -0,0 +1,242 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loading wav based datasets, including MusdbHQ.""" + +from collections import OrderedDict +import hashlib +import math +import json +import os +from pathlib import Path +import tqdm + +import musdb +import julius +import torch as th +from torch import distributed +import torchaudio as ta +from torch.nn import functional as F + +from .audio import convert_audio_channels +from . import distrib + +MIXTURE = "mixture" +EXT = ".wav" + + +def _track_metadata(track, sources, normalize=True, ext=EXT): + track_length = None + track_samplerate = None + mean = 0 + std = 1 + for source in sources + [MIXTURE]: + file = track / f"{source}{ext}" + try: + info = ta.info(str(file)) + except RuntimeError: + print(file) + raise + length = info.num_frames + if track_length is None: + track_length = length + track_samplerate = info.sample_rate + elif track_length != length: + raise ValueError( + f"Invalid length for file {file}: " + f"expecting {track_length} but got {length}.") + elif info.sample_rate != track_samplerate: + raise ValueError( + f"Invalid sample rate for file {file}: " + f"expecting {track_samplerate} but got {info.sample_rate}.") + if source == MIXTURE and normalize: + try: + wav, _ = ta.load(str(file)) + except RuntimeError: + print(file) + raise + wav = wav.mean(0) + mean = wav.mean().item() + std = wav.std().item() + + return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} + + +def build_metadata(path, sources, normalize=True, ext=EXT): + """ + Build the metadata for `Wavset`. + + Args: + path (str or Path): path to dataset. + sources (list[str]): list of sources to look for. + normalize (bool): if True, loads full track and store normalization + values based on the mixture file. + ext (str): extension of audio files (default is .wav). + """ + + meta = {} + path = Path(path) + pendings = [] + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(8) as pool: + for root, folders, files in os.walk(path, followlinks=True): + root = Path(root) + if root.name.startswith('.') or folders or root == path: + continue + name = str(root.relative_to(path)) + pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext))) + # meta[name] = _track_metadata(root, sources, normalize, ext) + for name, pending in tqdm.tqdm(pendings, ncols=120): + meta[name] = pending.result() + return meta + + +class Wavset: + def __init__( + self, + root, metadata, sources, + segment=None, shift=None, normalize=True, + samplerate=44100, channels=2, ext=EXT): + """ + Waveset (or mp3 set for that matter). Can be used to train + with arbitrary sources. Each track should be one folder inside of `path`. + The folder should contain files named `{source}.{ext}`. + + Args: + root (Path or str): root folder for the dataset. + metadata (dict): output from `build_metadata`. + sources (list[str]): list of source names. + segment (None or float): segment length in seconds. If `None`, returns entire tracks. + shift (None or float): stride in seconds bewteen samples. + normalize (bool): normalizes input audio, **based on the metadata content**, + i.e. the entire track is normalized, not individual extracts. + samplerate (int): target sample rate. if the file sample rate + is different, it will be resampled on the fly. + channels (int): target nb of channels. if different, will be + changed onthe fly. + ext (str): extension for audio files (default is .wav). + + samplerate and channels are converted on the fly. + """ + self.root = Path(root) + self.metadata = OrderedDict(metadata) + self.segment = segment + self.shift = shift or segment + self.normalize = normalize + self.sources = sources + self.channels = channels + self.samplerate = samplerate + self.ext = ext + self.num_examples = [] + for name, meta in self.metadata.items(): + track_duration = meta['length'] / meta['samplerate'] + if segment is None or track_duration < segment: + examples = 1 + else: + examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1) + self.num_examples.append(examples) + + def __len__(self): + return sum(self.num_examples) + + def get_file(self, name, source): + return self.root / name / f"{source}{self.ext}" + + def __getitem__(self, index): + for name, examples in zip(self.metadata, self.num_examples): + if index >= examples: + index -= examples + continue + meta = self.metadata[name] + num_frames = -1 + offset = 0 + if self.segment is not None: + offset = int(meta['samplerate'] * self.shift * index) + num_frames = int(math.ceil(meta['samplerate'] * self.segment)) + wavs = [] + for source in self.sources: + file = self.get_file(name, source) + wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) + wav = convert_audio_channels(wav, self.channels) + wavs.append(wav) + + example = th.stack(wavs) + example = julius.resample_frac(example, meta['samplerate'], self.samplerate) + if self.normalize: + example = (example - meta['mean']) / meta['std'] + if self.segment: + length = int(self.segment * self.samplerate) + example = example[..., :length] + example = F.pad(example, (0, length - example.shape[-1])) + return example + + +def get_wav_datasets(args): + """Extract the wav datasets from the XP arguments.""" + sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8] + metadata_file = Path(args.metadata) / ('wav_' + sig + ".json") + train_path = Path(args.wav) / "train" + valid_path = Path(args.wav) / "valid" + if not metadata_file.is_file() and distrib.rank == 0: + metadata_file.parent.mkdir(exist_ok=True, parents=True) + train = build_metadata(train_path, args.sources) + valid = build_metadata(valid_path, args.sources) + json.dump([train, valid], open(metadata_file, "w")) + if distrib.world_size > 1: + distributed.barrier() + train, valid = json.load(open(metadata_file)) + if args.full_cv: + kw_cv = {} + else: + kw_cv = {'segment': args.segment, 'shift': args.shift} + train_set = Wavset(train_path, train, args.sources, + segment=args.segment, shift=args.shift, + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize) + valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources), + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize, **kw_cv) + return train_set, valid_set + + +def _get_musdb_valid(): + # Return musdb valid set. + import yaml + setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml' + setup = yaml.safe_load(open(setup_path, 'r')) + return setup['validation_tracks'] + + +def get_musdb_wav_datasets(args): + """Extract the musdb dataset from the XP arguments.""" + sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8] + metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json") + root = Path(args.musdb) / "train" + if not metadata_file.is_file() and distrib.rank == 0: + metadata_file.parent.mkdir(exist_ok=True, parents=True) + metadata = build_metadata(root, args.sources) + json.dump(metadata, open(metadata_file, "w")) + if distrib.world_size > 1: + distributed.barrier() + metadata = json.load(open(metadata_file)) + + valid_tracks = _get_musdb_valid() + if args.train_valid: + metadata_train = metadata + else: + metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} + metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks} + if args.full_cv: + kw_cv = {} + else: + kw_cv = {'segment': args.segment, 'shift': args.shift} + train_set = Wavset(root, metadata_train, args.sources, + segment=args.segment, shift=args.shift, + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize) + valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources), + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize, **kw_cv) + return train_set, valid_set diff --git a/demucs/wdemucs.py b/demucs/wdemucs.py new file mode 100644 index 0000000..b0d799e --- /dev/null +++ b/demucs/wdemucs.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# For compat +from .hdemucs import HDemucs + +WDemucs = HDemucs diff --git a/diffq/__init__.py b/diffq/__init__.py new file mode 100644 index 0000000..b67e784 --- /dev/null +++ b/diffq/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +""" +This package implements different quantization strategies: + +- `diffq.uniform.UniformQuantizer`: classic uniform quantization over n bits. +- `diffq.diffq.DiffQuantizer`: differentiable quantizer based on scaled noise injection. +- `diffq.lsq.LSQ`: Learnt Step size Quantizer based on [Esser et al. 2019] https://arxiv.org/abs/1902.08153 +- `diffq.bitpack`: efficient CPU bit-packing for returning quantized states. +- `diffq.torch_pack`: torch based bit-packing compatible with torchscript. +- `diffq.ts_export`: support exporting DiffQ based models to torchscript. + + +Also, do check `diffq.base.BaseQuantizer` for the common methods of all Quantizers. +""" + +from .uniform import UniformQuantizer +from .diffq import DiffQuantizer +from .lsq import LSQ +from .base import restore_quantized_state +from . import ts_export + + +__version__ = "0.2.2" diff --git a/diffq/base.py b/diffq/base.py new file mode 100644 index 0000000..acbded3 --- /dev/null +++ b/diffq/base.py @@ -0,0 +1,343 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Base class for all quantizers.""" +from contextlib import contextmanager +from dataclasses import dataclass +from concurrent import futures +from fnmatch import fnmatch +from functools import partial +import io +import math +from multiprocessing import cpu_count +import pickle +import typing as tp +import zlib + +import torch + +from . import bitpack +from . import torch_pack as torch_pack_mod + + +class BaseQuantizer: + @dataclass + class _QuantizedParam: + name: str + param: torch.nn.Parameter + module: torch.nn.Module + # If a Parameter is used multiple times, `other` can be used + # to share state between the different Quantizers + other: tp.Optional[tp.Any] + + def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False, + exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True): + self.model = model + self.min_size = min_size + self.float16 = float16 + self.exclude = exclude + self.detect_bound = detect_bound + self._quantized = False + self._need_unquantize = None + self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook) + self._post_handle = self.model.register_forward_hook(self._forward_hook) + + self._qparams = [] + self._float16 = [] + self._others = [] + self._rnns = [] + + self._saved = [] + + self._find_params() + + def _find_params(self): + min_params = self.min_size * 2**20 // 4 + previous = {} + for module_name, module in self.model.named_modules(): + if isinstance(module, torch.nn.RNNBase): + self._rnns.append(module) + for name, param in list(module.named_parameters(recurse=False)): + full_name = f"{module_name}.{name}" + matched = False + for pattern in self.exclude: + if fnmatch(full_name, pattern) or fnmatch(name, pattern): + matched = True + break + + if param.numel() <= min_params or matched: + if id(param) in previous: + continue + if self.detect_bound: + previous[id(param)] = None + if self.float16: + self._float16.append(param) + else: + self._others.append(param) + else: + qparam = self._register_param(name, param, module, previous.get(id(param))) + if self.detect_bound: + previous[id(param)] = qparam + self._qparams.append(qparam) + + def _register_param(self, name, param, module, other): + return self.__class__._QuantizedParam(name, param, module, other) + + def _forward_pre_hook(self, module, input): + if self.model.training: + self._quantized_state = None + self.unquantize() + if self._pre_forward_train(): + self._fix_rnns() + else: + assert self._need_unquantize is None + self._need_unquantize = self.quantize() + + def _forward_hook(self, module, input, output): + if self.model.training: + if self._post_forward_train(): + self._fix_rnns(flatten=False) # Hacky, next forward will flatten + else: + if self._need_unquantize: + self._need_unquantize = None + self.unquantize() + + def quantize(self): + """ + Immediately apply quantization to the model parameters. + Model parameters are saved to later allow restoring the unquantized state. + + Note that you shouldn't need to call this for model evaluation, as long as + you properly call `model.train()` and `model.eval()`, but this can be + useful for weight inspection. + """ + if self._quantized: + return False + self._saved = [qp.param.data.to('cpu', copy=True) + for qp in self._qparams if qp.other is None] + self.restore_quantized_state(self.get_quantized_state(packed=False)) + self._quantized = True + self._fix_rnns() + return True + + @contextmanager + def enter_quantize(self): + """Context manager for entering quantized state.""" + self.quantize() + try: + yield + finally: + self.unquantize() + + def unquantize(self): + """ + Revert a previous call to `quantize()`. + """ + if not self._quantized: + return + if not self._saved: + raise RuntimeError("Nothing to restore. This shouldn't happen") + for qparam in self._qparams: + if qparam.other is None: + qparam.param.data[:] = self._saved.pop(0) + assert len(self._saved) == 0 + self._quantized = False + self._fix_rnns() + + def _pre_forward_train(self) -> bool: + """ + Called once before each forward for continuous quantization. + Should return True if parameters were changed. + """ + return False + + def _post_forward_train(self) -> bool: + """ + Called once after each forward (to restore state for instance). + Should return True if parameters were changed. + """ + return False + + def _fix_rnns(self, flatten=True): + """ + To be called after quantization happened to fix RNNs. + """ + for rnn in self._rnns: + rnn._flat_weights = [ + (lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn) + for wn in rnn._flat_weights_names] + if flatten: + rnn.flatten_parameters() + + def _bit_pack_param(self, qparam: _QuantizedParam, quantized: tp.Any, + pack_fn: tp.Any) -> tp.Any: + """Further bitpack the quantized representation. + This is used to return the quantized state. Should be overriden. + """ + return quantized + + def _bit_unpack_param(self, qparam: _QuantizedParam, packed: tp.Any, + unpack_fn: tp.Any) -> tp.Any: + """Unpack bitpacked representation. Should be overriden + """ + return packed + + def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: + """ + To be overriden. + """ + raise NotImplementedError() + + def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: + """ + To be overriden. + """ + raise NotImplementedError() + + def get_quantized_state(self, packed=True, torch_pack=False): + """ + Return a quantized representation fo the weights. If `packed` is True, + this will also perform bitpacking to ensure optimal store. + If `torck_pack` is true, the bitpacking from `torch_pack` will be used. + It is slower (except maybe on GPU), but is compatible with torchscript. + + You can restore a model from a quantized state either using + `BaseQuantizer.restore_quantized_state` or `diffq.restore_quantized_state` + if you do not have the original quantizer around anymore. + """ + float16_params = [] + for p in self._float16: + q = p.data.half() + float16_params.append(q) + + if torch_pack: + pack_fn = torch_pack_mod.pack + else: + pack_fn = bitpack.pack + + all_quantized = [] + for qparam in self._qparams: + if qparam.other is not None: + continue + quantized = self._quantize_param(qparam) + if packed: + quantized = self._bit_pack_param(qparam, quantized, pack_fn=pack_fn) + all_quantized.append(quantized) + + state = { + "quantized": all_quantized, + "float16": float16_params, + "others": [p.data.clone() for p in self._others], + } + + kwargs = dict(self._init_kwargs) + kwargs.pop("model") + state["meta"] = { + "init_kwargs": kwargs, + "klass": self.__class__, + "packed": packed, + "torch_pack": torch_pack + } + return state + + def restore_quantized_state(self, state) -> None: + """ + Restore the state of the model from the quantized state. + """ + for p, q in zip(self._float16, state["float16"]): + p.data[:] = q.to(p) + + for p, q in zip(self._others, state["others"]): + p.data[:] = q + + meta = state.get("meta", {}) + packed = meta.get("packed", False) + torch_pack = meta.get("torch_pack", False) + + if torch_pack: + unpack_fn = torch_pack_mod.unpack + else: + unpack_fn = bitpack.unpack + + remaining = list(state["quantized"]) + for qparam in self._qparams: + if qparam.other is not None: + # Only unquantize first appearance of nn.Parameter. + continue + quantized = remaining.pop(0) + if packed: + quantized = self._bit_unpack_param(qparam, quantized, unpack_fn) + qparam.param.data[:] = self._unquantize_param(qparam, quantized) + assert not remaining + self._fix_rnns() + + def detach(self) -> None: + """ + Detach from the model, removes hooks and anything else. + """ + self._pre_handle.remove() + self._post_handle.remove() + + def model_size(self) -> torch.Tensor: + """ + Returns an estimate of the quantized model size. + """ + total = torch.tensor(0.) + for p in self._float16: + total += 16 * p.numel() + for p in self._others: + total += 32 * p.numel() + return total / 2**20 / 8 # bits to MegaBytes + + def true_model_size(self) -> float: + """ + Return the true quantized model size, in MB, without extra + compression. + """ + return self.model_size().item() + + def packed_model_size(self) -> float: + """Return the packed model size, when stored with pickle. + This should be mostly equivalent to `true_model_size` up to some + slight overhead for storing metadata. + """ + state = self.get_quantized_state(packed=True) + return len(pickle.dumps(state)) / 2 ** 20 + + def compressed_model_size(self, compress_level=-1, num_workers=8) -> float: + """ + Return the compressed quantized model size, in MB. + + Args: + compress_level (int): compression level used with zlib, + see `zlib.compress` for details. + num_workers (int): will split the final big byte representation in that + many chunks processed in parallels. + """ + out = io.BytesIO() + torch.save(self.get_quantized_state(packed=False), out) + ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers) + return ms / 2 ** 20 + + +def restore_quantized_state(model: torch.nn.Module, state: dict): + assert "meta" in state + quantizer = state["meta"]["klass"](model, **state["meta"]["init_kwargs"]) + quantizer.restore_quantized_state(state) + quantizer.detach() + + +def _compress_len(data, compress_level): + return len(zlib.compress(data, level=compress_level)) + + +def _parallel_compress_len(data, compress_level, num_workers): + num_workers = min(cpu_count(), num_workers) + chunk_size = int(math.ceil(len(data) / num_workers)) + chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)] + with futures.ThreadPoolExecutor(num_workers) as pool: + # thread pool is okay here, zlib calls an external C lib and GIL is released + # before the call. + return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks)) diff --git a/diffq/bitpack.cp39-win_amd64.pyd b/diffq/bitpack.cp39-win_amd64.pyd new file mode 100644 index 0000000..a28d6e9 Binary files /dev/null and b/diffq/bitpack.cp39-win_amd64.pyd differ diff --git a/diffq/diffq.py b/diffq/diffq.py new file mode 100644 index 0000000..5fa5113 --- /dev/null +++ b/diffq/diffq.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Differentiable quantizer based on scaled noise injection. +""" +from dataclasses import dataclass +import math +import typing as tp + +import torch + +from .base import BaseQuantizer +from .uniform import uniform_quantize, uniform_unquantize +from .utils import capture_init, simple_repr + + +class DiffQuantizer(BaseQuantizer): + @dataclass + class _QuantizedParam(BaseQuantizer._QuantizedParam): + logit: torch.nn.Parameter + + @capture_init + def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False, + group_size: int = 1, min_bits: float = 2, max_bits: float = 15, + param="bits", noise="gaussian", + init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq", + exclude: tp.List[str] = [], detect_bound: bool = True): + """ + Differentiable quantizer based on scaled noise injection. + For every parameter `p` in the model, this introduces a number of bits parameter + `b` with the same dimensions (when group_size = 1). + Before each forward, `p` is replaced by `p + U` + with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization + step for `b` bits. + This noise approximates the quantization noise in a differentiable manner, both + with respect to the unquantized parameter `p` and the number of bits `b`. + + At eveluation (as detected with `model.eval()`), the model is replaced + by its true quantized version, and restored when going back to training. + + When doing actual quantization (for serialization, or evaluation), + the number of bits is rounded to the nearest integer, and needs to be stored along. + This will cost a few bits per dimension. To reduce this cost, one can use `group_size`, + which will use a single noise level for multiple weight entries. + + You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the + model size in MB. You can then use this estimate as a penalty in your training loss. + + Args: + model (torch.nn.Module): model to quantize + min_size (float): minimum size in MB of a parameter to be quantized. + float16 (bool): if a layer is smaller than min_size, should we still do float16? + group_size (int): weight entries are groupped together to reduce the number + of noise scales to store. This should divide the size of all parameters + bigger than min_size. + min_bits (float): minimal number of bits. + max_bits (float): maximal number of bits. + init_bits (float): initial number of bits. + extra_bits (float): extra bits to add for actual quantization (before roundoff). + suffix (str): suffix used for the name of the extra noise scale parameters. + exclude (list[str]): list of patterns used to match parameters to exclude. + For instance `['bias']` to exclude all bias terms. + detect_bound (bool): if True, will detect bound parameters and reuse + the same quantized tensor for both, as well as the same number of bits. + + ..Warning:: + You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly. + + """ + self.group_size = group_size + self.min_bits = min_bits + self.max_bits = max_bits + self.init_bits = init_bits + self.extra_bits = extra_bits + self.suffix = suffix + self.param = param + self.noise = noise + assert noise in ["gaussian", "uniform"] + self._optimizer_setup = False + + self._min_noise = 1 / (2 ** self.max_bits - 1) + self._max_noise = 1 / (2 ** self.min_bits - 1) + + assert group_size >= 0 + assert min_bits < init_bits < max_bits, \ + "init_bits must be between min_bits and max_bits excluded3" + + for name, _ in model.named_parameters(): + if name.endswith(suffix): + raise RuntimeError("The model already has some noise scales parameters, " + "maybe you used twice a DiffQuantizer on the same model?.") + + super().__init__(model, min_size, float16, exclude, detect_bound) + + def _get_bits(self, logit: torch.Tensor): + if self.param == "noise": + return torch.log2(1 + 1 / self._get_noise_scale(logit)) + else: + t = torch.sigmoid(logit) + return self.max_bits * t + (1 - t) * self.min_bits + + def _get_noise_scale(self, logit: torch.Tensor): + if self.param == "noise": + t = torch.sigmoid(logit) + return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise)) + else: + return 1 / (2 ** self._get_bits(logit) - 1) + + def _register_param(self, name, param, module, other): + if other is not None: + return self.__class__._QuantizedParam( + name=name, param=param, module=module, logit=other.logit, other=other) + assert self.group_size == 0 or param.numel() % self.group_size == 0 + # we want the initial number of bits to be init_bits. + if self.param == "noise": + noise_scale = 1 / (2 ** self.init_bits - 1) + t = (math.log(noise_scale) - math.log(self._max_noise)) / ( + math.log(self._min_noise) - math.log(self._max_noise)) + else: + t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits) + assert 0 < t < 1 + logit = torch.logit(torch.tensor(float(t))) + assert abs(self._get_bits(logit) - self.init_bits) < 1e-5 + if self.group_size > 0: + nparam = param.numel() // self.group_size + else: + nparam = 1 + logit = torch.nn.Parameter( + torch.full( + (nparam,), + logit, + device=param.device)) + module.register_parameter(name + self.suffix, logit) + return self.__class__._QuantizedParam( + name=name, param=param, module=module, logit=logit, other=None) + + def clear_optimizer(self, optimizer: torch.optim.Optimizer): + params = [qp.logit for qp in self._qparams] + + for group in optimizer.param_groups: + new_params = [] + for q in list(group["params"]): + matched = False + for p in params: + if p is q: + matched = True + if not matched: + new_params.append(q) + group["params"][:] = new_params + + def setup_optimizer(self, optimizer: torch.optim.Optimizer, + lr: float = 1e-3, **kwargs): + """ + Setup the optimizer to tune the number of bits. In particular, this will deactivate + weight decay for the bits parameters. + + Args: + optimizer (torch.Optimizer): optimizer to use. + lr (float): specific learning rate for the bits parameters. 1e-3 + is perfect for Adam.,w + kwargs (dict): overrides for other optimization parameters for the bits. + """ + assert not self._optimizer_setup + self._optimizer_setup = True + + params = [qp.logit for qp in self._qparams] + + for group in optimizer.param_groups: + for q in list(group["params"]): + for p in params: + if p is q: + raise RuntimeError("You should create the optimizer " + "before the quantizer!") + + group = {"params": params, "lr": lr, "weight_decay": 0} + group.update(kwargs) + optimizer.add_param_group(group) + + def no_optimizer(self): + """ + Call this if you do not want to use an optimizer. + """ + self._optimizer_setup = True + + def check_unused(self): + for qparam in self._qparams: + if qparam.other is not None: + continue + grad = qparam.param.grad + if grad is None or (grad == 0).all(): + if qparam.logit.grad is not None: + qparam.logit.grad.data.zero_() + + def model_size(self, exact=False): + """ + Differentiable estimate of the model size. + The size is returned in MB. + + If `exact` is True, then the output is no longer differentiable but + reflect exactly an achievable size, even without compression, + i.e.same as returned by `naive_model_size()`. + """ + total = super().model_size() + subtotal = 0 + for qparam in self._qparams: + # only count the first appearance of a Parameter + if qparam.other is not None: + continue + bits = self.extra_bits + self._get_bits(qparam.logit) + if exact: + bits = bits.round().clamp(1, 15) + if self.group_size == 0: + group_size = qparam.param.numel() + else: + group_size = self.group_size + subtotal += group_size * bits.sum() + subtotal += 2 * 32 # param scale + + # Number of bits to represent each number of bits + bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits))) + subtotal += 8 # 8 bits for bits_bits + subtotal += bits_bits * bits.numel() + + subtotal /= 2 ** 20 * 8 # bits -> MegaBytes + return total + subtotal + + def true_model_size(self): + """ + Naive model size without zlib compression. + """ + return self.model_size(exact=True).item() + + def _pre_forward_train(self): + if not self._optimizer_setup: + raise RuntimeError("You must call `setup_optimizer()` on your optimizer " + "before starting training.") + for qparam in self._qparams: + if qparam.other is not None: + noisy = qparam.other.module._parameters[qparam.other.name] + else: + bits = self._get_bits(qparam.logit)[:, None] + if self.group_size == 0: + p_flat = qparam.param.view(-1) + else: + p_flat = qparam.param.view(-1, self.group_size) + scale = p_flat.max() - p_flat.min() + unit = 1 / (2**bits - 1) + if self.noise == "uniform": + noise_source = (torch.rand_like(p_flat) - 0.5) + elif self.noise == "gaussian": + noise_source = torch.randn_like(p_flat) / 2 + noise = scale * unit * noise_source + noisy = p_flat + noise + # We bypass the checks by PyTorch on parameters being leafs + qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param) + return True + + def _post_forward_train(self): + for qparam in self._qparams: + qparam.module._parameters[qparam.name] = qparam.param + return True + + def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: + bits = self.extra_bits + self._get_bits(qparam.logit) + bits = bits.round().clamp(1, 15)[:, None].byte() + if self.group_size == 0: + p = qparam.param.data.view(1, -1) + else: + p = qparam.param.data.view(-1, self.group_size) + levels, scales = uniform_quantize(p, bits) + return levels, scales, bits[:, 0] + + def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: + levels, param_scale, bits = quantized + bits = bits[:, None] + return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data) + + def _bit_pack_param(self, qparam, quantized, pack_fn): + levels, scales, bits = quantized + all_packed = [] + for bit in range(1, 15): + sub_levels = levels[bits == bit] + if not sub_levels.numel(): + all_packed.append(None) + else: + packed = pack_fn(sub_levels, bit) + all_packed.append(packed) + packed_bits = pack_fn(bits - self.min_bits) + return (all_packed, scales, packed_bits) + + def _bit_unpack_param(self, qparam, packed, unpack_fn): + """Unpack bitpacked representation. Should be overriden. + """ + packed_all_levels, scales, packed_bits = packed + bits = unpack_fn(packed_bits, qparam.logit.numel()) + self.min_bits + bits = bits.to(qparam.param.device) + levels = torch.empty(qparam.logit.numel(), self.group_size, + dtype=torch.short, device=qparam.param.device) + for idx, packed_levels in enumerate(packed_all_levels): + bit = idx + 1 + if packed_levels is None: + continue + sub_levels = levels[bits == bit] + levels[bits == bit] = unpack_fn( + packed_levels, sub_levels.numel()).view_as(sub_levels).to(sub_levels) + return (levels, scales, bits) + + def detach(self): + super().detach() + for qparam in self._qparams: + delattr(qparam.module, qparam.name + self.suffix) + + def __repr__(self): + return simple_repr(self) diff --git a/diffq/lsq.py b/diffq/lsq.py new file mode 100644 index 0000000..95238b0 --- /dev/null +++ b/diffq/lsq.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Learnt-Stepsize quantizer from [Esser et al. 2019] https://arxiv.org/abs/1902.08153. +""" +from dataclasses import dataclass +import typing as tp + +import torch + +from .base import BaseQuantizer +from .utils import capture_init, simple_repr + + +class LSQ(BaseQuantizer): + """Implements weight only quantization based on [Esser et al. 2019]. + https://arxiv.org/abs/1902.08153 + """ + @dataclass + class _QuantizedParam(BaseQuantizer._QuantizedParam): + scale: torch.nn.Parameter + + @capture_init + def __init__(self, model: torch.nn.Module, bits: int = 8, min_size: float = 0.01, + float16: bool = False, suffix: str = "_lsq", exclude=[], detect_bound=True): + assert 0 < bits <= 15 + self.suffix = suffix + self._optimizer_setup = False + self.bits = bits + + for name, _ in model.named_parameters(): + if name.endswith(suffix): + raise RuntimeError("The model already has some noise scales parameters, " + "maybe you used twice a LSQ on the same model?.") + + super().__init__(model, min_size, float16, exclude, detect_bound) + + def _register_param(self, name, param, module, other): + if other is not None: + return self.__class__._QuantizedParam( + name=name, param=param, module=module, scale=other.scale, other=other) + # we want the initial number of bits to be init_bits. + scale = 2 * param.data.abs().mean() / (2 ** (self.bits - 1))**0.5 + scale = torch.nn.Parameter(scale) + module.register_parameter(name + self.suffix, scale) + return self.__class__._QuantizedParam( + name=name, param=param, module=module, scale=scale, other=None) + + def clear_optimizer(self, optimizer: torch.optim.Optimizer): + params = [qp.scale for qp in self._qparams] + + for group in optimizer.param_groups: + new_params = [] + for q in list(group["params"]): + matched = False + for p in params: + if p is q: + matched = True + if not matched: + new_params.append(q) + group["params"][:] = new_params + + def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs): + """ + Setup the optimizer to tune the scale parameter. + Following [Esser et al. 2019], we use the same LR and weight decay + as the base optimizer, unless specified otherwise. + + Args: + optimizer (torch.Optimizer): optimizer to use. + kwargs (dict): overrides for optimization parameters + """ + assert not self._optimizer_setup + self._optimizer_setup = True + + params = [qp.scale for qp in self._qparams] + + for group in optimizer.param_groups: + for q in list(group["params"]): + for p in params: + if p is q: + raise RuntimeError("You should create the optimizer " + "before the quantizer!") + + group = {"params": params} + group.update(kwargs) + optimizer.add_param_group(group) + + def no_optimizer(self): + """ + Call this if you do not want to use an optimizer. + """ + self._optimizer_setup = True + + def model_size(self, exact=False): + """ + Differentiable estimate of the model size. + The size is returned in MB. + + If `exact` is True, then the output is no longer differentiable but + reflect exactly an achievable size, even without compression, + i.e.same as returned by `naive_model_size()`. + """ + total = super().model_size() + subtotal = 0 + for qparam in self._qparams: + # only count the first appearance of a Parameter + if qparam.other is not None: + continue + bits = qparam.param.numel() * self.bits + subtotal += bits + subtotal += 1 * 32 # param scale + + subtotal /= 2 ** 20 * 8 # bits -> MegaBytes + return total + subtotal + + def true_model_size(self): + """ + Naive model size without zlib compression. + """ + return self.model_size(exact=True).item() + + def _pre_forward_train(self): + if not self._optimizer_setup: + raise RuntimeError("You must call `setup_optimizer()` on your optimizer " + "before starting training.") + for qparam in self._qparams: + scale = qparam.scale + quant, _ = quantize(qparam.param, scale, self.bits) + # We bypass the checks by PyTorch on parameters being leafs + qparam.module._parameters[qparam.name] = quant + return True + + def _post_forward_train(self): + for qparam in self._qparams: + qparam.module._parameters[qparam.name] = qparam.param + return True + + def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: + _, index = quantize(qparam.param, qparam.scale, self.bits) + assert (index <= (2 ** (self.bits - 1) - 1)).all(), index.max() + assert (index >= (-2 ** (self.bits - 1))).all(), index.min() + return index.detach().short(), qparam.scale.detach() + + def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: + index, scale = quantized + return index.float() * scale + + def _bit_pack_param(self, qparam, quantized, pack_fn): + levels, scale = quantized + packed = pack_fn(levels + 2 ** (self.bits - 1)) + return (packed, scale) + + def _bit_unpack_param(self, qparam, packed, unpack_fn): + """Unpack bitpacked representation. Should be overriden + """ + packed_levels, scale = packed + levels = unpack_fn( + packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param) + levels -= 2 ** (self.bits - 1) + return (levels, scale) + + def detach(self): + super().detach() + for qparam in self._qparams: + delattr(qparam.module, qparam.name + self.suffix) + + def __repr__(self): + return simple_repr(self) + + +def roundpass(x): + return (x.round() - x).detach() + x + + +def gradscale(x, scale): + return (x - x * scale).detach() + x * scale + + +def quantize(tensor, scale, bits): + low = - 2 ** (bits - 1) + high = 2 ** (bits - 1) - 1 + scale = gradscale(scale, 1 / (tensor.numel() * high)**0.5) + + index = tensor / scale + index = index.clamp(low, high) + index = roundpass(index) + return index * scale, index diff --git a/diffq/torch_pack.py b/diffq/torch_pack.py new file mode 100644 index 0000000..926af7d --- /dev/null +++ b/diffq/torch_pack.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Bit packing in pure PyTorch. +Slower than bitpack.pyx but compatible with torchscript. +""" +import math +import typing as tp +import torch +from torch.nn import functional as F + + +def as_rectangle(p: torch.Tensor, side: int): + """Reshape as rectangle, using padding when necessary so that out shape is [*, side]""" + p_flat = p.view(-1) + ideal_length = int(math.ceil(len(p_flat) / side) * side) + p_flat_pad = F.pad(p_flat, (0, ideal_length - len(p_flat))) + return p_flat_pad.view(side, -1) + + +def _storage_size(dtype: torch.dtype): + if dtype == torch.int64: + return 64 + elif dtype == torch.int32: + return 32 + elif dtype == torch.int16: + return 16 + elif dtype == torch.uint8: + return 8 + else: + raise ValueError("Invalid bitpacking storage type") + + +def pack(indexes, nbits: int = 0, storage_dtype: torch.dtype = torch.int16): + """You can think of indexes as a "Tensor" of bits of shape [L, nbits]. + Instead of concatenating naively as [L * nbits], we instead look at it transposed as + [nbits, L]. For L = 16 * G, we get [nbits, G, 16] which is trivial to store + efficiently on int16 integers. + There will be overhead if L is far from a multiple of 16 (e.g. 1) but for large + model layers this is acceptable. Storage type can be changed. + + `nbits` should be the number of bits on which the indexes are coded, and will + actually be determined automatically if set to 0. + """ + assert not indexes.dtype.is_floating_point + if indexes.numel() > 0: + assert indexes.max().item() < 2 ** 15 + assert indexes.min().item() >= 0 + if nbits == 0: + nbits = int(math.ceil(math.log2(1 + (indexes.max())))) + else: + assert indexes.max().item() < 2 ** nbits + + indexes = indexes.reshape(-1) + storage_size = _storage_size(storage_dtype) + rect = as_rectangle(indexes, storage_size) + out = torch.zeros(nbits, rect.shape[1], dtype=storage_dtype, device=indexes.device) + for in_bit in range(nbits): + for out_bit in range(storage_size): + d = ((rect[out_bit] >> in_bit) & 1).to(out.dtype) << out_bit + out[in_bit, :] |= d + return out + + +def unpack(packed: torch.Tensor, length: tp.Optional[int] = None): + """Opposite of `pack`. You might need to specify the original length.""" + storage_size = _storage_size(packed.dtype) + nbits, groups = packed.shape + out = torch.zeros(storage_size, groups, dtype=torch.int16, device=packed.device) + for in_bit in range(storage_size): + for out_bit in range(nbits): + bit_value = (packed[out_bit, :] >> in_bit) & 1 + out[in_bit, :] = out[in_bit, :] | (bit_value.to(out) << out_bit) + out = out.view(-1) + if length is not None: + out = out[:length] + return out diff --git a/diffq/ts_export.py b/diffq/ts_export.py new file mode 100644 index 0000000..d67ab2b --- /dev/null +++ b/diffq/ts_export.py @@ -0,0 +1,209 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""TorchScript export support. +We have to do a lot of black magic for TorchScript to be happy +because we cannot dynamically allocate new weights when loading the model. + +Here is how it works: +- we generate code in a temporary python file for the given model that explicitely + override all the weights on the first forward from their packed version. + This is because TorchScript does not let us iterate over parameters in a generic manner. +- we zero out all the original weights. We cannot simply remove those weights + because TorchScript won't let us recreate them. +- A TorchScript file is just a zip file, but stored without compression. + In order to remove the cost of storing the zeroed out weights, we unzip the file, + and zip it again with compression. +""" +import importlib +import os +from pathlib import Path +import random +import sys +import typing as tp +import tempfile +import zipfile + +import torch +from torch import jit + +from .diffq import DiffQuantizer +from .uniform import uniform_unquantize +from .torch_pack import unpack + +_DiffQPacked = tp.Tuple[ + tp.List[tp.Optional[torch.Tensor]], tp.Tuple[float, float], + torch.Tensor, tp.List[int]] + +# This is the template for the generated class. +TEMPLATE = ''' +import typing as tp +import torch +from torch import jit + +from diffq.ts_export import _unpack_param, _DiffQPacked + +from {module} import {klass} + + +class DiffQTSModel(torch.nn.Module): + def __init__(self, model: {klass}, group_size: int, min_bits: int, + packed: tp.List[_DiffQPacked]): + super().__init__() + self.group_size = group_size + self.min_bits = min_bits + self.model = model + self._unpacked = False + self._packed = packed + + @jit.export + def unpack(self): + """ + Unpack the weights, automatically called on the first forward, + or explicitely.""" + if self._unpacked: + return +{unpack_assigns} + self._unpacked = True + + def forward(self, x: torch.Tensor): + self.unpack() + return self.model.forward(x) +''' + +# those are the assignments for each quantized weight. +UNPACK_ASSIGN = (' ' * 8) + ('self.model{full_name}.data[:] = ' + '_unpack_param(self._packed[{index}], ' + 'group_size=self.group_size, min_bits=self.min_bits)') +UNPACK_ASSIGN_SAME = (' ' * 8) + 'self.model{full_name} = self.model{other_name}' + + +def export(quantizer: DiffQuantizer, path: tp.Union[str, Path]): + """Export the given quantized model to the given path. + We must save the quantized model ourselves, as we need to recompress + the zip archive afterwards. + """ + packed: tp.List[_DiffQPacked] = [] + uniq_name = ''.join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(12)]) + with tempfile.TemporaryDirectory() as tmpdir: + sys.path.insert(0, tmpdir) + try: + code = _codegen(quantizer) + with open(Path(tmpdir) / f'{uniq_name}.py', 'w') as f: + f.write(code) + module = importlib.import_module(uniq_name) + ts_klass = module.DiffQTSModel + state = quantizer.get_quantized_state(packed=True, torch_pack=True) + quantized = state["quantized"] + for qparam in quantizer._qparams: + if qparam.other is None: + levels, scales, bits = quantized.pop(0) + size = qparam.param.size() + packed.append((levels, scales, bits, list(size))) + qparam.param.data.zero_() + quantizer.detach() + ts_premodel = ts_klass(quantizer.model, quantizer.group_size, + quantizer.min_bits, packed) + ts_model = jit.script(ts_premodel) + if path is not None: + jit.save(ts_model, path) + recompress(path) + finally: + sys.path.pop(0) + + return ts_model + + +def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor: + """Function called from TorchScript on the first forward to decode the + packed weights to FP32. + """ + packed_all_levels, scales, packed_bits, shape = packed + numel = 1 + for dim in shape: + numel *= dim + bits = unpack(packed_bits, numel // group_size) + min_bits + levels = torch.empty(bits.numel(), group_size, dtype=torch.short) + for idx, packed_levels in enumerate(packed_all_levels): + bit = idx + 1 + if packed_levels is not None: + sub_levels = levels[bits == bit] + levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels) + bits = bits[:, None] + unquant = uniform_unquantize(levels, scales, bits) + if len(shape) == 4: + return unquant.view(shape[0], shape[1], shape[2], shape[3]) + elif len(shape) == 3: + return unquant.view(shape[0], shape[1], shape[2]) + elif len(shape) == 2: + return unquant.view(shape[0], shape[1]) + elif len(shape) == 1: + return unquant.view(shape[0]) + else: + raise RuntimeError("Invalid numbr of dim") + + +def recompress(path: tp.Union[str, Path]): + """After having saved the torchscript file, this will recompress it + to make sure all the zeroed out parameters don't actually take any space. + """ + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(path) as zipin: + zipin.extractall(tmpdir) + with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED, + compresslevel=1) as zipout: + for root, folders, files in os.walk(tmpdir): + for file in files: + fp = Path(root) / file + name = fp.relative_to(tmpdir) + zipout.write(fp, name) + + +def _get_full_name_access(full_name): + # When generating code, we need to handle attributes vs. indexing. + parts = [] + for part in full_name.split("."): + try: + index = int(part) + except ValueError: + parts.append("." + part) + else: + parts.append(f"[{index}]") + return "".join(parts) + + +def _codegen(quantizer: DiffQuantizer): + # Generates the code for the given quantizer + module = quantizer.model.__class__.__module__ + klass = quantizer.model.__class__.__name__ + model = quantizer.model + + assert not quantizer.float16 + names = {} + for mod_name, mod in model.named_modules(): + names[mod] = mod_name + unpack_assigns = [] + + index = 0 + for qparam in quantizer._qparams: + mod_name = names[qparam.module] + if mod_name == '': + full_name = qparam.name + else: + full_name = mod_name + '.' + qparam.name + full_name = _get_full_name_access(full_name) + if qparam.other is None: + unpack_assigns.append(UNPACK_ASSIGN.format(full_name=full_name, index=index)) + index += 1 + else: + other_name = names[(qparam.other.module, qparam.other.name)] + other_name = _get_full_name_access(other_name) + unpack_assigns.append( + UNPACK_ASSIGN_SAME.format(full_name=full_name, other_name=other_name)) + + return TEMPLATE.format( + module=module, + klass=klass, + unpack_assigns='\n'.join(unpack_assigns)) diff --git a/diffq/uniform.py b/diffq/uniform.py new file mode 100644 index 0000000..7d02a4b --- /dev/null +++ b/diffq/uniform.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Classic uniform quantization over n bits. +""" +from typing import Tuple +import torch + +from .base import BaseQuantizer +from .utils import capture_init, simple_repr + + +def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)): + """ + Quantize the given weights over `bits` bits. + + Returns: + - quantized levels + - (min, max) range. + + """ + assert (bits >= 1).all() and (bits <= 15).all() + num_levels = (2 ** bits.float()).long() + mn = p.min().item() + mx = p.max().item() + p = (p - mn) / (mx - mn) # put p in [0, 1] + unit = 1 / (num_levels - 1) # quantization unit + levels = (p / unit).round() + if (bits <= 8).all(): + levels = levels.byte() + else: + levels = levels.short() + return levels, (mn, mx) + + +def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float], + bits: torch.Tensor = torch.tensor(8.)): + """ + Unquantize the weights from the levels and scale. Return a float32 tensor. + """ + mn, mx = scales + num_levels = 2 ** bits.float() + unit = 1 / (num_levels - 1) + levels = levels.float() + p = levels * unit # in [0, 1] + return p * (mx - mn) + mn + + +class UniformQuantizer(BaseQuantizer): + @capture_init + def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01, + float16: bool = False, qat: bool = False, exclude=[], detect_bound=True): + """ + Args: + model (torch.nn.Module): model to quantize + bits (float): number of bits to quantize over. + min_size (float): minimum size in MB of a parameter to be quantized. + float16 (bool): if a layer is smaller than min_size, should we still do float16? + qat (bool): perform quantized aware training. + exclude (list[str]): list of patterns used to match parameters to exclude. + For instance `['bias']` to exclude all bias terms. + detect_bound (bool): if True, will detect bound parameters and reuse + the same quantized tensor for both. + """ + self.bits = float(bits) + self.qat = qat + + super().__init__(model, min_size, float16, exclude, detect_bound) + + def __repr__(self): + return simple_repr(self, ) + + def _pre_forward_train(self): + if self.qat: + for qparam in self._qparams: + if qparam.other is not None: + new_param = qparam.other.module._parameters[qparam.other.name] + else: + quantized = self._quantize_param(qparam) + qvalue = self._unquantize_param(qparam, quantized) + new_param = qparam.param + (qvalue - qparam.param).detach() + qparam.module._parameters[qparam.name] = new_param + return True + return False + + def _post_forward_train(self): + if self.qat: + for qparam in self._qparams: + qparam.module._parameters[qparam.name] = qparam.param + return True + return False + + def _quantize_param(self, qparam): + levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits)) + return (levels, scales) + + def _unquantize_param(self, qparam, quantized): + levels, scales = quantized + return uniform_unquantize(levels, scales, torch.tensor(self.bits)) + + def _bit_pack_param(self, qparam, quantized, pack_fn): + levels, scales = quantized + packed = pack_fn(levels, self.bits) + return (packed, scales) + + def _bit_unpack_param(self, qparam, packed, unpack_fn): + """Unpack bitpacked representation. Should be overriden + """ + packed_levels, scales = packed + levels = unpack_fn( + packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param) + return (levels, scales) + + def model_size(self): + """ + Non differentiable model size in MB. + """ + total = super().model_size() + subtotal = 0 + for qparam in self._qparams: + if qparam.other is None: # if parameter is bound, count only one copy. + subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales + subtotal /= 2**20 * 8 # bits to MegaBytes + return total + subtotal + + def true_model_size(self): + """ + Return the true quantized model size, in MB, without extra + compression. + """ + return self.model_size().item() diff --git a/diffq/utils.py b/diffq/utils.py new file mode 100644 index 0000000..4cd810c --- /dev/null +++ b/diffq/utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import inspect +from typing import Optional, List + + +def simple_repr(obj, attrs: Optional[List[str]] = None, overrides={}): + """ + Return a simple representation string for `obj`. + If `attrs` is not None, it should be a list of attributes to include. + """ + params = inspect.signature(obj.__class__).parameters + attrs_repr = [] + if attrs is None: + attrs = params.keys() + for attr in attrs: + display = False + if attr in overrides: + value = overrides[attr] + elif hasattr(obj, attr): + value = getattr(obj, attr) + else: + continue + if attr in params: + param = params[attr] + if param.default is inspect._empty or value != param.default: + display = True + else: + display = True + + if display: + attrs_repr.append(f"{attr}={value}") + return f"{obj.__class__.__name__}({','.join(attrs_repr)})" + + +def capture_init(init): + """capture_init. + + Decorate `__init__` with this, and you can then + recover the *args and **kwargs passed to it in `self._init_args_kwargs` + """ + signature = inspect.signature(init) + + @functools.wraps(init) + def __init__(self, *args, **kwargs): + bound = signature.bind(self, *args, **kwargs) + actual_kwargs = dict(bound.arguments) + del actual_kwargs['self'] + actual_kwargs.update(bound.kwargs) + self._init_kwargs = actual_kwargs + init(self, *args, **kwargs) + + return __init__