mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-28 01:10:56 +01:00
Delete demucs directory
This commit is contained in:
parent
b5ed278fe4
commit
6a37c04cfb
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
237
demucs/apply.py
237
demucs/apply.py
@ -1,237 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""
|
|
||||||
Code to apply a model to a mix. It will handle chunking with overlaps and
|
|
||||||
inteprolation between chunks, as well as the "shift trick".
|
|
||||||
"""
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import random
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
import torch as th
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from .demucs import Demucs
|
|
||||||
from .hdemucs import HDemucs
|
|
||||||
from .utils import center_trim, DummyPoolExecutor
|
|
||||||
|
|
||||||
Model = tp.Union[Demucs, HDemucs]
|
|
||||||
|
|
||||||
|
|
||||||
class BagOfModels(nn.Module):
|
|
||||||
def __init__(self, models: tp.List[Model],
|
|
||||||
weights: tp.Optional[tp.List[tp.List[float]]] = None,
|
|
||||||
segment: tp.Optional[float] = None):
|
|
||||||
"""
|
|
||||||
Represents a bag of models with specific weights.
|
|
||||||
You should call `apply_model` rather than calling directly the forward here for
|
|
||||||
optimal performance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
models (list[nn.Module]): list of Demucs/HDemucs models.
|
|
||||||
weights (list[list[float]]): list of weights. If None, assumed to
|
|
||||||
be all ones, otherwise it should be a list of N list (N number of models),
|
|
||||||
each containing S floats (S number of sources).
|
|
||||||
segment (None or float): overrides the `segment` attribute of each model
|
|
||||||
(this is performed inplace, be careful is you reuse the models passed).
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
assert len(models) > 0
|
|
||||||
first = models[0]
|
|
||||||
for other in models:
|
|
||||||
assert other.sources == first.sources
|
|
||||||
assert other.samplerate == first.samplerate
|
|
||||||
assert other.audio_channels == first.audio_channels
|
|
||||||
if segment is not None:
|
|
||||||
other.segment = segment
|
|
||||||
|
|
||||||
self.audio_channels = first.audio_channels
|
|
||||||
self.samplerate = first.samplerate
|
|
||||||
self.sources = first.sources
|
|
||||||
self.models = nn.ModuleList(models)
|
|
||||||
|
|
||||||
if weights is None:
|
|
||||||
weights = [[1. for _ in first.sources] for _ in models]
|
|
||||||
else:
|
|
||||||
assert len(weights) == len(models)
|
|
||||||
for weight in weights:
|
|
||||||
assert len(weight) == len(first.sources)
|
|
||||||
self.weights = weights
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
raise NotImplementedError("Call `apply_model` on this.")
|
|
||||||
|
|
||||||
|
|
||||||
class TensorChunk:
|
|
||||||
def __init__(self, tensor, offset=0, length=None):
|
|
||||||
total_length = tensor.shape[-1]
|
|
||||||
assert offset >= 0
|
|
||||||
assert offset < total_length
|
|
||||||
|
|
||||||
if length is None:
|
|
||||||
length = total_length - offset
|
|
||||||
else:
|
|
||||||
length = min(total_length - offset, length)
|
|
||||||
|
|
||||||
self.tensor = tensor
|
|
||||||
self.offset = offset
|
|
||||||
self.length = length
|
|
||||||
self.device = tensor.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
shape = list(self.tensor.shape)
|
|
||||||
shape[-1] = self.length
|
|
||||||
return shape
|
|
||||||
|
|
||||||
def padded(self, target_length):
|
|
||||||
delta = target_length - self.length
|
|
||||||
total_length = self.tensor.shape[-1]
|
|
||||||
assert delta >= 0
|
|
||||||
|
|
||||||
start = self.offset - delta // 2
|
|
||||||
end = start + target_length
|
|
||||||
|
|
||||||
correct_start = max(0, start)
|
|
||||||
correct_end = min(total_length, end)
|
|
||||||
|
|
||||||
pad_left = correct_start - start
|
|
||||||
pad_right = end - correct_end
|
|
||||||
|
|
||||||
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
|
||||||
assert out.shape[-1] == target_length
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_chunk(tensor_or_chunk):
|
|
||||||
if isinstance(tensor_or_chunk, TensorChunk):
|
|
||||||
return tensor_or_chunk
|
|
||||||
else:
|
|
||||||
assert isinstance(tensor_or_chunk, th.Tensor)
|
|
||||||
return TensorChunk(tensor_or_chunk)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_model(model, mix, shifts=1, split=True,
|
|
||||||
overlap=0.25, transition_power=1., progress=False, device=None,
|
|
||||||
num_workers=0, pool=None):
|
|
||||||
"""
|
|
||||||
Apply model to a given mixture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
|
||||||
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
|
||||||
all predictions are averaged. This effectively makes the model time equivariant
|
|
||||||
and improves SDR by up to 0.2 points.
|
|
||||||
split (bool): if True, the input will be broken down in 8 seconds extracts
|
|
||||||
and predictions will be performed individually on each and concatenated.
|
|
||||||
Useful for model with large memory footprint like Tasnet.
|
|
||||||
progress (bool): if True, show a progress bar (requires split=True)
|
|
||||||
device (torch.device, str, or None): if provided, device on which to
|
|
||||||
execute the computation, otherwise `mix.device` is assumed.
|
|
||||||
When `device` is different from `mix.device`, only local computations will
|
|
||||||
be on `device`, while the entire tracks will be stored on `mix.device`.
|
|
||||||
"""
|
|
||||||
if device is None:
|
|
||||||
device = mix.device
|
|
||||||
else:
|
|
||||||
device = th.device(device)
|
|
||||||
if pool is None:
|
|
||||||
if num_workers > 0 and device.type == 'cpu':
|
|
||||||
pool = ThreadPoolExecutor(num_workers)
|
|
||||||
else:
|
|
||||||
pool = DummyPoolExecutor()
|
|
||||||
kwargs = {
|
|
||||||
'shifts': shifts,
|
|
||||||
'split': split,
|
|
||||||
'overlap': overlap,
|
|
||||||
'transition_power': transition_power,
|
|
||||||
'progress': progress,
|
|
||||||
'device': device,
|
|
||||||
'pool': pool,
|
|
||||||
}
|
|
||||||
if isinstance(model, BagOfModels):
|
|
||||||
# Special treatment for bag of model.
|
|
||||||
# We explicitely apply multiple times `apply_model` so that the random shifts
|
|
||||||
# are different for each model.
|
|
||||||
estimates = 0
|
|
||||||
totals = [0] * len(model.sources)
|
|
||||||
for sub_model, weight in zip(model.models, model.weights):
|
|
||||||
original_model_device = next(iter(sub_model.parameters())).device
|
|
||||||
sub_model.to(device)
|
|
||||||
|
|
||||||
out = apply_model(sub_model, mix, **kwargs)
|
|
||||||
sub_model.to(original_model_device)
|
|
||||||
for k, inst_weight in enumerate(weight):
|
|
||||||
out[:, k, :, :] *= inst_weight
|
|
||||||
totals[k] += inst_weight
|
|
||||||
estimates += out
|
|
||||||
del out
|
|
||||||
|
|
||||||
for k in range(estimates.shape[1]):
|
|
||||||
estimates[:, k, :, :] /= totals[k]
|
|
||||||
return estimates
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
|
||||||
batch, channels, length = mix.shape
|
|
||||||
if split:
|
|
||||||
kwargs['split'] = False
|
|
||||||
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
|
||||||
sum_weight = th.zeros(length, device=mix.device)
|
|
||||||
segment = int(model.samplerate * model.segment)
|
|
||||||
stride = int((1 - overlap) * segment)
|
|
||||||
offsets = range(0, length, stride)
|
|
||||||
scale = stride / model.samplerate
|
|
||||||
# We start from a triangle shaped weight, with maximal weight in the middle
|
|
||||||
# of the segment. Then we normalize and take to the power `transition_power`.
|
|
||||||
# Large values of transition power will lead to sharper transitions.
|
|
||||||
weight = th.cat([th.arange(1, segment // 2 + 1, device=device),
|
|
||||||
th.arange(segment - segment // 2, 0, -1, device=device)])
|
|
||||||
assert len(weight) == segment
|
|
||||||
# If the overlap < 50%, this will translate to linear transition when
|
|
||||||
# transition_power is 1.
|
|
||||||
weight = (weight / weight.max())**transition_power
|
|
||||||
futures = []
|
|
||||||
for offset in offsets:
|
|
||||||
chunk = TensorChunk(mix, offset, segment)
|
|
||||||
future = pool.submit(apply_model, model, chunk, **kwargs)
|
|
||||||
futures.append((future, offset))
|
|
||||||
offset += segment
|
|
||||||
if progress:
|
|
||||||
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
|
|
||||||
for future, offset in futures:
|
|
||||||
chunk_out = future.result()
|
|
||||||
chunk_length = chunk_out.shape[-1]
|
|
||||||
out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
|
|
||||||
sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
|
|
||||||
assert sum_weight.min() > 0
|
|
||||||
out /= sum_weight
|
|
||||||
return out
|
|
||||||
elif shifts:
|
|
||||||
kwargs['shifts'] = 0
|
|
||||||
max_shift = int(0.5 * model.samplerate)
|
|
||||||
mix = tensor_chunk(mix)
|
|
||||||
padded_mix = mix.padded(length + 2 * max_shift)
|
|
||||||
out = 0
|
|
||||||
for _ in range(shifts):
|
|
||||||
offset = random.randint(0, max_shift)
|
|
||||||
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
|
||||||
shifted_out = apply_model(model, shifted, **kwargs)
|
|
||||||
out += shifted_out[..., max_shift - offset:]
|
|
||||||
out /= shifts
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
if hasattr(model, 'valid_length'):
|
|
||||||
valid_length = model.valid_length(length)
|
|
||||||
else:
|
|
||||||
valid_length = length
|
|
||||||
mix = tensor_chunk(mix)
|
|
||||||
padded_mix = mix.padded(valid_length).to(device)
|
|
||||||
with th.no_grad():
|
|
||||||
out = model(padded_mix)
|
|
||||||
return center_trim(out, length)
|
|
256
demucs/audio.py
256
demucs/audio.py
@ -1,256 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
import json
|
|
||||||
import subprocess as sp
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import lameenc
|
|
||||||
import julius
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio as ta
|
|
||||||
|
|
||||||
from .utils import temp_filenames
|
|
||||||
|
|
||||||
|
|
||||||
def _read_info(path):
|
|
||||||
stdout_data = sp.check_output([
|
|
||||||
'ffprobe', "-loglevel", "panic",
|
|
||||||
str(path), '-print_format', 'json', '-show_format', '-show_streams'
|
|
||||||
])
|
|
||||||
return json.loads(stdout_data.decode('utf-8'))
|
|
||||||
|
|
||||||
|
|
||||||
class AudioFile:
|
|
||||||
"""
|
|
||||||
Allows to read audio from any format supported by ffmpeg, as well as resampling or
|
|
||||||
converting to mono on the fly. See :method:`read` for more details.
|
|
||||||
"""
|
|
||||||
def __init__(self, path: Path):
|
|
||||||
self.path = Path(path)
|
|
||||||
self._info = None
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
features = [("path", self.path)]
|
|
||||||
features.append(("samplerate", self.samplerate()))
|
|
||||||
features.append(("channels", self.channels()))
|
|
||||||
features.append(("streams", len(self)))
|
|
||||||
features_str = ", ".join(f"{name}={value}" for name, value in features)
|
|
||||||
return f"AudioFile({features_str})"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def info(self):
|
|
||||||
if self._info is None:
|
|
||||||
self._info = _read_info(self.path)
|
|
||||||
return self._info
|
|
||||||
|
|
||||||
@property
|
|
||||||
def duration(self):
|
|
||||||
return float(self.info['format']['duration'])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _audio_streams(self):
|
|
||||||
return [
|
|
||||||
index for index, stream in enumerate(self.info["streams"])
|
|
||||||
if stream["codec_type"] == "audio"
|
|
||||||
]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._audio_streams)
|
|
||||||
|
|
||||||
def channels(self, stream=0):
|
|
||||||
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
|
|
||||||
|
|
||||||
def samplerate(self, stream=0):
|
|
||||||
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
|
|
||||||
|
|
||||||
def read(self,
|
|
||||||
seek_time=None,
|
|
||||||
duration=None,
|
|
||||||
streams=slice(None),
|
|
||||||
samplerate=None,
|
|
||||||
channels=None,
|
|
||||||
temp_folder=None):
|
|
||||||
"""
|
|
||||||
Slightly more efficient implementation than stempeg,
|
|
||||||
in particular, this will extract all stems at once
|
|
||||||
rather than having to loop over one file multiple times
|
|
||||||
for each stream.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seek_time (float): seek time in seconds or None if no seeking is needed.
|
|
||||||
duration (float): duration in seconds to extract or None to extract until the end.
|
|
||||||
streams (slice, int or list): streams to extract, can be a single int, a list or
|
|
||||||
a slice. If it is a slice or list, the output will be of size [S, C, T]
|
|
||||||
with S the number of streams, C the number of channels and T the number of samples.
|
|
||||||
If it is an int, the output will be [C, T].
|
|
||||||
samplerate (int): if provided, will resample on the fly. If None, no resampling will
|
|
||||||
be done. Original sampling rate can be obtained with :method:`samplerate`.
|
|
||||||
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
|
|
||||||
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
|
|
||||||
See https://sound.stackexchange.com/a/42710.
|
|
||||||
Our definition of mono is simply the average of the two channels. Any other
|
|
||||||
value will be ignored.
|
|
||||||
temp_folder (str or Path or None): temporary folder to use for decoding.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
streams = np.array(range(len(self)))[streams]
|
|
||||||
single = not isinstance(streams, np.ndarray)
|
|
||||||
if single:
|
|
||||||
streams = [streams]
|
|
||||||
|
|
||||||
if duration is None:
|
|
||||||
target_size = None
|
|
||||||
query_duration = None
|
|
||||||
else:
|
|
||||||
target_size = int((samplerate or self.samplerate()) * duration)
|
|
||||||
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
|
|
||||||
|
|
||||||
with temp_filenames(len(streams)) as filenames:
|
|
||||||
command = ['ffmpeg', '-y']
|
|
||||||
command += ['-loglevel', 'panic']
|
|
||||||
if seek_time:
|
|
||||||
command += ['-ss', str(seek_time)]
|
|
||||||
command += ['-i', str(self.path)]
|
|
||||||
for stream, filename in zip(streams, filenames):
|
|
||||||
command += ['-map', f'0:{self._audio_streams[stream]}']
|
|
||||||
if query_duration is not None:
|
|
||||||
command += ['-t', str(query_duration)]
|
|
||||||
command += ['-threads', '1']
|
|
||||||
command += ['-f', 'f32le']
|
|
||||||
if samplerate is not None:
|
|
||||||
command += ['-ar', str(samplerate)]
|
|
||||||
command += [filename]
|
|
||||||
|
|
||||||
sp.run(command, check=True)
|
|
||||||
wavs = []
|
|
||||||
for filename in filenames:
|
|
||||||
wav = np.fromfile(filename, dtype=np.float32)
|
|
||||||
wav = torch.from_numpy(wav)
|
|
||||||
wav = wav.view(-1, self.channels()).t()
|
|
||||||
if channels is not None:
|
|
||||||
wav = convert_audio_channels(wav, channels)
|
|
||||||
if target_size is not None:
|
|
||||||
wav = wav[..., :target_size]
|
|
||||||
wavs.append(wav)
|
|
||||||
wav = torch.stack(wavs, dim=0)
|
|
||||||
if single:
|
|
||||||
wav = wav[0]
|
|
||||||
return wav
|
|
||||||
|
|
||||||
|
|
||||||
def convert_audio_channels(wav, channels=2):
|
|
||||||
"""Convert audio to the given number of channels."""
|
|
||||||
*shape, src_channels, length = wav.shape
|
|
||||||
if src_channels == channels:
|
|
||||||
pass
|
|
||||||
elif channels == 1:
|
|
||||||
# Case 1:
|
|
||||||
# The caller asked 1-channel audio, but the stream have multiple
|
|
||||||
# channels, downmix all channels.
|
|
||||||
wav = wav.mean(dim=-2, keepdim=True)
|
|
||||||
elif src_channels == 1:
|
|
||||||
# Case 2:
|
|
||||||
# The caller asked for multiple channels, but the input file have
|
|
||||||
# one single channel, replicate the audio over all channels.
|
|
||||||
wav = wav.expand(*shape, channels, length)
|
|
||||||
elif src_channels >= channels:
|
|
||||||
# Case 3:
|
|
||||||
# The caller asked for multiple channels, and the input file have
|
|
||||||
# more channels than requested. In that case return the first channels.
|
|
||||||
wav = wav[..., :channels, :]
|
|
||||||
else:
|
|
||||||
# Case 4: What is a reasonable choice here?
|
|
||||||
raise ValueError('The audio file has less channels than requested but is not mono.')
|
|
||||||
return wav
|
|
||||||
|
|
||||||
|
|
||||||
def convert_audio(wav, from_samplerate, to_samplerate, channels):
|
|
||||||
"""Convert audio from a given samplerate to a target one and target number of channels."""
|
|
||||||
wav = convert_audio_channels(wav, channels)
|
|
||||||
return julius.resample_frac(wav, from_samplerate, to_samplerate)
|
|
||||||
|
|
||||||
|
|
||||||
def i16_pcm(wav):
|
|
||||||
"""Convert audio to 16 bits integer PCM format."""
|
|
||||||
if wav.dtype.is_floating_point:
|
|
||||||
return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
|
|
||||||
else:
|
|
||||||
return wav
|
|
||||||
|
|
||||||
|
|
||||||
def f32_pcm(wav):
|
|
||||||
"""Convert audio to float 32 bits PCM format."""
|
|
||||||
if wav.dtype.is_floating_point:
|
|
||||||
return wav
|
|
||||||
else:
|
|
||||||
return wav.float() / (2**15 - 1)
|
|
||||||
|
|
||||||
|
|
||||||
def as_dtype_pcm(wav, dtype):
|
|
||||||
"""Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
|
|
||||||
if wav.dtype.is_floating_point:
|
|
||||||
return f32_pcm(wav)
|
|
||||||
else:
|
|
||||||
return i16_pcm(wav)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False):
|
|
||||||
"""Save given audio as mp3. This should work on all OSes."""
|
|
||||||
C, T = wav.shape
|
|
||||||
wav = i16_pcm(wav)
|
|
||||||
encoder = lameenc.Encoder()
|
|
||||||
encoder.set_bit_rate(bitrate)
|
|
||||||
encoder.set_in_sample_rate(samplerate)
|
|
||||||
encoder.set_channels(C)
|
|
||||||
encoder.set_quality(2) # 2-highest, 7-fastest
|
|
||||||
if not verbose:
|
|
||||||
encoder.silence()
|
|
||||||
wav = wav.transpose(0, 1).numpy()
|
|
||||||
mp3_data = encoder.encode(wav.tobytes())
|
|
||||||
mp3_data += encoder.flush()
|
|
||||||
with open(path, "wb") as f:
|
|
||||||
f.write(mp3_data)
|
|
||||||
|
|
||||||
|
|
||||||
def prevent_clip(wav, mode='rescale'):
|
|
||||||
"""
|
|
||||||
different strategies for avoiding raw clipping.
|
|
||||||
"""
|
|
||||||
assert wav.dtype.is_floating_point, "too late for clipping"
|
|
||||||
if mode == 'rescale':
|
|
||||||
wav = wav / max(1.01 * wav.abs().max(), 1)
|
|
||||||
elif mode == 'clamp':
|
|
||||||
wav = wav.clamp(-0.99, 0.99)
|
|
||||||
elif mode == 'tanh':
|
|
||||||
wav = torch.tanh(wav)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid mode {mode}")
|
|
||||||
return wav
|
|
||||||
|
|
||||||
|
|
||||||
def save_audio(wav, path, samplerate, bitrate=320, clip='rescale',
|
|
||||||
bits_per_sample=16, as_float=False):
|
|
||||||
"""Save audio file, automatically preventing clipping if necessary
|
|
||||||
based on the given `clip` strategy. If the path ends in `.mp3`, this
|
|
||||||
will save as mp3 with the given `bitrate`.
|
|
||||||
"""
|
|
||||||
wav = prevent_clip(wav, mode=clip)
|
|
||||||
path = Path(path)
|
|
||||||
suffix = path.suffix.lower()
|
|
||||||
if suffix == ".mp3":
|
|
||||||
encode_mp3(wav, path, samplerate, bitrate)
|
|
||||||
elif suffix == ".wav":
|
|
||||||
if as_float:
|
|
||||||
bits_per_sample = 32
|
|
||||||
encoding = 'PCM_F'
|
|
||||||
else:
|
|
||||||
encoding = 'PCM_S'
|
|
||||||
ta.save(str(path), wav, sample_rate=samplerate,
|
|
||||||
encoding=encoding, bits_per_sample=bits_per_sample)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid suffix for path: {suffix}")
|
|
459
demucs/demucs.py
459
demucs/demucs.py
@ -1,459 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
import julius
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .states import capture_init
|
|
||||||
from .utils import center_trim, unfold
|
|
||||||
|
|
||||||
|
|
||||||
class BLSTM(nn.Module):
|
|
||||||
"""
|
|
||||||
BiLSTM with same hidden units as input dim.
|
|
||||||
If `max_steps` is not None, input will be splitting in overlapping
|
|
||||||
chunks and the LSTM applied separately on each chunk.
|
|
||||||
"""
|
|
||||||
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
|
||||||
super().__init__()
|
|
||||||
assert max_steps is None or max_steps % 4 == 0
|
|
||||||
self.max_steps = max_steps
|
|
||||||
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
|
||||||
self.linear = nn.Linear(2 * dim, dim)
|
|
||||||
self.skip = skip
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, C, T = x.shape
|
|
||||||
y = x
|
|
||||||
framed = False
|
|
||||||
if self.max_steps is not None and T > self.max_steps:
|
|
||||||
width = self.max_steps
|
|
||||||
stride = width // 2
|
|
||||||
frames = unfold(x, width, stride)
|
|
||||||
nframes = frames.shape[2]
|
|
||||||
framed = True
|
|
||||||
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
|
||||||
|
|
||||||
x = x.permute(2, 0, 1)
|
|
||||||
|
|
||||||
x = self.lstm(x)[0]
|
|
||||||
x = self.linear(x)
|
|
||||||
x = x.permute(1, 2, 0)
|
|
||||||
if framed:
|
|
||||||
out = []
|
|
||||||
frames = x.reshape(B, -1, C, width)
|
|
||||||
limit = stride // 2
|
|
||||||
for k in range(nframes):
|
|
||||||
if k == 0:
|
|
||||||
out.append(frames[:, k, :, :-limit])
|
|
||||||
elif k == nframes - 1:
|
|
||||||
out.append(frames[:, k, :, limit:])
|
|
||||||
else:
|
|
||||||
out.append(frames[:, k, :, limit:-limit])
|
|
||||||
out = torch.cat(out, -1)
|
|
||||||
out = out[..., :T]
|
|
||||||
x = out
|
|
||||||
if self.skip:
|
|
||||||
x = x + y
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def rescale_conv(conv, reference):
|
|
||||||
"""Rescale initial weight scale. It is unclear why it helps but it certainly does.
|
|
||||||
"""
|
|
||||||
std = conv.weight.std().detach()
|
|
||||||
scale = (std / reference)**0.5
|
|
||||||
conv.weight.data /= scale
|
|
||||||
if conv.bias is not None:
|
|
||||||
conv.bias.data /= scale
|
|
||||||
|
|
||||||
|
|
||||||
def rescale_module(module, reference):
|
|
||||||
for sub in module.modules():
|
|
||||||
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
|
||||||
rescale_conv(sub, reference)
|
|
||||||
|
|
||||||
|
|
||||||
class LayerScale(nn.Module):
|
|
||||||
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
|
||||||
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
|
||||||
"""
|
|
||||||
def __init__(self, channels: int, init: float = 0):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
|
||||||
self.scale.data[:] = init
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.scale[:, None] * x
|
|
||||||
|
|
||||||
|
|
||||||
class DConv(nn.Module):
|
|
||||||
"""
|
|
||||||
New residual branches in each encoder layer.
|
|
||||||
This alternates dilated convolutions, potentially with LSTMs and attention.
|
|
||||||
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
|
||||||
e.g. of dim `channels // compress`.
|
|
||||||
"""
|
|
||||||
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
|
|
||||||
norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
|
|
||||||
kernel=3, dilate=True):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
channels: input/output channels for residual branch.
|
|
||||||
compress: amount of channel compression inside the branch.
|
|
||||||
depth: number of layers in the residual branch. Each layer has its own
|
|
||||||
projection, and potentially LSTM and attention.
|
|
||||||
init: initial scale for LayerNorm.
|
|
||||||
norm: use GroupNorm.
|
|
||||||
attn: use LocalAttention.
|
|
||||||
heads: number of heads for the LocalAttention.
|
|
||||||
ndecay: number of decay controls in the LocalAttention.
|
|
||||||
lstm: use LSTM.
|
|
||||||
gelu: Use GELU activation.
|
|
||||||
kernel: kernel size for the (dilated) convolutions.
|
|
||||||
dilate: if true, use dilation, increasing with the depth.
|
|
||||||
"""
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
assert kernel % 2 == 1
|
|
||||||
self.channels = channels
|
|
||||||
self.compress = compress
|
|
||||||
self.depth = abs(depth)
|
|
||||||
dilate = depth > 0
|
|
||||||
|
|
||||||
norm_fn: tp.Callable[[int], nn.Module]
|
|
||||||
norm_fn = lambda d: nn.Identity() # noqa
|
|
||||||
if norm:
|
|
||||||
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
|
||||||
|
|
||||||
hidden = int(channels / compress)
|
|
||||||
|
|
||||||
act: tp.Type[nn.Module]
|
|
||||||
if gelu:
|
|
||||||
act = nn.GELU
|
|
||||||
else:
|
|
||||||
act = nn.ReLU
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
for d in range(self.depth):
|
|
||||||
dilation = 2 ** d if dilate else 1
|
|
||||||
padding = dilation * (kernel // 2)
|
|
||||||
mods = [
|
|
||||||
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
|
||||||
norm_fn(hidden), act(),
|
|
||||||
nn.Conv1d(hidden, 2 * channels, 1),
|
|
||||||
norm_fn(2 * channels), nn.GLU(1),
|
|
||||||
LayerScale(channels, init),
|
|
||||||
]
|
|
||||||
if attn:
|
|
||||||
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
|
||||||
if lstm:
|
|
||||||
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
|
||||||
layer = nn.Sequential(*mods)
|
|
||||||
self.layers.append(layer)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for layer in self.layers:
|
|
||||||
x = x + layer(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LocalState(nn.Module):
|
|
||||||
"""Local state allows to have attention based only on data (no positional embedding),
|
|
||||||
but while setting a constraint on the time window (e.g. decaying penalty term).
|
|
||||||
|
|
||||||
Also a failed experiments with trying to provide some frequency based attention.
|
|
||||||
"""
|
|
||||||
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
|
||||||
super().__init__()
|
|
||||||
assert channels % heads == 0, (channels, heads)
|
|
||||||
self.heads = heads
|
|
||||||
self.nfreqs = nfreqs
|
|
||||||
self.ndecay = ndecay
|
|
||||||
self.content = nn.Conv1d(channels, channels, 1)
|
|
||||||
self.query = nn.Conv1d(channels, channels, 1)
|
|
||||||
self.key = nn.Conv1d(channels, channels, 1)
|
|
||||||
if nfreqs:
|
|
||||||
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
|
||||||
if ndecay:
|
|
||||||
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
|
||||||
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
|
||||||
self.query_decay.weight.data *= 0.01
|
|
||||||
assert self.query_decay.bias is not None # stupid type checker
|
|
||||||
self.query_decay.bias.data[:] = -2
|
|
||||||
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, C, T = x.shape
|
|
||||||
heads = self.heads
|
|
||||||
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
|
||||||
# left index are keys, right index are queries
|
|
||||||
delta = indexes[:, None] - indexes[None, :]
|
|
||||||
|
|
||||||
queries = self.query(x).view(B, heads, -1, T)
|
|
||||||
keys = self.key(x).view(B, heads, -1, T)
|
|
||||||
# t are keys, s are queries
|
|
||||||
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
|
||||||
dots /= keys.shape[2]**0.5
|
|
||||||
if self.nfreqs:
|
|
||||||
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
|
||||||
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
|
||||||
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
|
|
||||||
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
|
||||||
if self.ndecay:
|
|
||||||
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
|
||||||
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
|
||||||
decay_q = torch.sigmoid(decay_q) / 2
|
|
||||||
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
|
||||||
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
|
||||||
|
|
||||||
# Kill self reference.
|
|
||||||
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
|
||||||
weights = torch.softmax(dots, dim=2)
|
|
||||||
|
|
||||||
content = self.content(x).view(B, heads, -1, T)
|
|
||||||
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
|
||||||
if self.nfreqs:
|
|
||||||
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
|
||||||
result = torch.cat([result, time_sig], 2)
|
|
||||||
result = result.reshape(B, -1, T)
|
|
||||||
return x + self.proj(result)
|
|
||||||
|
|
||||||
|
|
||||||
class Demucs(nn.Module):
|
|
||||||
@capture_init
|
|
||||||
def __init__(self,
|
|
||||||
sources,
|
|
||||||
# Channels
|
|
||||||
audio_channels=2,
|
|
||||||
channels=64,
|
|
||||||
growth=2.,
|
|
||||||
# Main structure
|
|
||||||
depth=6,
|
|
||||||
rewrite=True,
|
|
||||||
lstm_layers=0,
|
|
||||||
# Convolutions
|
|
||||||
kernel_size=8,
|
|
||||||
stride=4,
|
|
||||||
context=1,
|
|
||||||
# Activations
|
|
||||||
gelu=True,
|
|
||||||
glu=True,
|
|
||||||
# Normalization
|
|
||||||
norm_starts=4,
|
|
||||||
norm_groups=4,
|
|
||||||
# DConv residual branch
|
|
||||||
dconv_mode=1,
|
|
||||||
dconv_depth=2,
|
|
||||||
dconv_comp=4,
|
|
||||||
dconv_attn=4,
|
|
||||||
dconv_lstm=4,
|
|
||||||
dconv_init=1e-4,
|
|
||||||
# Pre/post processing
|
|
||||||
normalize=True,
|
|
||||||
resample=True,
|
|
||||||
# Weight init
|
|
||||||
rescale=0.1,
|
|
||||||
# Metadata
|
|
||||||
samplerate=44100,
|
|
||||||
segment=4 * 10):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
sources (list[str]): list of source names
|
|
||||||
audio_channels (int): stereo or mono
|
|
||||||
channels (int): first convolution channels
|
|
||||||
depth (int): number of encoder/decoder layers
|
|
||||||
growth (float): multiply (resp divide) number of channels by that
|
|
||||||
for each layer of the encoder (resp decoder)
|
|
||||||
depth (int): number of layers in the encoder and in the decoder.
|
|
||||||
rewrite (bool): add 1x1 convolution to each layer.
|
|
||||||
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
|
|
||||||
by default, as this is now replaced by the smaller and faster small LSTMs
|
|
||||||
in the DConv branches.
|
|
||||||
kernel_size (int): kernel size for convolutions
|
|
||||||
stride (int): stride for convolutions
|
|
||||||
context (int): kernel size of the convolution in the
|
|
||||||
decoder before the transposed convolution. If > 1,
|
|
||||||
will provide some context from neighboring time steps.
|
|
||||||
gelu: use GELU activation function.
|
|
||||||
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
|
|
||||||
norm_starts: layer at which group norm starts being used.
|
|
||||||
decoder layers are numbered in reverse order.
|
|
||||||
norm_groups: number of groups for group norm.
|
|
||||||
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
|
||||||
dconv_depth: depth of residual DConv branch.
|
|
||||||
dconv_comp: compression of DConv branch.
|
|
||||||
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
|
||||||
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
|
||||||
dconv_init: initial scale for the DConv branch LayerScale.
|
|
||||||
normalize (bool): normalizes the input audio on the fly, and scales back
|
|
||||||
the output by the same amount.
|
|
||||||
resample (bool): upsample x2 the input and downsample /2 the output.
|
|
||||||
rescale (int): rescale initial weights of convolutions
|
|
||||||
to get their standard deviation closer to `rescale`.
|
|
||||||
samplerate (int): stored as meta information for easing
|
|
||||||
future evaluations of the model.
|
|
||||||
segment (float): duration of the chunks of audio to ideally evaluate the model on.
|
|
||||||
This is used by `demucs.apply.apply_model`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.audio_channels = audio_channels
|
|
||||||
self.sources = sources
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.context = context
|
|
||||||
self.stride = stride
|
|
||||||
self.depth = depth
|
|
||||||
self.resample = resample
|
|
||||||
self.channels = channels
|
|
||||||
self.normalize = normalize
|
|
||||||
self.samplerate = samplerate
|
|
||||||
self.segment = segment
|
|
||||||
self.encoder = nn.ModuleList()
|
|
||||||
self.decoder = nn.ModuleList()
|
|
||||||
self.skip_scales = nn.ModuleList()
|
|
||||||
|
|
||||||
if glu:
|
|
||||||
activation = nn.GLU(dim=1)
|
|
||||||
ch_scale = 2
|
|
||||||
else:
|
|
||||||
activation = nn.ReLU()
|
|
||||||
ch_scale = 1
|
|
||||||
if gelu:
|
|
||||||
act2 = nn.GELU
|
|
||||||
else:
|
|
||||||
act2 = nn.ReLU
|
|
||||||
|
|
||||||
in_channels = audio_channels
|
|
||||||
padding = 0
|
|
||||||
for index in range(depth):
|
|
||||||
norm_fn = lambda d: nn.Identity() # noqa
|
|
||||||
if index >= norm_starts:
|
|
||||||
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
|
||||||
|
|
||||||
encode = []
|
|
||||||
encode += [
|
|
||||||
nn.Conv1d(in_channels, channels, kernel_size, stride),
|
|
||||||
norm_fn(channels),
|
|
||||||
act2(),
|
|
||||||
]
|
|
||||||
attn = index >= dconv_attn
|
|
||||||
lstm = index >= dconv_lstm
|
|
||||||
if dconv_mode & 1:
|
|
||||||
encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
|
||||||
compress=dconv_comp, attn=attn, lstm=lstm)]
|
|
||||||
if rewrite:
|
|
||||||
encode += [
|
|
||||||
nn.Conv1d(channels, ch_scale * channels, 1),
|
|
||||||
norm_fn(ch_scale * channels), activation]
|
|
||||||
self.encoder.append(nn.Sequential(*encode))
|
|
||||||
|
|
||||||
decode = []
|
|
||||||
if index > 0:
|
|
||||||
out_channels = in_channels
|
|
||||||
else:
|
|
||||||
out_channels = len(self.sources) * audio_channels
|
|
||||||
if rewrite:
|
|
||||||
decode += [
|
|
||||||
nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
|
|
||||||
norm_fn(ch_scale * channels), activation]
|
|
||||||
if dconv_mode & 2:
|
|
||||||
decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
|
||||||
compress=dconv_comp, attn=attn, lstm=lstm)]
|
|
||||||
decode += [nn.ConvTranspose1d(channels, out_channels,
|
|
||||||
kernel_size, stride, padding=padding)]
|
|
||||||
if index > 0:
|
|
||||||
decode += [norm_fn(out_channels), act2()]
|
|
||||||
self.decoder.insert(0, nn.Sequential(*decode))
|
|
||||||
in_channels = channels
|
|
||||||
channels = int(growth * channels)
|
|
||||||
|
|
||||||
channels = in_channels
|
|
||||||
if lstm_layers:
|
|
||||||
self.lstm = BLSTM(channels, lstm_layers)
|
|
||||||
else:
|
|
||||||
self.lstm = None
|
|
||||||
|
|
||||||
if rescale:
|
|
||||||
rescale_module(self, reference=rescale)
|
|
||||||
|
|
||||||
def valid_length(self, length):
|
|
||||||
"""
|
|
||||||
Return the nearest valid length to use with the model so that
|
|
||||||
there is no time steps left over in a convolution, e.g. for all
|
|
||||||
layers, size of the input - kernel_size % stride = 0.
|
|
||||||
|
|
||||||
Note that input are automatically padded if necessary to ensure that the output
|
|
||||||
has the same length as the input.
|
|
||||||
"""
|
|
||||||
if self.resample:
|
|
||||||
length *= 2
|
|
||||||
|
|
||||||
for _ in range(self.depth):
|
|
||||||
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
|
||||||
length = max(1, length)
|
|
||||||
|
|
||||||
for idx in range(self.depth):
|
|
||||||
length = (length - 1) * self.stride + self.kernel_size
|
|
||||||
|
|
||||||
if self.resample:
|
|
||||||
length = math.ceil(length / 2)
|
|
||||||
return int(length)
|
|
||||||
|
|
||||||
def forward(self, mix):
|
|
||||||
x = mix
|
|
||||||
length = x.shape[-1]
|
|
||||||
|
|
||||||
if self.normalize:
|
|
||||||
mono = mix.mean(dim=1, keepdim=True)
|
|
||||||
mean = mono.mean(dim=-1, keepdim=True)
|
|
||||||
std = mono.std(dim=-1, keepdim=True)
|
|
||||||
x = (x - mean) / (1e-5 + std)
|
|
||||||
else:
|
|
||||||
mean = 0
|
|
||||||
std = 1
|
|
||||||
|
|
||||||
delta = self.valid_length(length) - length
|
|
||||||
x = F.pad(x, (delta // 2, delta - delta // 2))
|
|
||||||
|
|
||||||
if self.resample:
|
|
||||||
x = julius.resample_frac(x, 1, 2)
|
|
||||||
|
|
||||||
saved = []
|
|
||||||
for encode in self.encoder:
|
|
||||||
x = encode(x)
|
|
||||||
saved.append(x)
|
|
||||||
|
|
||||||
if self.lstm:
|
|
||||||
x = self.lstm(x)
|
|
||||||
|
|
||||||
for decode in self.decoder:
|
|
||||||
skip = saved.pop(-1)
|
|
||||||
skip = center_trim(skip, x)
|
|
||||||
x = decode(x + skip)
|
|
||||||
|
|
||||||
if self.resample:
|
|
||||||
x = julius.resample_frac(x, 2, 1)
|
|
||||||
x = x * std + mean
|
|
||||||
x = center_trim(x, length)
|
|
||||||
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
|
||||||
return x
|
|
||||||
|
|
||||||
def load_state_dict(self, state, strict=True):
|
|
||||||
# fix a mismatch with previous generation Demucs models.
|
|
||||||
for idx in range(self.depth):
|
|
||||||
for a in ['encoder', 'decoder']:
|
|
||||||
for b in ['bias', 'weight']:
|
|
||||||
new = f'{a}.{idx}.3.{b}'
|
|
||||||
old = f'{a}.{idx}.2.{b}'
|
|
||||||
if old in state and new not in state:
|
|
||||||
state[new] = state.pop(old)
|
|
||||||
super().load_state_dict(state, strict=strict)
|
|
@ -1,100 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""Distributed training utilities.
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from torch.utils.data import DataLoader, Subset
|
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
||||||
|
|
||||||
from dora import distrib as dora_distrib
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
rank = 0
|
|
||||||
world_size = 1
|
|
||||||
|
|
||||||
|
|
||||||
def init():
|
|
||||||
global rank, world_size
|
|
||||||
if not torch.distributed.is_initialized():
|
|
||||||
dora_distrib.init()
|
|
||||||
rank = dora_distrib.rank()
|
|
||||||
world_size = dora_distrib.world_size()
|
|
||||||
|
|
||||||
|
|
||||||
def average(metrics, count=1.):
|
|
||||||
if isinstance(metrics, dict):
|
|
||||||
keys, values = zip(*sorted(metrics.items()))
|
|
||||||
values = average(values, count)
|
|
||||||
return dict(zip(keys, values))
|
|
||||||
if world_size == 1:
|
|
||||||
return metrics
|
|
||||||
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
|
|
||||||
tensor *= count
|
|
||||||
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
|
||||||
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def wrap(model):
|
|
||||||
if world_size == 1:
|
|
||||||
return model
|
|
||||||
else:
|
|
||||||
return DistributedDataParallel(
|
|
||||||
model,
|
|
||||||
# find_unused_parameters=True,
|
|
||||||
device_ids=[torch.cuda.current_device()],
|
|
||||||
output_device=torch.cuda.current_device())
|
|
||||||
|
|
||||||
|
|
||||||
def barrier():
|
|
||||||
if world_size > 1:
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
def share(obj=None, src=0):
|
|
||||||
if world_size == 1:
|
|
||||||
return obj
|
|
||||||
size = torch.empty(1, device='cuda', dtype=torch.long)
|
|
||||||
if rank == src:
|
|
||||||
dump = pickle.dumps(obj)
|
|
||||||
size[0] = len(dump)
|
|
||||||
torch.distributed.broadcast(size, src=src)
|
|
||||||
# size variable is now set to the length of pickled obj in all processes
|
|
||||||
|
|
||||||
if rank == src:
|
|
||||||
buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
|
|
||||||
else:
|
|
||||||
buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
|
|
||||||
torch.distributed.broadcast(buffer, src=src)
|
|
||||||
# buffer variable is now set to pickled obj in all processes
|
|
||||||
|
|
||||||
if rank != src:
|
|
||||||
obj = pickle.loads(buffer.cpu().numpy().tobytes())
|
|
||||||
logger.debug(f"Shared object of size {len(buffer)}")
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a dataloader properly in case of distributed training.
|
|
||||||
If a gradient is going to be computed you must set `shuffle=True`.
|
|
||||||
"""
|
|
||||||
if world_size == 1:
|
|
||||||
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
|
||||||
|
|
||||||
if shuffle:
|
|
||||||
# train means we will compute backward, we use DistributedSampler
|
|
||||||
sampler = DistributedSampler(dataset)
|
|
||||||
# We ignore shuffle, DistributedSampler already shuffles
|
|
||||||
return klass(dataset, *args, **kwargs, sampler=sampler)
|
|
||||||
else:
|
|
||||||
# We make a manual shard, as DistributedSampler otherwise replicate some examples
|
|
||||||
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
|
|
||||||
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
|
@ -1,66 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
# Inspired from https://github.com/rwightman/pytorch-image-models
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .states import swap_state
|
|
||||||
|
|
||||||
|
|
||||||
class ModelEMA:
|
|
||||||
"""
|
|
||||||
Perform EMA on a model. You can switch to the EMA weights temporarily
|
|
||||||
with the `swap` method.
|
|
||||||
|
|
||||||
ema = ModelEMA(model)
|
|
||||||
with ema.swap():
|
|
||||||
# compute valid metrics with averaged model.
|
|
||||||
"""
|
|
||||||
def __init__(self, model, decay=0.9999, unbias=True, device='cpu'):
|
|
||||||
self.decay = decay
|
|
||||||
self.model = model
|
|
||||||
self.state = {}
|
|
||||||
self.count = 0
|
|
||||||
self.device = device
|
|
||||||
self.unbias = unbias
|
|
||||||
|
|
||||||
self._init()
|
|
||||||
|
|
||||||
def _init(self):
|
|
||||||
for key, val in self.model.state_dict().items():
|
|
||||||
if val.dtype != torch.float32:
|
|
||||||
continue
|
|
||||||
device = self.device or val.device
|
|
||||||
if key not in self.state:
|
|
||||||
self.state[key] = val.detach().to(device, copy=True)
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
if self.unbias:
|
|
||||||
self.count = self.count * self.decay + 1
|
|
||||||
w = 1 / self.count
|
|
||||||
else:
|
|
||||||
w = 1 - self.decay
|
|
||||||
for key, val in self.model.state_dict().items():
|
|
||||||
if val.dtype != torch.float32:
|
|
||||||
continue
|
|
||||||
device = self.device or val.device
|
|
||||||
self.state[key].mul_(1 - w)
|
|
||||||
self.state[key].add_(val.detach().to(device), alpha=w)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def swap(self):
|
|
||||||
with swap_state(self.model, self.state):
|
|
||||||
yield
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
return {'state': self.state, 'count': self.count}
|
|
||||||
|
|
||||||
def load_state_dict(self, state):
|
|
||||||
self.count = state['count']
|
|
||||||
for k, v in state['state'].items():
|
|
||||||
self.state[k].copy_(v)
|
|
@ -1,173 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
"""Test time evaluation, either using the original SDR from [Vincent et al. 2006]
|
|
||||||
or the newest SDR definition from the MDX 2021 competition (this one will
|
|
||||||
be reported as `nsdr` for `new sdr`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from concurrent import futures
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from dora.log import LogProgress
|
|
||||||
import numpy as np
|
|
||||||
import musdb
|
|
||||||
import museval
|
|
||||||
import torch as th
|
|
||||||
|
|
||||||
from .apply import apply_model
|
|
||||||
from .audio import convert_audio, save_audio
|
|
||||||
from . import distrib
|
|
||||||
from .utils import DummyPoolExecutor
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def new_sdr(references, estimates):
|
|
||||||
"""
|
|
||||||
Compute the SDR according to the MDX challenge definition.
|
|
||||||
Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
|
|
||||||
"""
|
|
||||||
assert references.dim() == 4
|
|
||||||
assert estimates.dim() == 4
|
|
||||||
delta = 1e-7 # avoid numerical errors
|
|
||||||
num = th.sum(th.square(references), dim=(2, 3))
|
|
||||||
den = th.sum(th.square(references - estimates), dim=(2, 3))
|
|
||||||
num += delta
|
|
||||||
den += delta
|
|
||||||
scores = 10 * th.log10(num / den)
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
def eval_track(references, estimates, win, hop, compute_sdr=True):
|
|
||||||
references = references.transpose(1, 2).double()
|
|
||||||
estimates = estimates.transpose(1, 2).double()
|
|
||||||
|
|
||||||
new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]
|
|
||||||
|
|
||||||
if not compute_sdr:
|
|
||||||
return None, new_scores
|
|
||||||
else:
|
|
||||||
references = references.numpy()
|
|
||||||
estimates = estimates.numpy()
|
|
||||||
scores = museval.metrics.bss_eval(
|
|
||||||
references, estimates,
|
|
||||||
compute_permutation=False,
|
|
||||||
window=win,
|
|
||||||
hop=hop,
|
|
||||||
framewise_filters=False,
|
|
||||||
bsseval_sources_version=False)[:-1]
|
|
||||||
return scores, new_scores
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(solver, compute_sdr=False):
|
|
||||||
"""
|
|
||||||
Evaluate model using museval.
|
|
||||||
`new_only` means using only the MDX definition of the SDR, which is much faster to evaluate.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = solver.args
|
|
||||||
|
|
||||||
output_dir = solver.folder / "results"
|
|
||||||
output_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
json_folder = solver.folder / "results/test"
|
|
||||||
json_folder.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
# we load tracks from the original musdb set
|
|
||||||
if args.test.nonhq is None:
|
|
||||||
test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True)
|
|
||||||
else:
|
|
||||||
test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False)
|
|
||||||
src_rate = args.dset.musdb_samplerate
|
|
||||||
|
|
||||||
eval_device = 'cpu'
|
|
||||||
|
|
||||||
model = solver.model
|
|
||||||
win = int(1. * model.samplerate)
|
|
||||||
hop = int(1. * model.samplerate)
|
|
||||||
|
|
||||||
indexes = range(distrib.rank, len(test_set), distrib.world_size)
|
|
||||||
indexes = LogProgress(logger, indexes, updates=args.misc.num_prints,
|
|
||||||
name='Eval')
|
|
||||||
pendings = []
|
|
||||||
|
|
||||||
pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor
|
|
||||||
with pool(args.test.workers) as pool:
|
|
||||||
for index in indexes:
|
|
||||||
track = test_set.tracks[index]
|
|
||||||
|
|
||||||
mix = th.from_numpy(track.audio).t().float()
|
|
||||||
if mix.dim() == 1:
|
|
||||||
mix = mix[None]
|
|
||||||
mix = mix.to(solver.device)
|
|
||||||
ref = mix.mean(dim=0) # mono mixture
|
|
||||||
mix = (mix - ref.mean()) / ref.std()
|
|
||||||
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
|
|
||||||
estimates = apply_model(model, mix[None],
|
|
||||||
shifts=args.test.shifts, split=args.test.split,
|
|
||||||
overlap=args.test.overlap)[0]
|
|
||||||
estimates = estimates * ref.std() + ref.mean()
|
|
||||||
estimates = estimates.to(eval_device)
|
|
||||||
|
|
||||||
references = th.stack(
|
|
||||||
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
|
|
||||||
if references.dim() == 2:
|
|
||||||
references = references[:, None]
|
|
||||||
references = references.to(eval_device)
|
|
||||||
references = convert_audio(references, src_rate,
|
|
||||||
model.samplerate, model.audio_channels)
|
|
||||||
if args.test.save:
|
|
||||||
folder = solver.folder / "wav" / track.name
|
|
||||||
folder.mkdir(exist_ok=True, parents=True)
|
|
||||||
for name, estimate in zip(model.sources, estimates):
|
|
||||||
save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate)
|
|
||||||
|
|
||||||
pendings.append((track.name, pool.submit(
|
|
||||||
eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr)))
|
|
||||||
|
|
||||||
pendings = LogProgress(logger, pendings, updates=args.misc.num_prints,
|
|
||||||
name='Eval (BSS)')
|
|
||||||
tracks = {}
|
|
||||||
for track_name, pending in pendings:
|
|
||||||
pending = pending.result()
|
|
||||||
scores, nsdrs = pending
|
|
||||||
tracks[track_name] = {}
|
|
||||||
for idx, target in enumerate(model.sources):
|
|
||||||
tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]}
|
|
||||||
if scores is not None:
|
|
||||||
(sdr, isr, sir, sar) = scores
|
|
||||||
for idx, target in enumerate(model.sources):
|
|
||||||
values = {
|
|
||||||
"SDR": sdr[idx].tolist(),
|
|
||||||
"SIR": sir[idx].tolist(),
|
|
||||||
"ISR": isr[idx].tolist(),
|
|
||||||
"SAR": sar[idx].tolist()
|
|
||||||
}
|
|
||||||
tracks[track_name][target].update(values)
|
|
||||||
|
|
||||||
all_tracks = {}
|
|
||||||
for src in range(distrib.world_size):
|
|
||||||
all_tracks.update(distrib.share(tracks, src))
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
metric_names = next(iter(all_tracks.values()))[model.sources[0]]
|
|
||||||
for metric_name in metric_names:
|
|
||||||
avg = 0
|
|
||||||
avg_of_medians = 0
|
|
||||||
for source in model.sources:
|
|
||||||
medians = [
|
|
||||||
np.nanmedian(all_tracks[track][source][metric_name])
|
|
||||||
for track in all_tracks.keys()]
|
|
||||||
mean = np.mean(medians)
|
|
||||||
median = np.median(medians)
|
|
||||||
result[metric_name.lower() + "_" + source] = mean
|
|
||||||
result[metric_name.lower() + "_med" + "_" + source] = median
|
|
||||||
avg += mean / len(model.sources)
|
|
||||||
avg_of_medians += median / len(model.sources)
|
|
||||||
result[metric_name.lower()] = avg
|
|
||||||
result[metric_name.lower() + "_med"] = avg_of_medians
|
|
||||||
return result
|
|
@ -1,761 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""
|
|
||||||
This code contains the spectrogram and Hybrid version of Demucs.
|
|
||||||
"""
|
|
||||||
from copy import deepcopy
|
|
||||||
import math
|
|
||||||
|
|
||||||
from openunmix.filtering import wiener
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .demucs import DConv, rescale_module
|
|
||||||
from .states import capture_init
|
|
||||||
from .spec import spectro, ispectro
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledEmbedding(nn.Module):
|
|
||||||
"""
|
|
||||||
Boost learning rate for embeddings (with `scale`).
|
|
||||||
Also, can make embeddings continuous with `smooth`.
|
|
||||||
"""
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int,
|
|
||||||
scale: float = 10., smooth=False):
|
|
||||||
super().__init__()
|
|
||||||
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
|
||||||
if smooth:
|
|
||||||
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
|
||||||
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
|
||||||
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
|
||||||
self.embedding.weight.data[:] = weight
|
|
||||||
self.embedding.weight.data /= scale
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.embedding.weight * self.scale
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.embedding(x) * self.scale
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class HEncLayer(nn.Module):
|
|
||||||
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
|
||||||
freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
|
|
||||||
rewrite=True):
|
|
||||||
"""Encoder layer. This used both by the time and the frequency branch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chin: number of input channels.
|
|
||||||
chout: number of output channels.
|
|
||||||
norm_groups: number of groups for group norm.
|
|
||||||
empty: used to make a layer with just the first conv. this is used
|
|
||||||
before merging the time and freq. branches.
|
|
||||||
freq: this is acting on frequencies.
|
|
||||||
dconv: insert DConv residual branches.
|
|
||||||
norm: use GroupNorm.
|
|
||||||
context: context size for the 1x1 conv.
|
|
||||||
dconv_kw: list of kwargs for the DConv class.
|
|
||||||
pad: pad the input. Padding is done so that the output size is
|
|
||||||
always the input size / stride.
|
|
||||||
rewrite: add 1x1 conv at the end of the layer.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
norm_fn = lambda d: nn.Identity() # noqa
|
|
||||||
if norm:
|
|
||||||
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
|
||||||
if pad:
|
|
||||||
pad = kernel_size // 4
|
|
||||||
else:
|
|
||||||
pad = 0
|
|
||||||
klass = nn.Conv1d
|
|
||||||
self.freq = freq
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.empty = empty
|
|
||||||
self.norm = norm
|
|
||||||
self.pad = pad
|
|
||||||
if freq:
|
|
||||||
kernel_size = [kernel_size, 1]
|
|
||||||
stride = [stride, 1]
|
|
||||||
pad = [pad, 0]
|
|
||||||
klass = nn.Conv2d
|
|
||||||
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
|
||||||
if self.empty:
|
|
||||||
return
|
|
||||||
self.norm1 = norm_fn(chout)
|
|
||||||
self.rewrite = None
|
|
||||||
if rewrite:
|
|
||||||
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
|
||||||
self.norm2 = norm_fn(2 * chout)
|
|
||||||
|
|
||||||
self.dconv = None
|
|
||||||
if dconv:
|
|
||||||
self.dconv = DConv(chout, **dconv_kw)
|
|
||||||
|
|
||||||
def forward(self, x, inject=None):
|
|
||||||
"""
|
|
||||||
`inject` is used to inject the result from the time branch into the frequency branch,
|
|
||||||
when both have the same stride.
|
|
||||||
"""
|
|
||||||
if not self.freq and x.dim() == 4:
|
|
||||||
B, C, Fr, T = x.shape
|
|
||||||
x = x.view(B, -1, T)
|
|
||||||
|
|
||||||
if not self.freq:
|
|
||||||
le = x.shape[-1]
|
|
||||||
if not le % self.stride == 0:
|
|
||||||
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
|
||||||
y = self.conv(x)
|
|
||||||
if self.empty:
|
|
||||||
return y
|
|
||||||
if inject is not None:
|
|
||||||
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
|
||||||
if inject.dim() == 3 and y.dim() == 4:
|
|
||||||
inject = inject[:, :, None]
|
|
||||||
y = y + inject
|
|
||||||
y = F.gelu(self.norm1(y))
|
|
||||||
if self.dconv:
|
|
||||||
if self.freq:
|
|
||||||
B, C, Fr, T = y.shape
|
|
||||||
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
|
||||||
y = self.dconv(y)
|
|
||||||
if self.freq:
|
|
||||||
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
|
||||||
if self.rewrite:
|
|
||||||
z = self.norm2(self.rewrite(y))
|
|
||||||
z = F.glu(z, dim=1)
|
|
||||||
else:
|
|
||||||
z = y
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
class MultiWrap(nn.Module):
|
|
||||||
"""
|
|
||||||
Takes one layer and replicate it N times. each replica will act
|
|
||||||
on a frequency band. All is done so that if the N replica have the same weights,
|
|
||||||
then this is exactly equivalent to applying the original module on all frequencies.
|
|
||||||
|
|
||||||
This is a bit over-engineered to avoid edge artifacts when splitting
|
|
||||||
the frequency bands, but it is possible the naive implementation would work as well...
|
|
||||||
"""
|
|
||||||
def __init__(self, layer, split_ratios):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
layer: module to clone, must be either HEncLayer or HDecLayer.
|
|
||||||
split_ratios: list of float indicating which ratio to keep for each band.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.split_ratios = split_ratios
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
self.conv = isinstance(layer, HEncLayer)
|
|
||||||
assert not layer.norm
|
|
||||||
assert layer.freq
|
|
||||||
assert layer.pad
|
|
||||||
if not self.conv:
|
|
||||||
assert not layer.context_freq
|
|
||||||
for k in range(len(split_ratios) + 1):
|
|
||||||
lay = deepcopy(layer)
|
|
||||||
if self.conv:
|
|
||||||
lay.conv.padding = (0, 0)
|
|
||||||
else:
|
|
||||||
lay.pad = False
|
|
||||||
for m in lay.modules():
|
|
||||||
if hasattr(m, 'reset_parameters'):
|
|
||||||
m.reset_parameters()
|
|
||||||
self.layers.append(lay)
|
|
||||||
|
|
||||||
def forward(self, x, skip=None, length=None):
|
|
||||||
B, C, Fr, T = x.shape
|
|
||||||
|
|
||||||
ratios = list(self.split_ratios) + [1]
|
|
||||||
start = 0
|
|
||||||
outs = []
|
|
||||||
for ratio, layer in zip(ratios, self.layers):
|
|
||||||
if self.conv:
|
|
||||||
pad = layer.kernel_size // 4
|
|
||||||
if ratio == 1:
|
|
||||||
limit = Fr
|
|
||||||
frames = -1
|
|
||||||
else:
|
|
||||||
limit = int(round(Fr * ratio))
|
|
||||||
le = limit - start
|
|
||||||
if start == 0:
|
|
||||||
le += pad
|
|
||||||
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
|
||||||
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
|
||||||
if start == 0:
|
|
||||||
limit -= pad
|
|
||||||
assert limit - start > 0, (limit, start)
|
|
||||||
assert limit <= Fr, (limit, Fr)
|
|
||||||
y = x[:, :, start:limit, :]
|
|
||||||
if start == 0:
|
|
||||||
y = F.pad(y, (0, 0, pad, 0))
|
|
||||||
if ratio == 1:
|
|
||||||
y = F.pad(y, (0, 0, 0, pad))
|
|
||||||
outs.append(layer(y))
|
|
||||||
start = limit - layer.kernel_size + layer.stride
|
|
||||||
else:
|
|
||||||
if ratio == 1:
|
|
||||||
limit = Fr
|
|
||||||
else:
|
|
||||||
limit = int(round(Fr * ratio))
|
|
||||||
last = layer.last
|
|
||||||
layer.last = True
|
|
||||||
|
|
||||||
y = x[:, :, start:limit]
|
|
||||||
s = skip[:, :, start:limit]
|
|
||||||
out, _ = layer(y, s, None)
|
|
||||||
if outs:
|
|
||||||
outs[-1][:, :, -layer.stride:] += (
|
|
||||||
out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
|
|
||||||
out = out[:, :, layer.stride:]
|
|
||||||
if ratio == 1:
|
|
||||||
out = out[:, :, :-layer.stride // 2, :]
|
|
||||||
if start == 0:
|
|
||||||
out = out[:, :, layer.stride // 2:, :]
|
|
||||||
outs.append(out)
|
|
||||||
layer.last = last
|
|
||||||
start = limit
|
|
||||||
out = torch.cat(outs, dim=2)
|
|
||||||
if not self.conv and not last:
|
|
||||||
out = F.gelu(out)
|
|
||||||
if self.conv:
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
return out, None
|
|
||||||
|
|
||||||
|
|
||||||
class HDecLayer(nn.Module):
|
|
||||||
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
|
||||||
freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
|
|
||||||
context_freq=True, rewrite=True):
|
|
||||||
"""
|
|
||||||
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
norm_fn = lambda d: nn.Identity() # noqa
|
|
||||||
if norm:
|
|
||||||
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
|
||||||
if pad:
|
|
||||||
pad = kernel_size // 4
|
|
||||||
else:
|
|
||||||
pad = 0
|
|
||||||
self.pad = pad
|
|
||||||
self.last = last
|
|
||||||
self.freq = freq
|
|
||||||
self.chin = chin
|
|
||||||
self.empty = empty
|
|
||||||
self.stride = stride
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.norm = norm
|
|
||||||
self.context_freq = context_freq
|
|
||||||
klass = nn.Conv1d
|
|
||||||
klass_tr = nn.ConvTranspose1d
|
|
||||||
if freq:
|
|
||||||
kernel_size = [kernel_size, 1]
|
|
||||||
stride = [stride, 1]
|
|
||||||
klass = nn.Conv2d
|
|
||||||
klass_tr = nn.ConvTranspose2d
|
|
||||||
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
|
||||||
self.norm2 = norm_fn(chout)
|
|
||||||
if self.empty:
|
|
||||||
return
|
|
||||||
self.rewrite = None
|
|
||||||
if rewrite:
|
|
||||||
if context_freq:
|
|
||||||
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
|
||||||
else:
|
|
||||||
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
|
|
||||||
[0, context])
|
|
||||||
self.norm1 = norm_fn(2 * chin)
|
|
||||||
|
|
||||||
self.dconv = None
|
|
||||||
if dconv:
|
|
||||||
self.dconv = DConv(chin, **dconv_kw)
|
|
||||||
|
|
||||||
def forward(self, x, skip, length):
|
|
||||||
if self.freq and x.dim() == 3:
|
|
||||||
B, C, T = x.shape
|
|
||||||
x = x.view(B, self.chin, -1, T)
|
|
||||||
|
|
||||||
if not self.empty:
|
|
||||||
x = x + skip
|
|
||||||
|
|
||||||
if self.rewrite:
|
|
||||||
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
|
||||||
else:
|
|
||||||
y = x
|
|
||||||
if self.dconv:
|
|
||||||
if self.freq:
|
|
||||||
B, C, Fr, T = y.shape
|
|
||||||
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
|
||||||
y = self.dconv(y)
|
|
||||||
if self.freq:
|
|
||||||
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
|
||||||
else:
|
|
||||||
y = x
|
|
||||||
assert skip is None
|
|
||||||
z = self.norm2(self.conv_tr(y))
|
|
||||||
if self.freq:
|
|
||||||
if self.pad:
|
|
||||||
z = z[..., self.pad:-self.pad, :]
|
|
||||||
else:
|
|
||||||
z = z[..., self.pad:self.pad + length]
|
|
||||||
assert z.shape[-1] == length, (z.shape[-1], length)
|
|
||||||
if not self.last:
|
|
||||||
z = F.gelu(z)
|
|
||||||
return z, y
|
|
||||||
|
|
||||||
|
|
||||||
class HDemucs(nn.Module):
|
|
||||||
"""
|
|
||||||
Spectrogram and hybrid Demucs model.
|
|
||||||
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
|
||||||
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
|
||||||
Frequency layers can still access information across time steps thanks to the DConv residual.
|
|
||||||
|
|
||||||
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
|
||||||
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
|
||||||
|
|
||||||
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
|
||||||
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
|
||||||
Open Unmix implementation [Stoter et al. 2019].
|
|
||||||
|
|
||||||
The loss is always on the temporal domain, by backpropagating through the above
|
|
||||||
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
|
||||||
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
|
||||||
contribution, without changing the one from the waveform, which will lead to worse performance.
|
|
||||||
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
|
||||||
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
|
||||||
hybrid models.
|
|
||||||
|
|
||||||
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
|
||||||
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
|
||||||
|
|
||||||
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
|
||||||
"""
|
|
||||||
@capture_init
|
|
||||||
def __init__(self,
|
|
||||||
sources,
|
|
||||||
# Channels
|
|
||||||
audio_channels=2,
|
|
||||||
channels=48,
|
|
||||||
channels_time=None,
|
|
||||||
growth=2,
|
|
||||||
# STFT
|
|
||||||
nfft=4096,
|
|
||||||
wiener_iters=0,
|
|
||||||
end_iters=0,
|
|
||||||
wiener_residual=False,
|
|
||||||
cac=True,
|
|
||||||
# Main structure
|
|
||||||
depth=6,
|
|
||||||
rewrite=True,
|
|
||||||
hybrid=True,
|
|
||||||
hybrid_old=False,
|
|
||||||
# Frequency branch
|
|
||||||
multi_freqs=None,
|
|
||||||
multi_freqs_depth=2,
|
|
||||||
freq_emb=0.2,
|
|
||||||
emb_scale=10,
|
|
||||||
emb_smooth=True,
|
|
||||||
# Convolutions
|
|
||||||
kernel_size=8,
|
|
||||||
time_stride=2,
|
|
||||||
stride=4,
|
|
||||||
context=1,
|
|
||||||
context_enc=0,
|
|
||||||
# Normalization
|
|
||||||
norm_starts=4,
|
|
||||||
norm_groups=4,
|
|
||||||
# DConv residual branch
|
|
||||||
dconv_mode=1,
|
|
||||||
dconv_depth=2,
|
|
||||||
dconv_comp=4,
|
|
||||||
dconv_attn=4,
|
|
||||||
dconv_lstm=4,
|
|
||||||
dconv_init=1e-4,
|
|
||||||
# Weight init
|
|
||||||
rescale=0.1,
|
|
||||||
# Metadata
|
|
||||||
samplerate=44100,
|
|
||||||
segment=4 * 10):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
sources (list[str]): list of source names.
|
|
||||||
audio_channels (int): input/output audio channels.
|
|
||||||
channels (int): initial number of hidden channels.
|
|
||||||
channels_time: if not None, use a different `channels` value for the time branch.
|
|
||||||
growth: increase the number of hidden channels by this factor at each layer.
|
|
||||||
nfft: number of fft bins. Note that changing this require careful computation of
|
|
||||||
various shape parameters and will not work out of the box for hybrid models.
|
|
||||||
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
|
||||||
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
|
||||||
wiener_residual: add residual source before wiener filtering.
|
|
||||||
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
|
||||||
in input and output. no further processing is done before ISTFT.
|
|
||||||
depth (int): number of layers in the encoder and in the decoder.
|
|
||||||
rewrite (bool): add 1x1 convolution to each layer.
|
|
||||||
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
|
|
||||||
hybrid_old: some models trained for MDX had a padding bug. This replicates
|
|
||||||
this bug to avoid retraining them.
|
|
||||||
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
|
||||||
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
|
||||||
layers will be wrapped.
|
|
||||||
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
|
||||||
the actual value controls the weight of the embedding.
|
|
||||||
emb_scale: equivalent to scaling the embedding learning rate
|
|
||||||
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
|
||||||
kernel_size: kernel_size for encoder and decoder layers.
|
|
||||||
stride: stride for encoder and decoder layers.
|
|
||||||
time_stride: stride for the final time layer, after the merge.
|
|
||||||
context: context for 1x1 conv in the decoder.
|
|
||||||
context_enc: context for 1x1 conv in the encoder.
|
|
||||||
norm_starts: layer at which group norm starts being used.
|
|
||||||
decoder layers are numbered in reverse order.
|
|
||||||
norm_groups: number of groups for group norm.
|
|
||||||
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
|
||||||
dconv_depth: depth of residual DConv branch.
|
|
||||||
dconv_comp: compression of DConv branch.
|
|
||||||
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
|
||||||
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
|
||||||
dconv_init: initial scale for the DConv branch LayerScale.
|
|
||||||
rescale: weight recaling trick
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.cac = cac
|
|
||||||
self.wiener_residual = wiener_residual
|
|
||||||
self.audio_channels = audio_channels
|
|
||||||
self.sources = sources
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.context = context
|
|
||||||
self.stride = stride
|
|
||||||
self.depth = depth
|
|
||||||
self.channels = channels
|
|
||||||
self.samplerate = samplerate
|
|
||||||
self.segment = segment
|
|
||||||
|
|
||||||
self.nfft = nfft
|
|
||||||
self.hop_length = nfft // 4
|
|
||||||
self.wiener_iters = wiener_iters
|
|
||||||
self.end_iters = end_iters
|
|
||||||
self.freq_emb = None
|
|
||||||
self.hybrid = hybrid
|
|
||||||
self.hybrid_old = hybrid_old
|
|
||||||
if hybrid_old:
|
|
||||||
assert hybrid, "hybrid_old must come with hybrid=True"
|
|
||||||
if hybrid:
|
|
||||||
assert wiener_iters == end_iters
|
|
||||||
|
|
||||||
self.encoder = nn.ModuleList()
|
|
||||||
self.decoder = nn.ModuleList()
|
|
||||||
|
|
||||||
if hybrid:
|
|
||||||
self.tencoder = nn.ModuleList()
|
|
||||||
self.tdecoder = nn.ModuleList()
|
|
||||||
|
|
||||||
chin = audio_channels
|
|
||||||
chin_z = chin # number of channels for the freq branch
|
|
||||||
if self.cac:
|
|
||||||
chin_z *= 2
|
|
||||||
chout = channels_time or channels
|
|
||||||
chout_z = channels
|
|
||||||
freqs = nfft // 2
|
|
||||||
|
|
||||||
for index in range(depth):
|
|
||||||
lstm = index >= dconv_lstm
|
|
||||||
attn = index >= dconv_attn
|
|
||||||
norm = index >= norm_starts
|
|
||||||
freq = freqs > 1
|
|
||||||
stri = stride
|
|
||||||
ker = kernel_size
|
|
||||||
if not freq:
|
|
||||||
assert freqs == 1
|
|
||||||
ker = time_stride * 2
|
|
||||||
stri = time_stride
|
|
||||||
|
|
||||||
pad = True
|
|
||||||
last_freq = False
|
|
||||||
if freq and freqs <= kernel_size:
|
|
||||||
ker = freqs
|
|
||||||
pad = False
|
|
||||||
last_freq = True
|
|
||||||
|
|
||||||
kw = {
|
|
||||||
'kernel_size': ker,
|
|
||||||
'stride': stri,
|
|
||||||
'freq': freq,
|
|
||||||
'pad': pad,
|
|
||||||
'norm': norm,
|
|
||||||
'rewrite': rewrite,
|
|
||||||
'norm_groups': norm_groups,
|
|
||||||
'dconv_kw': {
|
|
||||||
'lstm': lstm,
|
|
||||||
'attn': attn,
|
|
||||||
'depth': dconv_depth,
|
|
||||||
'compress': dconv_comp,
|
|
||||||
'init': dconv_init,
|
|
||||||
'gelu': True,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
kwt = dict(kw)
|
|
||||||
kwt['freq'] = 0
|
|
||||||
kwt['kernel_size'] = kernel_size
|
|
||||||
kwt['stride'] = stride
|
|
||||||
kwt['pad'] = True
|
|
||||||
kw_dec = dict(kw)
|
|
||||||
multi = False
|
|
||||||
if multi_freqs and index < multi_freqs_depth:
|
|
||||||
multi = True
|
|
||||||
kw_dec['context_freq'] = False
|
|
||||||
|
|
||||||
if last_freq:
|
|
||||||
chout_z = max(chout, chout_z)
|
|
||||||
chout = chout_z
|
|
||||||
|
|
||||||
enc = HEncLayer(chin_z, chout_z,
|
|
||||||
dconv=dconv_mode & 1, context=context_enc, **kw)
|
|
||||||
if hybrid and freq:
|
|
||||||
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
|
|
||||||
empty=last_freq, **kwt)
|
|
||||||
self.tencoder.append(tenc)
|
|
||||||
|
|
||||||
if multi:
|
|
||||||
enc = MultiWrap(enc, multi_freqs)
|
|
||||||
self.encoder.append(enc)
|
|
||||||
if index == 0:
|
|
||||||
chin = self.audio_channels * len(self.sources)
|
|
||||||
chin_z = chin
|
|
||||||
if self.cac:
|
|
||||||
chin_z *= 2
|
|
||||||
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
|
|
||||||
last=index == 0, context=context, **kw_dec)
|
|
||||||
if multi:
|
|
||||||
dec = MultiWrap(dec, multi_freqs)
|
|
||||||
if hybrid and freq:
|
|
||||||
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
|
|
||||||
last=index == 0, context=context, **kwt)
|
|
||||||
self.tdecoder.insert(0, tdec)
|
|
||||||
self.decoder.insert(0, dec)
|
|
||||||
|
|
||||||
chin = chout
|
|
||||||
chin_z = chout_z
|
|
||||||
chout = int(growth * chout)
|
|
||||||
chout_z = int(growth * chout_z)
|
|
||||||
if freq:
|
|
||||||
if freqs <= kernel_size:
|
|
||||||
freqs = 1
|
|
||||||
else:
|
|
||||||
freqs //= stride
|
|
||||||
if index == 0 and freq_emb:
|
|
||||||
self.freq_emb = ScaledEmbedding(
|
|
||||||
freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
|
||||||
self.freq_emb_scale = freq_emb
|
|
||||||
|
|
||||||
if rescale:
|
|
||||||
rescale_module(self, reference=rescale)
|
|
||||||
|
|
||||||
def _spec(self, x):
|
|
||||||
hl = self.hop_length
|
|
||||||
nfft = self.nfft
|
|
||||||
x0 = x # noqa
|
|
||||||
|
|
||||||
if self.hybrid:
|
|
||||||
# We re-pad the signal in order to keep the property
|
|
||||||
# that the size of the output is exactly the size of the input
|
|
||||||
# divided by the stride (here hop_length), when divisible.
|
|
||||||
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
|
||||||
# which is not supported by torch.stft.
|
|
||||||
# Having all convolution operations follow this convention allow to easily
|
|
||||||
# align the time and frequency branches later on.
|
|
||||||
assert hl == nfft // 4
|
|
||||||
le = int(math.ceil(x.shape[-1] / hl))
|
|
||||||
pad = hl // 2 * 3
|
|
||||||
if not self.hybrid_old:
|
|
||||||
x = F.pad(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
|
|
||||||
else:
|
|
||||||
x = F.pad(x, (pad, pad + le * hl - x.shape[-1]))
|
|
||||||
|
|
||||||
z = spectro(x, nfft, hl)[..., :-1, :]
|
|
||||||
if self.hybrid:
|
|
||||||
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
|
||||||
z = z[..., 2:2+le]
|
|
||||||
return z
|
|
||||||
|
|
||||||
def _ispec(self, z, length=None, scale=0):
|
|
||||||
hl = self.hop_length // (4 ** scale)
|
|
||||||
z = F.pad(z, (0, 0, 0, 1))
|
|
||||||
if self.hybrid:
|
|
||||||
z = F.pad(z, (2, 2))
|
|
||||||
pad = hl // 2 * 3
|
|
||||||
if not self.hybrid_old:
|
|
||||||
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
|
||||||
else:
|
|
||||||
le = hl * int(math.ceil(length / hl))
|
|
||||||
x = ispectro(z, hl, length=le)
|
|
||||||
if not self.hybrid_old:
|
|
||||||
x = x[..., pad:pad + length]
|
|
||||||
else:
|
|
||||||
x = x[..., :length]
|
|
||||||
else:
|
|
||||||
x = ispectro(z, hl, length)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _magnitude(self, z):
|
|
||||||
# return the magnitude of the spectrogram, except when cac is True,
|
|
||||||
# in which case we just move the complex dimension to the channel one.
|
|
||||||
if self.cac:
|
|
||||||
B, C, Fr, T = z.shape
|
|
||||||
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
|
||||||
m = m.reshape(B, C * 2, Fr, T)
|
|
||||||
else:
|
|
||||||
m = z.abs()
|
|
||||||
return m
|
|
||||||
|
|
||||||
def _mask(self, z, m):
|
|
||||||
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
|
||||||
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
|
||||||
niters = self.wiener_iters
|
|
||||||
if self.cac:
|
|
||||||
B, S, C, Fr, T = m.shape
|
|
||||||
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
|
||||||
out = torch.view_as_complex(out.contiguous())
|
|
||||||
return out
|
|
||||||
if self.training:
|
|
||||||
niters = self.end_iters
|
|
||||||
if niters < 0:
|
|
||||||
z = z[:, None]
|
|
||||||
return z / (1e-8 + z.abs()) * m
|
|
||||||
else:
|
|
||||||
return self._wiener(m, z, niters)
|
|
||||||
|
|
||||||
def _wiener(self, mag_out, mix_stft, niters):
|
|
||||||
# apply wiener filtering from OpenUnmix.
|
|
||||||
init = mix_stft.dtype
|
|
||||||
wiener_win_len = 300
|
|
||||||
residual = self.wiener_residual
|
|
||||||
|
|
||||||
B, S, C, Fq, T = mag_out.shape
|
|
||||||
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
|
||||||
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
|
||||||
|
|
||||||
outs = []
|
|
||||||
for sample in range(B):
|
|
||||||
pos = 0
|
|
||||||
out = []
|
|
||||||
for pos in range(0, T, wiener_win_len):
|
|
||||||
frame = slice(pos, pos + wiener_win_len)
|
|
||||||
z_out = wiener(
|
|
||||||
mag_out[sample, frame], mix_stft[sample, frame], niters,
|
|
||||||
residual=residual)
|
|
||||||
out.append(z_out.transpose(-1, -2))
|
|
||||||
outs.append(torch.cat(out, dim=0))
|
|
||||||
out = torch.view_as_complex(torch.stack(outs, 0))
|
|
||||||
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
|
||||||
if residual:
|
|
||||||
out = out[:, :-1]
|
|
||||||
assert list(out.shape) == [B, S, C, Fq, T]
|
|
||||||
return out.to(init)
|
|
||||||
|
|
||||||
def forward(self, mix):
|
|
||||||
x = mix
|
|
||||||
length = x.shape[-1]
|
|
||||||
|
|
||||||
z = self._spec(mix)
|
|
||||||
mag = self._magnitude(z)
|
|
||||||
x = mag
|
|
||||||
|
|
||||||
B, C, Fq, T = x.shape
|
|
||||||
|
|
||||||
# unlike previous Demucs, we always normalize because it is easier.
|
|
||||||
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
|
||||||
std = x.std(dim=(1, 2, 3), keepdim=True)
|
|
||||||
x = (x - mean) / (1e-5 + std)
|
|
||||||
# x will be the freq. branch input.
|
|
||||||
|
|
||||||
if self.hybrid:
|
|
||||||
# Prepare the time branch input.
|
|
||||||
xt = mix
|
|
||||||
meant = xt.mean(dim=(1, 2), keepdim=True)
|
|
||||||
stdt = xt.std(dim=(1, 2), keepdim=True)
|
|
||||||
xt = (xt - meant) / (1e-5 + stdt)
|
|
||||||
|
|
||||||
# okay, this is a giant mess I know...
|
|
||||||
saved = [] # skip connections, freq.
|
|
||||||
saved_t = [] # skip connections, time.
|
|
||||||
lengths = [] # saved lengths to properly remove padding, freq branch.
|
|
||||||
lengths_t = [] # saved lengths for time branch.
|
|
||||||
for idx, encode in enumerate(self.encoder):
|
|
||||||
lengths.append(x.shape[-1])
|
|
||||||
inject = None
|
|
||||||
if self.hybrid and idx < len(self.tencoder):
|
|
||||||
# we have not yet merged branches.
|
|
||||||
lengths_t.append(xt.shape[-1])
|
|
||||||
tenc = self.tencoder[idx]
|
|
||||||
xt = tenc(xt)
|
|
||||||
if not tenc.empty:
|
|
||||||
# save for skip connection
|
|
||||||
saved_t.append(xt)
|
|
||||||
else:
|
|
||||||
# tenc contains just the first conv., so that now time and freq.
|
|
||||||
# branches have the same shape and can be merged.
|
|
||||||
inject = xt
|
|
||||||
x = encode(x, inject)
|
|
||||||
if idx == 0 and self.freq_emb is not None:
|
|
||||||
# add frequency embedding to allow for non equivariant convolutions
|
|
||||||
# over the frequency axis.
|
|
||||||
frs = torch.arange(x.shape[-2], device=x.device)
|
|
||||||
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
|
||||||
x = x + self.freq_emb_scale * emb
|
|
||||||
|
|
||||||
saved.append(x)
|
|
||||||
|
|
||||||
x = torch.zeros_like(x)
|
|
||||||
if self.hybrid:
|
|
||||||
xt = torch.zeros_like(x)
|
|
||||||
# initialize everything to zero (signal will go through u-net skips).
|
|
||||||
|
|
||||||
for idx, decode in enumerate(self.decoder):
|
|
||||||
skip = saved.pop(-1)
|
|
||||||
x, pre = decode(x, skip, lengths.pop(-1))
|
|
||||||
# `pre` contains the output just before final transposed convolution,
|
|
||||||
# which is used when the freq. and time branch separate.
|
|
||||||
|
|
||||||
if self.hybrid:
|
|
||||||
offset = self.depth - len(self.tdecoder)
|
|
||||||
if self.hybrid and idx >= offset:
|
|
||||||
tdec = self.tdecoder[idx - offset]
|
|
||||||
length_t = lengths_t.pop(-1)
|
|
||||||
if tdec.empty:
|
|
||||||
assert pre.shape[2] == 1, pre.shape
|
|
||||||
pre = pre[:, :, 0]
|
|
||||||
xt, _ = tdec(pre, None, length_t)
|
|
||||||
else:
|
|
||||||
skip = saved_t.pop(-1)
|
|
||||||
xt, _ = tdec(xt, skip, length_t)
|
|
||||||
|
|
||||||
# Let's make sure we used all stored skip connections.
|
|
||||||
assert len(saved) == 0
|
|
||||||
assert len(lengths_t) == 0
|
|
||||||
assert len(saved_t) == 0
|
|
||||||
|
|
||||||
S = len(self.sources)
|
|
||||||
x = x.view(B, S, -1, Fq, T)
|
|
||||||
x = x * std[:, None] + mean[:, None]
|
|
||||||
|
|
||||||
zout = self._mask(z, x)
|
|
||||||
x = self._ispec(zout, length)
|
|
||||||
|
|
||||||
if self.hybrid:
|
|
||||||
xt = xt.view(B, S, -1, length)
|
|
||||||
xt = xt * stdt[:, None] + meant[:, None]
|
|
||||||
x = xt + x
|
|
||||||
return x
|
|
@ -1,66 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""Loading pretrained models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
from dora.log import fatal
|
|
||||||
|
|
||||||
from .hdemucs import HDemucs
|
|
||||||
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
|
|
||||||
REMOTE_ROOT = Path(__file__).parent / 'remote'
|
|
||||||
|
|
||||||
SOURCES = ["drums", "bass", "other", "vocals"]
|
|
||||||
|
|
||||||
|
|
||||||
def demucs_unittest():
|
|
||||||
model = HDemucs(channels=4, sources=SOURCES)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_flags(parser):
|
|
||||||
group = parser.add_mutually_exclusive_group(required=False)
|
|
||||||
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
|
|
||||||
group.add_argument("-n", "--name", default="mdx_extra_q",
|
|
||||||
help="Pretrained model name or signature. Default is mdx_extra_q.")
|
|
||||||
parser.add_argument("--repo", type=Path,
|
|
||||||
help="Folder containing all pre-trained models for use with -n.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(name: str,
|
|
||||||
repo: tp.Optional[Path] = None):
|
|
||||||
"""`name` must be a bag of models name or a pretrained signature
|
|
||||||
from the remote AWS model repo or the specified local repo if `repo` is not None.
|
|
||||||
"""
|
|
||||||
if name == 'demucs_unittest':
|
|
||||||
return demucs_unittest()
|
|
||||||
model_repo: ModelOnlyRepo
|
|
||||||
if repo is None:
|
|
||||||
remote_files = [line.strip()
|
|
||||||
for line in (REMOTE_ROOT / 'files.txt').read_text().split('\n')
|
|
||||||
if line.strip()]
|
|
||||||
model_repo = RemoteRepo(ROOT_URL, remote_files)
|
|
||||||
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
|
|
||||||
else:
|
|
||||||
if not repo.is_dir():
|
|
||||||
fatal(f"{repo} must exist and be a directory.")
|
|
||||||
model_repo = LocalRepo(repo)
|
|
||||||
bag_repo = BagOnlyRepo(repo, model_repo)
|
|
||||||
any_repo = AnyModelRepo(model_repo, bag_repo)
|
|
||||||
return any_repo.get_model(name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_from_args(args):
|
|
||||||
"""
|
|
||||||
Load local model package or pre-trained model.
|
|
||||||
"""
|
|
||||||
return get_model(name=args.name, repo=args.repo)
|
|
153
demucs/repo.py
153
demucs/repo.py
@ -1,153 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""Represents a model repository, including pre-trained models and bags of models.
|
|
||||||
A repo can either be the main remote repository stored in AWS, or a local repository
|
|
||||||
with your own models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from hashlib import sha256
|
|
||||||
from pathlib import Path
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from .apply import BagOfModels, Model
|
|
||||||
from .states import load_model
|
|
||||||
|
|
||||||
|
|
||||||
AnyModel = tp.Union[Model, BagOfModels]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLoadingError(RuntimeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def check_checksum(path: Path, checksum: str):
|
|
||||||
sha = sha256()
|
|
||||||
with open(path, 'rb') as file:
|
|
||||||
while True:
|
|
||||||
buf = file.read(2**20)
|
|
||||||
if not buf:
|
|
||||||
break
|
|
||||||
sha.update(buf)
|
|
||||||
actual_checksum = sha.hexdigest()[:len(checksum)]
|
|
||||||
if actual_checksum != checksum:
|
|
||||||
raise ModelLoadingError(f'Invalid checksum for file {path}, '
|
|
||||||
f'expected {checksum} but got {actual_checksum}')
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOnlyRepo:
|
|
||||||
"""Base class for all model only repos.
|
|
||||||
"""
|
|
||||||
def has_model(self, sig: str) -> bool:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_model(self, sig: str) -> Model:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteRepo(ModelOnlyRepo):
|
|
||||||
def __init__(self, root_url: str, remote_files: tp.List[str]):
|
|
||||||
if not root_url.endswith('/'):
|
|
||||||
root_url += '/'
|
|
||||||
self._models: tp.Dict[str, str] = {}
|
|
||||||
for file in remote_files:
|
|
||||||
sig, checksum = file.split('.')[0].split('-')
|
|
||||||
assert sig not in self._models
|
|
||||||
self._models[sig] = root_url + file
|
|
||||||
|
|
||||||
def has_model(self, sig: str) -> bool:
|
|
||||||
return sig in self._models
|
|
||||||
|
|
||||||
def get_model(self, sig: str) -> Model:
|
|
||||||
try:
|
|
||||||
url = self._models[sig]
|
|
||||||
except KeyError:
|
|
||||||
raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
|
|
||||||
pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
|
|
||||||
return load_model(pkg)
|
|
||||||
|
|
||||||
|
|
||||||
class LocalRepo(ModelOnlyRepo):
|
|
||||||
def __init__(self, root: Path):
|
|
||||||
self.root = root
|
|
||||||
self.scan()
|
|
||||||
|
|
||||||
def scan(self):
|
|
||||||
self._models = {}
|
|
||||||
self._checksums = {}
|
|
||||||
for file in self.root.iterdir():
|
|
||||||
if file.suffix == '.th':
|
|
||||||
if '-' in file.stem:
|
|
||||||
xp_sig, checksum = file.stem.split('-')
|
|
||||||
self._checksums[xp_sig] = checksum
|
|
||||||
else:
|
|
||||||
xp_sig = file.stem
|
|
||||||
if xp_sig in self._models:
|
|
||||||
raise ModelLoadingError(
|
|
||||||
f'Duplicate pre-trained model exist for signature {xp_sig}. '
|
|
||||||
'Please delete all but one.')
|
|
||||||
self._models[xp_sig] = file
|
|
||||||
|
|
||||||
def has_model(self, sig: str) -> bool:
|
|
||||||
return sig in self._models
|
|
||||||
|
|
||||||
def get_model(self, sig: str) -> Model:
|
|
||||||
try:
|
|
||||||
file = self._models[sig]
|
|
||||||
except KeyError:
|
|
||||||
raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
|
|
||||||
if sig in self._checksums:
|
|
||||||
check_checksum(file, self._checksums[sig])
|
|
||||||
return load_model(file)
|
|
||||||
|
|
||||||
|
|
||||||
class BagOnlyRepo:
|
|
||||||
"""Handles only YAML files containing bag of models, leaving the actual
|
|
||||||
model loading to some Repo.
|
|
||||||
"""
|
|
||||||
def __init__(self, root: Path, model_repo: ModelOnlyRepo):
|
|
||||||
self.root = root
|
|
||||||
self.model_repo = model_repo
|
|
||||||
self.scan()
|
|
||||||
|
|
||||||
def scan(self):
|
|
||||||
self._bags = {}
|
|
||||||
for file in self.root.iterdir():
|
|
||||||
if file.suffix == '.yaml':
|
|
||||||
self._bags[file.stem] = file
|
|
||||||
|
|
||||||
def has_model(self, name: str) -> bool:
|
|
||||||
return name in self._bags
|
|
||||||
|
|
||||||
def get_model(self, name: str) -> BagOfModels:
|
|
||||||
try:
|
|
||||||
yaml_file = self._bags[name]
|
|
||||||
except KeyError:
|
|
||||||
raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
|
|
||||||
'a bag of models.')
|
|
||||||
bag = yaml.safe_load(open(yaml_file))
|
|
||||||
signatures = bag['models']
|
|
||||||
models = [self.model_repo.get_model(sig) for sig in signatures]
|
|
||||||
weights = bag.get('weights')
|
|
||||||
segment = bag.get('segment')
|
|
||||||
return BagOfModels(models, weights, segment)
|
|
||||||
|
|
||||||
|
|
||||||
class AnyModelRepo:
|
|
||||||
def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
|
|
||||||
self.model_repo = model_repo
|
|
||||||
self.bag_repo = bag_repo
|
|
||||||
|
|
||||||
def has_model(self, name_or_sig: str) -> bool:
|
|
||||||
return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
|
|
||||||
|
|
||||||
def get_model(self, name_or_sig: str) -> AnyModel:
|
|
||||||
if self.model_repo.has_model(name_or_sig):
|
|
||||||
return self.model_repo.get_model(name_or_sig)
|
|
||||||
else:
|
|
||||||
return self.bag_repo.get_model(name_or_sig)
|
|
@ -1,186 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
from dora.log import fatal
|
|
||||||
import torch as th
|
|
||||||
import torchaudio as ta
|
|
||||||
|
|
||||||
from .apply import apply_model, BagOfModels
|
|
||||||
from .audio import AudioFile, convert_audio, save_audio
|
|
||||||
from .pretrained import get_model_from_args, add_model_flags, ModelLoadingError
|
|
||||||
|
|
||||||
|
|
||||||
def load_track(track, audio_channels, samplerate):
|
|
||||||
errors = {}
|
|
||||||
wav = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
wav = AudioFile(track).read(
|
|
||||||
streams=0,
|
|
||||||
samplerate=samplerate,
|
|
||||||
channels=audio_channels)
|
|
||||||
except FileNotFoundError:
|
|
||||||
errors['ffmpeg'] = 'Ffmpeg is not installed.'
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
errors['ffmpeg'] = 'FFmpeg could not read the file.'
|
|
||||||
|
|
||||||
if wav is None:
|
|
||||||
try:
|
|
||||||
wav, sr = ta.load(str(track))
|
|
||||||
except RuntimeError as err:
|
|
||||||
errors['torchaudio'] = err.args[0]
|
|
||||||
else:
|
|
||||||
wav = convert_audio(wav, sr, samplerate, audio_channels)
|
|
||||||
|
|
||||||
if wav is None:
|
|
||||||
print(f"Could not load file {track}. "
|
|
||||||
"Maybe it is not a supported file format? ")
|
|
||||||
for backend, error in errors.items():
|
|
||||||
print(f"When trying to load using {backend}, got the following error: {error}")
|
|
||||||
sys.exit(1)
|
|
||||||
return wav
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser("demucs.separate",
|
|
||||||
description="Separate the sources for the given tracks")
|
|
||||||
parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
|
|
||||||
add_model_flags(parser)
|
|
||||||
parser.add_argument("-v", "--verbose", action="store_true")
|
|
||||||
parser.add_argument("-o",
|
|
||||||
"--out",
|
|
||||||
type=Path,
|
|
||||||
default=Path("separated"),
|
|
||||||
help="Folder where to put extracted tracks. A subfolder "
|
|
||||||
"with the model name will be created.")
|
|
||||||
parser.add_argument("-d",
|
|
||||||
"--device",
|
|
||||||
default="cuda" if th.cuda.is_available() else "cpu",
|
|
||||||
help="Device to use, default is cuda if available else cpu")
|
|
||||||
parser.add_argument("--shifts",
|
|
||||||
default=1,
|
|
||||||
type=int,
|
|
||||||
help="Number of random shifts for equivariant stabilization."
|
|
||||||
"Increase separation time but improves quality for Demucs. 10 was used "
|
|
||||||
"in the original paper.")
|
|
||||||
parser.add_argument("--overlap",
|
|
||||||
default=0.25,
|
|
||||||
type=float,
|
|
||||||
help="Overlap between the splits.")
|
|
||||||
split_group = parser.add_mutually_exclusive_group()
|
|
||||||
split_group.add_argument("--no-split",
|
|
||||||
action="store_false",
|
|
||||||
dest="split",
|
|
||||||
default=True,
|
|
||||||
help="Doesn't split audio in chunks. "
|
|
||||||
"This can use large amounts of memory.")
|
|
||||||
split_group.add_argument("--segment", type=int,
|
|
||||||
help="Set split size of each chunk. "
|
|
||||||
"This can help save memory of graphic card. ")
|
|
||||||
parser.add_argument("--two-stems",
|
|
||||||
dest="stem", metavar="STEM",
|
|
||||||
help="Only separate audio into {STEM} and no_{STEM}. ")
|
|
||||||
group = parser.add_mutually_exclusive_group()
|
|
||||||
group.add_argument("--int24", action="store_true",
|
|
||||||
help="Save wav output as 24 bits wav.")
|
|
||||||
group.add_argument("--float32", action="store_true",
|
|
||||||
help="Save wav output as float32 (2x bigger).")
|
|
||||||
parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp"],
|
|
||||||
help="Strategy for avoiding clipping: rescaling entire signal "
|
|
||||||
"if necessary (rescale) or hard clipping (clamp).")
|
|
||||||
parser.add_argument("--mp3", action="store_true",
|
|
||||||
help="Convert the output wavs to mp3.")
|
|
||||||
parser.add_argument("--mp3-bitrate",
|
|
||||||
default=320,
|
|
||||||
type=int,
|
|
||||||
help="Bitrate of converted mp3.")
|
|
||||||
parser.add_argument("-j", "--jobs",
|
|
||||||
default=0,
|
|
||||||
type=int,
|
|
||||||
help="Number of jobs. This can increase memory usage but will "
|
|
||||||
"be much faster when multiple cores are available.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
|
||||||
model = get_model_from_args(args)
|
|
||||||
except ModelLoadingError as error:
|
|
||||||
fatal(error.args[0])
|
|
||||||
|
|
||||||
if args.segment is not None and args.segment < 8:
|
|
||||||
fatal('Segment must greater than 8. ')
|
|
||||||
|
|
||||||
if isinstance(model, BagOfModels):
|
|
||||||
if args.segment is not None:
|
|
||||||
for sub in model.models:
|
|
||||||
sub.segment = args.segment
|
|
||||||
else:
|
|
||||||
if args.segment is not None:
|
|
||||||
sub.segment = args.segment
|
|
||||||
|
|
||||||
model.cpu()
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if args.stem is not None and args.stem not in model.sources:
|
|
||||||
fatal(
|
|
||||||
'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format(
|
|
||||||
stem=args.stem, sources=', '.join(model.sources)))
|
|
||||||
out = args.out / args.name
|
|
||||||
out.mkdir(parents=True, exist_ok=True)
|
|
||||||
print(f"Separated tracks will be stored in {out.resolve()}")
|
|
||||||
for track in args.tracks:
|
|
||||||
if not track.exists():
|
|
||||||
print(
|
|
||||||
f"File {track} does not exist. If the path contains spaces, "
|
|
||||||
"please try again after surrounding the entire path with quotes \"\".",
|
|
||||||
file=sys.stderr)
|
|
||||||
continue
|
|
||||||
print(f"Separating track {track}")
|
|
||||||
wav = load_track(track, model.audio_channels, model.samplerate)
|
|
||||||
|
|
||||||
ref = wav.mean(0)
|
|
||||||
wav = (wav - ref.mean()) / ref.std()
|
|
||||||
sources = apply_model(model, wav[None], device=args.device, shifts=args.shifts,
|
|
||||||
split=args.split, overlap=args.overlap, progress=True,
|
|
||||||
num_workers=args.jobs)[0]
|
|
||||||
sources = sources * ref.std() + ref.mean()
|
|
||||||
|
|
||||||
track_folder = out / track.name.rsplit(".", 1)[0]
|
|
||||||
track_folder.mkdir(exist_ok=True)
|
|
||||||
if args.mp3:
|
|
||||||
ext = ".mp3"
|
|
||||||
else:
|
|
||||||
ext = ".wav"
|
|
||||||
kwargs = {
|
|
||||||
'samplerate': model.samplerate,
|
|
||||||
'bitrate': args.mp3_bitrate,
|
|
||||||
'clip': args.clip_mode,
|
|
||||||
'as_float': args.float32,
|
|
||||||
'bits_per_sample': 24 if args.int24 else 16,
|
|
||||||
}
|
|
||||||
if args.stem is None:
|
|
||||||
for source, name in zip(sources, model.sources):
|
|
||||||
stem = str(track_folder / (name + ext))
|
|
||||||
save_audio(source, stem, **kwargs)
|
|
||||||
else:
|
|
||||||
sources = list(sources)
|
|
||||||
stem = str(track_folder / (args.stem + ext))
|
|
||||||
save_audio(sources.pop(model.sources.index(args.stem)), stem, **kwargs)
|
|
||||||
# Warning : after poping the stem, selected stem is no longer in the list 'sources'
|
|
||||||
other_stem = th.zeros_like(sources[0])
|
|
||||||
for i in sources:
|
|
||||||
other_stem += i
|
|
||||||
stem = str(track_folder / ("no_" + args.stem + ext))
|
|
||||||
save_audio(other_stem, stem, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
404
demucs/solver.py
404
demucs/solver.py
@ -1,404 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""Main training loop."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from dora import get_xp
|
|
||||||
from dora.utils import write_and_rename
|
|
||||||
from dora.log import LogProgress, bold
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from . import augment, distrib, states, pretrained
|
|
||||||
from .apply import apply_model
|
|
||||||
from .ema import ModelEMA
|
|
||||||
from .evaluate import evaluate, new_sdr
|
|
||||||
from .svd import svd_penalty
|
|
||||||
from .utils import pull_metric, EMA
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _summary(metrics):
|
|
||||||
return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items())
|
|
||||||
|
|
||||||
|
|
||||||
class Solver(object):
|
|
||||||
def __init__(self, loaders, model, optimizer, args):
|
|
||||||
self.args = args
|
|
||||||
self.loaders = loaders
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer)
|
|
||||||
self.dmodel = distrib.wrap(model)
|
|
||||||
self.device = next(iter(self.model.parameters())).device
|
|
||||||
|
|
||||||
# Exponential moving average of the model, either updated every batch or epoch.
|
|
||||||
# The best model from all the EMAs and the original one is kept based on the valid
|
|
||||||
# loss for the final best model.
|
|
||||||
self.emas = {'batch': [], 'epoch': []}
|
|
||||||
for kind in self.emas.keys():
|
|
||||||
decays = getattr(args.ema, kind)
|
|
||||||
device = self.device if kind == 'batch' else 'cpu'
|
|
||||||
if decays:
|
|
||||||
for decay in decays:
|
|
||||||
self.emas[kind].append(ModelEMA(self.model, decay, device=device))
|
|
||||||
|
|
||||||
# data augment
|
|
||||||
augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift),
|
|
||||||
same=args.augment.shift_same)]
|
|
||||||
if args.augment.flip:
|
|
||||||
augments += [augment.FlipChannels(), augment.FlipSign()]
|
|
||||||
for aug in ['scale', 'remix']:
|
|
||||||
kw = getattr(args.augment, aug)
|
|
||||||
if kw.proba:
|
|
||||||
augments.append(getattr(augment, aug.capitalize())(**kw))
|
|
||||||
self.augment = torch.nn.Sequential(*augments)
|
|
||||||
|
|
||||||
xp = get_xp()
|
|
||||||
self.folder = xp.folder
|
|
||||||
# Checkpoints
|
|
||||||
self.checkpoint_file = xp.folder / 'checkpoint.th'
|
|
||||||
self.best_file = xp.folder / 'best.th'
|
|
||||||
logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve())
|
|
||||||
self.best_state = None
|
|
||||||
self.best_changed = False
|
|
||||||
|
|
||||||
self.link = xp.link
|
|
||||||
self.history = self.link.history
|
|
||||||
|
|
||||||
self._reset()
|
|
||||||
|
|
||||||
def _serialize(self, epoch):
|
|
||||||
package = {}
|
|
||||||
package['state'] = self.model.state_dict()
|
|
||||||
package['optimizer'] = self.optimizer.state_dict()
|
|
||||||
package['history'] = self.history
|
|
||||||
package['best_state'] = self.best_state
|
|
||||||
package['args'] = self.args
|
|
||||||
for kind, emas in self.emas.items():
|
|
||||||
for k, ema in enumerate(emas):
|
|
||||||
package[f'ema_{kind}_{k}'] = ema.state_dict()
|
|
||||||
with write_and_rename(self.checkpoint_file) as tmp:
|
|
||||||
torch.save(package, tmp)
|
|
||||||
|
|
||||||
save_every = self.args.save_every
|
|
||||||
if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs:
|
|
||||||
with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp:
|
|
||||||
torch.save(package, tmp)
|
|
||||||
|
|
||||||
if self.best_changed:
|
|
||||||
# Saving only the latest best model.
|
|
||||||
with write_and_rename(self.best_file) as tmp:
|
|
||||||
package = states.serialize_model(self.model, self.args)
|
|
||||||
package['state'] = self.best_state
|
|
||||||
torch.save(package, tmp)
|
|
||||||
self.best_changed = False
|
|
||||||
|
|
||||||
def _reset(self):
|
|
||||||
"""Reset state of the solver, potentially using checkpoint."""
|
|
||||||
if self.checkpoint_file.exists():
|
|
||||||
logger.info(f'Loading checkpoint model: {self.checkpoint_file}')
|
|
||||||
package = torch.load(self.checkpoint_file, 'cpu')
|
|
||||||
self.model.load_state_dict(package['state'])
|
|
||||||
self.optimizer.load_state_dict(package['optimizer'])
|
|
||||||
self.history[:] = package['history']
|
|
||||||
self.best_state = package['best_state']
|
|
||||||
for kind, emas in self.emas.items():
|
|
||||||
for k, ema in enumerate(emas):
|
|
||||||
ema.load_state_dict(package[f'ema_{kind}_{k}'])
|
|
||||||
elif self.args.continue_pretrained:
|
|
||||||
model = pretrained.get_model(
|
|
||||||
name=self.args.continue_pretrained,
|
|
||||||
repo=self.args.pretrained_repo)
|
|
||||||
self.model.load_state_dict(model.state_dict())
|
|
||||||
elif self.args.continue_from:
|
|
||||||
name = 'checkpoint.th'
|
|
||||||
root = self.folder.parent
|
|
||||||
cf = root / str(self.args.continue_from) / name
|
|
||||||
logger.info("Loading from %s", cf)
|
|
||||||
package = torch.load(cf, 'cpu')
|
|
||||||
self.best_state = package['best_state']
|
|
||||||
if self.args.continue_best:
|
|
||||||
self.model.load_state_dict(package['best_state'], strict=False)
|
|
||||||
else:
|
|
||||||
self.model.load_state_dict(package['state'], strict=False)
|
|
||||||
if self.args.continue_opt:
|
|
||||||
self.optimizer.load_state_dict(package['optimizer'])
|
|
||||||
|
|
||||||
def _format_train(self, metrics: dict) -> dict:
|
|
||||||
"""Formatting for train/valid metrics."""
|
|
||||||
losses = {
|
|
||||||
'loss': format(metrics['loss'], ".4f"),
|
|
||||||
'reco': format(metrics['reco'], ".4f"),
|
|
||||||
}
|
|
||||||
if 'nsdr' in metrics:
|
|
||||||
losses['nsdr'] = format(metrics['nsdr'], ".3f")
|
|
||||||
if self.quantizer is not None:
|
|
||||||
losses['ms'] = format(metrics['ms'], ".2f")
|
|
||||||
if 'grad' in metrics:
|
|
||||||
losses['grad'] = format(metrics['grad'], ".4f")
|
|
||||||
if 'best' in metrics:
|
|
||||||
losses['best'] = format(metrics['best'], '.4f')
|
|
||||||
if 'bname' in metrics:
|
|
||||||
losses['bname'] = metrics['bname']
|
|
||||||
if 'penalty' in metrics:
|
|
||||||
losses['penalty'] = format(metrics['penalty'], ".4f")
|
|
||||||
if 'hloss' in metrics:
|
|
||||||
losses['hloss'] = format(metrics['hloss'], ".4f")
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def _format_test(self, metrics: dict) -> dict:
|
|
||||||
"""Formatting for test metrics."""
|
|
||||||
losses = {}
|
|
||||||
if 'sdr' in metrics:
|
|
||||||
losses['sdr'] = format(metrics['sdr'], '.3f')
|
|
||||||
if 'nsdr' in metrics:
|
|
||||||
losses['nsdr'] = format(metrics['nsdr'], '.3f')
|
|
||||||
for source in self.model.sources:
|
|
||||||
key = f'sdr_{source}'
|
|
||||||
if key in metrics:
|
|
||||||
losses[key] = format(metrics[key], '.3f')
|
|
||||||
key = f'nsdr_{source}'
|
|
||||||
if key in metrics:
|
|
||||||
losses[key] = format(metrics[key], '.3f')
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
# Optimizing the model
|
|
||||||
if self.history:
|
|
||||||
logger.info("Replaying metrics from previous run")
|
|
||||||
for epoch, metrics in enumerate(self.history):
|
|
||||||
formatted = self._format_train(metrics['train'])
|
|
||||||
logger.info(
|
|
||||||
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
|
|
||||||
formatted = self._format_train(metrics['valid'])
|
|
||||||
logger.info(
|
|
||||||
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
|
|
||||||
if 'test' in metrics:
|
|
||||||
formatted = self._format_test(metrics['test'])
|
|
||||||
if formatted:
|
|
||||||
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
|
|
||||||
|
|
||||||
epoch = 0
|
|
||||||
for epoch in range(len(self.history), self.args.epochs):
|
|
||||||
# Train one epoch
|
|
||||||
self.model.train() # Turn on BatchNorm & Dropout
|
|
||||||
metrics = {}
|
|
||||||
logger.info('-' * 70)
|
|
||||||
logger.info("Training...")
|
|
||||||
metrics['train'] = self._run_one_epoch(epoch)
|
|
||||||
formatted = self._format_train(metrics['train'])
|
|
||||||
logger.info(
|
|
||||||
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
|
|
||||||
|
|
||||||
# Cross validation
|
|
||||||
logger.info('-' * 70)
|
|
||||||
logger.info('Cross validation...')
|
|
||||||
self.model.eval() # Turn off Batchnorm & Dropout
|
|
||||||
with torch.no_grad():
|
|
||||||
valid = self._run_one_epoch(epoch, train=False)
|
|
||||||
bvalid = valid
|
|
||||||
bname = 'main'
|
|
||||||
state = states.copy_state(self.model.state_dict())
|
|
||||||
metrics['valid'] = {}
|
|
||||||
metrics['valid']['main'] = valid
|
|
||||||
key = self.args.test.metric
|
|
||||||
for kind, emas in self.emas.items():
|
|
||||||
for k, ema in enumerate(emas):
|
|
||||||
with ema.swap():
|
|
||||||
valid = self._run_one_epoch(epoch, train=False)
|
|
||||||
name = f'ema_{kind}_{k}'
|
|
||||||
metrics['valid'][name] = valid
|
|
||||||
a = valid[key]
|
|
||||||
b = bvalid[key]
|
|
||||||
if key.startswith('nsdr'):
|
|
||||||
a = -a
|
|
||||||
b = -b
|
|
||||||
if a < b:
|
|
||||||
bvalid = valid
|
|
||||||
state = ema.state
|
|
||||||
bname = name
|
|
||||||
metrics['valid'].update(bvalid)
|
|
||||||
metrics['valid']['bname'] = bname
|
|
||||||
|
|
||||||
valid_loss = metrics['valid'][key]
|
|
||||||
mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss]
|
|
||||||
if key.startswith('nsdr'):
|
|
||||||
best_loss = max(mets)
|
|
||||||
else:
|
|
||||||
best_loss = min(mets)
|
|
||||||
metrics['valid']['best'] = best_loss
|
|
||||||
if self.args.svd.penalty > 0:
|
|
||||||
kw = dict(self.args.svd)
|
|
||||||
kw.pop('penalty')
|
|
||||||
with torch.no_grad():
|
|
||||||
penalty = svd_penalty(self.model, exact=True, **kw)
|
|
||||||
metrics['valid']['penalty'] = penalty
|
|
||||||
|
|
||||||
formatted = self._format_train(metrics['valid'])
|
|
||||||
logger.info(
|
|
||||||
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
|
|
||||||
|
|
||||||
# Save the best model
|
|
||||||
if valid_loss == best_loss or self.args.dset.train_valid:
|
|
||||||
logger.info(bold('New best valid loss %.4f'), valid_loss)
|
|
||||||
self.best_state = states.copy_state(state)
|
|
||||||
self.best_changed = True
|
|
||||||
|
|
||||||
# Eval model every `test.every` epoch or on last epoch
|
|
||||||
should_eval = (epoch + 1) % self.args.test.every == 0
|
|
||||||
is_last = epoch == self.args.epochs - 1
|
|
||||||
reco = metrics['valid']['main']['reco']
|
|
||||||
# Tries to detect divergence in a reliable way and finish job
|
|
||||||
# not to waste compute.
|
|
||||||
div = epoch >= 180 and reco > 0.18
|
|
||||||
div = div or epoch >= 100 and reco > 0.25
|
|
||||||
div = div and self.args.optim.loss == 'l1'
|
|
||||||
if div:
|
|
||||||
logger.warning("Finishing training early because valid loss is too high.")
|
|
||||||
is_last = True
|
|
||||||
if should_eval or is_last:
|
|
||||||
# Evaluate on the testset
|
|
||||||
logger.info('-' * 70)
|
|
||||||
logger.info('Evaluating on the test set...')
|
|
||||||
# We switch to the best known model for testing
|
|
||||||
if self.args.test.best:
|
|
||||||
state = self.best_state
|
|
||||||
else:
|
|
||||||
state = states.copy_state(self.model.state_dict())
|
|
||||||
compute_sdr = self.args.test.sdr and is_last
|
|
||||||
with states.swap_state(self.model, state):
|
|
||||||
with torch.no_grad():
|
|
||||||
metrics['test'] = evaluate(self, compute_sdr=compute_sdr)
|
|
||||||
formatted = self._format_test(metrics['test'])
|
|
||||||
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
|
|
||||||
self.link.push_metrics(metrics)
|
|
||||||
|
|
||||||
if distrib.rank == 0:
|
|
||||||
# Save model each epoch
|
|
||||||
self._serialize(epoch)
|
|
||||||
logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())
|
|
||||||
if is_last:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _run_one_epoch(self, epoch, train=True):
|
|
||||||
args = self.args
|
|
||||||
data_loader = self.loaders['train'] if train else self.loaders['valid']
|
|
||||||
# get a different order for distributed training, otherwise this will get ignored
|
|
||||||
data_loader.sampler.epoch = epoch
|
|
||||||
|
|
||||||
label = ["Valid", "Train"][train]
|
|
||||||
name = label + f" | Epoch {epoch + 1}"
|
|
||||||
total = len(data_loader)
|
|
||||||
if args.max_batches:
|
|
||||||
total = min(total, args.max_batches)
|
|
||||||
logprog = LogProgress(logger, data_loader, total=total,
|
|
||||||
updates=self.args.misc.num_prints, name=name)
|
|
||||||
averager = EMA()
|
|
||||||
|
|
||||||
for idx, sources in enumerate(logprog):
|
|
||||||
sources = sources.to(self.device)
|
|
||||||
if train:
|
|
||||||
sources = self.augment(sources)
|
|
||||||
mix = sources.sum(dim=1)
|
|
||||||
else:
|
|
||||||
mix = sources[:, 0]
|
|
||||||
sources = sources[:, 1:]
|
|
||||||
|
|
||||||
if not train and self.args.valid_apply:
|
|
||||||
estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0)
|
|
||||||
else:
|
|
||||||
estimate = self.dmodel(mix)
|
|
||||||
if train and hasattr(self.model, 'transform_target'):
|
|
||||||
sources = self.model.transform_target(mix, sources)
|
|
||||||
assert estimate.shape == sources.shape, (estimate.shape, sources.shape)
|
|
||||||
dims = tuple(range(2, sources.dim()))
|
|
||||||
|
|
||||||
if args.optim.loss == 'l1':
|
|
||||||
loss = F.l1_loss(estimate, sources, reduction='none')
|
|
||||||
loss = loss.mean(dims).mean(0)
|
|
||||||
reco = loss
|
|
||||||
elif args.optim.loss == 'mse':
|
|
||||||
loss = F.mse_loss(estimate, sources, reduction='none')
|
|
||||||
loss = loss.mean(dims)
|
|
||||||
reco = loss**0.5
|
|
||||||
reco = reco.mean(0)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid loss {self.args.loss}")
|
|
||||||
weights = torch.tensor(args.weights).to(sources)
|
|
||||||
loss = (loss * weights).sum() / weights.sum()
|
|
||||||
|
|
||||||
ms = 0
|
|
||||||
if self.quantizer is not None:
|
|
||||||
ms = self.quantizer.model_size()
|
|
||||||
if args.quant.diffq:
|
|
||||||
loss += args.quant.diffq * ms
|
|
||||||
|
|
||||||
losses = {}
|
|
||||||
losses['reco'] = (reco * weights).sum() / weights.sum()
|
|
||||||
losses['ms'] = ms
|
|
||||||
|
|
||||||
if not train:
|
|
||||||
nsdrs = new_sdr(sources, estimate.detach()).mean(0)
|
|
||||||
total = 0
|
|
||||||
for source, nsdr, w in zip(self.model.sources, nsdrs, weights):
|
|
||||||
losses[f'nsdr_{source}'] = nsdr
|
|
||||||
total += w * nsdr
|
|
||||||
losses['nsdr'] = total / weights.sum()
|
|
||||||
|
|
||||||
if train and args.svd.penalty > 0:
|
|
||||||
kw = dict(args.svd)
|
|
||||||
kw.pop('penalty')
|
|
||||||
penalty = svd_penalty(self.model, **kw)
|
|
||||||
losses['penalty'] = penalty
|
|
||||||
loss += args.svd.penalty * penalty
|
|
||||||
|
|
||||||
losses['loss'] = loss
|
|
||||||
|
|
||||||
for k, source in enumerate(self.model.sources):
|
|
||||||
losses[f'reco_{source}'] = reco[k]
|
|
||||||
|
|
||||||
# optimize model in training mode
|
|
||||||
if train:
|
|
||||||
loss.backward()
|
|
||||||
grad_norm = 0
|
|
||||||
grads = []
|
|
||||||
for p in self.model.parameters():
|
|
||||||
if p.grad is not None:
|
|
||||||
grad_norm += p.grad.data.norm()**2
|
|
||||||
grads.append(p.grad.data)
|
|
||||||
losses['grad'] = grad_norm ** 0.5
|
|
||||||
if args.optim.clip_grad:
|
|
||||||
torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.model.parameters(),
|
|
||||||
args.optim.clip_grad)
|
|
||||||
|
|
||||||
if self.args.flag == 'uns':
|
|
||||||
for n, p in self.model.named_parameters():
|
|
||||||
if p.grad is None:
|
|
||||||
print('no grad', n)
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
for ema in self.emas['batch']:
|
|
||||||
ema.update()
|
|
||||||
losses = averager(losses)
|
|
||||||
logs = self._format_train(losses)
|
|
||||||
logprog.update(**logs)
|
|
||||||
# Just in case, clear some memory
|
|
||||||
del loss, estimate, reco, ms
|
|
||||||
if args.max_batches == idx:
|
|
||||||
break
|
|
||||||
if self.args.debug and train:
|
|
||||||
break
|
|
||||||
if self.args.flag == 'debug':
|
|
||||||
break
|
|
||||||
if train:
|
|
||||||
for ema in self.emas['epoch']:
|
|
||||||
ema.update()
|
|
||||||
return distrib.average(losses, idx + 1)
|
|
@ -1,41 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""Conveniance wrapper to perform STFT and iSTFT"""
|
|
||||||
|
|
||||||
import torch as th
|
|
||||||
|
|
||||||
|
|
||||||
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
|
||||||
*other, length = x.shape
|
|
||||||
x = x.reshape(-1, length)
|
|
||||||
z = th.stft(x,
|
|
||||||
n_fft * (1 + pad),
|
|
||||||
hop_length or n_fft // 4,
|
|
||||||
window=th.hann_window(n_fft).to(x),
|
|
||||||
win_length=n_fft,
|
|
||||||
normalized=True,
|
|
||||||
center=True,
|
|
||||||
return_complex=True,
|
|
||||||
pad_mode='reflect')
|
|
||||||
_, freqs, frame = z.shape
|
|
||||||
return z.view(*other, freqs, frame)
|
|
||||||
|
|
||||||
|
|
||||||
def ispectro(z, hop_length=None, length=None, pad=0):
|
|
||||||
*other, freqs, frames = z.shape
|
|
||||||
n_fft = 2 * freqs - 2
|
|
||||||
z = z.view(-1, freqs, frames)
|
|
||||||
win_length = n_fft // (1 + pad)
|
|
||||||
x = th.istft(z,
|
|
||||||
n_fft,
|
|
||||||
hop_length,
|
|
||||||
window=th.hann_window(win_length).to(z.real),
|
|
||||||
win_length=win_length,
|
|
||||||
normalized=True,
|
|
||||||
length=length,
|
|
||||||
center=True)
|
|
||||||
_, length = x.shape
|
|
||||||
return x.view(*other, length)
|
|
148
demucs/states.py
148
demucs/states.py
@ -1,148 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""
|
|
||||||
Utilities to save and load models.
|
|
||||||
"""
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import hashlib
|
|
||||||
import inspect
|
|
||||||
import io
|
|
||||||
from pathlib import Path
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_quantizer(model, args, optimizer=None):
|
|
||||||
"""Return the quantizer given the XP quantization args."""
|
|
||||||
quantizer = None
|
|
||||||
if args.diffq:
|
|
||||||
quantizer = DiffQuantizer(
|
|
||||||
model, min_size=args.min_size, group_size=args.group_size)
|
|
||||||
if optimizer is not None:
|
|
||||||
quantizer.setup_optimizer(optimizer)
|
|
||||||
elif args.qat:
|
|
||||||
quantizer = UniformQuantizer(
|
|
||||||
model, bits=args.qat, min_size=args.min_size)
|
|
||||||
return quantizer
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(path_or_package, strict=False):
|
|
||||||
"""Load a model from the given serialized model, either given as a dict (already loaded)
|
|
||||||
or a path to a file on disk."""
|
|
||||||
if isinstance(path_or_package, dict):
|
|
||||||
package = path_or_package
|
|
||||||
elif isinstance(path_or_package, (str, Path)):
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
path = path_or_package
|
|
||||||
package = torch.load(path, 'cpu')
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid type for {path_or_package}.")
|
|
||||||
|
|
||||||
klass = package["klass"]
|
|
||||||
args = package["args"]
|
|
||||||
kwargs = package["kwargs"]
|
|
||||||
|
|
||||||
if strict:
|
|
||||||
model = klass(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
sig = inspect.signature(klass)
|
|
||||||
for key in list(kwargs):
|
|
||||||
if key not in sig.parameters:
|
|
||||||
warnings.warn("Dropping inexistant parameter " + key)
|
|
||||||
del kwargs[key]
|
|
||||||
model = klass(*args, **kwargs)
|
|
||||||
|
|
||||||
state = package["state"]
|
|
||||||
|
|
||||||
set_state(model, state)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_state(model, quantizer, half=False):
|
|
||||||
"""Get the state from a model, potentially with quantization applied.
|
|
||||||
If `half` is True, model are stored as half precision, which shouldn't impact performance
|
|
||||||
but half the state size."""
|
|
||||||
if quantizer is None:
|
|
||||||
dtype = torch.half if half else None
|
|
||||||
state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
|
|
||||||
else:
|
|
||||||
state = quantizer.get_quantized_state()
|
|
||||||
state['__quantized'] = True
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def set_state(model, state, quantizer=None):
|
|
||||||
"""Set the state on a given model."""
|
|
||||||
if state.get('__quantized'):
|
|
||||||
if quantizer is not None:
|
|
||||||
quantizer.restore_quantized_state(model, state['quantized'])
|
|
||||||
else:
|
|
||||||
restore_quantized_state(model, state)
|
|
||||||
else:
|
|
||||||
model.load_state_dict(state)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def save_with_checksum(content, path):
|
|
||||||
"""Save the given value on disk, along with a sha256 hash.
|
|
||||||
Should be used with the output of either `serialize_model` or `get_state`."""
|
|
||||||
buf = io.BytesIO()
|
|
||||||
torch.save(content, buf)
|
|
||||||
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
|
|
||||||
|
|
||||||
path = path.parent / (path.stem + "-" + sig + path.suffix)
|
|
||||||
path.write_bytes(buf.getvalue())
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_model(model, training_args, quantizer=None, half=True):
|
|
||||||
args, kwargs = model._init_args_kwargs
|
|
||||||
klass = model.__class__
|
|
||||||
|
|
||||||
state = get_state(model, quantizer, half)
|
|
||||||
return {
|
|
||||||
'klass': klass,
|
|
||||||
'args': args,
|
|
||||||
'kwargs': kwargs,
|
|
||||||
'state': state,
|
|
||||||
'training_args': OmegaConf.to_container(training_args, resolve=True),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def copy_state(state):
|
|
||||||
return {k: v.cpu().clone() for k, v in state.items()}
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def swap_state(model, state):
|
|
||||||
"""
|
|
||||||
Context manager that swaps the state of a model, e.g:
|
|
||||||
|
|
||||||
# model is in old state
|
|
||||||
with swap_state(model, new_state):
|
|
||||||
# model in new state
|
|
||||||
# model back to old state
|
|
||||||
"""
|
|
||||||
old_state = copy_state(model.state_dict())
|
|
||||||
model.load_state_dict(state, strict=False)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
model.load_state_dict(old_state)
|
|
||||||
|
|
||||||
|
|
||||||
def capture_init(init):
|
|
||||||
@functools.wraps(init)
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self._init_args_kwargs = (args, kwargs)
|
|
||||||
init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
return __init__
|
|
@ -1,83 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""Ways to make the model stronger."""
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def power_iteration(m, niters=1, bs=1):
|
|
||||||
"""This is the power method. batch size is used to try multiple starting point in parallel."""
|
|
||||||
assert m.dim() == 2
|
|
||||||
assert m.shape[0] == m.shape[1]
|
|
||||||
dim = m.shape[0]
|
|
||||||
b = torch.randn(dim, bs, device=m.device, dtype=m.dtype)
|
|
||||||
|
|
||||||
for _ in range(niters):
|
|
||||||
n = m.mm(b)
|
|
||||||
norm = n.norm(dim=0, keepdim=True)
|
|
||||||
b = n / (1e-10 + norm)
|
|
||||||
|
|
||||||
return norm.mean()
|
|
||||||
|
|
||||||
|
|
||||||
# We need a shared RNG to make sure all the distributed worker will skip the penalty together,
|
|
||||||
# as otherwise we wouldn't get any speed up.
|
|
||||||
penalty_rng = random.Random(1234)
|
|
||||||
|
|
||||||
|
|
||||||
def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True,
|
|
||||||
proba=1, conv_only=False, exact=False, bs=1):
|
|
||||||
"""
|
|
||||||
Penalty on the largest singular value for a layer.
|
|
||||||
Args:
|
|
||||||
- model: model to penalize
|
|
||||||
- min_size: minimum size in MB of a layer to penalize.
|
|
||||||
- dim: projection dimension for the svd_lowrank. Higher is better but slower.
|
|
||||||
- niters: number of iterations in the algorithm used by svd_lowrank.
|
|
||||||
- powm: use power method instead of lowrank SVD, my own experience
|
|
||||||
is that it is both slower and less stable.
|
|
||||||
- convtr: when True, differentiate between Conv and Transposed Conv.
|
|
||||||
this is kept for compatibility with older experiments.
|
|
||||||
- proba: probability to apply the penalty.
|
|
||||||
- conv_only: only apply to conv and conv transposed, not LSTM
|
|
||||||
(might not be reliable for other models than Demucs).
|
|
||||||
- exact: use exact SVD (slow but useful at validation).
|
|
||||||
- bs: batch_size for power method.
|
|
||||||
"""
|
|
||||||
total = 0
|
|
||||||
if penalty_rng.random() > proba:
|
|
||||||
return 0.
|
|
||||||
|
|
||||||
for m in model.modules():
|
|
||||||
for name, p in m.named_parameters(recurse=False):
|
|
||||||
if p.numel() / 2**18 < min_size:
|
|
||||||
continue
|
|
||||||
if convtr:
|
|
||||||
if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)):
|
|
||||||
if p.dim() in [3, 4]:
|
|
||||||
p = p.transpose(0, 1).contiguous()
|
|
||||||
if p.dim() == 3:
|
|
||||||
p = p.view(len(p), -1)
|
|
||||||
elif p.dim() == 4:
|
|
||||||
p = p.view(len(p), -1)
|
|
||||||
elif p.dim() == 1:
|
|
||||||
continue
|
|
||||||
elif conv_only:
|
|
||||||
continue
|
|
||||||
assert p.dim() == 2, (name, p.shape)
|
|
||||||
if exact:
|
|
||||||
estimate = torch.svd(p, compute_uv=False)[1].pow(2).max()
|
|
||||||
elif powm:
|
|
||||||
a, b = p.shape
|
|
||||||
if a < b:
|
|
||||||
n = p.mm(p.t())
|
|
||||||
else:
|
|
||||||
n = p.t().mm(p)
|
|
||||||
estimate = power_iteration(n, niters, bs)
|
|
||||||
else:
|
|
||||||
estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2)
|
|
||||||
total += estimate
|
|
||||||
return total / proba
|
|
131
demucs/utils.py
131
demucs/utils.py
@ -1,131 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def unfold(a, kernel_size, stride):
|
|
||||||
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
|
||||||
with K the kernel size, by extracting frames with the given stride.
|
|
||||||
|
|
||||||
This will pad the input so that `F = ceil(T / K)`.
|
|
||||||
|
|
||||||
see https://github.com/pytorch/pytorch/issues/60466
|
|
||||||
"""
|
|
||||||
*shape, length = a.shape
|
|
||||||
n_frames = math.ceil(length / stride)
|
|
||||||
tgt_length = (n_frames - 1) * stride + kernel_size
|
|
||||||
a = F.pad(a, (0, tgt_length - length))
|
|
||||||
strides = list(a.stride())
|
|
||||||
assert strides[-1] == 1, 'data should be contiguous'
|
|
||||||
strides = strides[:-1] + [stride, 1]
|
|
||||||
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
|
||||||
|
|
||||||
|
|
||||||
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
|
|
||||||
"""
|
|
||||||
Center trim `tensor` with respect to `reference`, along the last dimension.
|
|
||||||
`reference` can also be a number, representing the length to trim to.
|
|
||||||
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
|
||||||
"""
|
|
||||||
ref_size: int
|
|
||||||
if isinstance(reference, torch.Tensor):
|
|
||||||
ref_size = reference.size(-1)
|
|
||||||
else:
|
|
||||||
ref_size = reference
|
|
||||||
delta = tensor.size(-1) - ref_size
|
|
||||||
if delta < 0:
|
|
||||||
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
|
||||||
if delta:
|
|
||||||
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def pull_metric(history: tp.List[dict], name: str):
|
|
||||||
out = []
|
|
||||||
for metrics in history:
|
|
||||||
metric = metrics
|
|
||||||
for part in name.split("."):
|
|
||||||
metric = metric[part]
|
|
||||||
out.append(metric)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def EMA(beta: float = 1):
|
|
||||||
"""
|
|
||||||
Exponential Moving Average callback.
|
|
||||||
Returns a single function that can be called to repeatidly update the EMA
|
|
||||||
with a dict of metrics. The callback will return
|
|
||||||
the new averaged dict of metrics.
|
|
||||||
|
|
||||||
Note that for `beta=1`, this is just plain averaging.
|
|
||||||
"""
|
|
||||||
fix: tp.Dict[str, float] = defaultdict(float)
|
|
||||||
total: tp.Dict[str, float] = defaultdict(float)
|
|
||||||
|
|
||||||
def _update(metrics: dict, weight: float = 1) -> dict:
|
|
||||||
nonlocal total, fix
|
|
||||||
for key, value in metrics.items():
|
|
||||||
total[key] = total[key] * beta + weight * float(value)
|
|
||||||
fix[key] = fix[key] * beta + weight
|
|
||||||
return {key: tot / fix[key] for key, tot in total.items()}
|
|
||||||
return _update
|
|
||||||
|
|
||||||
|
|
||||||
def sizeof_fmt(num: float, suffix: str = 'B'):
|
|
||||||
"""
|
|
||||||
Given `num` bytes, return human readable size.
|
|
||||||
Taken from https://stackoverflow.com/a/1094933
|
|
||||||
"""
|
|
||||||
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
|
|
||||||
if abs(num) < 1024.0:
|
|
||||||
return "%3.1f%s%s" % (num, unit, suffix)
|
|
||||||
num /= 1024.0
|
|
||||||
return "%.1f%s%s" % (num, 'Yi', suffix)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def temp_filenames(count: int, delete=True):
|
|
||||||
names = []
|
|
||||||
try:
|
|
||||||
for _ in range(count):
|
|
||||||
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
|
||||||
yield names
|
|
||||||
finally:
|
|
||||||
if delete:
|
|
||||||
for name in names:
|
|
||||||
os.unlink(name)
|
|
||||||
|
|
||||||
|
|
||||||
class DummyPoolExecutor:
|
|
||||||
class DummyResult:
|
|
||||||
def __init__(self, func, *args, **kwargs):
|
|
||||||
self.func = func
|
|
||||||
self.args = args
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def result(self):
|
|
||||||
return self.func(*self.args, **self.kwargs)
|
|
||||||
|
|
||||||
def __init__(self, workers=0):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def submit(self, func, *args, **kwargs):
|
|
||||||
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
||||||
return
|
|
242
demucs/wav.py
242
demucs/wav.py
@ -1,242 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
"""Loading wav based datasets, including MusdbHQ."""
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
import hashlib
|
|
||||||
import math
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
import musdb
|
|
||||||
import julius
|
|
||||||
import torch as th
|
|
||||||
from torch import distributed
|
|
||||||
import torchaudio as ta
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .audio import convert_audio_channels
|
|
||||||
from . import distrib
|
|
||||||
|
|
||||||
MIXTURE = "mixture"
|
|
||||||
EXT = ".wav"
|
|
||||||
|
|
||||||
|
|
||||||
def _track_metadata(track, sources, normalize=True, ext=EXT):
|
|
||||||
track_length = None
|
|
||||||
track_samplerate = None
|
|
||||||
mean = 0
|
|
||||||
std = 1
|
|
||||||
for source in sources + [MIXTURE]:
|
|
||||||
file = track / f"{source}{ext}"
|
|
||||||
try:
|
|
||||||
info = ta.info(str(file))
|
|
||||||
except RuntimeError:
|
|
||||||
print(file)
|
|
||||||
raise
|
|
||||||
length = info.num_frames
|
|
||||||
if track_length is None:
|
|
||||||
track_length = length
|
|
||||||
track_samplerate = info.sample_rate
|
|
||||||
elif track_length != length:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid length for file {file}: "
|
|
||||||
f"expecting {track_length} but got {length}.")
|
|
||||||
elif info.sample_rate != track_samplerate:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid sample rate for file {file}: "
|
|
||||||
f"expecting {track_samplerate} but got {info.sample_rate}.")
|
|
||||||
if source == MIXTURE and normalize:
|
|
||||||
try:
|
|
||||||
wav, _ = ta.load(str(file))
|
|
||||||
except RuntimeError:
|
|
||||||
print(file)
|
|
||||||
raise
|
|
||||||
wav = wav.mean(0)
|
|
||||||
mean = wav.mean().item()
|
|
||||||
std = wav.std().item()
|
|
||||||
|
|
||||||
return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}
|
|
||||||
|
|
||||||
|
|
||||||
def build_metadata(path, sources, normalize=True, ext=EXT):
|
|
||||||
"""
|
|
||||||
Build the metadata for `Wavset`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str or Path): path to dataset.
|
|
||||||
sources (list[str]): list of sources to look for.
|
|
||||||
normalize (bool): if True, loads full track and store normalization
|
|
||||||
values based on the mixture file.
|
|
||||||
ext (str): extension of audio files (default is .wav).
|
|
||||||
"""
|
|
||||||
|
|
||||||
meta = {}
|
|
||||||
path = Path(path)
|
|
||||||
pendings = []
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
with ThreadPoolExecutor(8) as pool:
|
|
||||||
for root, folders, files in os.walk(path, followlinks=True):
|
|
||||||
root = Path(root)
|
|
||||||
if root.name.startswith('.') or folders or root == path:
|
|
||||||
continue
|
|
||||||
name = str(root.relative_to(path))
|
|
||||||
pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext)))
|
|
||||||
# meta[name] = _track_metadata(root, sources, normalize, ext)
|
|
||||||
for name, pending in tqdm.tqdm(pendings, ncols=120):
|
|
||||||
meta[name] = pending.result()
|
|
||||||
return meta
|
|
||||||
|
|
||||||
|
|
||||||
class Wavset:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
root, metadata, sources,
|
|
||||||
segment=None, shift=None, normalize=True,
|
|
||||||
samplerate=44100, channels=2, ext=EXT):
|
|
||||||
"""
|
|
||||||
Waveset (or mp3 set for that matter). Can be used to train
|
|
||||||
with arbitrary sources. Each track should be one folder inside of `path`.
|
|
||||||
The folder should contain files named `{source}.{ext}`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
root (Path or str): root folder for the dataset.
|
|
||||||
metadata (dict): output from `build_metadata`.
|
|
||||||
sources (list[str]): list of source names.
|
|
||||||
segment (None or float): segment length in seconds. If `None`, returns entire tracks.
|
|
||||||
shift (None or float): stride in seconds bewteen samples.
|
|
||||||
normalize (bool): normalizes input audio, **based on the metadata content**,
|
|
||||||
i.e. the entire track is normalized, not individual extracts.
|
|
||||||
samplerate (int): target sample rate. if the file sample rate
|
|
||||||
is different, it will be resampled on the fly.
|
|
||||||
channels (int): target nb of channels. if different, will be
|
|
||||||
changed onthe fly.
|
|
||||||
ext (str): extension for audio files (default is .wav).
|
|
||||||
|
|
||||||
samplerate and channels are converted on the fly.
|
|
||||||
"""
|
|
||||||
self.root = Path(root)
|
|
||||||
self.metadata = OrderedDict(metadata)
|
|
||||||
self.segment = segment
|
|
||||||
self.shift = shift or segment
|
|
||||||
self.normalize = normalize
|
|
||||||
self.sources = sources
|
|
||||||
self.channels = channels
|
|
||||||
self.samplerate = samplerate
|
|
||||||
self.ext = ext
|
|
||||||
self.num_examples = []
|
|
||||||
for name, meta in self.metadata.items():
|
|
||||||
track_duration = meta['length'] / meta['samplerate']
|
|
||||||
if segment is None or track_duration < segment:
|
|
||||||
examples = 1
|
|
||||||
else:
|
|
||||||
examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1)
|
|
||||||
self.num_examples.append(examples)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return sum(self.num_examples)
|
|
||||||
|
|
||||||
def get_file(self, name, source):
|
|
||||||
return self.root / name / f"{source}{self.ext}"
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
for name, examples in zip(self.metadata, self.num_examples):
|
|
||||||
if index >= examples:
|
|
||||||
index -= examples
|
|
||||||
continue
|
|
||||||
meta = self.metadata[name]
|
|
||||||
num_frames = -1
|
|
||||||
offset = 0
|
|
||||||
if self.segment is not None:
|
|
||||||
offset = int(meta['samplerate'] * self.shift * index)
|
|
||||||
num_frames = int(math.ceil(meta['samplerate'] * self.segment))
|
|
||||||
wavs = []
|
|
||||||
for source in self.sources:
|
|
||||||
file = self.get_file(name, source)
|
|
||||||
wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
|
|
||||||
wav = convert_audio_channels(wav, self.channels)
|
|
||||||
wavs.append(wav)
|
|
||||||
|
|
||||||
example = th.stack(wavs)
|
|
||||||
example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
|
|
||||||
if self.normalize:
|
|
||||||
example = (example - meta['mean']) / meta['std']
|
|
||||||
if self.segment:
|
|
||||||
length = int(self.segment * self.samplerate)
|
|
||||||
example = example[..., :length]
|
|
||||||
example = F.pad(example, (0, length - example.shape[-1]))
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
def get_wav_datasets(args):
|
|
||||||
"""Extract the wav datasets from the XP arguments."""
|
|
||||||
sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8]
|
|
||||||
metadata_file = Path(args.metadata) / ('wav_' + sig + ".json")
|
|
||||||
train_path = Path(args.wav) / "train"
|
|
||||||
valid_path = Path(args.wav) / "valid"
|
|
||||||
if not metadata_file.is_file() and distrib.rank == 0:
|
|
||||||
metadata_file.parent.mkdir(exist_ok=True, parents=True)
|
|
||||||
train = build_metadata(train_path, args.sources)
|
|
||||||
valid = build_metadata(valid_path, args.sources)
|
|
||||||
json.dump([train, valid], open(metadata_file, "w"))
|
|
||||||
if distrib.world_size > 1:
|
|
||||||
distributed.barrier()
|
|
||||||
train, valid = json.load(open(metadata_file))
|
|
||||||
if args.full_cv:
|
|
||||||
kw_cv = {}
|
|
||||||
else:
|
|
||||||
kw_cv = {'segment': args.segment, 'shift': args.shift}
|
|
||||||
train_set = Wavset(train_path, train, args.sources,
|
|
||||||
segment=args.segment, shift=args.shift,
|
|
||||||
samplerate=args.samplerate, channels=args.channels,
|
|
||||||
normalize=args.normalize)
|
|
||||||
valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources),
|
|
||||||
samplerate=args.samplerate, channels=args.channels,
|
|
||||||
normalize=args.normalize, **kw_cv)
|
|
||||||
return train_set, valid_set
|
|
||||||
|
|
||||||
|
|
||||||
def _get_musdb_valid():
|
|
||||||
# Return musdb valid set.
|
|
||||||
import yaml
|
|
||||||
setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml'
|
|
||||||
setup = yaml.safe_load(open(setup_path, 'r'))
|
|
||||||
return setup['validation_tracks']
|
|
||||||
|
|
||||||
|
|
||||||
def get_musdb_wav_datasets(args):
|
|
||||||
"""Extract the musdb dataset from the XP arguments."""
|
|
||||||
sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8]
|
|
||||||
metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json")
|
|
||||||
root = Path(args.musdb) / "train"
|
|
||||||
if not metadata_file.is_file() and distrib.rank == 0:
|
|
||||||
metadata_file.parent.mkdir(exist_ok=True, parents=True)
|
|
||||||
metadata = build_metadata(root, args.sources)
|
|
||||||
json.dump(metadata, open(metadata_file, "w"))
|
|
||||||
if distrib.world_size > 1:
|
|
||||||
distributed.barrier()
|
|
||||||
metadata = json.load(open(metadata_file))
|
|
||||||
|
|
||||||
valid_tracks = _get_musdb_valid()
|
|
||||||
if args.train_valid:
|
|
||||||
metadata_train = metadata
|
|
||||||
else:
|
|
||||||
metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks}
|
|
||||||
metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks}
|
|
||||||
if args.full_cv:
|
|
||||||
kw_cv = {}
|
|
||||||
else:
|
|
||||||
kw_cv = {'segment': args.segment, 'shift': args.shift}
|
|
||||||
train_set = Wavset(root, metadata_train, args.sources,
|
|
||||||
segment=args.segment, shift=args.shift,
|
|
||||||
samplerate=args.samplerate, channels=args.channels,
|
|
||||||
normalize=args.normalize)
|
|
||||||
valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources),
|
|
||||||
samplerate=args.samplerate, channels=args.channels,
|
|
||||||
normalize=args.normalize, **kw_cv)
|
|
||||||
return train_set, valid_set
|
|
@ -1,9 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
# For compat
|
|
||||||
from .hdemucs import HDemucs
|
|
||||||
|
|
||||||
WDemucs = HDemucs
|
|
Loading…
Reference in New Issue
Block a user