@ -0,0 +1,237 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Code to apply a model to a mix. It will handle chunking with overlaps and
inteprolation between chunks, as well as the "shift trick".
from concurrent.futures import ThreadPoolExecutor
import random
import typing as tp
import torch as th
from torch import nn
from torch.nn import functional as F
import tqdm
from .demucs import Demucs
from .hdemucs import HDemucs
from .utils import center_trim, DummyPoolExecutor
Model = tp.Union[Demucs, HDemucs]
class BagOfModels(nn.Module):
def __init__(self, models: tp.List[Model],
weights: tp.Optional[tp.List[tp.List[float]]] = None,
segment: tp.Optional[float] = None):
Represents a bag of models with specific weights.
You should call `apply_model` rather than calling directly the forward here for
optimal performance.
models (list[nn.Module]): list of Demucs/HDemucs models.
weights (list[list[float]]): list of weights. If None, assumed to
be all ones, otherwise it should be a list of N list (N number of models),
each containing S floats (S number of sources).
segment (None or float): overrides the `segment` attribute of each model
(this is performed inplace, be careful is you reuse the models passed).
assert len(models) > 0
first = models[0]
for other in models:
assert other.sources == first.sources
assert other.samplerate == first.samplerate
assert other.audio_channels == first.audio_channels
if segment is not None:
other.segment = segment
self.audio_channels = first.audio_channels
self.samplerate = first.samplerate
self.sources = first.sources
self.models = nn.ModuleList(models)
if weights is None:
weights = [[1. for _ in first.sources] for _ in models]
assert len(weights) == len(models)
for weight in weights:
assert len(weight) == len(first.sources)
self.weights = weights
def forward(self, x):
raise NotImplementedError("Call `apply_model` on this.")
class TensorChunk:
def __init__(self, tensor, offset=0, length=None):
total_length = tensor.shape[-1]
assert offset >= 0
assert offset < total_length
if length is None:
length = total_length - offset
length = min(total_length - offset, length)
self.tensor = tensor
self.offset = offset
self.length = length
self.device = tensor.device
def shape(self):
shape = list(self.tensor.shape)
shape[-1] = self.length
return shape
def padded(self, target_length):
delta = target_length - self.length
total_length = self.tensor.shape[-1]
assert delta >= 0
start = self.offset - delta // 2
end = start + target_length
correct_start = max(0, start)
correct_end = min(total_length, end)
pad_left = correct_start - start
pad_right = end - correct_end
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
assert out.shape[-1] == target_length
return out
def tensor_chunk(tensor_or_chunk):
if isinstance(tensor_or_chunk, TensorChunk):
return tensor_or_chunk
assert isinstance(tensor_or_chunk, th.Tensor)
return TensorChunk(tensor_or_chunk)
def apply_model(model, mix, shifts=1, split=True,
overlap=0.25, transition_power=1., progress=False, device=None,
num_workers=0, pool=None):
Apply model to a given mixture.
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
and apply the oppositve shift to the output. This is repeated `shifts` time and
all predictions are averaged. This effectively makes the model time equivariant
and improves SDR by up to 0.2 points.
split (bool): if True, the input will be broken down in 8 seconds extracts
and predictions will be performed individually on each and concatenated.
Useful for model with large memory footprint like Tasnet.
progress (bool): if True, show a progress bar (requires split=True)
device (torch.device, str, or None): if provided, device on which to
execute the computation, otherwise `mix.device` is assumed.
When `device` is different from `mix.device`, only local computations will
be on `device`, while the entire tracks will be stored on `mix.device`.
if device is None:
device = mix.device
device = th.device(device)
if pool is None:
if num_workers > 0 and device.type == 'cpu':
pool = ThreadPoolExecutor(num_workers)
pool = DummyPoolExecutor()
kwargs = {
'shifts': shifts,
'split': split,
'overlap': overlap,
'transition_power': transition_power,
'progress': progress,
'device': device,
'pool': pool,
if isinstance(model, BagOfModels):
# Special treatment for bag of model.
# We explicitely apply multiple times `apply_model` so that the random shifts
# are different for each model.
estimates = 0
totals = [0] * len(model.sources)
for sub_model, weight in zip(model.models, model.weights):
original_model_device = next(iter(sub_model.parameters())).device
out = apply_model(sub_model, mix, **kwargs)
for k, inst_weight in enumerate(weight):
out[:, k, :, :] *= inst_weight
totals[k] += inst_weight
estimates += out
del out
for k in range(estimates.shape[1]):
estimates[:, k, :, :] /= totals[k]
return estimates
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
batch, channels, length = mix.shape
if split:
kwargs['split'] = False
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
sum_weight = th.zeros(length, device=mix.device)
segment = int(model.samplerate * model.segment)
stride = int((1 - overlap) * segment)
offsets = range(0, length, stride)
scale = stride / model.samplerate
# We start from a triangle shaped weight, with maximal weight in the middle
# of the segment. Then we normalize and take to the power `transition_power`.
# Large values of transition power will lead to sharper transitions.
weight = th.cat([th.arange(1, segment // 2 + 1, device=device),
th.arange(segment - segment // 2, 0, -1, device=device)])
assert len(weight) == segment
# If the overlap < 50%, this will translate to linear transition when
# transition_power is 1.
weight = (weight / weight.max())**transition_power
futures = []
for offset in offsets:
chunk = TensorChunk(mix, offset, segment)
future = pool.submit(apply_model, model, chunk, **kwargs)
futures.append((future, offset))
offset += segment
if progress:
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
for future, offset in futures:
chunk_out = future.result()
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
assert sum_weight.min() > 0
out /= sum_weight
return out
elif shifts:
kwargs['shifts'] = 0
max_shift = int(0.5 * model.samplerate)
mix = tensor_chunk(mix)
padded_mix = mix.padded(length + 2 * max_shift)
out = 0
for _ in range(shifts):
offset = random.randint(0, max_shift)
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
shifted_out = apply_model(model, shifted, **kwargs)
out += shifted_out[..., max_shift - offset:]
out /= shifts
return out
if hasattr(model, 'valid_length'):
valid_length = model.valid_length(length)
valid_length = length
mix = tensor_chunk(mix)
padded_mix = mix.padded(valid_length).to(device)
with th.no_grad():
out = model(padded_mix)
return center_trim(out, length)
Normal file
Normal file
@ -0,0 +1,256 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import subprocess as sp
from pathlib import Path
import lameenc
import julius
import numpy as np
import torch
import torchaudio as ta
from .utils import temp_filenames
def _read_info(path):
stdout_data = sp.check_output([
'ffprobe', "-loglevel", "panic",
str(path), '-print_format', 'json', '-show_format', '-show_streams'
return json.loads(stdout_data.decode('utf-8'))
class AudioFile:
Allows to read audio from any format supported by ffmpeg, as well as resampling or
converting to mono on the fly. See :method:`read` for more details.
def __init__(self, path: Path):
self.path = Path(path)
self._info = None
def __repr__(self):
features = [("path", self.path)]
features.append(("samplerate", self.samplerate()))
features.append(("channels", self.channels()))
features.append(("streams", len(self)))
features_str = ", ".join(f"{name}={value}" for name, value in features)
return f"AudioFile({features_str})"
def info(self):
if self._info is None:
self._info = _read_info(self.path)
return self._info
def duration(self):
return float(self.info['format']['duration'])
def _audio_streams(self):
return [
index for index, stream in enumerate(self.info["streams"])
if stream["codec_type"] == "audio"
def __len__(self):
return len(self._audio_streams)
def channels(self, stream=0):
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
def samplerate(self, stream=0):
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
def read(self,
Slightly more efficient implementation than stempeg,
in particular, this will extract all stems at once
rather than having to loop over one file multiple times
for each stream.
seek_time (float): seek time in seconds or None if no seeking is needed.
duration (float): duration in seconds to extract or None to extract until the end.
streams (slice, int or list): streams to extract, can be a single int, a list or
a slice. If it is a slice or list, the output will be of size [S, C, T]
with S the number of streams, C the number of channels and T the number of samples.
If it is an int, the output will be [C, T].
samplerate (int): if provided, will resample on the fly. If None, no resampling will
be done. Original sampling rate can be obtained with :method:`samplerate`.
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
See https://sound.stackexchange.com/a/42710.
Our definition of mono is simply the average of the two channels. Any other
value will be ignored.
temp_folder (str or Path or None): temporary folder to use for decoding.
streams = np.array(range(len(self)))[streams]
single = not isinstance(streams, np.ndarray)
if single:
streams = [streams]
if duration is None:
target_size = None
query_duration = None
target_size = int((samplerate or self.samplerate()) * duration)
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
with temp_filenames(len(streams)) as filenames:
command = ['ffmpeg', '-y']
command += ['-loglevel', 'panic']
if seek_time:
command += ['-ss', str(seek_time)]
command += ['-i', str(self.path)]
for stream, filename in zip(streams, filenames):
command += ['-map', f'0:{self._audio_streams[stream]}']
if query_duration is not None:
command += ['-t', str(query_duration)]
command += ['-threads', '1']
command += ['-f', 'f32le']
if samplerate is not None:
command += ['-ar', str(samplerate)]
command += [filename]
sp.run(command, check=True)
wavs = []
for filename in filenames:
wav = np.fromfile(filename, dtype=np.float32)
wav = torch.from_numpy(wav)
wav = wav.view(-1, self.channels()).t()
if channels is not None:
wav = convert_audio_channels(wav, channels)
if target_size is not None:
wav = wav[..., :target_size]
wav = torch.stack(wavs, dim=0)
if single:
wav = wav[0]
return wav
def convert_audio_channels(wav, channels=2):
"""Convert audio to the given number of channels."""
*shape, src_channels, length = wav.shape
if src_channels == channels:
elif channels == 1:
# Case 1:
# The caller asked 1-channel audio, but the stream have multiple
# channels, downmix all channels.
wav = wav.mean(dim=-2, keepdim=True)
elif src_channels == 1:
# Case 2:
# The caller asked for multiple channels, but the input file have
# one single channel, replicate the audio over all channels.
wav = wav.expand(*shape, channels, length)
elif src_channels >= channels:
# Case 3:
# The caller asked for multiple channels, and the input file have
# more channels than requested. In that case return the first channels.
wav = wav[..., :channels, :]
# Case 4: What is a reasonable choice here?
raise ValueError('The audio file has less channels than requested but is not mono.')
return wav
def convert_audio(wav, from_samplerate, to_samplerate, channels):
"""Convert audio from a given samplerate to a target one and target number of channels."""
wav = convert_audio_channels(wav, channels)
return julius.resample_frac(wav, from_samplerate, to_samplerate)
def i16_pcm(wav):
"""Convert audio to 16 bits integer PCM format."""
if wav.dtype.is_floating_point:
return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
return wav
def f32_pcm(wav):
"""Convert audio to float 32 bits PCM format."""
if wav.dtype.is_floating_point:
return wav
return wav.float() / (2**15 - 1)
def as_dtype_pcm(wav, dtype):
"""Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
if wav.dtype.is_floating_point:
return f32_pcm(wav)
return i16_pcm(wav)
def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False):
"""Save given audio as mp3. This should work on all OSes."""
C, T = wav.shape
wav = i16_pcm(wav)
encoder = lameenc.Encoder()
encoder.set_quality(2) # 2-highest, 7-fastest
if not verbose:
wav = wav.transpose(0, 1).numpy()
mp3_data = encoder.encode(wav.tobytes())
mp3_data += encoder.flush()
with open(path, "wb") as f:
def prevent_clip(wav, mode='rescale'):
different strategies for avoiding raw clipping.
assert wav.dtype.is_floating_point, "too late for clipping"
if mode == 'rescale':
wav = wav / max(1.01 * wav.abs().max(), 1)
elif mode == 'clamp':
wav = wav.clamp(-0.99, 0.99)
elif mode == 'tanh':
wav = torch.tanh(wav)
raise ValueError(f"Invalid mode {mode}")
return wav
def save_audio(wav, path, samplerate, bitrate=320, clip='rescale',
bits_per_sample=16, as_float=False):
"""Save audio file, automatically preventing clipping if necessary
based on the given `clip` strategy. If the path ends in `.mp3`, this
will save as mp3 with the given `bitrate`.
wav = prevent_clip(wav, mode=clip)
path = Path(path)
suffix = path.suffix.lower()
if suffix == ".mp3":
encode_mp3(wav, path, samplerate, bitrate)
elif suffix == ".wav":
if as_float:
bits_per_sample = 32
encoding = 'PCM_F'
encoding = 'PCM_S'
ta.save(str(path), wav, sample_rate=samplerate,
encoding=encoding, bits_per_sample=bits_per_sample)
raise ValueError(f"Invalid suffix for path: {suffix}")
Normal file
Normal file
@ -0,0 +1,459 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import julius
import torch
from torch import nn
from torch.nn import functional as F
from .states import capture_init
from .utils import center_trim, unfold
class BLSTM(nn.Module):
BiLSTM with same hidden units as input dim.
If `max_steps` is not None, input will be splitting in overlapping
chunks and the LSTM applied separately on each chunk.
def __init__(self, dim, layers=1, max_steps=None, skip=False):
assert max_steps is None or max_steps % 4 == 0
self.max_steps = max_steps
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
self.skip = skip
def forward(self, x):
B, C, T = x.shape
y = x
framed = False
if self.max_steps is not None and T > self.max_steps:
width = self.max_steps
stride = width // 2
frames = unfold(x, width, stride)
nframes = frames.shape[2]
framed = True
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
if framed:
out = []
frames = x.reshape(B, -1, C, width)
limit = stride // 2
for k in range(nframes):
if k == 0:
out.append(frames[:, k, :, :-limit])
elif k == nframes - 1:
out.append(frames[:, k, :, limit:])
out.append(frames[:, k, :, limit:-limit])
out = torch.cat(out, -1)
out = out[..., :T]
x = out
if self.skip:
x = x + y
return x
def rescale_conv(conv, reference):
"""Rescale initial weight scale. It is unclear why it helps but it certainly does.
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
rescale_conv(sub, reference)
class LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonaly residual outputs close to 0 initially, then learnt.
def __init__(self, channels: int, init: float = 0):
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
self.scale.data[:] = init
def forward(self, x):
return self.scale[:, None] * x
class DConv(nn.Module):
New residual branches in each encoder layer.
This alternates dilated convolutions, potentially with LSTMs and attention.
Also before entering each residual branch, dimension is projected on a smaller subspace,
e.g. of dim `channels // compress`.
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
kernel=3, dilate=True):
channels: input/output channels for residual branch.
compress: amount of channel compression inside the branch.
depth: number of layers in the residual branch. Each layer has its own
projection, and potentially LSTM and attention.
init: initial scale for LayerNorm.
norm: use GroupNorm.
attn: use LocalAttention.
heads: number of heads for the LocalAttention.
ndecay: number of decay controls in the LocalAttention.
lstm: use LSTM.
gelu: Use GELU activation.
kernel: kernel size for the (dilated) convolutions.
dilate: if true, use dilation, increasing with the depth.
assert kernel % 2 == 1
self.channels = channels
self.compress = compress
self.depth = abs(depth)
dilate = depth > 0
norm_fn: tp.Callable[[int], nn.Module]
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
hidden = int(channels / compress)
act: tp.Type[nn.Module]
if gelu:
act = nn.GELU
act = nn.ReLU
self.layers = nn.ModuleList([])
for d in range(self.depth):
dilation = 2 ** d if dilate else 1
padding = dilation * (kernel // 2)
mods = [
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
norm_fn(hidden), act(),
nn.Conv1d(hidden, 2 * channels, 1),
norm_fn(2 * channels), nn.GLU(1),
LayerScale(channels, init),
if attn:
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
if lstm:
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
layer = nn.Sequential(*mods)
def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return x
class LocalState(nn.Module):
"""Local state allows to have attention based only on data (no positional embedding),
but while setting a constraint on the time window (e.g. decaying penalty term).
Also a failed experiments with trying to provide some frequency based attention.
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
assert channels % heads == 0, (channels, heads)
self.heads = heads
self.nfreqs = nfreqs
self.ndecay = ndecay
self.content = nn.Conv1d(channels, channels, 1)
self.query = nn.Conv1d(channels, channels, 1)
self.key = nn.Conv1d(channels, channels, 1)
if nfreqs:
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
if ndecay:
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
self.query_decay.weight.data *= 0.01
assert self.query_decay.bias is not None # stupid type checker
self.query_decay.bias.data[:] = -2
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
def forward(self, x):
B, C, T = x.shape
heads = self.heads
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
# left index are keys, right index are queries
delta = indexes[:, None] - indexes[None, :]
queries = self.query(x).view(B, heads, -1, T)
keys = self.key(x).view(B, heads, -1, T)
# t are keys, s are queries
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
dots /= keys.shape[2]**0.5
if self.nfreqs:
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
if self.ndecay:
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
decay_q = self.query_decay(x).view(B, heads, -1, T)
decay_q = torch.sigmoid(decay_q) / 2
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
# Kill self reference.
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
weights = torch.softmax(dots, dim=2)
content = self.content(x).view(B, heads, -1, T)
result = torch.einsum("bhts,bhct->bhcs", weights, content)
if self.nfreqs:
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
result = torch.cat([result, time_sig], 2)
result = result.reshape(B, -1, T)
return x + self.proj(result)
class Demucs(nn.Module):
def __init__(self,
# Channels
# Main structure
# Convolutions
# Activations
# Normalization
# DConv residual branch
# Pre/post processing
# Weight init
# Metadata
segment=4 * 10):
sources (list[str]): list of source names
audio_channels (int): stereo or mono
channels (int): first convolution channels
depth (int): number of encoder/decoder layers
growth (float): multiply (resp divide) number of channels by that
for each layer of the encoder (resp decoder)
depth (int): number of layers in the encoder and in the decoder.
rewrite (bool): add 1x1 convolution to each layer.
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
by default, as this is now replaced by the smaller and faster small LSTMs
in the DConv branches.
kernel_size (int): kernel size for convolutions
stride (int): stride for convolutions
context (int): kernel size of the convolution in the
decoder before the transposed convolution. If > 1,
will provide some context from neighboring time steps.
gelu: use GELU activation function.
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
norm_starts: layer at which group norm starts being used.
decoder layers are numbered in reverse order.
norm_groups: number of groups for group norm.
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
dconv_depth: depth of residual DConv branch.
dconv_comp: compression of DConv branch.
dconv_attn: adds attention layers in DConv branch starting at this layer.
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
dconv_init: initial scale for the DConv branch LayerScale.
normalize (bool): normalizes the input audio on the fly, and scales back
the output by the same amount.
resample (bool): upsample x2 the input and downsample /2 the output.
rescale (int): rescale initial weights of convolutions
to get their standard deviation closer to `rescale`.
samplerate (int): stored as meta information for easing
future evaluations of the model.
segment (float): duration of the chunks of audio to ideally evaluate the model on.
This is used by `demucs.apply.apply_model`.
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.resample = resample
self.channels = channels
self.normalize = normalize
self.samplerate = samplerate
self.segment = segment
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.skip_scales = nn.ModuleList()
if glu:
activation = nn.GLU(dim=1)
ch_scale = 2
activation = nn.ReLU()
ch_scale = 1
if gelu:
act2 = nn.GELU
act2 = nn.ReLU
in_channels = audio_channels
padding = 0
for index in range(depth):
norm_fn = lambda d: nn.Identity() # noqa
if index >= norm_starts:
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
encode = []
encode += [
nn.Conv1d(in_channels, channels, kernel_size, stride),
attn = index >= dconv_attn
lstm = index >= dconv_lstm
if dconv_mode & 1:
encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
compress=dconv_comp, attn=attn, lstm=lstm)]
if rewrite:
encode += [
nn.Conv1d(channels, ch_scale * channels, 1),
norm_fn(ch_scale * channels), activation]
decode = []
if index > 0:
out_channels = in_channels
out_channels = len(self.sources) * audio_channels
if rewrite:
decode += [
nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
norm_fn(ch_scale * channels), activation]
if dconv_mode & 2:
decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
compress=dconv_comp, attn=attn, lstm=lstm)]
decode += [nn.ConvTranspose1d(channels, out_channels,
kernel_size, stride, padding=padding)]
if index > 0:
decode += [norm_fn(out_channels), act2()]
self.decoder.insert(0, nn.Sequential(*decode))
in_channels = channels
channels = int(growth * channels)
channels = in_channels
if lstm_layers:
self.lstm = BLSTM(channels, lstm_layers)
self.lstm = None
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolution, e.g. for all
layers, size of the input - kernel_size % stride = 0.
Note that input are automatically padded if necessary to ensure that the output
has the same length as the input.
if self.resample:
length *= 2
for _ in range(self.depth):
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(1, length)
for idx in range(self.depth):
length = (length - 1) * self.stride + self.kernel_size
if self.resample:
length = math.ceil(length / 2)
return int(length)
def forward(self, mix):
x = mix
length = x.shape[-1]
if self.normalize:
mono = mix.mean(dim=1, keepdim=True)
mean = mono.mean(dim=-1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
x = (x - mean) / (1e-5 + std)
mean = 0
std = 1
delta = self.valid_length(length) - length
x = F.pad(x, (delta // 2, delta - delta // 2))
if self.resample:
x = julius.resample_frac(x, 1, 2)
saved = []
for encode in self.encoder:
x = encode(x)
if self.lstm:
x = self.lstm(x)
for decode in self.decoder:
skip = saved.pop(-1)
skip = center_trim(skip, x)
x = decode(x + skip)
if self.resample:
x = julius.resample_frac(x, 2, 1)
x = x * std + mean
x = center_trim(x, length)
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
return x
def load_state_dict(self, state, strict=True):
# fix a mismatch with previous generation Demucs models.
for idx in range(self.depth):
for a in ['encoder', 'decoder']:
for b in ['bias', 'weight']:
new = f'{a}.{idx}.3.{b}'
old = f'{a}.{idx}.2.{b}'
if old in state and new not in state:
state[new] = state.pop(old)
super().load_state_dict(state, strict=strict)
Normal file
Normal file
@ -0,0 +1,100 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Distributed training utilities.
import logging
import pickle
import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Subset
from torch.nn.parallel.distributed import DistributedDataParallel
from dora import distrib as dora_distrib
logger = logging.getLogger(__name__)
rank = 0
world_size = 1
def init():
global rank, world_size
if not torch.distributed.is_initialized():
rank = dora_distrib.rank()
world_size = dora_distrib.world_size()
def average(metrics, count=1.):
if isinstance(metrics, dict):
keys, values = zip(*sorted(metrics.items()))
values = average(values, count)
return dict(zip(keys, values))
if world_size == 1:
return metrics
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
tensor *= count
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
def wrap(model):
if world_size == 1:
return model
return DistributedDataParallel(
# find_unused_parameters=True,
def barrier():
if world_size > 1:
def share(obj=None, src=0):
if world_size == 1:
return obj
size = torch.empty(1, device='cuda', dtype=torch.long)
if rank == src:
dump = pickle.dumps(obj)
size[0] = len(dump)
torch.distributed.broadcast(size, src=src)
# size variable is now set to the length of pickled obj in all processes
if rank == src:
buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
torch.distributed.broadcast(buffer, src=src)
# buffer variable is now set to pickled obj in all processes
if rank != src:
obj = pickle.loads(buffer.cpu().numpy().tobytes())
logger.debug(f"Shared object of size {len(buffer)}")
return obj
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
Create a dataloader properly in case of distributed training.
If a gradient is going to be computed you must set `shuffle=True`.
if world_size == 1:
return klass(dataset, *args, shuffle=shuffle, **kwargs)
if shuffle:
# train means we will compute backward, we use DistributedSampler
sampler = DistributedSampler(dataset)
# We ignore shuffle, DistributedSampler already shuffles
return klass(dataset, *args, **kwargs, sampler=sampler)
# We make a manual shard, as DistributedSampler otherwise replicate some examples
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
return klass(dataset, *args, shuffle=shuffle, **kwargs)
Normal file
Normal file
@ -0,0 +1,66 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Inspired from https://github.com/rwightman/pytorch-image-models
from contextlib import contextmanager
import torch
from .states import swap_state
class ModelEMA:
Perform EMA on a model. You can switch to the EMA weights temporarily
with the `swap` method.
ema = ModelEMA(model)
with ema.swap():
# compute valid metrics with averaged model.
def __init__(self, model, decay=0.9999, unbias=True, device='cpu'):
self.decay = decay
self.model = model
self.state = {}
self.count = 0
self.device = device
self.unbias = unbias
def _init(self):
for key, val in self.model.state_dict().items():
if val.dtype != torch.float32:
device = self.device or val.device
if key not in self.state:
self.state[key] = val.detach().to(device, copy=True)
def update(self):
if self.unbias:
self.count = self.count * self.decay + 1
w = 1 / self.count
w = 1 - self.decay
for key, val in self.model.state_dict().items():
if val.dtype != torch.float32:
device = self.device or val.device
self.state[key].mul_(1 - w)
self.state[key].add_(val.detach().to(device), alpha=w)
def swap(self):
with swap_state(self.model, self.state):
def state_dict(self):
return {'state': self.state, 'count': self.count}
def load_state_dict(self, state):
self.count = state['count']
for k, v in state['state'].items():
Normal file
Normal file
@ -0,0 +1,173 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Test time evaluation, either using the original SDR from [Vincent et al. 2006]
or the newest SDR definition from the MDX 2021 competition (this one will
be reported as `nsdr` for `new sdr`).
from concurrent import futures
import logging
from dora.log import LogProgress
import numpy as np
import musdb
import museval
import torch as th
from .apply import apply_model
from .audio import convert_audio, save_audio
from . import distrib
from .utils import DummyPoolExecutor
logger = logging.getLogger(__name__)
def new_sdr(references, estimates):
Compute the SDR according to the MDX challenge definition.
Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
assert references.dim() == 4
assert estimates.dim() == 4
delta = 1e-7 # avoid numerical errors
num = th.sum(th.square(references), dim=(2, 3))
den = th.sum(th.square(references - estimates), dim=(2, 3))
num += delta
den += delta
scores = 10 * th.log10(num / den)
return scores
def eval_track(references, estimates, win, hop, compute_sdr=True):
references = references.transpose(1, 2).double()
estimates = estimates.transpose(1, 2).double()
new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]
if not compute_sdr:
return None, new_scores
references = references.numpy()
estimates = estimates.numpy()
scores = museval.metrics.bss_eval(
references, estimates,
return scores, new_scores
def evaluate(solver, compute_sdr=False):
Evaluate model using museval.
`new_only` means using only the MDX definition of the SDR, which is much faster to evaluate.
args = solver.args
output_dir = solver.folder / "results"
output_dir.mkdir(exist_ok=True, parents=True)
json_folder = solver.folder / "results/test"
json_folder.mkdir(exist_ok=True, parents=True)
# we load tracks from the original musdb set
if args.test.nonhq is None:
test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True)
test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False)
src_rate = args.dset.musdb_samplerate
eval_device = 'cpu'
model = solver.model
win = int(1. * model.samplerate)
hop = int(1. * model.samplerate)
indexes = range(distrib.rank, len(test_set), distrib.world_size)
indexes = LogProgress(logger, indexes, updates=args.misc.num_prints,
pendings = []
pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor
with pool(args.test.workers) as pool:
for index in indexes:
track = test_set.tracks[index]
mix = th.from_numpy(track.audio).t().float()
if mix.dim() == 1:
mix = mix[None]
mix = mix.to(solver.device)
ref = mix.mean(dim=0) # mono mixture
mix = (mix - ref.mean()) / ref.std()
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
estimates = apply_model(model, mix[None],
shifts=args.test.shifts, split=args.test.split,
estimates = estimates * ref.std() + ref.mean()
estimates = estimates.to(eval_device)
references = th.stack(
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
if references.dim() == 2:
references = references[:, None]
references = references.to(eval_device)
references = convert_audio(references, src_rate,
model.samplerate, model.audio_channels)
if args.test.save:
folder = solver.folder / "wav" / track.name
folder.mkdir(exist_ok=True, parents=True)
for name, estimate in zip(model.sources, estimates):
save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate)
pendings.append((track.name, pool.submit(
eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr)))
pendings = LogProgress(logger, pendings, updates=args.misc.num_prints,
name='Eval (BSS)')
tracks = {}
for track_name, pending in pendings:
pending = pending.result()
scores, nsdrs = pending
tracks[track_name] = {}
for idx, target in enumerate(model.sources):
tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]}
if scores is not None:
(sdr, isr, sir, sar) = scores
for idx, target in enumerate(model.sources):
values = {
"SDR": sdr[idx].tolist(),
"SIR": sir[idx].tolist(),
"ISR": isr[idx].tolist(),
"SAR": sar[idx].tolist()
all_tracks = {}
for src in range(distrib.world_size):
all_tracks.update(distrib.share(tracks, src))
result = {}
metric_names = next(iter(all_tracks.values()))[model.sources[0]]
for metric_name in metric_names:
avg = 0
avg_of_medians = 0
for source in model.sources:
medians = [
for track in all_tracks.keys()]
mean = np.mean(medians)
median = np.median(medians)
result[metric_name.lower() + "_" + source] = mean
result[metric_name.lower() + "_med" + "_" + source] = median
avg += mean / len(model.sources)
avg_of_medians += median / len(model.sources)
result[metric_name.lower()] = avg
result[metric_name.lower() + "_med"] = avg_of_medians
return result
Normal file
Normal file
@ -0,0 +1,761 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
This code contains the spectrogram and Hybrid version of Demucs.
from copy import deepcopy
import math
from openunmix.filtering import wiener
import torch
from torch import nn
from torch.nn import functional as F
from .demucs import DConv, rescale_module
from .states import capture_init
from .spec import spectro, ispectro
class ScaledEmbedding(nn.Module):
Boost learning rate for embeddings (with `scale`).
Also, can make embeddings continuous with `smooth`.
def __init__(self, num_embeddings: int, embedding_dim: int,
scale: float = 10., smooth=False):
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
if smooth:
weight = torch.cumsum(self.embedding.weight.data, dim=0)
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
self.embedding.weight.data[:] = weight
self.embedding.weight.data /= scale
self.scale = scale
def weight(self):
return self.embedding.weight * self.scale
def forward(self, x):
out = self.embedding(x) * self.scale
return out
class HEncLayer(nn.Module):
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
"""Encoder layer. This used both by the time and the frequency branch.
chin: number of input channels.
chout: number of output channels.
norm_groups: number of groups for group norm.
empty: used to make a layer with just the first conv. this is used
before merging the time and freq. branches.
freq: this is acting on frequencies.
dconv: insert DConv residual branches.
norm: use GroupNorm.
context: context size for the 1x1 conv.
dconv_kw: list of kwargs for the DConv class.
pad: pad the input. Padding is done so that the output size is
always the input size / stride.
rewrite: add 1x1 conv at the end of the layer.
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
pad = kernel_size // 4
pad = 0
klass = nn.Conv1d
self.freq = freq
self.kernel_size = kernel_size
self.stride = stride
self.empty = empty
self.norm = norm
self.pad = pad
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
pad = [pad, 0]
klass = nn.Conv2d
self.conv = klass(chin, chout, kernel_size, stride, pad)
if self.empty:
self.norm1 = norm_fn(chout)
self.rewrite = None
if rewrite:
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
self.norm2 = norm_fn(2 * chout)
self.dconv = None
if dconv:
self.dconv = DConv(chout, **dconv_kw)
def forward(self, x, inject=None):
`inject` is used to inject the result from the time branch into the frequency branch,
when both have the same stride.
if not self.freq and x.dim() == 4:
B, C, Fr, T = x.shape
x = x.view(B, -1, T)
if not self.freq:
le = x.shape[-1]
if not le % self.stride == 0:
x = F.pad(x, (0, self.stride - (le % self.stride)))
y = self.conv(x)
if self.empty:
return y
if inject is not None:
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
if inject.dim() == 3 and y.dim() == 4:
inject = inject[:, :, None]
y = y + inject
y = F.gelu(self.norm1(y))
if self.dconv:
if self.freq:
B, C, Fr, T = y.shape
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
y = self.dconv(y)
if self.freq:
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
if self.rewrite:
z = self.norm2(self.rewrite(y))
z = F.glu(z, dim=1)
z = y
return z
class MultiWrap(nn.Module):
Takes one layer and replicate it N times. each replica will act
on a frequency band. All is done so that if the N replica have the same weights,
then this is exactly equivalent to applying the original module on all frequencies.
This is a bit over-engineered to avoid edge artifacts when splitting
the frequency bands, but it is possible the naive implementation would work as well...
def __init__(self, layer, split_ratios):
layer: module to clone, must be either HEncLayer or HDecLayer.
split_ratios: list of float indicating which ratio to keep for each band.
self.split_ratios = split_ratios
self.layers = nn.ModuleList()
self.conv = isinstance(layer, HEncLayer)
assert not layer.norm
assert layer.freq
assert layer.pad
if not self.conv:
assert not layer.context_freq
for k in range(len(split_ratios) + 1):
lay = deepcopy(layer)
if self.conv:
lay.conv.padding = (0, 0)
lay.pad = False
for m in lay.modules():
if hasattr(m, 'reset_parameters'):
def forward(self, x, skip=None, length=None):
B, C, Fr, T = x.shape
ratios = list(self.split_ratios) + [1]
start = 0
outs = []
for ratio, layer in zip(ratios, self.layers):
if self.conv:
pad = layer.kernel_size // 4
if ratio == 1:
limit = Fr
frames = -1
limit = int(round(Fr * ratio))
le = limit - start
if start == 0:
le += pad
frames = round((le - layer.kernel_size) / layer.stride + 1)
limit = start + (frames - 1) * layer.stride + layer.kernel_size
if start == 0:
limit -= pad
assert limit - start > 0, (limit, start)
assert limit <= Fr, (limit, Fr)
y = x[:, :, start:limit, :]
if start == 0:
y = F.pad(y, (0, 0, pad, 0))
if ratio == 1:
y = F.pad(y, (0, 0, 0, pad))
start = limit - layer.kernel_size + layer.stride
if ratio == 1:
limit = Fr
limit = int(round(Fr * ratio))
last = layer.last
layer.last = True
y = x[:, :, start:limit]
s = skip[:, :, start:limit]
out, _ = layer(y, s, None)
if outs:
outs[-1][:, :, -layer.stride:] += (
out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
out = out[:, :, layer.stride:]
if ratio == 1:
out = out[:, :, :-layer.stride // 2, :]
if start == 0:
out = out[:, :, layer.stride // 2:, :]
layer.last = last
start = limit
out = torch.cat(outs, dim=2)
if not self.conv and not last:
out = F.gelu(out)
if self.conv:
return out
return out, None
class HDecLayer(nn.Module):
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
context_freq=True, rewrite=True):
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
pad = kernel_size // 4
pad = 0
self.pad = pad
self.last = last
self.freq = freq
self.chin = chin
self.empty = empty
self.stride = stride
self.kernel_size = kernel_size
self.norm = norm
self.context_freq = context_freq
klass = nn.Conv1d
klass_tr = nn.ConvTranspose1d
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
klass = nn.Conv2d
klass_tr = nn.ConvTranspose2d
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
self.norm2 = norm_fn(chout)
if self.empty:
self.rewrite = None
if rewrite:
if context_freq:
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
[0, context])
self.norm1 = norm_fn(2 * chin)
self.dconv = None
if dconv:
self.dconv = DConv(chin, **dconv_kw)
def forward(self, x, skip, length):
if self.freq and x.dim() == 3:
B, C, T = x.shape
x = x.view(B, self.chin, -1, T)
if not self.empty:
x = x + skip
if self.rewrite:
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
y = x
if self.dconv:
if self.freq:
B, C, Fr, T = y.shape
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
y = self.dconv(y)
if self.freq:
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
y = x
assert skip is None
z = self.norm2(self.conv_tr(y))
if self.freq:
if self.pad:
z = z[..., self.pad:-self.pad, :]
z = z[..., self.pad:self.pad + length]
assert z.shape[-1] == length, (z.shape[-1], length)
if not self.last:
z = F.gelu(z)
return z, y
class HDemucs(nn.Module):
Spectrogram and hybrid Demucs model.
The spectrogram model has the same structure as Demucs, except the first few layers are over the
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
Frequency layers can still access information across time steps thanks to the DConv residual.
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
as the frequency branch and then the two are combined. The opposite happens in the decoder.
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
Open Unmix implementation [Stoter et al. 2019].
The loss is always on the temporal domain, by backpropagating through the above
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
contribution, without changing the one from the waveform, which will lead to worse performance.
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
CaC on the other hand provides similar performance for hybrid, and works naturally with
hybrid models.
This model also uses frequency embeddings are used to improve efficiency on convolutions
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
def __init__(self,
# Channels
# Main structure
# Frequency branch
# Convolutions
# Normalization
# DConv residual branch
# Weight init
# Metadata
segment=4 * 10):
sources (list[str]): list of source names.
audio_channels (int): input/output audio channels.
channels (int): initial number of hidden channels.
channels_time: if not None, use a different `channels` value for the time branch.
growth: increase the number of hidden channels by this factor at each layer.
nfft: number of fft bins. Note that changing this require careful computation of
various shape parameters and will not work out of the box for hybrid models.
wiener_iters: when using Wiener filtering, number of iterations at test time.
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
wiener_residual: add residual source before wiener filtering.
cac: uses complex as channels, i.e. complex numbers are 2 channels each
in input and output. no further processing is done before ISTFT.
depth (int): number of layers in the encoder and in the decoder.
rewrite (bool): add 1x1 convolution to each layer.
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
hybrid_old: some models trained for MDX had a padding bug. This replicates
this bug to avoid retraining them.
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
layers will be wrapped.
freq_emb: add frequency embedding after the first frequency layer if > 0,
the actual value controls the weight of the embedding.
emb_scale: equivalent to scaling the embedding learning rate
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
kernel_size: kernel_size for encoder and decoder layers.
stride: stride for encoder and decoder layers.
time_stride: stride for the final time layer, after the merge.
context: context for 1x1 conv in the decoder.
context_enc: context for 1x1 conv in the encoder.
norm_starts: layer at which group norm starts being used.
decoder layers are numbered in reverse order.
norm_groups: number of groups for group norm.
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
dconv_depth: depth of residual DConv branch.
dconv_comp: compression of DConv branch.
dconv_attn: adds attention layers in DConv branch starting at this layer.
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
dconv_init: initial scale for the DConv branch LayerScale.
rescale: weight recaling trick
self.cac = cac
self.wiener_residual = wiener_residual
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.channels = channels
self.samplerate = samplerate
self.segment = segment
self.nfft = nfft
self.hop_length = nfft // 4
self.wiener_iters = wiener_iters
self.end_iters = end_iters
self.freq_emb = None
self.hybrid = hybrid
self.hybrid_old = hybrid_old
if hybrid_old:
assert hybrid, "hybrid_old must come with hybrid=True"
if hybrid:
assert wiener_iters == end_iters
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
if hybrid:
self.tencoder = nn.ModuleList()
self.tdecoder = nn.ModuleList()
chin = audio_channels
chin_z = chin # number of channels for the freq branch
if self.cac:
chin_z *= 2
chout = channels_time or channels
chout_z = channels
freqs = nfft // 2
for index in range(depth):
lstm = index >= dconv_lstm
attn = index >= dconv_attn
norm = index >= norm_starts
freq = freqs > 1
stri = stride
ker = kernel_size
if not freq:
assert freqs == 1
ker = time_stride * 2
stri = time_stride
pad = True
last_freq = False
if freq and freqs <= kernel_size:
ker = freqs
pad = False
last_freq = True
kw = {
'kernel_size': ker,
'stride': stri,
'freq': freq,
'pad': pad,
'norm': norm,
'rewrite': rewrite,
'norm_groups': norm_groups,
'dconv_kw': {
'lstm': lstm,
'attn': attn,
'depth': dconv_depth,
'compress': dconv_comp,
'init': dconv_init,
'gelu': True,
kwt = dict(kw)
kwt['freq'] = 0
kwt['kernel_size'] = kernel_size
kwt['stride'] = stride
kwt['pad'] = True
kw_dec = dict(kw)
multi = False
if multi_freqs and index < multi_freqs_depth:
multi = True
kw_dec['context_freq'] = False
if last_freq:
chout_z = max(chout, chout_z)
chout = chout_z
enc = HEncLayer(chin_z, chout_z,
dconv=dconv_mode & 1, context=context_enc, **kw)
if hybrid and freq:
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
empty=last_freq, **kwt)
if multi:
enc = MultiWrap(enc, multi_freqs)
if index == 0:
chin = self.audio_channels * len(self.sources)
chin_z = chin
if self.cac:
chin_z *= 2
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
last=index == 0, context=context, **kw_dec)
if multi:
dec = MultiWrap(dec, multi_freqs)
if hybrid and freq:
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
last=index == 0, context=context, **kwt)
self.tdecoder.insert(0, tdec)
self.decoder.insert(0, dec)
chin = chout
chin_z = chout_z
chout = int(growth * chout)
chout_z = int(growth * chout_z)
if freq:
if freqs <= kernel_size:
freqs = 1
freqs //= stride
if index == 0 and freq_emb:
self.freq_emb = ScaledEmbedding(
freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
self.freq_emb_scale = freq_emb
if rescale:
rescale_module(self, reference=rescale)
def _spec(self, x):
hl = self.hop_length
nfft = self.nfft
x0 = x # noqa
if self.hybrid:
# We re-pad the signal in order to keep the property
# that the size of the output is exactly the size of the input
# divided by the stride (here hop_length), when divisible.
# This is achieved by padding by 1/4th of the kernel size (here nfft).
# which is not supported by torch.stft.
# Having all convolution operations follow this convention allow to easily
# align the time and frequency branches later on.
assert hl == nfft // 4
le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3
if not self.hybrid_old:
x = F.pad(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
x = F.pad(x, (pad, pad + le * hl - x.shape[-1]))
z = spectro(x, nfft, hl)[..., :-1, :]
if self.hybrid:
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
z = z[..., 2:2+le]
return z
def _ispec(self, z, length=None, scale=0):
hl = self.hop_length // (4 ** scale)
z = F.pad(z, (0, 0, 0, 1))
if self.hybrid:
z = F.pad(z, (2, 2))
pad = hl // 2 * 3
if not self.hybrid_old:
le = hl * int(math.ceil(length / hl)) + 2 * pad
le = hl * int(math.ceil(length / hl))
x = ispectro(z, hl, length=le)
if not self.hybrid_old:
x = x[..., pad:pad + length]
x = x[..., :length]
x = ispectro(z, hl, length)
return x
def _magnitude(self, z):
# return the magnitude of the spectrogram, except when cac is True,
# in which case we just move the complex dimension to the channel one.
if self.cac:
B, C, Fr, T = z.shape
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
m = m.reshape(B, C * 2, Fr, T)
m = z.abs()
return m
def _mask(self, z, m):
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
niters = self.wiener_iters
if self.cac:
B, S, C, Fr, T = m.shape
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
out = torch.view_as_complex(out.contiguous())
return out
if self.training:
niters = self.end_iters
if niters < 0:
z = z[:, None]
return z / (1e-8 + z.abs()) * m
return self._wiener(m, z, niters)
def _wiener(self, mag_out, mix_stft, niters):
# apply wiener filtering from OpenUnmix.
init = mix_stft.dtype
wiener_win_len = 300
residual = self.wiener_residual
B, S, C, Fq, T = mag_out.shape
mag_out = mag_out.permute(0, 4, 3, 2, 1)
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
outs = []
for sample in range(B):
pos = 0
out = []
for pos in range(0, T, wiener_win_len):
frame = slice(pos, pos + wiener_win_len)
z_out = wiener(
mag_out[sample, frame], mix_stft[sample, frame], niters,
out.append(z_out.transpose(-1, -2))
outs.append(torch.cat(out, dim=0))
out = torch.view_as_complex(torch.stack(outs, 0))
out = out.permute(0, 4, 3, 2, 1).contiguous()
if residual:
out = out[:, :-1]
assert list(out.shape) == [B, S, C, Fq, T]
return out.to(init)
def forward(self, mix):
x = mix
length = x.shape[-1]
z = self._spec(mix)
mag = self._magnitude(z)
x = mag
B, C, Fq, T = x.shape
# unlike previous Demucs, we always normalize because it is easier.
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True)
x = (x - mean) / (1e-5 + std)
# x will be the freq. branch input.
if self.hybrid:
# Prepare the time branch input.
xt = mix
meant = xt.mean(dim=(1, 2), keepdim=True)
stdt = xt.std(dim=(1, 2), keepdim=True)
xt = (xt - meant) / (1e-5 + stdt)
# okay, this is a giant mess I know...
saved = [] # skip connections, freq.
saved_t = [] # skip connections, time.
lengths = [] # saved lengths to properly remove padding, freq branch.
lengths_t = [] # saved lengths for time branch.
for idx, encode in enumerate(self.encoder):
inject = None
if self.hybrid and idx < len(self.tencoder):
# we have not yet merged branches.
tenc = self.tencoder[idx]
xt = tenc(xt)
if not tenc.empty:
# save for skip connection
# tenc contains just the first conv., so that now time and freq.
# branches have the same shape and can be merged.
inject = xt
x = encode(x, inject)
if idx == 0 and self.freq_emb is not None:
# add frequency embedding to allow for non equivariant convolutions
# over the frequency axis.
frs = torch.arange(x.shape[-2], device=x.device)
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
x = x + self.freq_emb_scale * emb
x = torch.zeros_like(x)
if self.hybrid:
xt = torch.zeros_like(x)
# initialize everything to zero (signal will go through u-net skips).
for idx, decode in enumerate(self.decoder):
skip = saved.pop(-1)
x, pre = decode(x, skip, lengths.pop(-1))
# `pre` contains the output just before final transposed convolution,
# which is used when the freq. and time branch separate.
if self.hybrid:
offset = self.depth - len(self.tdecoder)
if self.hybrid and idx >= offset:
tdec = self.tdecoder[idx - offset]
length_t = lengths_t.pop(-1)
if tdec.empty:
assert pre.shape[2] == 1, pre.shape
pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t)
skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t)
# Let's make sure we used all stored skip connections.
assert len(saved) == 0
assert len(lengths_t) == 0
assert len(saved_t) == 0
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
zout = self._mask(z, x)
x = self._ispec(zout, length)
if self.hybrid:
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
x = xt + x
return x
Normal file
Normal file
@ -0,0 +1,66 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Loading pretrained models.
import logging
from pathlib import Path
import typing as tp
from dora.log import fatal
from .hdemucs import HDemucs
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
logger = logging.getLogger(__name__)
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
REMOTE_ROOT = Path(__file__).parent / 'remote'
SOURCES = ["drums", "bass", "other", "vocals"]
def demucs_unittest():
model = HDemucs(channels=4, sources=SOURCES)
return model
def add_model_flags(parser):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
group.add_argument("-n", "--name", default="mdx_extra_q",
help="Pretrained model name or signature. Default is mdx_extra_q.")
parser.add_argument("--repo", type=Path,
help="Folder containing all pre-trained models for use with -n.")
def get_model(name: str,
repo: tp.Optional[Path] = None):
"""`name` must be a bag of models name or a pretrained signature
from the remote AWS model repo or the specified local repo if `repo` is not None.
if name == 'demucs_unittest':
return demucs_unittest()
model_repo: ModelOnlyRepo
if repo is None:
remote_files = [line.strip()
for line in (REMOTE_ROOT / 'files.txt').read_text().split('\n')
if line.strip()]
model_repo = RemoteRepo(ROOT_URL, remote_files)
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
if not repo.is_dir():
fatal(f"{repo} must exist and be a directory.")
model_repo = LocalRepo(repo)
bag_repo = BagOnlyRepo(repo, model_repo)
any_repo = AnyModelRepo(model_repo, bag_repo)
return any_repo.get_model(name)
def get_model_from_args(args):
Load local model package or pre-trained model.
return get_model(name=args.name, repo=args.repo)
Normal file
Normal file
@ -0,0 +1,153 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Represents a model repository, including pre-trained models and bags of models.
A repo can either be the main remote repository stored in AWS, or a local repository
with your own models.
from hashlib import sha256
from pathlib import Path
import typing as tp
import torch
import yaml
from .apply import BagOfModels, Model
from .states import load_model
AnyModel = tp.Union[Model, BagOfModels]
class ModelLoadingError(RuntimeError):
def check_checksum(path: Path, checksum: str):
sha = sha256()
with open(path, 'rb') as file:
while True:
buf = file.read(2**20)
if not buf:
actual_checksum = sha.hexdigest()[:len(checksum)]
if actual_checksum != checksum:
raise ModelLoadingError(f'Invalid checksum for file {path}, '
f'expected {checksum} but got {actual_checksum}')
class ModelOnlyRepo:
"""Base class for all model only repos.
def has_model(self, sig: str) -> bool:
raise NotImplementedError()
def get_model(self, sig: str) -> Model:
raise NotImplementedError()
class RemoteRepo(ModelOnlyRepo):
def __init__(self, root_url: str, remote_files: tp.List[str]):
if not root_url.endswith('/'):
root_url += '/'
self._models: tp.Dict[str, str] = {}
for file in remote_files:
sig, checksum = file.split('.')[0].split('-')
assert sig not in self._models
self._models[sig] = root_url + file
def has_model(self, sig: str) -> bool:
return sig in self._models
def get_model(self, sig: str) -> Model:
url = self._models[sig]
except KeyError:
raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
return load_model(pkg)
class LocalRepo(ModelOnlyRepo):
def __init__(self, root: Path):
self.root = root
def scan(self):
self._models = {}
self._checksums = {}
for file in self.root.iterdir():
if file.suffix == '.th':
if '-' in file.stem:
xp_sig, checksum = file.stem.split('-')
self._checksums[xp_sig] = checksum
xp_sig = file.stem
if xp_sig in self._models:
raise ModelLoadingError(
f'Duplicate pre-trained model exist for signature {xp_sig}. '
'Please delete all but one.')
self._models[xp_sig] = file
def has_model(self, sig: str) -> bool:
return sig in self._models
def get_model(self, sig: str) -> Model:
file = self._models[sig]
except KeyError:
raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
if sig in self._checksums:
check_checksum(file, self._checksums[sig])
return load_model(file)
class BagOnlyRepo:
"""Handles only YAML files containing bag of models, leaving the actual
model loading to some Repo.
def __init__(self, root: Path, model_repo: ModelOnlyRepo):
self.root = root
self.model_repo = model_repo
def scan(self):
self._bags = {}
for file in self.root.iterdir():
if file.suffix == '.yaml':
self._bags[file.stem] = file
def has_model(self, name: str) -> bool:
return name in self._bags
def get_model(self, name: str) -> BagOfModels:
yaml_file = self._bags[name]
except KeyError:
raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
'a bag of models.')
bag = yaml.safe_load(open(yaml_file))
signatures = bag['models']
models = [self.model_repo.get_model(sig) for sig in signatures]
weights = bag.get('weights')
segment = bag.get('segment')
return BagOfModels(models, weights, segment)
class AnyModelRepo:
def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
self.model_repo = model_repo
self.bag_repo = bag_repo
def has_model(self, name_or_sig: str) -> bool:
return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
def get_model(self, name_or_sig: str) -> AnyModel:
if self.model_repo.has_model(name_or_sig):
return self.model_repo.get_model(name_or_sig)
return self.bag_repo.get_model(name_or_sig)
Normal file
Normal file
@ -0,0 +1,186 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import sys
from pathlib import Path
import subprocess
from dora.log import fatal
import torch as th
import torchaudio as ta
from .apply import apply_model, BagOfModels
from .audio import AudioFile, convert_audio, save_audio
from .pretrained import get_model_from_args, add_model_flags, ModelLoadingError
def load_track(track, audio_channels, samplerate):
errors = {}
wav = None
wav = AudioFile(track).read(
except FileNotFoundError:
errors['ffmpeg'] = 'Ffmpeg is not installed.'
except subprocess.CalledProcessError:
errors['ffmpeg'] = 'FFmpeg could not read the file.'
if wav is None:
wav, sr = ta.load(str(track))
except RuntimeError as err:
errors['torchaudio'] = err.args[0]
wav = convert_audio(wav, sr, samplerate, audio_channels)
if wav is None:
print(f"Could not load file {track}. "
"Maybe it is not a supported file format? ")
for backend, error in errors.items():
print(f"When trying to load using {backend}, got the following error: {error}")
return wav
def main():
parser = argparse.ArgumentParser("demucs.separate",
description="Separate the sources for the given tracks")
parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
parser.add_argument("-v", "--verbose", action="store_true")
help="Folder where to put extracted tracks. A subfolder "
"with the model name will be created.")
default="cuda" if th.cuda.is_available() else "cpu",
help="Device to use, default is cuda if available else cpu")
help="Number of random shifts for equivariant stabilization."
"Increase separation time but improves quality for Demucs. 10 was used "
"in the original paper.")
help="Overlap between the splits.")
split_group = parser.add_mutually_exclusive_group()
help="Doesn't split audio in chunks. "
"This can use large amounts of memory.")
split_group.add_argument("--segment", type=int,
help="Set split size of each chunk. "
"This can help save memory of graphic card. ")
dest="stem", metavar="STEM",
help="Only separate audio into {STEM} and no_{STEM}. ")
group = parser.add_mutually_exclusive_group()
group.add_argument("--int24", action="store_true",
help="Save wav output as 24 bits wav.")
group.add_argument("--float32", action="store_true",
help="Save wav output as float32 (2x bigger).")
parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp"],
help="Strategy for avoiding clipping: rescaling entire signal "
"if necessary (rescale) or hard clipping (clamp).")
parser.add_argument("--mp3", action="store_true",
help="Convert the output wavs to mp3.")
help="Bitrate of converted mp3.")
parser.add_argument("-j", "--jobs",
help="Number of jobs. This can increase memory usage but will "
"be much faster when multiple cores are available.")
args = parser.parse_args()
model = get_model_from_args(args)
except ModelLoadingError as error:
if args.segment is not None and args.segment < 8:
fatal('Segment must greater than 8. ')
if isinstance(model, BagOfModels):
if args.segment is not None:
for sub in model.models:
sub.segment = args.segment
if args.segment is not None:
sub.segment = args.segment
if args.stem is not None and args.stem not in model.sources:
'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format(
stem=args.stem, sources=', '.join(model.sources)))
out = args.out / args.name
out.mkdir(parents=True, exist_ok=True)
print(f"Separated tracks will be stored in {out.resolve()}")
for track in args.tracks:
if not track.exists():
f"File {track} does not exist. If the path contains spaces, "
"please try again after surrounding the entire path with quotes \"\".",
print(f"Separating track {track}")
wav = load_track(track, model.audio_channels, model.samplerate)
ref = wav.mean(0)
wav = (wav - ref.mean()) / ref.std()
sources = apply_model(model, wav[None], device=args.device, shifts=args.shifts,
split=args.split, overlap=args.overlap, progress=True,
sources = sources * ref.std() + ref.mean()
track_folder = out / track.name.rsplit(".", 1)[0]
if args.mp3:
ext = ".mp3"
ext = ".wav"
kwargs = {
'samplerate': model.samplerate,
'bitrate': args.mp3_bitrate,
'clip': args.clip_mode,
'as_float': args.float32,
'bits_per_sample': 24 if args.int24 else 16,
if args.stem is None:
for source, name in zip(sources, model.sources):
stem = str(track_folder / (name + ext))
save_audio(source, stem, **kwargs)
sources = list(sources)
stem = str(track_folder / (args.stem + ext))
save_audio(sources.pop(model.sources.index(args.stem)), stem, **kwargs)
# Warning : after poping the stem, selected stem is no longer in the list 'sources'
other_stem = th.zeros_like(sources[0])
for i in sources:
other_stem += i
stem = str(track_folder / ("no_" + args.stem + ext))
save_audio(other_stem, stem, **kwargs)
if __name__ == "__main__":
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Main training loop."""
import logging
from dora import get_xp
from dora.utils import write_and_rename
from dora.log import LogProgress, bold
import torch
import torch.nn.functional as F
from . import augment, distrib, states, pretrained
from .apply import apply_model
from .ema import ModelEMA
from .evaluate import evaluate, new_sdr
from .svd import svd_penalty
from .utils import pull_metric, EMA
logger = logging.getLogger(__name__)
def _summary(metrics):
return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items())
class Solver(object):
def __init__(self, loaders, model, optimizer, args):
self.args = args
self.loaders = loaders
self.model = model
self.optimizer = optimizer
self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer)
self.dmodel = distrib.wrap(model)
self.device = next(iter(self.model.parameters())).device
# Exponential moving average of the model, either updated every batch or epoch.
# The best model from all the EMAs and the original one is kept based on the valid
# loss for the final best model.
self.emas = {'batch': [], 'epoch': []}
for kind in self.emas.keys():
decays = getattr(args.ema, kind)
device = self.device if kind == 'batch' else 'cpu'
if decays:
for decay in decays:
self.emas[kind].append(ModelEMA(self.model, decay, device=device))
# data augment
augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift),
if args.augment.flip:
augments += [augment.FlipChannels(), augment.FlipSign()]
for aug in ['scale', 'remix']:
kw = getattr(args.augment, aug)
if kw.proba:
augments.append(getattr(augment, aug.capitalize())(**kw))
self.augment = torch.nn.Sequential(*augments)
xp = get_xp()
self.folder = xp.folder
# Checkpoints
self.checkpoint_file = xp.folder / 'checkpoint.th'
self.best_file = xp.folder / 'best.th'
logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve())
self.best_state = None
self.best_changed = False
self.link = xp.link
self.history = self.link.history
def _serialize(self, epoch):
package = {}
package['state'] = self.model.state_dict()
package['optimizer'] = self.optimizer.state_dict()
package['history'] = self.history
package['best_state'] = self.best_state
package['args'] = self.args
for kind, emas in self.emas.items():
for k, ema in enumerate(emas):
package[f'ema_{kind}_{k}'] = ema.state_dict()
with write_and_rename(self.checkpoint_file) as tmp:
torch.save(package, tmp)
save_every = self.args.save_every
if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs:
with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp:
torch.save(package, tmp)
if self.best_changed:
# Saving only the latest best model.
with write_and_rename(self.best_file) as tmp:
package = states.serialize_model(self.model, self.args)
package['state'] = self.best_state
torch.save(package, tmp)
self.best_changed = False
def _reset(self):
"""Reset state of the solver, potentially using checkpoint."""
if self.checkpoint_file.exists():
logger.info(f'Loading checkpoint model: {self.checkpoint_file}')
package = torch.load(self.checkpoint_file, 'cpu')
self.history[:] = package['history']
self.best_state = package['best_state']
for kind, emas in self.emas.items():
for k, ema in enumerate(emas):
elif self.args.continue_pretrained:
model = pretrained.get_model(
elif self.args.continue_from:
name = 'checkpoint.th'
root = self.folder.parent
cf = root / str(self.args.continue_from) / name
logger.info("Loading from %s", cf)
package = torch.load(cf, 'cpu')
self.best_state = package['best_state']
if self.args.continue_best:
self.model.load_state_dict(package['best_state'], strict=False)
self.model.load_state_dict(package['state'], strict=False)
if self.args.continue_opt:
def _format_train(self, metrics: dict) -> dict:
"""Formatting for train/valid metrics."""
losses = {
'loss': format(metrics['loss'], ".4f"),
'reco': format(metrics['reco'], ".4f"),
if 'nsdr' in metrics:
losses['nsdr'] = format(metrics['nsdr'], ".3f")
if self.quantizer is not None:
losses['ms'] = format(metrics['ms'], ".2f")
if 'grad' in metrics:
losses['grad'] = format(metrics['grad'], ".4f")
if 'best' in metrics:
losses['best'] = format(metrics['best'], '.4f')
if 'bname' in metrics:
losses['bname'] = metrics['bname']
if 'penalty' in metrics:
losses['penalty'] = format(metrics['penalty'], ".4f")
if 'hloss' in metrics:
losses['hloss'] = format(metrics['hloss'], ".4f")
return losses
def _format_test(self, metrics: dict) -> dict:
"""Formatting for test metrics."""
losses = {}
if 'sdr' in metrics:
losses['sdr'] = format(metrics['sdr'], '.3f')
if 'nsdr' in metrics:
losses['nsdr'] = format(metrics['nsdr'], '.3f')
for source in self.model.sources:
key = f'sdr_{source}'
if key in metrics:
losses[key] = format(metrics[key], '.3f')
key = f'nsdr_{source}'
if key in metrics:
losses[key] = format(metrics[key], '.3f')
return losses
def train(self):
# Optimizing the model
if self.history:
logger.info("Replaying metrics from previous run")
for epoch, metrics in enumerate(self.history):
formatted = self._format_train(metrics['train'])
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
formatted = self._format_train(metrics['valid'])
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
if 'test' in metrics:
formatted = self._format_test(metrics['test'])
if formatted:
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
epoch = 0
for epoch in range(len(self.history), self.args.epochs):
# Train one epoch
self.model.train() # Turn on BatchNorm & Dropout
metrics = {}
logger.info('-' * 70)
metrics['train'] = self._run_one_epoch(epoch)
formatted = self._format_train(metrics['train'])
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
# Cross validation
logger.info('-' * 70)
logger.info('Cross validation...')
self.model.eval() # Turn off Batchnorm & Dropout
with torch.no_grad():
valid = self._run_one_epoch(epoch, train=False)
bvalid = valid
bname = 'main'
state = states.copy_state(self.model.state_dict())
metrics['valid'] = {}
metrics['valid']['main'] = valid
key = self.args.test.metric
for kind, emas in self.emas.items():
for k, ema in enumerate(emas):
with ema.swap():
valid = self._run_one_epoch(epoch, train=False)
name = f'ema_{kind}_{k}'
metrics['valid'][name] = valid
a = valid[key]
b = bvalid[key]
if key.startswith('nsdr'):
a = -a
b = -b
if a < b:
bvalid = valid
state = ema.state
bname = name
metrics['valid']['bname'] = bname
valid_loss = metrics['valid'][key]
mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss]
if key.startswith('nsdr'):
best_loss = max(mets)
best_loss = min(mets)
metrics['valid']['best'] = best_loss
if self.args.svd.penalty > 0:
kw = dict(self.args.svd)
with torch.no_grad():
penalty = svd_penalty(self.model, exact=True, **kw)
metrics['valid']['penalty'] = penalty
formatted = self._format_train(metrics['valid'])
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
# Save the best model
if valid_loss == best_loss or self.args.dset.train_valid:
logger.info(bold('New best valid loss %.4f'), valid_loss)
self.best_state = states.copy_state(state)
self.best_changed = True
# Eval model every `test.every` epoch or on last epoch
should_eval = (epoch + 1) % self.args.test.every == 0
is_last = epoch == self.args.epochs - 1
reco = metrics['valid']['main']['reco']
# Tries to detect divergence in a reliable way and finish job
# not to waste compute.
div = epoch >= 180 and reco > 0.18
div = div or epoch >= 100 and reco > 0.25
div = div and self.args.optim.loss == 'l1'
if div:
logger.warning("Finishing training early because valid loss is too high.")
is_last = True
if should_eval or is_last:
# Evaluate on the testset
logger.info('-' * 70)
logger.info('Evaluating on the test set...')
# We switch to the best known model for testing
if self.args.test.best:
state = self.best_state
state = states.copy_state(self.model.state_dict())
compute_sdr = self.args.test.sdr and is_last
with states.swap_state(self.model, state):
with torch.no_grad():
metrics['test'] = evaluate(self, compute_sdr=compute_sdr)
formatted = self._format_test(metrics['test'])
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
if distrib.rank == 0:
# Save model each epoch
logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())
if is_last:
def _run_one_epoch(self, epoch, train=True):
args = self.args
data_loader = self.loaders['train'] if train else self.loaders['valid']
# get a different order for distributed training, otherwise this will get ignored
data_loader.sampler.epoch = epoch
label = ["Valid", "Train"][train]
name = label + f" | Epoch {epoch + 1}"
total = len(data_loader)
if args.max_batches:
total = min(total, args.max_batches)
logprog = LogProgress(logger, data_loader, total=total,
updates=self.args.misc.num_prints, name=name)
averager = EMA()
for idx, sources in enumerate(logprog):
sources = sources.to(self.device)
if train:
sources = self.augment(sources)
mix = sources.sum(dim=1)
mix = sources[:, 0]
sources = sources[:, 1:]
if not train and self.args.valid_apply:
estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0)
estimate = self.dmodel(mix)
if train and hasattr(self.model, 'transform_target'):
sources = self.model.transform_target(mix, sources)
assert estimate.shape == sources.shape, (estimate.shape, sources.shape)
dims = tuple(range(2, sources.dim()))
if args.optim.loss == 'l1':
loss = F.l1_loss(estimate, sources, reduction='none')
loss = loss.mean(dims).mean(0)
reco = loss
elif args.optim.loss == 'mse':
loss = F.mse_loss(estimate, sources, reduction='none')
loss = loss.mean(dims)
reco = loss**0.5
reco = reco.mean(0)
raise ValueError(f"Invalid loss {self.args.loss}")
weights = torch.tensor(args.weights).to(sources)
loss = (loss * weights).sum() / weights.sum()
ms = 0
if self.quantizer is not None:
ms = self.quantizer.model_size()
if args.quant.diffq:
loss += args.quant.diffq * ms
losses = {}
losses['reco'] = (reco * weights).sum() / weights.sum()
losses['ms'] = ms
if not train:
nsdrs = new_sdr(sources, estimate.detach()).mean(0)
total = 0
for source, nsdr, w in zip(self.model.sources, nsdrs, weights):
losses[f'nsdr_{source}'] = nsdr
total += w * nsdr
losses['nsdr'] = total / weights.sum()
if train and args.svd.penalty > 0:
kw = dict(args.svd)
penalty = svd_penalty(self.model, **kw)
losses['penalty'] = penalty
loss += args.svd.penalty * penalty
losses['loss'] = loss
for k, source in enumerate(self.model.sources):
losses[f'reco_{source}'] = reco[k]
# optimize model in training mode
if train:
grad_norm = 0
grads = []
for p in self.model.parameters():
if p.grad is not None:
grad_norm += p.grad.data.norm()**2
losses['grad'] = grad_norm ** 0.5
if args.optim.clip_grad:
if self.args.flag == 'uns':
for n, p in self.model.named_parameters():
if p.grad is None:
print('no grad', n)
for ema in self.emas['batch']:
losses = averager(losses)
logs = self._format_train(losses)
# Just in case, clear some memory
del loss, estimate, reco, ms
if args.max_batches == idx:
if self.args.debug and train:
if self.args.flag == 'debug':
if train:
for ema in self.emas['epoch']:
return distrib.average(losses, idx + 1)
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Conveniance wrapper to perform STFT and iSTFT"""
import torch as th
def spectro(x, n_fft=512, hop_length=None, pad=0):
*other, length = x.shape
x = x.reshape(-1, length)
z = th.stft(x,
n_fft * (1 + pad),
hop_length or n_fft // 4,
_, freqs, frame = z.shape
return z.view(*other, freqs, frame)
def ispectro(z, hop_length=None, length=None, pad=0):
*other, freqs, frames = z.shape
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
x = th.istft(z,
_, length = x.shape
return x.view(*other, length)
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Utilities to save and load models.
from contextlib import contextmanager
import functools
import hashlib
import inspect
import io
from pathlib import Path
import warnings
from omegaconf import OmegaConf
from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
import torch
def get_quantizer(model, args, optimizer=None):
"""Return the quantizer given the XP quantization args."""
quantizer = None
if args.diffq:
quantizer = DiffQuantizer(
model, min_size=args.min_size, group_size=args.group_size)
if optimizer is not None:
elif args.qat:
quantizer = UniformQuantizer(
model, bits=args.qat, min_size=args.min_size)
return quantizer
def load_model(path_or_package, strict=False):
"""Load a model from the given serialized model, either given as a dict (already loaded)
or a path to a file on disk."""
if isinstance(path_or_package, dict):
package = path_or_package
elif isinstance(path_or_package, (str, Path)):
with warnings.catch_warnings():
path = path_or_package
package = torch.load(path, 'cpu')
raise ValueError(f"Invalid type for {path_or_package}.")
klass = package["klass"]
args = package["args"]
kwargs = package["kwargs"]
if strict:
model = klass(*args, **kwargs)
sig = inspect.signature(klass)
for key in list(kwargs):
if key not in sig.parameters:
warnings.warn("Dropping inexistant parameter " + key)
del kwargs[key]
model = klass(*args, **kwargs)
state = package["state"]
set_state(model, state)
return model
def get_state(model, quantizer, half=False):
"""Get the state from a model, potentially with quantization applied.
If `half` is True, model are stored as half precision, which shouldn't impact performance
but half the state size."""
if quantizer is None:
dtype = torch.half if half else None
state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
state = quantizer.get_quantized_state()
state['__quantized'] = True
return state
def set_state(model, state, quantizer=None):
"""Set the state on a given model."""
if state.get('__quantized'):
if quantizer is not None:
quantizer.restore_quantized_state(model, state['quantized'])
restore_quantized_state(model, state)
return state
def save_with_checksum(content, path):
"""Save the given value on disk, along with a sha256 hash.
Should be used with the output of either `serialize_model` or `get_state`."""
buf = io.BytesIO()
torch.save(content, buf)
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
path = path.parent / (path.stem + "-" + sig + path.suffix)
def serialize_model(model, training_args, quantizer=None, half=True):
args, kwargs = model._init_args_kwargs
klass = model.__class__
state = get_state(model, quantizer, half)
return {
'klass': klass,
'args': args,
'kwargs': kwargs,
'state': state,
'training_args': OmegaConf.to_container(training_args, resolve=True),
def copy_state(state):
return {k: v.cpu().clone() for k, v in state.items()}
def swap_state(model, state):
Context manager that swaps the state of a model, e.g:
# model is in old state
with swap_state(model, new_state):
# model in new state
# model back to old state
old_state = copy_state(model.state_dict())
model.load_state_dict(state, strict=False)
def capture_init(init):
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Ways to make the model stronger."""
import random
import torch
def power_iteration(m, niters=1, bs=1):
"""This is the power method. batch size is used to try multiple starting point in parallel."""
assert m.dim() == 2
assert m.shape[0] == m.shape[1]
dim = m.shape[0]
b = torch.randn(dim, bs, device=m.device, dtype=m.dtype)
for _ in range(niters):
n = m.mm(b)
norm = n.norm(dim=0, keepdim=True)
b = n / (1e-10 + norm)
return norm.mean()
# We need a shared RNG to make sure all the distributed worker will skip the penalty together,
# as otherwise we wouldn't get any speed up.
penalty_rng = random.Random(1234)
def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True,
proba=1, conv_only=False, exact=False, bs=1):
Penalty on the largest singular value for a layer.
- model: model to penalize
- min_size: minimum size in MB of a layer to penalize.
- dim: projection dimension for the svd_lowrank. Higher is better but slower.
- niters: number of iterations in the algorithm used by svd_lowrank.
- powm: use power method instead of lowrank SVD, my own experience
is that it is both slower and less stable.
- convtr: when True, differentiate between Conv and Transposed Conv.
this is kept for compatibility with older experiments.
- proba: probability to apply the penalty.
- conv_only: only apply to conv and conv transposed, not LSTM
(might not be reliable for other models than Demucs).
- exact: use exact SVD (slow but useful at validation).
- bs: batch_size for power method.
total = 0
if penalty_rng.random() > proba:
return 0.
for m in model.modules():
for name, p in m.named_parameters(recurse=False):
if p.numel() / 2**18 < min_size:
if convtr:
if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)):
if p.dim() in [3, 4]:
p = p.transpose(0, 1).contiguous()
if p.dim() == 3:
p = p.view(len(p), -1)
elif p.dim() == 4:
p = p.view(len(p), -1)
elif p.dim() == 1:
elif conv_only:
assert p.dim() == 2, (name, p.shape)
if exact:
estimate = torch.svd(p, compute_uv=False)[1].pow(2).max()
elif powm:
a, b = p.shape
if a < b:
n = p.mm(p.t())
n = p.t().mm(p)
estimate = power_iteration(n, niters, bs)
estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2)
total += estimate
return total / proba
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from contextlib import contextmanager
import math
import os
import tempfile
import typing as tp
import torch
from torch.nn import functional as F
def unfold(a, kernel_size, stride):
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
*shape, length = a.shape
n_frames = math.ceil(length / stride)
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(a, (0, tgt_length - length))
strides = list(a.stride())
assert strides[-1] == 1, 'data should be contiguous'
strides = strides[:-1] + [stride, 1]
return a.as_strided([*shape, n_frames, kernel_size], strides)
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
Center trim `tensor` with respect to `reference`, along the last dimension.
`reference` can also be a number, representing the length to trim to.
If the size difference != 0 mod 2, the extra sample is removed on the right side.
ref_size: int
if isinstance(reference, torch.Tensor):
ref_size = reference.size(-1)
ref_size = reference
delta = tensor.size(-1) - ref_size
if delta < 0:
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
if delta:
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
return tensor
def pull_metric(history: tp.List[dict], name: str):
out = []
for metrics in history:
metric = metrics
for part in name.split("."):
metric = metric[part]
return out
def EMA(beta: float = 1):
Exponential Moving Average callback.
Returns a single function that can be called to repeatidly update the EMA
with a dict of metrics. The callback will return
the new averaged dict of metrics.
Note that for `beta=1`, this is just plain averaging.
fix: tp.Dict[str, float] = defaultdict(float)
total: tp.Dict[str, float] = defaultdict(float)
def _update(metrics: dict, weight: float = 1) -> dict:
nonlocal total, fix
for key, value in metrics.items():
total[key] = total[key] * beta + weight * float(value)
fix[key] = fix[key] * beta + weight
return {key: tot / fix[key] for key, tot in total.items()}
return _update
def sizeof_fmt(num: float, suffix: str = 'B'):
Given `num` bytes, return human readable size.
Taken from https://stackoverflow.com/a/1094933
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
if abs(num) < 1024.0:
return "%3.1f%s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f%s%s" % (num, 'Yi', suffix)
def temp_filenames(count: int, delete=True):
names = []
for _ in range(count):
yield names
if delete:
for name in names:
class DummyPoolExecutor:
class DummyResult:
def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
def result(self):
return self.func(*self.args, **self.kwargs)
def __init__(self, workers=0):
def submit(self, func, *args, **kwargs):
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Loading wav based datasets, including MusdbHQ."""
from collections import OrderedDict
import hashlib
import math
import json
import os
from pathlib import Path
import tqdm
import musdb
import julius
import torch as th
from torch import distributed
import torchaudio as ta
from torch.nn import functional as F
from .audio import convert_audio_channels
from . import distrib
MIXTURE = "mixture"
EXT = ".wav"
def _track_metadata(track, sources, normalize=True, ext=EXT):
track_length = None
track_samplerate = None
mean = 0
std = 1
for source in sources + [MIXTURE]:
file = track / f"{source}{ext}"
info = ta.info(str(file))
except RuntimeError:
length = info.num_frames
if track_length is None:
track_length = length
track_samplerate = info.sample_rate
elif track_length != length:
raise ValueError(
f"Invalid length for file {file}: "
f"expecting {track_length} but got {length}.")
elif info.sample_rate != track_samplerate:
raise ValueError(
f"Invalid sample rate for file {file}: "
f"expecting {track_samplerate} but got {info.sample_rate}.")
if source == MIXTURE and normalize:
wav, _ = ta.load(str(file))
except RuntimeError:
wav = wav.mean(0)
mean = wav.mean().item()
std = wav.std().item()
return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}
def build_metadata(path, sources, normalize=True, ext=EXT):
Build the metadata for `Wavset`.
path (str or Path): path to dataset.
sources (list[str]): list of sources to look for.
normalize (bool): if True, loads full track and store normalization
values based on the mixture file.
ext (str): extension of audio files (default is .wav).
meta = {}
path = Path(path)
pendings = []
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(8) as pool:
for root, folders, files in os.walk(path, followlinks=True):
root = Path(root)
if root.name.startswith('.') or folders or root == path:
name = str(root.relative_to(path))
pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext)))
# meta[name] = _track_metadata(root, sources, normalize, ext)
for name, pending in tqdm.tqdm(pendings, ncols=120):
meta[name] = pending.result()
return meta
class Wavset:
def __init__(
root, metadata, sources,
segment=None, shift=None, normalize=True,
samplerate=44100, channels=2, ext=EXT):
Waveset (or mp3 set for that matter). Can be used to train
with arbitrary sources. Each track should be one folder inside of `path`.
The folder should contain files named `{source}.{ext}`.
root (Path or str): root folder for the dataset.
metadata (dict): output from `build_metadata`.
sources (list[str]): list of source names.
segment (None or float): segment length in seconds. If `None`, returns entire tracks.
shift (None or float): stride in seconds bewteen samples.
normalize (bool): normalizes input audio, **based on the metadata content**,
i.e. the entire track is normalized, not individual extracts.
samplerate (int): target sample rate. if the file sample rate
is different, it will be resampled on the fly.
channels (int): target nb of channels. if different, will be
changed onthe fly.
ext (str): extension for audio files (default is .wav).
samplerate and channels are converted on the fly.
self.root = Path(root)
self.metadata = OrderedDict(metadata)
self.segment = segment
self.shift = shift or segment
self.normalize = normalize
self.sources = sources
self.channels = channels
self.samplerate = samplerate
self.ext = ext
self.num_examples = []
for name, meta in self.metadata.items():
track_duration = meta['length'] / meta['samplerate']
if segment is None or track_duration < segment:
examples = 1
examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1)
def __len__(self):
return sum(self.num_examples)
def get_file(self, name, source):
return self.root / name / f"{source}{self.ext}"
def __getitem__(self, index):
for name, examples in zip(self.metadata, self.num_examples):
if index >= examples:
index -= examples
meta = self.metadata[name]
num_frames = -1
offset = 0
if self.segment is not None:
offset = int(meta['samplerate'] * self.shift * index)
num_frames = int(math.ceil(meta['samplerate'] * self.segment))
wavs = []
for source in self.sources:
file = self.get_file(name, source)
wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
wav = convert_audio_channels(wav, self.channels)
example = th.stack(wavs)
example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
if self.normalize:
example = (example - meta['mean']) / meta['std']
if self.segment:
length = int(self.segment * self.samplerate)
example = example[..., :length]
example = F.pad(example, (0, length - example.shape[-1]))
return example
def get_wav_datasets(args):
"""Extract the wav datasets from the XP arguments."""
sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8]
metadata_file = Path(args.metadata) / ('wav_' + sig + ".json")
train_path = Path(args.wav) / "train"
valid_path = Path(args.wav) / "valid"
if not metadata_file.is_file() and distrib.rank == 0:
metadata_file.parent.mkdir(exist_ok=True, parents=True)
train = build_metadata(train_path, args.sources)
valid = build_metadata(valid_path, args.sources)
json.dump([train, valid], open(metadata_file, "w"))
if distrib.world_size > 1:
train, valid = json.load(open(metadata_file))
if args.full_cv:
kw_cv = {}
kw_cv = {'segment': args.segment, 'shift': args.shift}
train_set = Wavset(train_path, train, args.sources,
segment=args.segment, shift=args.shift,
samplerate=args.samplerate, channels=args.channels,
valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources),
samplerate=args.samplerate, channels=args.channels,
normalize=args.normalize, **kw_cv)
return train_set, valid_set
def _get_musdb_valid():
# Return musdb valid set.
import yaml
setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml'
setup = yaml.safe_load(open(setup_path, 'r'))
return setup['validation_tracks']
def get_musdb_wav_datasets(args):
"""Extract the musdb dataset from the XP arguments."""
sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8]
metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json")
root = Path(args.musdb) / "train"
if not metadata_file.is_file() and distrib.rank == 0:
metadata_file.parent.mkdir(exist_ok=True, parents=True)
metadata = build_metadata(root, args.sources)
json.dump(metadata, open(metadata_file, "w"))
if distrib.world_size > 1:
metadata = json.load(open(metadata_file))
valid_tracks = _get_musdb_valid()
if args.train_valid:
metadata_train = metadata
metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks}
metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks}
if args.full_cv:
kw_cv = {}
kw_cv = {'segment': args.segment, 'shift': args.shift}
train_set = Wavset(root, metadata_train, args.sources,
segment=args.segment, shift=args.shift,
samplerate=args.samplerate, channels=args.channels,
valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources),
samplerate=args.samplerate, channels=args.channels,
normalize=args.normalize, **kw_cv)
return train_set, valid_set
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# For compat
from .hdemucs import HDemucs
WDemucs = HDemucs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
This package implements different quantization strategies:
- `diffq.uniform.UniformQuantizer`: classic uniform quantization over n bits.
- `diffq.diffq.DiffQuantizer`: differentiable quantizer based on scaled noise injection.
- `diffq.lsq.LSQ`: Learnt Step size Quantizer based on [Esser et al. 2019] https://arxiv.org/abs/1902.08153
- `diffq.bitpack`: efficient CPU bit-packing for returning quantized states.
- `diffq.torch_pack`: torch based bit-packing compatible with torchscript.
- `diffq.ts_export`: support exporting DiffQ based models to torchscript.
Also, do check `diffq.base.BaseQuantizer` for the common methods of all Quantizers.
from .uniform import UniformQuantizer
from .diffq import DiffQuantizer
from .lsq import LSQ
from .base import restore_quantized_state
from . import ts_export
__version__ = "0.2.2"
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Base class for all quantizers."""
from contextlib import contextmanager
from dataclasses import dataclass
from concurrent import futures
from fnmatch import fnmatch
from functools import partial
import io
import math
from multiprocessing import cpu_count
import pickle
import typing as tp
import zlib
import torch
from . import bitpack
from . import torch_pack as torch_pack_mod
class BaseQuantizer:
class _QuantizedParam:
name: str
param: torch.nn.Parameter
module: torch.nn.Module
# If a Parameter is used multiple times, `other` can be used
# to share state between the different Quantizers
other: tp.Optional[tp.Any]
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True):
self.model = model
self.min_size = min_size
self.float16 = float16
self.exclude = exclude
self.detect_bound = detect_bound
self._quantized = False
self._need_unquantize = None
self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook)
self._post_handle = self.model.register_forward_hook(self._forward_hook)
self._qparams = []
self._float16 = []
self._others = []
self._rnns = []
self._saved = []
def _find_params(self):
min_params = self.min_size * 2**20 // 4
previous = {}
for module_name, module in self.model.named_modules():
if isinstance(module, torch.nn.RNNBase):
for name, param in list(module.named_parameters(recurse=False)):
full_name = f"{module_name}.{name}"
matched = False
for pattern in self.exclude:
if fnmatch(full_name, pattern) or fnmatch(name, pattern):
matched = True
if param.numel() <= min_params or matched:
if id(param) in previous:
if self.detect_bound:
previous[id(param)] = None
if self.float16:
qparam = self._register_param(name, param, module, previous.get(id(param)))
if self.detect_bound:
previous[id(param)] = qparam
def _register_param(self, name, param, module, other):
return self.__class__._QuantizedParam(name, param, module, other)
def _forward_pre_hook(self, module, input):
if self.model.training:
self._quantized_state = None
if self._pre_forward_train():
assert self._need_unquantize is None
self._need_unquantize = self.quantize()
def _forward_hook(self, module, input, output):
if self.model.training:
if self._post_forward_train():
self._fix_rnns(flatten=False) # Hacky, next forward will flatten
if self._need_unquantize:
self._need_unquantize = None
def quantize(self):
Immediately apply quantization to the model parameters.
Model parameters are saved to later allow restoring the unquantized state.
Note that you shouldn't need to call this for model evaluation, as long as
you properly call `model.train()` and `model.eval()`, but this can be
useful for weight inspection.
if self._quantized:
return False
self._saved = [qp.param.data.to('cpu', copy=True)
for qp in self._qparams if qp.other is None]
self._quantized = True
return True
def enter_quantize(self):
"""Context manager for entering quantized state."""
def unquantize(self):
Revert a previous call to `quantize()`.
if not self._quantized:
if not self._saved:
raise RuntimeError("Nothing to restore. This shouldn't happen")
for qparam in self._qparams:
if qparam.other is None:
qparam.param.data[:] = self._saved.pop(0)
assert len(self._saved) == 0
self._quantized = False
def _pre_forward_train(self) -> bool:
Called once before each forward for continuous quantization.
Should return True if parameters were changed.
return False
def _post_forward_train(self) -> bool:
Called once after each forward (to restore state for instance).
Should return True if parameters were changed.
return False
def _fix_rnns(self, flatten=True):
To be called after quantization happened to fix RNNs.
for rnn in self._rnns:
rnn._flat_weights = [
(lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn)
for wn in rnn._flat_weights_names]
if flatten:
def _bit_pack_param(self, qparam: _QuantizedParam, quantized: tp.Any,
pack_fn: tp.Any) -> tp.Any:
"""Further bitpack the quantized representation.
This is used to return the quantized state. Should be overriden.
return quantized
def _bit_unpack_param(self, qparam: _QuantizedParam, packed: tp.Any,
unpack_fn: tp.Any) -> tp.Any:
"""Unpack bitpacked representation. Should be overriden
return packed
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
To be overriden.
raise NotImplementedError()
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
To be overriden.
raise NotImplementedError()
def get_quantized_state(self, packed=True, torch_pack=False):
Return a quantized representation fo the weights. If `packed` is True,
this will also perform bitpacking to ensure optimal store.
If `torck_pack` is true, the bitpacking from `torch_pack` will be used.
It is slower (except maybe on GPU), but is compatible with torchscript.
You can restore a model from a quantized state either using
`BaseQuantizer.restore_quantized_state` or `diffq.restore_quantized_state`
if you do not have the original quantizer around anymore.
float16_params = []
for p in self._float16:
q = p.data.half()
if torch_pack:
pack_fn = torch_pack_mod.pack
pack_fn = bitpack.pack
all_quantized = []
for qparam in self._qparams:
if qparam.other is not None:
quantized = self._quantize_param(qparam)
if packed:
quantized = self._bit_pack_param(qparam, quantized, pack_fn=pack_fn)
state = {
"quantized": all_quantized,
"float16": float16_params,
"others": [p.data.clone() for p in self._others],
kwargs = dict(self._init_kwargs)
state["meta"] = {
"init_kwargs": kwargs,
"klass": self.__class__,
"packed": packed,
"torch_pack": torch_pack
return state
def restore_quantized_state(self, state) -> None:
Restore the state of the model from the quantized state.
for p, q in zip(self._float16, state["float16"]):
p.data[:] = q.to(p)
for p, q in zip(self._others, state["others"]):
p.data[:] = q
meta = state.get("meta", {})
packed = meta.get("packed", False)
torch_pack = meta.get("torch_pack", False)
if torch_pack:
unpack_fn = torch_pack_mod.unpack
unpack_fn = bitpack.unpack
remaining = list(state["quantized"])
for qparam in self._qparams:
if qparam.other is not None:
# Only unquantize first appearance of nn.Parameter.
quantized = remaining.pop(0)
if packed:
quantized = self._bit_unpack_param(qparam, quantized, unpack_fn)
qparam.param.data[:] = self._unquantize_param(qparam, quantized)
assert not remaining
def detach(self) -> None:
Detach from the model, removes hooks and anything else.
def model_size(self) -> torch.Tensor:
Returns an estimate of the quantized model size.
total = torch.tensor(0.)
for p in self._float16:
total += 16 * p.numel()
for p in self._others:
total += 32 * p.numel()
return total / 2**20 / 8 # bits to MegaBytes
def true_model_size(self) -> float:
Return the true quantized model size, in MB, without extra
return self.model_size().item()
def packed_model_size(self) -> float:
"""Return the packed model size, when stored with pickle.
This should be mostly equivalent to `true_model_size` up to some
slight overhead for storing metadata.
state = self.get_quantized_state(packed=True)
return len(pickle.dumps(state)) / 2 ** 20
def compressed_model_size(self, compress_level=-1, num_workers=8) -> float:
Return the compressed quantized model size, in MB.
compress_level (int): compression level used with zlib,
see `zlib.compress` for details.
num_workers (int): will split the final big byte representation in that
many chunks processed in parallels.
out = io.BytesIO()
torch.save(self.get_quantized_state(packed=False), out)
ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers)
return ms / 2 ** 20
def restore_quantized_state(model: torch.nn.Module, state: dict):
assert "meta" in state
quantizer = state["meta"]["klass"](model, **state["meta"]["init_kwargs"])
def _compress_len(data, compress_level):
return len(zlib.compress(data, level=compress_level))
def _parallel_compress_len(data, compress_level, num_workers):
num_workers = min(cpu_count(), num_workers)
chunk_size = int(math.ceil(len(data) / num_workers))
chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)]
with futures.ThreadPoolExecutor(num_workers) as pool:
# thread pool is okay here, zlib calls an external C lib and GIL is released
# before the call.
return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks))
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Differentiable quantizer based on scaled noise injection.
from dataclasses import dataclass
import math
import typing as tp
import torch
from .base import BaseQuantizer
from .uniform import uniform_quantize, uniform_unquantize
from .utils import capture_init, simple_repr
class DiffQuantizer(BaseQuantizer):
class _QuantizedParam(BaseQuantizer._QuantizedParam):
logit: torch.nn.Parameter
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
group_size: int = 1, min_bits: float = 2, max_bits: float = 15,
param="bits", noise="gaussian",
init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq",
exclude: tp.List[str] = [], detect_bound: bool = True):
Differentiable quantizer based on scaled noise injection.
For every parameter `p` in the model, this introduces a number of bits parameter
`b` with the same dimensions (when group_size = 1).
Before each forward, `p` is replaced by `p + U`
with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization
step for `b` bits.
This noise approximates the quantization noise in a differentiable manner, both
with respect to the unquantized parameter `p` and the number of bits `b`.
At eveluation (as detected with `model.eval()`), the model is replaced
by its true quantized version, and restored when going back to training.
When doing actual quantization (for serialization, or evaluation),
the number of bits is rounded to the nearest integer, and needs to be stored along.
This will cost a few bits per dimension. To reduce this cost, one can use `group_size`,
which will use a single noise level for multiple weight entries.
You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the
model size in MB. You can then use this estimate as a penalty in your training loss.
model (torch.nn.Module): model to quantize
min_size (float): minimum size in MB of a parameter to be quantized.
float16 (bool): if a layer is smaller than min_size, should we still do float16?
group_size (int): weight entries are groupped together to reduce the number
of noise scales to store. This should divide the size of all parameters
bigger than min_size.
min_bits (float): minimal number of bits.
max_bits (float): maximal number of bits.
init_bits (float): initial number of bits.
extra_bits (float): extra bits to add for actual quantization (before roundoff).
suffix (str): suffix used for the name of the extra noise scale parameters.
exclude (list[str]): list of patterns used to match parameters to exclude.
For instance `['bias']` to exclude all bias terms.
detect_bound (bool): if True, will detect bound parameters and reuse
the same quantized tensor for both, as well as the same number of bits.
You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly.
self.group_size = group_size
self.min_bits = min_bits
self.max_bits = max_bits
self.init_bits = init_bits
self.extra_bits = extra_bits
self.suffix = suffix
self.param = param
self.noise = noise
assert noise in ["gaussian", "uniform"]
self._optimizer_setup = False
self._min_noise = 1 / (2 ** self.max_bits - 1)
self._max_noise = 1 / (2 ** self.min_bits - 1)
assert group_size >= 0
assert min_bits < init_bits < max_bits, \
"init_bits must be between min_bits and max_bits excluded3"
for name, _ in model.named_parameters():
if name.endswith(suffix):
raise RuntimeError("The model already has some noise scales parameters, "
"maybe you used twice a DiffQuantizer on the same model?.")
super().__init__(model, min_size, float16, exclude, detect_bound)
def _get_bits(self, logit: torch.Tensor):
if self.param == "noise":
return torch.log2(1 + 1 / self._get_noise_scale(logit))
t = torch.sigmoid(logit)
return self.max_bits * t + (1 - t) * self.min_bits
def _get_noise_scale(self, logit: torch.Tensor):
if self.param == "noise":
t = torch.sigmoid(logit)
return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise))
return 1 / (2 ** self._get_bits(logit) - 1)
def _register_param(self, name, param, module, other):
if other is not None:
return self.__class__._QuantizedParam(
name=name, param=param, module=module, logit=other.logit, other=other)
assert self.group_size == 0 or param.numel() % self.group_size == 0
# we want the initial number of bits to be init_bits.
if self.param == "noise":
noise_scale = 1 / (2 ** self.init_bits - 1)
t = (math.log(noise_scale) - math.log(self._max_noise)) / (
math.log(self._min_noise) - math.log(self._max_noise))
t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
assert 0 < t < 1
logit = torch.logit(torch.tensor(float(t)))
assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
if self.group_size > 0:
nparam = param.numel() // self.group_size
nparam = 1
logit = torch.nn.Parameter(
module.register_parameter(name + self.suffix, logit)
return self.__class__._QuantizedParam(
name=name, param=param, module=module, logit=logit, other=None)
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
params = [qp.logit for qp in self._qparams]
for group in optimizer.param_groups:
new_params = []
for q in list(group["params"]):
matched = False
for p in params:
if p is q:
matched = True
if not matched:
group["params"][:] = new_params
def setup_optimizer(self, optimizer: torch.optim.Optimizer,
lr: float = 1e-3, **kwargs):
Setup the optimizer to tune the number of bits. In particular, this will deactivate
weight decay for the bits parameters.
optimizer (torch.Optimizer): optimizer to use.
lr (float): specific learning rate for the bits parameters. 1e-3
is perfect for Adam.,w
kwargs (dict): overrides for other optimization parameters for the bits.
assert not self._optimizer_setup
self._optimizer_setup = True
params = [qp.logit for qp in self._qparams]
for group in optimizer.param_groups:
for q in list(group["params"]):
for p in params:
if p is q:
raise RuntimeError("You should create the optimizer "
"before the quantizer!")
group = {"params": params, "lr": lr, "weight_decay": 0}
def no_optimizer(self):
Call this if you do not want to use an optimizer.
self._optimizer_setup = True
def check_unused(self):
for qparam in self._qparams:
if qparam.other is not None:
grad = qparam.param.grad
if grad is None or (grad == 0).all():
if qparam.logit.grad is not None:
def model_size(self, exact=False):
Differentiable estimate of the model size.
The size is returned in MB.
If `exact` is True, then the output is no longer differentiable but
reflect exactly an achievable size, even without compression,
i.e.same as returned by `naive_model_size()`.
total = super().model_size()
subtotal = 0
for qparam in self._qparams:
# only count the first appearance of a Parameter
if qparam.other is not None:
bits = self.extra_bits + self._get_bits(qparam.logit)
if exact:
bits = bits.round().clamp(1, 15)
if self.group_size == 0:
group_size = qparam.param.numel()
group_size = self.group_size
subtotal += group_size * bits.sum()
subtotal += 2 * 32 # param scale
# Number of bits to represent each number of bits
bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits)))
subtotal += 8 # 8 bits for bits_bits
subtotal += bits_bits * bits.numel()
subtotal /= 2 ** 20 * 8 # bits -> MegaBytes
return total + subtotal
def true_model_size(self):
Naive model size without zlib compression.
return self.model_size(exact=True).item()
def _pre_forward_train(self):
if not self._optimizer_setup:
raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
"before starting training.")
for qparam in self._qparams:
if qparam.other is not None:
noisy = qparam.other.module._parameters[qparam.other.name]
bits = self._get_bits(qparam.logit)[:, None]
if self.group_size == 0:
p_flat = qparam.param.view(-1)
p_flat = qparam.param.view(-1, self.group_size)
scale = p_flat.max() - p_flat.min()
unit = 1 / (2**bits - 1)
if self.noise == "uniform":
noise_source = (torch.rand_like(p_flat) - 0.5)
elif self.noise == "gaussian":
noise_source = torch.randn_like(p_flat) / 2
noise = scale * unit * noise_source
noisy = p_flat + noise
# We bypass the checks by PyTorch on parameters being leafs
qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param)
return True
def _post_forward_train(self):
for qparam in self._qparams:
qparam.module._parameters[qparam.name] = qparam.param
return True
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
bits = self.extra_bits + self._get_bits(qparam.logit)
bits = bits.round().clamp(1, 15)[:, None].byte()
if self.group_size == 0:
p = qparam.param.data.view(1, -1)
p = qparam.param.data.view(-1, self.group_size)
levels, scales = uniform_quantize(p, bits)
return levels, scales, bits[:, 0]
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
levels, param_scale, bits = quantized
bits = bits[:, None]
return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data)
def _bit_pack_param(self, qparam, quantized, pack_fn):
levels, scales, bits = quantized
all_packed = []
for bit in range(1, 15):
sub_levels = levels[bits == bit]
if not sub_levels.numel():
packed = pack_fn(sub_levels, bit)
packed_bits = pack_fn(bits - self.min_bits)
return (all_packed, scales, packed_bits)
def _bit_unpack_param(self, qparam, packed, unpack_fn):
"""Unpack bitpacked representation. Should be overriden.
packed_all_levels, scales, packed_bits = packed
bits = unpack_fn(packed_bits, qparam.logit.numel()) + self.min_bits
bits = bits.to(qparam.param.device)
levels = torch.empty(qparam.logit.numel(), self.group_size,
dtype=torch.short, device=qparam.param.device)
for idx, packed_levels in enumerate(packed_all_levels):
bit = idx + 1
if packed_levels is None:
sub_levels = levels[bits == bit]
levels[bits == bit] = unpack_fn(
packed_levels, sub_levels.numel()).view_as(sub_levels).to(sub_levels)
return (levels, scales, bits)
def detach(self):
for qparam in self._qparams:
delattr(qparam.module, qparam.name + self.suffix)
def __repr__(self):
return simple_repr(self)
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Learnt-Stepsize quantizer from [Esser et al. 2019] https://arxiv.org/abs/1902.08153.
from dataclasses import dataclass
import typing as tp
import torch
from .base import BaseQuantizer
from .utils import capture_init, simple_repr
class LSQ(BaseQuantizer):
"""Implements weight only quantization based on [Esser et al. 2019].
class _QuantizedParam(BaseQuantizer._QuantizedParam):
scale: torch.nn.Parameter
def __init__(self, model: torch.nn.Module, bits: int = 8, min_size: float = 0.01,
float16: bool = False, suffix: str = "_lsq", exclude=[], detect_bound=True):
assert 0 < bits <= 15
self.suffix = suffix
self._optimizer_setup = False
self.bits = bits
for name, _ in model.named_parameters():
if name.endswith(suffix):
raise RuntimeError("The model already has some noise scales parameters, "
"maybe you used twice a LSQ on the same model?.")
super().__init__(model, min_size, float16, exclude, detect_bound)
def _register_param(self, name, param, module, other):
if other is not None:
return self.__class__._QuantizedParam(
name=name, param=param, module=module, scale=other.scale, other=other)
# we want the initial number of bits to be init_bits.
scale = 2 * param.data.abs().mean() / (2 ** (self.bits - 1))**0.5
scale = torch.nn.Parameter(scale)
module.register_parameter(name + self.suffix, scale)
return self.__class__._QuantizedParam(
name=name, param=param, module=module, scale=scale, other=None)
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
params = [qp.scale for qp in self._qparams]
for group in optimizer.param_groups:
new_params = []
for q in list(group["params"]):
matched = False
for p in params:
if p is q:
matched = True
if not matched:
group["params"][:] = new_params
def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs):
Setup the optimizer to tune the scale parameter.
Following [Esser et al. 2019], we use the same LR and weight decay
as the base optimizer, unless specified otherwise.
optimizer (torch.Optimizer): optimizer to use.
kwargs (dict): overrides for optimization parameters
assert not self._optimizer_setup
self._optimizer_setup = True
params = [qp.scale for qp in self._qparams]
for group in optimizer.param_groups:
for q in list(group["params"]):
for p in params:
if p is q:
raise RuntimeError("You should create the optimizer "
"before the quantizer!")
group = {"params": params}
def no_optimizer(self):
Call this if you do not want to use an optimizer.
self._optimizer_setup = True
def model_size(self, exact=False):
Differentiable estimate of the model size.
The size is returned in MB.
If `exact` is True, then the output is no longer differentiable but
reflect exactly an achievable size, even without compression,
i.e.same as returned by `naive_model_size()`.
total = super().model_size()
subtotal = 0
for qparam in self._qparams:
# only count the first appearance of a Parameter
if qparam.other is not None:
bits = qparam.param.numel() * self.bits
subtotal += bits
subtotal += 1 * 32 # param scale
subtotal /= 2 ** 20 * 8 # bits -> MegaBytes
return total + subtotal
def true_model_size(self):
Naive model size without zlib compression.
return self.model_size(exact=True).item()
def _pre_forward_train(self):
if not self._optimizer_setup:
raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
"before starting training.")
for qparam in self._qparams:
scale = qparam.scale
quant, _ = quantize(qparam.param, scale, self.bits)
# We bypass the checks by PyTorch on parameters being leafs
qparam.module._parameters[qparam.name] = quant
return True
def _post_forward_train(self):
for qparam in self._qparams:
qparam.module._parameters[qparam.name] = qparam.param
return True
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
_, index = quantize(qparam.param, qparam.scale, self.bits)
assert (index <= (2 ** (self.bits - 1) - 1)).all(), index.max()
assert (index >= (-2 ** (self.bits - 1))).all(), index.min()
return index.detach().short(), qparam.scale.detach()
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
index, scale = quantized
return index.float() * scale
def _bit_pack_param(self, qparam, quantized, pack_fn):
levels, scale = quantized
packed = pack_fn(levels + 2 ** (self.bits - 1))
return (packed, scale)
def _bit_unpack_param(self, qparam, packed, unpack_fn):
"""Unpack bitpacked representation. Should be overriden
packed_levels, scale = packed
levels = unpack_fn(
packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
levels -= 2 ** (self.bits - 1)
return (levels, scale)
def detach(self):
for qparam in self._qparams:
delattr(qparam.module, qparam.name + self.suffix)
def __repr__(self):
return simple_repr(self)
def roundpass(x):
return (x.round() - x).detach() + x
def gradscale(x, scale):
return (x - x * scale).detach() + x * scale
def quantize(tensor, scale, bits):
low = - 2 ** (bits - 1)
high = 2 ** (bits - 1) - 1
scale = gradscale(scale, 1 / (tensor.numel() * high)**0.5)
index = tensor / scale
index = index.clamp(low, high)
index = roundpass(index)
return index * scale, index
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Bit packing in pure PyTorch.
Slower than bitpack.pyx but compatible with torchscript.
import math
import typing as tp
import torch
from torch.nn import functional as F
def as_rectangle(p: torch.Tensor, side: int):
"""Reshape as rectangle, using padding when necessary so that out shape is [*, side]"""
p_flat = p.view(-1)
ideal_length = int(math.ceil(len(p_flat) / side) * side)
p_flat_pad = F.pad(p_flat, (0, ideal_length - len(p_flat)))
return p_flat_pad.view(side, -1)
def _storage_size(dtype: torch.dtype):
if dtype == torch.int64:
return 64
elif dtype == torch.int32:
return 32
elif dtype == torch.int16:
return 16
elif dtype == torch.uint8:
return 8
raise ValueError("Invalid bitpacking storage type")
def pack(indexes, nbits: int = 0, storage_dtype: torch.dtype = torch.int16):
"""You can think of indexes as a "Tensor" of bits of shape [L, nbits].
Instead of concatenating naively as [L * nbits], we instead look at it transposed as
[nbits, L]. For L = 16 * G, we get [nbits, G, 16] which is trivial to store
efficiently on int16 integers.
There will be overhead if L is far from a multiple of 16 (e.g. 1) but for large
model layers this is acceptable. Storage type can be changed.
`nbits` should be the number of bits on which the indexes are coded, and will
actually be determined automatically if set to 0.
assert not indexes.dtype.is_floating_point
if indexes.numel() > 0:
assert indexes.max().item() < 2 ** 15
assert indexes.min().item() >= 0
if nbits == 0:
nbits = int(math.ceil(math.log2(1 + (indexes.max()))))
assert indexes.max().item() < 2 ** nbits
indexes = indexes.reshape(-1)
storage_size = _storage_size(storage_dtype)
rect = as_rectangle(indexes, storage_size)
out = torch.zeros(nbits, rect.shape[1], dtype=storage_dtype, device=indexes.device)
for in_bit in range(nbits):
for out_bit in range(storage_size):
d = ((rect[out_bit] >> in_bit) & 1).to(out.dtype) << out_bit
out[in_bit, :] |= d
return out
def unpack(packed: torch.Tensor, length: tp.Optional[int] = None):
"""Opposite of `pack`. You might need to specify the original length."""
storage_size = _storage_size(packed.dtype)
nbits, groups = packed.shape
out = torch.zeros(storage_size, groups, dtype=torch.int16, device=packed.device)
for in_bit in range(storage_size):
for out_bit in range(nbits):
bit_value = (packed[out_bit, :] >> in_bit) & 1
out[in_bit, :] = out[in_bit, :] | (bit_value.to(out) << out_bit)
out = out.view(-1)
if length is not None:
out = out[:length]
return out
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""TorchScript export support.
We have to do a lot of black magic for TorchScript to be happy
because we cannot dynamically allocate new weights when loading the model.
Here is how it works:
- we generate code in a temporary python file for the given model that explicitely
override all the weights on the first forward from their packed version.
This is because TorchScript does not let us iterate over parameters in a generic manner.
- we zero out all the original weights. We cannot simply remove those weights
because TorchScript won't let us recreate them.
- A TorchScript file is just a zip file, but stored without compression.
In order to remove the cost of storing the zeroed out weights, we unzip the file,
and zip it again with compression.
import importlib
import os
from pathlib import Path
import random
import sys
import typing as tp
import tempfile
import zipfile
import torch
from torch import jit
from .diffq import DiffQuantizer
from .uniform import uniform_unquantize
from .torch_pack import unpack
_DiffQPacked = tp.Tuple[
tp.List[tp.Optional[torch.Tensor]], tp.Tuple[float, float],
torch.Tensor, tp.List[int]]
# This is the template for the generated class.
import typing as tp
import torch
from torch import jit
from diffq.ts_export import _unpack_param, _DiffQPacked
from {module} import {klass}
class DiffQTSModel(torch.nn.Module):
def __init__(self, model: {klass}, group_size: int, min_bits: int,
packed: tp.List[_DiffQPacked]):
self.group_size = group_size
self.min_bits = min_bits
self.model = model
self._unpacked = False
self._packed = packed
def unpack(self):
Unpack the weights, automatically called on the first forward,
or explicitely."""
if self._unpacked:
self._unpacked = True
def forward(self, x: torch.Tensor):
return self.model.forward(x)
# those are the assignments for each quantized weight.
UNPACK_ASSIGN = (' ' * 8) + ('self.model{full_name}.data[:] = '
'_unpack_param(self._packed[{index}], '
'group_size=self.group_size, min_bits=self.min_bits)')
UNPACK_ASSIGN_SAME = (' ' * 8) + 'self.model{full_name} = self.model{other_name}'
def export(quantizer: DiffQuantizer, path: tp.Union[str, Path]):
"""Export the given quantized model to the given path.
We must save the quantized model ourselves, as we need to recompress
the zip archive afterwards.
packed: tp.List[_DiffQPacked] = []
uniq_name = ''.join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(12)])
with tempfile.TemporaryDirectory() as tmpdir:
sys.path.insert(0, tmpdir)
code = _codegen(quantizer)
with open(Path(tmpdir) / f'{uniq_name}.py', 'w') as f:
module = importlib.import_module(uniq_name)
ts_klass = module.DiffQTSModel
state = quantizer.get_quantized_state(packed=True, torch_pack=True)
quantized = state["quantized"]
for qparam in quantizer._qparams:
if qparam.other is None:
levels, scales, bits = quantized.pop(0)
size = qparam.param.size()
packed.append((levels, scales, bits, list(size)))
ts_premodel = ts_klass(quantizer.model, quantizer.group_size,
quantizer.min_bits, packed)
ts_model = jit.script(ts_premodel)
if path is not None:
jit.save(ts_model, path)
return ts_model
def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor:
"""Function called from TorchScript on the first forward to decode the
packed weights to FP32.
packed_all_levels, scales, packed_bits, shape = packed
numel = 1
for dim in shape:
numel *= dim
bits = unpack(packed_bits, numel // group_size) + min_bits
levels = torch.empty(bits.numel(), group_size, dtype=torch.short)
for idx, packed_levels in enumerate(packed_all_levels):
bit = idx + 1
if packed_levels is not None:
sub_levels = levels[bits == bit]
levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels)
bits = bits[:, None]
unquant = uniform_unquantize(levels, scales, bits)
if len(shape) == 4:
return unquant.view(shape[0], shape[1], shape[2], shape[3])
elif len(shape) == 3:
return unquant.view(shape[0], shape[1], shape[2])
elif len(shape) == 2:
return unquant.view(shape[0], shape[1])
elif len(shape) == 1:
return unquant.view(shape[0])
raise RuntimeError("Invalid numbr of dim")
def recompress(path: tp.Union[str, Path]):
"""After having saved the torchscript file, this will recompress it
to make sure all the zeroed out parameters don't actually take any space.
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(path) as zipin:
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED,
compresslevel=1) as zipout:
for root, folders, files in os.walk(tmpdir):
for file in files:
fp = Path(root) / file
name = fp.relative_to(tmpdir)
zipout.write(fp, name)
def _get_full_name_access(full_name):
# When generating code, we need to handle attributes vs. indexing.
parts = []
for part in full_name.split("."):
index = int(part)
except ValueError:
parts.append("." + part)
return "".join(parts)
def _codegen(quantizer: DiffQuantizer):
# Generates the code for the given quantizer
module = quantizer.model.__class__.__module__
klass = quantizer.model.__class__.__name__
model = quantizer.model
assert not quantizer.float16
names = {}
for mod_name, mod in model.named_modules():
names[mod] = mod_name
unpack_assigns = []
index = 0
for qparam in quantizer._qparams:
mod_name = names[qparam.module]
if mod_name == '':
full_name = qparam.name
full_name = mod_name + '.' + qparam.name
full_name = _get_full_name_access(full_name)
if qparam.other is None:
unpack_assigns.append(UNPACK_ASSIGN.format(full_name=full_name, index=index))
index += 1
other_name = names[(qparam.other.module, qparam.other.name)]
other_name = _get_full_name_access(other_name)
UNPACK_ASSIGN_SAME.format(full_name=full_name, other_name=other_name))
return TEMPLATE.format(
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Classic uniform quantization over n bits.
from typing import Tuple
import torch
from .base import BaseQuantizer
from .utils import capture_init, simple_repr
def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
Quantize the given weights over `bits` bits.
- quantized levels
- (min, max) range.
assert (bits >= 1).all() and (bits <= 15).all()
num_levels = (2 ** bits.float()).long()
mn = p.min().item()
mx = p.max().item()
p = (p - mn) / (mx - mn) # put p in [0, 1]
unit = 1 / (num_levels - 1) # quantization unit
levels = (p / unit).round()
if (bits <= 8).all():
levels = levels.byte()
levels = levels.short()
return levels, (mn, mx)
def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float],
bits: torch.Tensor = torch.tensor(8.)):
Unquantize the weights from the levels and scale. Return a float32 tensor.
mn, mx = scales
num_levels = 2 ** bits.float()
unit = 1 / (num_levels - 1)
levels = levels.float()
p = levels * unit # in [0, 1]
return p * (mx - mn) + mn
class UniformQuantizer(BaseQuantizer):
def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01,
float16: bool = False, qat: bool = False, exclude=[], detect_bound=True):
model (torch.nn.Module): model to quantize
bits (float): number of bits to quantize over.
min_size (float): minimum size in MB of a parameter to be quantized.
float16 (bool): if a layer is smaller than min_size, should we still do float16?
qat (bool): perform quantized aware training.
exclude (list[str]): list of patterns used to match parameters to exclude.
For instance `['bias']` to exclude all bias terms.
detect_bound (bool): if True, will detect bound parameters and reuse
the same quantized tensor for both.
self.bits = float(bits)
self.qat = qat
super().__init__(model, min_size, float16, exclude, detect_bound)
def __repr__(self):
return simple_repr(self, )
def _pre_forward_train(self):
if self.qat:
for qparam in self._qparams:
if qparam.other is not None:
new_param = qparam.other.module._parameters[qparam.other.name]
quantized = self._quantize_param(qparam)
qvalue = self._unquantize_param(qparam, quantized)
new_param = qparam.param + (qvalue - qparam.param).detach()
qparam.module._parameters[qparam.name] = new_param
return True
return False
def _post_forward_train(self):
if self.qat:
for qparam in self._qparams:
qparam.module._parameters[qparam.name] = qparam.param
return True
return False
def _quantize_param(self, qparam):
levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits))
return (levels, scales)
def _unquantize_param(self, qparam, quantized):
levels, scales = quantized
return uniform_unquantize(levels, scales, torch.tensor(self.bits))
def _bit_pack_param(self, qparam, quantized, pack_fn):
levels, scales = quantized
packed = pack_fn(levels, self.bits)
return (packed, scales)
def _bit_unpack_param(self, qparam, packed, unpack_fn):
"""Unpack bitpacked representation. Should be overriden
packed_levels, scales = packed
levels = unpack_fn(
packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
return (levels, scales)
def model_size(self):
Non differentiable model size in MB.
total = super().model_size()
subtotal = 0
for qparam in self._qparams:
if qparam.other is None: # if parameter is bound, count only one copy.
subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales
subtotal /= 2**20 * 8 # bits to MegaBytes
return total + subtotal
def true_model_size(self):
Return the true quantized model size, in MB, without extra
return self.model_size().item()
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import functools
import inspect
from typing import Optional, List
def simple_repr(obj, attrs: Optional[List[str]] = None, overrides={}):
Return a simple representation string for `obj`.
If `attrs` is not None, it should be a list of attributes to include.
params = inspect.signature(obj.__class__).parameters
attrs_repr = []
if attrs is None:
attrs = params.keys()
for attr in attrs:
display = False
if attr in overrides:
value = overrides[attr]
elif hasattr(obj, attr):
value = getattr(obj, attr)
if attr in params:
param = params[attr]
if param.default is inspect._empty or value != param.default:
display = True
display = True
if display:
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
def capture_init(init):
Decorate `__init__` with this, and you can then
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
signature = inspect.signature(init)
def __init__(self, *args, **kwargs):
bound = signature.bind(self, *args, **kwargs)
actual_kwargs = dict(bound.arguments)
del actual_kwargs['self']
self._init_kwargs = actual_kwargs
init(self, *args, **kwargs)
return __init__
