Delete demucs directory

This commit is contained in:
Anjok07 2022-06-13 02:10:01 -05:00 committed by GitHub
parent 50533ec73c
commit b16a422b39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 0 additions and 2580 deletions

View File

@ -1,7 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "2.0.3"

View File

@ -1,317 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import math
import os
import sys
import time
from dataclasses import dataclass, field
import torch as th
from torch import distributed, nn
from torch.nn.parallel.distributed import DistributedDataParallel
from .augment import FlipChannels, FlipSign, Remix, Scale, Shift
from .compressed import get_compressed_datasets
from .model import Demucs
from .parser import get_name, get_parser
from .raw import Rawset
from .repitch import RepitchedWrapper
from .pretrained import load_pretrained, SOURCES
from .tasnet import ConvTasNet
from .test import evaluate
from .train import train_model, validate_model
from .utils import (human_seconds, load_model, save_model, get_state,
save_state, sizeof_fmt, get_quantizer)
from .wav import get_wav_datasets, get_musdb_wav_datasets
@dataclass
class SavedState:
metrics: list = field(default_factory=list)
last_state: dict = None
best_state: dict = None
optimizer: dict = None
def main():
parser = get_parser()
args = parser.parse_args()
name = get_name(parser, args)
print(f"Experiment {name}")
if args.musdb is None and args.rank == 0:
print(
"You must provide the path to the MusDB dataset with the --musdb flag. "
"To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
file=sys.stderr)
sys.exit(1)
eval_folder = args.evals / name
eval_folder.mkdir(exist_ok=True, parents=True)
args.logs.mkdir(exist_ok=True)
metrics_path = args.logs / f"{name}.json"
eval_folder.mkdir(exist_ok=True, parents=True)
args.checkpoints.mkdir(exist_ok=True, parents=True)
args.models.mkdir(exist_ok=True, parents=True)
if args.device is None:
device = "cpu"
if th.cuda.is_available():
device = "cuda"
else:
device = args.device
th.manual_seed(args.seed)
# Prevents too many threads to be started when running `museval` as it can be quite
# inefficient on NUMA architectures.
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
if args.world_size > 1:
if device != "cuda" and args.rank == 0:
print("Error: distributed training is only available with cuda device", file=sys.stderr)
sys.exit(1)
th.cuda.set_device(args.rank % th.cuda.device_count())
distributed.init_process_group(backend="nccl",
init_method="tcp://" + args.master,
rank=args.rank,
world_size=args.world_size)
checkpoint = args.checkpoints / f"{name}.th"
checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
if args.restart and checkpoint.exists() and args.rank == 0:
checkpoint.unlink()
if args.test or args.test_pretrained:
args.epochs = 1
args.repeat = 0
if args.test:
model = load_model(args.models / args.test)
else:
model = load_pretrained(args.test_pretrained)
elif args.tasnet:
model = ConvTasNet(audio_channels=args.audio_channels,
samplerate=args.samplerate, X=args.X,
segment_length=4 * args.samples,
sources=SOURCES)
else:
model = Demucs(
audio_channels=args.audio_channels,
channels=args.channels,
context=args.context,
depth=args.depth,
glu=args.glu,
growth=args.growth,
kernel_size=args.kernel_size,
lstm_layers=args.lstm_layers,
rescale=args.rescale,
rewrite=args.rewrite,
stride=args.conv_stride,
resample=args.resample,
normalize=args.normalize,
samplerate=args.samplerate,
segment_length=4 * args.samples,
sources=SOURCES,
)
model.to(device)
if args.init:
model.load_state_dict(load_pretrained(args.init).state_dict())
if args.show:
print(model)
size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")
return
try:
saved = th.load(checkpoint, map_location='cpu')
except IOError:
saved = SavedState()
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
quantizer = None
quantizer = get_quantizer(model, args, optimizer)
if saved.last_state is not None:
model.load_state_dict(saved.last_state, strict=False)
if saved.optimizer is not None:
optimizer.load_state_dict(saved.optimizer)
model_name = f"{name}.th"
if args.save_model:
if args.rank == 0:
model.to("cpu")
model.load_state_dict(saved.best_state)
save_model(model, quantizer, args, args.models / model_name)
return
elif args.save_state:
model_name = f"{args.save_state}.th"
if args.rank == 0:
model.to("cpu")
model.load_state_dict(saved.best_state)
state = get_state(model, quantizer)
save_state(state, args.models / model_name)
return
if args.rank == 0:
done = args.logs / f"{name}.done"
if done.exists():
done.unlink()
augment = [Shift(args.data_stride)]
if args.augment:
augment += [FlipSign(), FlipChannels(), Scale(),
Remix(group_size=args.remix_group_size)]
augment = nn.Sequential(*augment).to(device)
print("Agumentation pipeline:", augment)
if args.mse:
criterion = nn.MSELoss()
else:
criterion = nn.L1Loss()
# Setting number of samples so that all convolution windows are full.
# Prevents hard to debug mistake with the prediction being shifted compared
# to the input mixture.
samples = model.valid_length(args.samples)
print(f"Number of training samples adjusted to {samples}")
samples = samples + args.data_stride
if args.repitch:
# We need a bit more audio samples, to account for potential
# tempo change.
samples = math.ceil(samples / (1 - 0.01 * args.max_tempo))
args.metadata.mkdir(exist_ok=True, parents=True)
if args.raw:
train_set = Rawset(args.raw / "train",
samples=samples,
channels=args.audio_channels,
streams=range(1, len(model.sources) + 1),
stride=args.data_stride)
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
elif args.wav:
train_set, valid_set = get_wav_datasets(args, samples, model.sources)
elif args.is_wav:
train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources)
else:
train_set, valid_set = get_compressed_datasets(args, samples)
if args.repitch:
train_set = RepitchedWrapper(
train_set,
proba=args.repitch,
max_tempo=args.max_tempo)
best_loss = float("inf")
for epoch, metrics in enumerate(saved.metrics):
print(f"Epoch {epoch:03d}: "
f"train={metrics['train']:.8f} "
f"valid={metrics['valid']:.8f} "
f"best={metrics['best']:.4f} "
f"ms={metrics.get('true_model_size', 0):.2f}MB "
f"cms={metrics.get('compressed_model_size', 0):.2f}MB "
f"duration={human_seconds(metrics['duration'])}")
best_loss = metrics['best']
if args.world_size > 1:
dmodel = DistributedDataParallel(model,
device_ids=[th.cuda.current_device()],
output_device=th.cuda.current_device())
else:
dmodel = model
for epoch in range(len(saved.metrics), args.epochs):
begin = time.time()
model.train()
train_loss, model_size = train_model(
epoch, train_set, dmodel, criterion, optimizer, augment,
quantizer=quantizer,
batch_size=args.batch_size,
device=device,
repeat=args.repeat,
seed=args.seed,
diffq=args.diffq,
workers=args.workers,
world_size=args.world_size)
model.eval()
valid_loss = validate_model(
epoch, valid_set, model, criterion,
device=device,
rank=args.rank,
split=args.split_valid,
overlap=args.overlap,
world_size=args.world_size)
ms = 0
cms = 0
if quantizer and args.rank == 0:
ms = quantizer.true_model_size()
cms = quantizer.compressed_model_size(num_workers=min(40, args.world_size * 10))
duration = time.time() - begin
if valid_loss < best_loss and ms <= args.ms_target:
best_loss = valid_loss
saved.best_state = {
key: value.to("cpu").clone()
for key, value in model.state_dict().items()
}
saved.metrics.append({
"train": train_loss,
"valid": valid_loss,
"best": best_loss,
"duration": duration,
"model_size": model_size,
"true_model_size": ms,
"compressed_model_size": cms,
})
if args.rank == 0:
json.dump(saved.metrics, open(metrics_path, "w"))
saved.last_state = model.state_dict()
saved.optimizer = optimizer.state_dict()
if args.rank == 0 and not args.test:
th.save(saved, checkpoint_tmp)
checkpoint_tmp.rename(checkpoint)
print(f"Epoch {epoch:03d}: "
f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB "
f"cms={cms:.2f}MB "
f"duration={human_seconds(duration)}")
if args.world_size > 1:
distributed.barrier()
del dmodel
model.load_state_dict(saved.best_state)
if args.eval_cpu:
device = "cpu"
model.to(device)
model.eval()
evaluate(model, args.musdb, eval_folder,
is_wav=args.is_wav,
rank=args.rank,
world_size=args.world_size,
device=device,
save=args.save,
split=args.split_valid,
shifts=args.shifts,
overlap=args.overlap,
workers=args.eval_workers)
model.to("cpu")
if args.rank == 0:
if not (args.test or args.test_pretrained):
save_model(model, quantizer, args, args.models / model_name)
print("done")
done.write_text("done")
if __name__ == "__main__":
main()

View File

@ -1,172 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import subprocess as sp
from pathlib import Path
import julius
import numpy as np
import torch
from .utils import temp_filenames
def _read_info(path):
stdout_data = sp.check_output([
'ffprobe', "-loglevel", "panic",
str(path), '-print_format', 'json', '-show_format', '-show_streams'
])
return json.loads(stdout_data.decode('utf-8'))
class AudioFile:
"""
Allows to read audio from any format supported by ffmpeg, as well as resampling or
converting to mono on the fly. See :method:`read` for more details.
"""
def __init__(self, path: Path):
self.path = Path(path)
self._info = None
def __repr__(self):
features = [("path", self.path)]
features.append(("samplerate", self.samplerate()))
features.append(("channels", self.channels()))
features.append(("streams", len(self)))
features_str = ", ".join(f"{name}={value}" for name, value in features)
return f"AudioFile({features_str})"
@property
def info(self):
if self._info is None:
self._info = _read_info(self.path)
return self._info
@property
def duration(self):
return float(self.info['format']['duration'])
@property
def _audio_streams(self):
return [
index for index, stream in enumerate(self.info["streams"])
if stream["codec_type"] == "audio"
]
def __len__(self):
return len(self._audio_streams)
def channels(self, stream=0):
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
def samplerate(self, stream=0):
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
def read(self,
seek_time=None,
duration=None,
streams=slice(None),
samplerate=None,
channels=None,
temp_folder=None):
"""
Slightly more efficient implementation than stempeg,
in particular, this will extract all stems at once
rather than having to loop over one file multiple times
for each stream.
Args:
seek_time (float): seek time in seconds or None if no seeking is needed.
duration (float): duration in seconds to extract or None to extract until the end.
streams (slice, int or list): streams to extract, can be a single int, a list or
a slice. If it is a slice or list, the output will be of size [S, C, T]
with S the number of streams, C the number of channels and T the number of samples.
If it is an int, the output will be [C, T].
samplerate (int): if provided, will resample on the fly. If None, no resampling will
be done. Original sampling rate can be obtained with :method:`samplerate`.
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
See https://sound.stackexchange.com/a/42710.
Our definition of mono is simply the average of the two channels. Any other
value will be ignored.
temp_folder (str or Path or None): temporary folder to use for decoding.
"""
streams = np.array(range(len(self)))[streams]
single = not isinstance(streams, np.ndarray)
if single:
streams = [streams]
if duration is None:
target_size = None
query_duration = None
else:
target_size = int((samplerate or self.samplerate()) * duration)
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
with temp_filenames(len(streams)) as filenames:
command = ['ffmpeg', '-y']
command += ['-loglevel', 'panic']
if seek_time:
command += ['-ss', str(seek_time)]
command += ['-i', str(self.path)]
for stream, filename in zip(streams, filenames):
command += ['-map', f'0:{self._audio_streams[stream]}']
if query_duration is not None:
command += ['-t', str(query_duration)]
command += ['-threads', '1']
command += ['-f', 'f32le']
if samplerate is not None:
command += ['-ar', str(samplerate)]
command += [filename]
sp.run(command, check=True)
wavs = []
for filename in filenames:
wav = np.fromfile(filename, dtype=np.float32)
wav = torch.from_numpy(wav)
wav = wav.view(-1, self.channels()).t()
if channels is not None:
wav = convert_audio_channels(wav, channels)
if target_size is not None:
wav = wav[..., :target_size]
wavs.append(wav)
wav = torch.stack(wavs, dim=0)
if single:
wav = wav[0]
return wav
def convert_audio_channels(wav, channels=2):
"""Convert audio to the given number of channels."""
*shape, src_channels, length = wav.shape
if src_channels == channels:
pass
elif channels == 1:
# Case 1:
# The caller asked 1-channel audio, but the stream have multiple
# channels, downmix all channels.
wav = wav.mean(dim=-2, keepdim=True)
elif src_channels == 1:
# Case 2:
# The caller asked for multiple channels, but the input file have
# one single channel, replicate the audio over all channels.
wav = wav.expand(*shape, channels, length)
elif src_channels >= channels:
# Case 3:
# The caller asked for multiple channels, and the input file have
# more channels than requested. In that case return the first channels.
wav = wav[..., :channels, :]
else:
# Case 4: What is a reasonable choice here?
raise ValueError('The audio file has less channels than requested but is not mono.')
return wav
def convert_audio(wav, from_samplerate, to_samplerate, channels):
wav = convert_audio_channels(wav, channels)
return julius.resample_frac(wav, from_samplerate, to_samplerate)

View File

@ -1,115 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from fractions import Fraction
from concurrent import futures
import musdb
from torch import distributed
from .audio import AudioFile
def get_musdb_tracks(root, *args, **kwargs):
mus = musdb.DB(root, *args, **kwargs)
return {track.name: track.path for track in mus}
class StemsSet:
def __init__(self, tracks, metadata, duration=None, stride=1,
samplerate=44100, channels=2, streams=slice(None)):
self.metadata = []
for name, path in tracks.items():
meta = dict(metadata[name])
meta["path"] = path
meta["name"] = name
self.metadata.append(meta)
if duration is not None and meta["duration"] < duration:
raise ValueError(f"Track {name} duration is too small {meta['duration']}")
self.metadata.sort(key=lambda x: x["name"])
self.duration = duration
self.stride = stride
self.channels = channels
self.samplerate = samplerate
self.streams = streams
def __len__(self):
return sum(self._examples_count(m) for m in self.metadata)
def _examples_count(self, meta):
if self.duration is None:
return 1
else:
return int((meta["duration"] - self.duration) // self.stride + 1)
def track_metadata(self, index):
for meta in self.metadata:
examples = self._examples_count(meta)
if index >= examples:
index -= examples
continue
return meta
def __getitem__(self, index):
for meta in self.metadata:
examples = self._examples_count(meta)
if index >= examples:
index -= examples
continue
streams = AudioFile(meta["path"]).read(seek_time=index * self.stride,
duration=self.duration,
channels=self.channels,
samplerate=self.samplerate,
streams=self.streams)
return (streams - meta["mean"]) / meta["std"]
def _get_track_metadata(path):
# use mono at 44kHz as reference. For any other settings data won't be perfectly
# normalized but it should be good enough.
audio = AudioFile(path)
mix = audio.read(streams=0, channels=1, samplerate=44100)
return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()}
def _build_metadata(tracks, workers=10):
pendings = []
with futures.ProcessPoolExecutor(workers) as pool:
for name, path in tracks.items():
pendings.append((name, pool.submit(_get_track_metadata, path)))
return {name: p.result() for name, p in pendings}
def _build_musdb_metadata(path, musdb, workers):
tracks = get_musdb_tracks(musdb)
metadata = _build_metadata(tracks, workers)
path.parent.mkdir(exist_ok=True, parents=True)
json.dump(metadata, open(path, "w"))
def get_compressed_datasets(args, samples):
metadata_file = args.metadata / "musdb.json"
if not metadata_file.is_file() and args.rank == 0:
_build_musdb_metadata(metadata_file, args.musdb, args.workers)
if args.world_size > 1:
distributed.barrier()
metadata = json.load(open(metadata_file))
duration = Fraction(samples, args.samplerate)
stride = Fraction(args.data_stride, args.samplerate)
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
metadata,
duration=duration,
stride=stride,
streams=slice(1, None),
samplerate=args.samplerate,
channels=args.audio_channels)
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
metadata,
samplerate=args.samplerate,
channels=args.audio_channels)
return train_set, valid_set

View File

@ -1,202 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import julius
from torch import nn
from .utils import capture_init, center_trim
class BLSTM(nn.Module):
def __init__(self, dim, layers=1):
super().__init__()
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x):
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
return x
def rescale_conv(conv, reference):
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)
class Demucs(nn.Module):
@capture_init
def __init__(self,
sources,
audio_channels=2,
channels=64,
depth=6,
rewrite=True,
glu=True,
rescale=0.1,
resample=True,
kernel_size=8,
stride=4,
growth=2.,
lstm_layers=2,
context=3,
normalize=False,
samplerate=44100,
segment_length=4 * 10 * 44100):
"""
Args:
sources (list[str]): list of source names
audio_channels (int): stereo or mono
channels (int): first convolution channels
depth (int): number of encoder/decoder layers
rewrite (bool): add 1x1 convolution to each encoder layer
and a convolution to each decoder layer.
For the decoder layer, `context` gives the kernel size.
glu (bool): use glu instead of ReLU
resample_input (bool): upsample x2 the input and downsample /2 the output.
rescale (int): rescale initial weights of convolutions
to get their standard deviation closer to `rescale`
kernel_size (int): kernel size for convolutions
stride (int): stride for convolutions
growth (float): multiply (resp divide) number of channels by that
for each layer of the encoder (resp decoder)
lstm_layers (int): number of lstm layers, 0 = no lstm
context (int): kernel size of the convolution in the
decoder before the transposed convolution. If > 1,
will provide some context from neighboring time
steps.
samplerate (int): stored as meta information for easing
future evaluations of the model.
segment_length (int): stored as meta information for easing
future evaluations of the model. Length of the segments on which
the model was trained.
"""
super().__init__()
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.resample = resample
self.channels = channels
self.normalize = normalize
self.samplerate = samplerate
self.segment_length = segment_length
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
if glu:
activation = nn.GLU(dim=1)
ch_scale = 2
else:
activation = nn.ReLU()
ch_scale = 1
in_channels = audio_channels
for index in range(depth):
encode = []
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
if rewrite:
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
self.encoder.append(nn.Sequential(*encode))
decode = []
if index > 0:
out_channels = in_channels
else:
out_channels = len(self.sources) * audio_channels
if rewrite:
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
if index > 0:
decode.append(nn.ReLU())
self.decoder.insert(0, nn.Sequential(*decode))
in_channels = channels
channels = int(growth * channels)
channels = in_channels
if lstm_layers:
self.lstm = BLSTM(channels, lstm_layers)
else:
self.lstm = None
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length when context = 1. If context > 1,
the two signals can be center trimmed to match.
For training, extracts should have a valid length.For evaluation
on full tracks we recommend passing `pad = True` to :method:`forward`.
"""
if self.resample:
length *= 2
for _ in range(self.depth):
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(1, length)
length += self.context - 1
for _ in range(self.depth):
length = (length - 1) * self.stride + self.kernel_size
if self.resample:
length = math.ceil(length / 2)
return int(length)
def forward(self, mix):
x = mix
if self.normalize:
mono = mix.mean(dim=1, keepdim=True)
mean = mono.mean(dim=-1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
else:
mean = 0
std = 1
x = (x - mean) / (1e-5 + std)
if self.resample:
x = julius.resample_frac(x, 1, 2)
saved = []
for encode in self.encoder:
x = encode(x)
saved.append(x)
if self.lstm:
x = self.lstm(x)
for decode in self.decoder:
skip = center_trim(saved.pop(-1), x)
x = x + skip
x = decode(x)
if self.resample:
x = julius.resample_frac(x, 2, 1)
x = x * std + mean
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
return x

View File

@ -1,244 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from pathlib import Path
def get_parser():
parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.")
default_raw = None
default_musdb = None
if 'DEMUCS_RAW' in os.environ:
default_raw = Path(os.environ['DEMUCS_RAW'])
if 'DEMUCS_MUSDB' in os.environ:
default_musdb = Path(os.environ['DEMUCS_MUSDB'])
parser.add_argument(
"--raw",
type=Path,
default=default_raw,
help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.")
parser.add_argument("--no_raw", action="store_const", const=None, dest="raw")
parser.add_argument("-m",
"--musdb",
type=Path,
default=default_musdb,
help="Path to musdb root")
parser.add_argument("--is_wav", action="store_true",
help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).")
parser.add_argument("--metadata", type=Path, default=Path("metadata/"),
help="Folder where metadata information is stored.")
parser.add_argument("--wav", type=Path,
help="Path to a wav dataset. This should contain a 'train' and a 'valid' "
"subfolder.")
parser.add_argument("--samplerate", type=int, default=44100)
parser.add_argument("--audio_channels", type=int, default=2)
parser.add_argument("--samples",
default=44100 * 10,
type=int,
help="number of samples to feed in")
parser.add_argument("--data_stride",
default=44100,
type=int,
help="Stride for chunks, shorter = longer epochs")
parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers")
parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers")
parser.add_argument("-d",
"--device",
help="Device to train on, default is cuda if available else cpu")
parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.")
parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file")
parser.add_argument("--test", help="Just run the test pipeline + one validation. "
"This should be a filename relative to the models/ folder.")
parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, "
"on a pretrained model. ")
parser.add_argument("--rank", default=0, type=int)
parser.add_argument("--world_size", default=1, type=int)
parser.add_argument("--master")
parser.add_argument("--checkpoints",
type=Path,
default=Path("checkpoints"),
help="Folder where to store checkpoints etc")
parser.add_argument("--evals",
type=Path,
default=Path("evals"),
help="Folder where to store evals and waveforms")
parser.add_argument("--save",
action="store_true",
help="Save estimated for the test set waveforms")
parser.add_argument("--logs",
type=Path,
default=Path("logs"),
help="Folder where to store logs")
parser.add_argument("--models",
type=Path,
default=Path("models"),
help="Folder where to store trained models")
parser.add_argument("-R",
"--restart",
action='store_true',
help='Restart training, ignoring previous run')
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs")
parser.add_argument("-r",
"--repeat",
type=int,
default=2,
help="Repeat the train set, longer epochs")
parser.add_argument("-b", "--batch_size", type=int, default=64)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1")
parser.add_argument("--init", help="Initialize from a pre-trained model.")
# Augmentation options
parser.add_argument("--no_augment",
action="store_false",
dest="augment",
default=True,
help="No basic data augmentation.")
parser.add_argument("--repitch", type=float, default=0.2,
help="Probability to do tempo/pitch change")
parser.add_argument("--max_tempo", type=float, default=12,
help="Maximum relative tempo change in %% when using repitch.")
parser.add_argument("--remix_group_size",
type=int,
default=4,
help="Shuffle sources using group of this size. Useful to somewhat "
"replicate multi-gpu training "
"on less GPUs.")
parser.add_argument("--shifts",
type=int,
default=10,
help="Number of random shifts used for the shift trick.")
parser.add_argument("--overlap",
type=float,
default=0.25,
help="Overlap when --split_valid is passed.")
# See model.py for doc
parser.add_argument("--growth",
type=float,
default=2.,
help="Number of channels between two layers will increase by this factor")
parser.add_argument("--depth",
type=int,
default=6,
help="Number of layers for the encoder and decoder")
parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM")
parser.add_argument("--channels",
type=int,
default=64,
help="Number of channels for the first encoder layer")
parser.add_argument("--kernel_size",
type=int,
default=8,
help="Kernel size for the (transposed) convolutions")
parser.add_argument("--conv_stride",
type=int,
default=4,
help="Stride for the (transposed) convolutions")
parser.add_argument("--context",
type=int,
default=3,
help="Context size for the decoder convolutions "
"before the transposed convolutions")
parser.add_argument("--rescale",
type=float,
default=0.1,
help="Initial weight rescale reference")
parser.add_argument("--no_resample", action="store_false",
default=True, dest="resample",
help="No Resampling of the input/output x2")
parser.add_argument("--no_glu",
action="store_false",
default=True,
dest="glu",
help="Replace all GLUs by ReLUs")
parser.add_argument("--no_rewrite",
action="store_false",
default=True,
dest="rewrite",
help="No 1x1 rewrite convolutions")
parser.add_argument("--normalize", action="store_true")
parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True)
# Tasnet options
parser.add_argument("--tasnet", action="store_true")
parser.add_argument("--split_valid",
action="store_true",
help="Predict chunks by chunks for valid and test. Required for tasnet")
parser.add_argument("--X", type=int, default=8)
# Other options
parser.add_argument("--show",
action="store_true",
help="Show model architecture, size and exit")
parser.add_argument("--save_model", action="store_true",
help="Skip traning, just save final model "
"for the current checkpoint value.")
parser.add_argument("--save_state",
help="Skip training, just save state "
"for the current checkpoint value. You should "
"provide a model name as argument.")
# Quantization options
parser.add_argument("--q-min-size", type=float, default=1,
help="Only quantize layers over this size (in MB)")
parser.add_argument(
"--qat", type=int, help="If provided, use QAT training with that many bits.")
parser.add_argument("--diffq", type=float, default=0)
parser.add_argument(
"--ms-target", type=float, default=162,
help="Model size target in MB, when using DiffQ. Best model will be kept "
"only if it is smaller than this target.")
return parser
def get_name(parser, args):
"""
Return the name of an experiment given the args. Some parameters are ignored,
for instance --workers, as they do not impact the final result.
"""
ignore_args = set([
"checkpoints",
"deterministic",
"eval",
"evals",
"eval_cpu",
"eval_workers",
"logs",
"master",
"rank",
"restart",
"save",
"save_model",
"save_state",
"show",
"workers",
"world_size",
])
parts = []
name_args = dict(args.__dict__)
for name, value in name_args.items():
if name in ignore_args:
continue
if value != parser.get_default(name):
if isinstance(value, Path):
parts.append(f"{name}={value.name}")
else:
parts.append(f"{name}={value}")
if parts:
name = " ".join(parts)
else:
name = "default"
return name

View File

@ -1,107 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import logging
from diffq import DiffQuantizer
import torch.hub
from .model import Demucs
from .tasnet import ConvTasNet
from .utils import set_state
logger = logging.getLogger(__name__)
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
PRETRAINED_MODELS = {
'demucs': 'e07c671f',
'demucs48_hq': '28a1282c',
'demucs_extra': '3646af93',
'demucs_quantized': '07afea75',
'tasnet': 'beb46fac',
'tasnet_extra': 'df3777b2',
'demucs_unittest': '09ebc15f',
}
SOURCES = ["drums", "bass", "other", "vocals"]
def get_url(name):
sig = PRETRAINED_MODELS[name]
return ROOT + name + "-" + sig[:8] + ".th"
def is_pretrained(name):
return name in PRETRAINED_MODELS
def load_pretrained(name):
if name == "demucs":
return demucs(pretrained=True)
elif name == "demucs48_hq":
return demucs(pretrained=True, hq=True, channels=48)
elif name == "demucs_extra":
return demucs(pretrained=True, extra=True)
elif name == "demucs_quantized":
return demucs(pretrained=True, quantized=True)
elif name == "demucs_unittest":
return demucs_unittest(pretrained=True)
elif name == "tasnet":
return tasnet(pretrained=True)
elif name == "tasnet_extra":
return tasnet(pretrained=True, extra=True)
else:
raise ValueError(f"Invalid pretrained name {name}")
def _load_state(name, model, quantizer=None):
url = get_url(name)
state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
set_state(model, quantizer, state)
if quantizer:
quantizer.detach()
def demucs_unittest(pretrained=True):
model = Demucs(channels=4, sources=SOURCES)
if pretrained:
_load_state('demucs_unittest', model)
return model
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
if not pretrained and (extra or quantized or hq):
raise ValueError("if extra or quantized is True, pretrained must be True.")
model = Demucs(sources=SOURCES, channels=channels)
if pretrained:
name = 'demucs'
if channels != 64:
name += str(channels)
quantizer = None
if sum([extra, quantized, hq]) > 1:
raise ValueError("Only one of extra, quantized, hq, can be True.")
if quantized:
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
name += '_quantized'
if extra:
name += '_extra'
if hq:
name += '_hq'
_load_state(name, model, quantizer)
return model
def tasnet(pretrained=True, extra=False):
if not pretrained and extra:
raise ValueError("if extra is True, pretrained must be True.")
model = ConvTasNet(X=10, sources=SOURCES)
if pretrained:
name = 'tasnet'
if extra:
name = 'tasnet_extra'
_load_state(name, model)
return model

View File

@ -1,173 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from collections import defaultdict, namedtuple
from pathlib import Path
import musdb
import numpy as np
import torch as th
import tqdm
from torch.utils.data import DataLoader
from .audio import AudioFile
ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"])
class Rawset:
"""
Dataset of raw, normalized, float32 audio files
"""
def __init__(self, path, samples=None, stride=None, channels=2, streams=None):
self.path = Path(path)
self.channels = channels
self.samples = samples
if stride is None:
stride = samples if samples is not None else 0
self.stride = stride
entries = defaultdict(list)
for root, folders, files in os.walk(self.path, followlinks=True):
folders.sort()
files.sort()
for file in files:
if file.endswith(".raw"):
path = Path(root) / file
name, stream = path.stem.rsplit('.', 1)
entries[(path.parent.relative_to(self.path), name)].append(int(stream))
self._entries = list(entries.keys())
sizes = []
self._lengths = []
ref_streams = sorted(entries[self._entries[0]])
assert ref_streams == list(range(len(ref_streams)))
if streams is None:
self.streams = ref_streams
else:
self.streams = streams
for entry in sorted(entries.keys()):
streams = entries[entry]
assert sorted(streams) == ref_streams
file = self._path(*entry)
length = file.stat().st_size // (4 * channels)
if samples is None:
sizes.append(1)
else:
if length < samples:
self._entries.remove(entry)
continue
sizes.append((length - samples) // stride + 1)
self._lengths.append(length)
if not sizes:
raise ValueError(f"Empty dataset {self.path}")
self._cumulative_sizes = np.cumsum(sizes)
self._sizes = sizes
def __len__(self):
return self._cumulative_sizes[-1]
@property
def total_length(self):
return sum(self._lengths)
def chunk_info(self, index):
file_index = np.searchsorted(self._cumulative_sizes, index, side='right')
if file_index == 0:
local_index = index
else:
local_index = index - self._cumulative_sizes[file_index - 1]
return ChunkInfo(offset=local_index * self.stride,
file_index=file_index,
local_index=local_index)
def _path(self, folder, name, stream=0):
return self.path / folder / (name + f'.{stream}.raw')
def __getitem__(self, index):
chunk = self.chunk_info(index)
entry = self._entries[chunk.file_index]
length = self.samples or self._lengths[chunk.file_index]
streams = []
to_read = length * self.channels * 4
for stream_index, stream in enumerate(self.streams):
offset = chunk.offset * 4 * self.channels
file = open(self._path(*entry, stream=stream), 'rb')
file.seek(offset)
content = file.read(to_read)
assert len(content) == to_read
content = np.frombuffer(content, dtype=np.float32)
content = content.copy() # make writable
streams.append(th.from_numpy(content).view(length, self.channels).t())
return th.stack(streams, dim=0)
def name(self, index):
chunk = self.chunk_info(index)
folder, name = self._entries[chunk.file_index]
return folder / name
class MusDBSet:
def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2):
self.mus = mus
self.streams = streams
self.samplerate = samplerate
self.channels = channels
def __len__(self):
return len(self.mus.tracks)
def __getitem__(self, index):
track = self.mus.tracks[index]
return (track.name, AudioFile(track.path).read(channels=self.channels,
seek_time=0,
streams=self.streams,
samplerate=self.samplerate))
def build_raw(mus, destination, normalize, workers, samplerate, channels):
destination.mkdir(parents=True, exist_ok=True)
loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate),
batch_size=1,
num_workers=workers,
collate_fn=lambda x: x[0])
for name, streams in tqdm.tqdm(loader):
if normalize:
ref = streams[0].mean(dim=0) # use mono mixture as reference
streams = (streams - ref.mean()) / ref.std()
for index, stream in enumerate(streams):
open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes())
def main():
parser = argparse.ArgumentParser('rawset')
parser.add_argument('--workers', type=int, default=10)
parser.add_argument('--samplerate', type=int, default=44100)
parser.add_argument('--channels', type=int, default=2)
parser.add_argument('musdb', type=Path)
parser.add_argument('destination', type=Path)
args = parser.parse_args()
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"),
args.destination / "train",
normalize=True,
channels=args.channels,
samplerate=args.samplerate,
workers=args.workers)
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"),
args.destination / "valid",
normalize=True,
samplerate=args.samplerate,
channels=args.channels,
workers=args.workers)
if __name__ == "__main__":
main()

View File

@ -1,185 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import sys
from pathlib import Path
import subprocess
import julius
import torch as th
import torchaudio as ta
from .audio import AudioFile, convert_audio_channels
from .pretrained import is_pretrained, load_pretrained
from .utils import apply_model, load_model
def load_track(track, device, audio_channels, samplerate):
errors = {}
wav = None
try:
wav = AudioFile(track).read(
streams=0,
samplerate=samplerate,
channels=audio_channels).to(device)
except FileNotFoundError:
errors['ffmpeg'] = 'Ffmpeg is not installed.'
except subprocess.CalledProcessError:
errors['ffmpeg'] = 'FFmpeg could not read the file.'
if wav is None:
try:
wav, sr = ta.load(str(track))
except RuntimeError as err:
errors['torchaudio'] = err.args[0]
else:
wav = convert_audio_channels(wav, audio_channels)
wav = wav.to(device)
wav = julius.resample_frac(wav, sr, samplerate)
if wav is None:
print(f"Could not load file {track}. "
"Maybe it is not a supported file format? ")
for backend, error in errors.items():
print(f"When trying to load using {backend}, got the following error: {error}")
sys.exit(1)
return wav
def encode_mp3(wav, path, bitrate=320, samplerate=44100, channels=2, verbose=False):
try:
import lameenc
except ImportError:
print("Failed to call lame encoder. Maybe it is not installed? "
"On windows, run `python.exe -m pip install -U lameenc`, "
"on OSX/Linux, run `python3 -m pip install -U lameenc`, "
"then try again.", file=sys.stderr)
sys.exit(1)
encoder = lameenc.Encoder()
encoder.set_bit_rate(bitrate)
encoder.set_in_sample_rate(samplerate)
encoder.set_channels(channels)
encoder.set_quality(2) # 2-highest, 7-fastest
if not verbose:
encoder.silence()
wav = wav.transpose(0, 1).numpy()
mp3_data = encoder.encode(wav.tobytes())
mp3_data += encoder.flush()
with open(path, "wb") as f:
f.write(mp3_data)
def main():
parser = argparse.ArgumentParser("demucs.separate",
description="Separate the sources for the given tracks")
parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
parser.add_argument("-n",
"--name",
default="demucs_quantized",
help="Model name. See README.md for the list of pretrained models. "
"Default is demucs_quantized.")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-o",
"--out",
type=Path,
default=Path("separated"),
help="Folder where to put extracted tracks. A subfolder "
"with the model name will be created.")
parser.add_argument("--models",
type=Path,
default=Path("models"),
help="Path to trained models. "
"Also used to store downloaded pretrained models")
parser.add_argument("-d",
"--device",
default="cuda" if th.cuda.is_available() else "cpu",
help="Device to use, default is cuda if available else cpu")
parser.add_argument("--shifts",
default=0,
type=int,
help="Number of random shifts for equivariant stabilization."
"Increase separation time but improves quality for Demucs. 10 was used "
"in the original paper.")
parser.add_argument("--overlap",
default=0.25,
type=float,
help="Overlap between the splits.")
parser.add_argument("--no-split",
action="store_false",
dest="split",
default=True,
help="Doesn't split audio in chunks. This can use large amounts of memory.")
parser.add_argument("--float32",
action="store_true",
help="Convert the output wavefile to use pcm f32 format instead of s16. "
"This should not make a difference if you just plan on listening to the "
"audio but might be needed to compute exactly metrics like SDR etc.")
parser.add_argument("--int16",
action="store_false",
dest="float32",
help="Opposite of --float32, here for compatibility.")
parser.add_argument("--mp3", action="store_true",
help="Convert the output wavs to mp3.")
parser.add_argument("--mp3-bitrate",
default=320,
type=int,
help="Bitrate of converted mp3.")
args = parser.parse_args()
name = args.name + ".th"
model_path = args.models / name
if model_path.is_file():
model = load_model(model_path)
else:
if is_pretrained(args.name):
model = load_pretrained(args.name)
else:
print(f"No pre-trained model {args.name}", file=sys.stderr)
sys.exit(1)
model.to(args.device)
out = args.out / args.name
out.mkdir(parents=True, exist_ok=True)
print(f"Separated tracks will be stored in {out.resolve()}")
for track in args.tracks:
if not track.exists():
print(
f"File {track} does not exist. If the path contains spaces, "
"please try again after surrounding the entire path with quotes \"\".",
file=sys.stderr)
continue
print(f"Separating track {track}")
wav = load_track(track, args.device, model.audio_channels, model.samplerate)
ref = wav.mean(0)
wav = (wav - ref.mean()) / ref.std()
sources = apply_model(model, wav, shifts=args.shifts, split=args.split,
overlap=args.overlap, progress=True)
sources = sources * ref.std() + ref.mean()
track_folder = out / track.name.rsplit(".", 1)[0]
track_folder.mkdir(exist_ok=True)
for source, name in zip(sources, model.sources):
source = source / max(1.01 * source.abs().max(), 1)
if args.mp3 or not args.float32:
source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short()
source = source.cpu()
stem = str(track_folder / name)
if args.mp3:
encode_mp3(source, stem + ".mp3",
bitrate=args.mp3_bitrate,
samplerate=model.samplerate,
channels=model.audio_channels,
verbose=args.verbose)
else:
wavname = str(track_folder / f"{name}.wav")
ta.save(wavname, source, sample_rate=model.samplerate)
if __name__ == "__main__":
main()

View File

@ -1,452 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Created on 2018/12
# Author: Kaituo XU
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
# Here is the original license:
# The MIT License (MIT)
#
# Copyright (c) 2018 Kaituo XU
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import capture_init
EPS = 1e-8
def overlap_and_add(signal, frame_step):
outer_dimensions = signal.size()[:-2]
frames, frame_length = signal.size()[-2:]
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
subframe_step = frame_step // subframe_length
subframes_per_frame = frame_length // subframe_length
output_size = frame_step * (frames - 1) + frame_length
output_subframes = output_size // subframe_length
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
frame = torch.arange(0, output_subframes,
device=signal.device).unfold(0, subframes_per_frame, subframe_step)
frame = frame.long() # signal may in GPU or CPU
frame = frame.contiguous().view(-1)
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
result.index_add_(-2, frame, subframe_signal)
result = result.view(*outer_dimensions, -1)
return result
class ConvTasNet(nn.Module):
@capture_init
def __init__(self,
sources,
N=256,
L=20,
B=256,
H=512,
P=3,
X=8,
R=4,
audio_channels=2,
norm_type="gLN",
causal=False,
mask_nonlinear='relu',
samplerate=44100,
segment_length=44100 * 2 * 4):
"""
Args:
sources: list of sources
N: Number of filters in autoencoder
L: Length of the filters (in samples)
B: Number of channels in bottleneck 1 × 1-conv block
H: Number of channels in convolutional blocks
P: Kernel size in convolutional blocks
X: Number of convolutional blocks in each repeat
R: Number of repeats
norm_type: BN, gLN, cLN
causal: causal or non-causal
mask_nonlinear: use which non-linear function to generate mask
"""
super(ConvTasNet, self).__init__()
# Hyper-parameter
self.sources = sources
self.C = len(sources)
self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
self.norm_type = norm_type
self.causal = causal
self.mask_nonlinear = mask_nonlinear
self.audio_channels = audio_channels
self.samplerate = samplerate
self.segment_length = segment_length
# Components
self.encoder = Encoder(L, N, audio_channels)
self.separator = TemporalConvNet(
N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
self.decoder = Decoder(N, L, audio_channels)
# init
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def valid_length(self, length):
return length
def forward(self, mixture):
"""
Args:
mixture: [M, T], M is batch size, T is #samples
Returns:
est_source: [M, C, T]
"""
mixture_w = self.encoder(mixture)
est_mask = self.separator(mixture_w)
est_source = self.decoder(mixture_w, est_mask)
# T changed after conv1d in encoder, fix it here
T_origin = mixture.size(-1)
T_conv = est_source.size(-1)
est_source = F.pad(est_source, (0, T_origin - T_conv))
return est_source
class Encoder(nn.Module):
"""Estimation of the nonnegative mixture weight by a 1-D conv layer.
"""
def __init__(self, L, N, audio_channels):
super(Encoder, self).__init__()
# Hyper-parameter
self.L, self.N = L, N
# Components
# 50% overlap
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
def forward(self, mixture):
"""
Args:
mixture: [M, T], M is batch size, T is #samples
Returns:
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
"""
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
return mixture_w
class Decoder(nn.Module):
def __init__(self, N, L, audio_channels):
super(Decoder, self).__init__()
# Hyper-parameter
self.N, self.L = N, L
self.audio_channels = audio_channels
# Components
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
def forward(self, mixture_w, est_mask):
"""
Args:
mixture_w: [M, N, K]
est_mask: [M, C, N, K]
Returns:
est_source: [M, C, T]
"""
# D = W * M
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
# S = DV
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
m, c, k, _ = est_source.size()
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
return est_source
class TemporalConvNet(nn.Module):
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
"""
Args:
N: Number of filters in autoencoder
B: Number of channels in bottleneck 1 × 1-conv block
H: Number of channels in convolutional blocks
P: Kernel size in convolutional blocks
X: Number of convolutional blocks in each repeat
R: Number of repeats
C: Number of speakers
norm_type: BN, gLN, cLN
causal: causal or non-causal
mask_nonlinear: use which non-linear function to generate mask
"""
super(TemporalConvNet, self).__init__()
# Hyper-parameter
self.C = C
self.mask_nonlinear = mask_nonlinear
# Components
# [M, N, K] -> [M, N, K]
layer_norm = ChannelwiseLayerNorm(N)
# [M, N, K] -> [M, B, K]
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
# [M, B, K] -> [M, B, K]
repeats = []
for r in range(R):
blocks = []
for x in range(X):
dilation = 2**x
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
blocks += [
TemporalBlock(B,
H,
P,
stride=1,
padding=padding,
dilation=dilation,
norm_type=norm_type,
causal=causal)
]
repeats += [nn.Sequential(*blocks)]
temporal_conv_net = nn.Sequential(*repeats)
# [M, B, K] -> [M, C*N, K]
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
# Put together
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
mask_conv1x1)
def forward(self, mixture_w):
"""
Keep this API same with TasNet
Args:
mixture_w: [M, N, K], M is batch size
returns:
est_mask: [M, C, N, K]
"""
M, N, K = mixture_w.size()
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
if self.mask_nonlinear == 'softmax':
est_mask = F.softmax(score, dim=1)
elif self.mask_nonlinear == 'relu':
est_mask = F.relu(score)
else:
raise ValueError("Unsupported mask non-linear function")
return est_mask
class TemporalBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
norm_type="gLN",
causal=False):
super(TemporalBlock, self).__init__()
# [M, B, K] -> [M, H, K]
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
prelu = nn.PReLU()
norm = chose_norm(norm_type, out_channels)
# [M, H, K] -> [M, B, K]
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
dilation, norm_type, causal)
# Put together
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
def forward(self, x):
"""
Args:
x: [M, B, K]
Returns:
[M, B, K]
"""
residual = x
out = self.net(x)
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
return out + residual # look like w/o F.relu is better than w/ F.relu
# return F.relu(out + residual)
class DepthwiseSeparableConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
norm_type="gLN",
causal=False):
super(DepthwiseSeparableConv, self).__init__()
# Use `groups` option to implement depthwise convolution
# [M, H, K] -> [M, H, K]
depthwise_conv = nn.Conv1d(in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=False)
if causal:
chomp = Chomp1d(padding)
prelu = nn.PReLU()
norm = chose_norm(norm_type, in_channels)
# [M, H, K] -> [M, B, K]
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
# Put together
if causal:
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
else:
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
def forward(self, x):
"""
Args:
x: [M, H, K]
Returns:
result: [M, B, K]
"""
return self.net(x)
class Chomp1d(nn.Module):
"""To ensure the output length is the same as the input.
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
"""
Args:
x: [M, H, Kpad]
Returns:
[M, H, K]
"""
return x[:, :, :-self.chomp_size].contiguous()
def chose_norm(norm_type, channel_size):
"""The input of normlization will be (M, C, K), where M is batch size,
C is channel size and K is sequence length.
"""
if norm_type == "gLN":
return GlobalLayerNorm(channel_size)
elif norm_type == "cLN":
return ChannelwiseLayerNorm(channel_size)
elif norm_type == "id":
return nn.Identity()
else: # norm_type == "BN":
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
# along M and K, so this BN usage is right.
return nn.BatchNorm1d(channel_size)
# TODO: Use nn.LayerNorm to impl cLN to speed up
class ChannelwiseLayerNorm(nn.Module):
"""Channel-wise Layer Normalization (cLN)"""
def __init__(self, channel_size):
super(ChannelwiseLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.reset_parameters()
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, y):
"""
Args:
y: [M, N, K], M is batch size, N is channel size, K is length
Returns:
cLN_y: [M, N, K]
"""
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return cLN_y
class GlobalLayerNorm(nn.Module):
"""Global Layer Normalization (gLN)"""
def __init__(self, channel_size):
super(GlobalLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.reset_parameters()
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, y):
"""
Args:
y: [M, N, K], M is batch size, N is channel size, K is length
Returns:
gLN_y: [M, N, K]
"""
# TODO: in torch 1.0, torch.mean() support dim list
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return gLN_y
if __name__ == "__main__":
torch.manual_seed(123)
M, N, L, T = 2, 3, 4, 12
K = 2 * T // L - 1
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
mixture = torch.randint(3, (M, T))
# test Encoder
encoder = Encoder(L, N)
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
mixture_w = encoder(mixture)
print('mixture', mixture)
print('U', encoder.conv1d_U.weight)
print('mixture_w', mixture_w)
print('mixture_w size', mixture_w.size())
# test TemporalConvNet
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
est_mask = separator(mixture_w)
print('est_mask', est_mask)
# test Decoder
decoder = Decoder(N, L)
est_mask = torch.randint(2, (B, K, C, N))
est_source = decoder(mixture_w, est_mask)
print('est_source', est_source)
# test Conv-TasNet
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
est_source = conv_tasnet(mixture)
print('est_source', est_source)
print('est_source size', est_source.size())

View File

@ -1,109 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import gzip
import sys
from concurrent import futures
import musdb
import museval
import torch as th
import tqdm
from scipy.io import wavfile
from torch import distributed
from .audio import convert_audio
from .utils import apply_model
def evaluate(model,
musdb_path,
eval_folder,
workers=2,
device="cpu",
rank=0,
save=False,
shifts=0,
split=False,
overlap=0.25,
is_wav=False,
world_size=1):
"""
Evaluate model using museval. Run the model
on a single GPU, the bottleneck being the call to museval.
"""
output_dir = eval_folder / "results"
output_dir.mkdir(exist_ok=True, parents=True)
json_folder = eval_folder / "results/test"
json_folder.mkdir(exist_ok=True, parents=True)
# we load tracks from the original musdb set
test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav)
src_rate = 44100 # hardcoded for now...
for p in model.parameters():
p.requires_grad = False
p.grad = None
pendings = []
with futures.ProcessPoolExecutor(workers or 1) as pool:
for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout):
track = test_set.tracks[index]
out = json_folder / f"{track.name}.json.gz"
if out.exists():
continue
mix = th.from_numpy(track.audio).t().float()
ref = mix.mean(dim=0) # mono mixture
mix = (mix - ref.mean()) / ref.std()
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
estimates = apply_model(model, mix.to(device),
shifts=shifts, split=split, overlap=overlap)
estimates = estimates * ref.std() + ref.mean()
estimates = estimates.transpose(1, 2)
references = th.stack(
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
references = convert_audio(references, src_rate,
model.samplerate, model.audio_channels)
references = references.transpose(1, 2).numpy()
estimates = estimates.cpu().numpy()
win = int(1. * model.samplerate)
hop = int(1. * model.samplerate)
if save:
folder = eval_folder / "wav/test" / track.name
folder.mkdir(exist_ok=True, parents=True)
for name, estimate in zip(model.sources, estimates):
wavfile.write(str(folder / (name + ".wav")), 44100, estimate)
if workers:
pendings.append((track.name, pool.submit(
museval.evaluate, references, estimates, win=win, hop=hop)))
else:
pendings.append((track.name, museval.evaluate(
references, estimates, win=win, hop=hop)))
del references, mix, estimates, track
for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout):
if workers:
pending = pending.result()
sdr, isr, sir, sar = pending
track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name)
for idx, target in enumerate(model.sources):
values = {
"SDR": sdr[idx].tolist(),
"SIR": sir[idx].tolist(),
"ISR": isr[idx].tolist(),
"SAR": sar[idx].tolist()
}
track_store.add_target(target_name=target, values=values)
json_path = json_folder / f"{track_name}.json.gz"
gzip.open(json_path, "w").write(track_store.json.encode('utf-8'))
if world_size > 1:
distributed.barrier()

View File

@ -1,323 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import errno
import functools
import hashlib
import inspect
import io
import os
import random
import socket
import tempfile
import warnings
import zlib
from contextlib import contextmanager
from diffq import UniformQuantizer, DiffQuantizer
import torch as th
import tqdm
from torch import distributed
from torch.nn import functional as F
def center_trim(tensor, reference):
"""
Center trim `tensor` with respect to `reference`, along the last dimension.
`reference` can also be a number, representing the length to trim to.
If the size difference != 0 mod 2, the extra sample is removed on the right side.
"""
if hasattr(reference, "size"):
reference = reference.size(-1)
delta = tensor.size(-1) - reference
if delta < 0:
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
if delta:
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
return tensor
def average_metric(metric, count=1.):
"""
Average `metric` which should be a float across all hosts. `count` should be
the weight for this particular host (i.e. number of examples).
"""
metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda')
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
return metric[1].item() / metric[0].item()
def free_port(host='', low=20000, high=40000):
"""
Return a port number that is most likely free.
This could suffer from a race condition although
it should be quite rare.
"""
sock = socket.socket()
while True:
port = random.randint(low, high)
try:
sock.bind((host, port))
except OSError as error:
if error.errno == errno.EADDRINUSE:
continue
raise
return port
def sizeof_fmt(num, suffix='B'):
"""
Given `num` bytes, return human readable size.
Taken from https://stackoverflow.com/a/1094933
"""
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
if abs(num) < 1024.0:
return "%3.1f%s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f%s%s" % (num, 'Yi', suffix)
def human_seconds(seconds, display='.2f'):
"""
Given `seconds` seconds, return human readable duration.
"""
value = seconds * 1e6
ratios = [1e3, 1e3, 60, 60, 24]
names = ['us', 'ms', 's', 'min', 'hrs', 'days']
last = names.pop(0)
for name, ratio in zip(names, ratios):
if value / ratio < 0.3:
break
value /= ratio
last = name
return f"{format(value, display)} {last}"
class TensorChunk:
def __init__(self, tensor, offset=0, length=None):
total_length = tensor.shape[-1]
assert offset >= 0
assert offset < total_length
if length is None:
length = total_length - offset
else:
length = min(total_length - offset, length)
self.tensor = tensor
self.offset = offset
self.length = length
self.device = tensor.device
@property
def shape(self):
shape = list(self.tensor.shape)
shape[-1] = self.length
return shape
def padded(self, target_length):
delta = target_length - self.length
total_length = self.tensor.shape[-1]
assert delta >= 0
start = self.offset - delta // 2
end = start + target_length
correct_start = max(0, start)
correct_end = min(total_length, end)
pad_left = correct_start - start
pad_right = end - correct_end
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
assert out.shape[-1] == target_length
return out
def tensor_chunk(tensor_or_chunk):
if isinstance(tensor_or_chunk, TensorChunk):
return tensor_or_chunk
else:
assert isinstance(tensor_or_chunk, th.Tensor)
return TensorChunk(tensor_or_chunk)
def apply_model(model, mix, shifts=None, split=False,
overlap=0.25, transition_power=1., progress=False):
"""
Apply model to a given mixture.
Args:
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
and apply the oppositve shift to the output. This is repeated `shifts` time and
all predictions are averaged. This effectively makes the model time equivariant
and improves SDR by up to 0.2 points.
split (bool): if True, the input will be broken down in 8 seconds extracts
and predictions will be performed individually on each and concatenated.
Useful for model with large memory footprint like Tasnet.
progress (bool): if True, show a progress bar (requires split=True)
"""
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
device = mix.device
channels, length = mix.shape
if split:
out = th.zeros(len(model.sources), channels, length, device=device)
sum_weight = th.zeros(length, device=device)
segment = model.segment_length
stride = int((1 - overlap) * segment)
offsets = range(0, length, stride)
scale = stride / model.samplerate
if progress:
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
# We start from a triangle shaped weight, with maximal weight in the middle
# of the segment. Then we normalize and take to the power `transition_power`.
# Large values of transition power will lead to sharper transitions.
weight = th.cat([th.arange(1, segment // 2 + 1),
th.arange(segment - segment // 2, 0, -1)]).to(device)
assert len(weight) == segment
# If the overlap < 50%, this will translate to linear transition when
# transition_power is 1.
weight = (weight / weight.max())**transition_power
for offset in offsets:
chunk = TensorChunk(mix, offset, segment)
chunk_out = apply_model(model, chunk, shifts=shifts)
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out
sum_weight[offset:offset + segment] += weight[:chunk_length]
offset += segment
assert sum_weight.min() > 0
out /= sum_weight
return out
elif shifts:
max_shift = int(0.5 * model.samplerate)
mix = tensor_chunk(mix)
padded_mix = mix.padded(length + 2 * max_shift)
out = 0
for _ in range(shifts):
offset = random.randint(0, max_shift)
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
shifted_out = apply_model(model, shifted)
out += shifted_out[..., max_shift - offset:]
out /= shifts
return out
else:
valid_length = model.valid_length(length)
mix = tensor_chunk(mix)
padded_mix = mix.padded(valid_length)
with th.no_grad():
out = model(padded_mix.unsqueeze(0))[0]
return center_trim(out, length)
@contextmanager
def temp_filenames(count, delete=True):
names = []
try:
for _ in range(count):
names.append(tempfile.NamedTemporaryFile(delete=False).name)
yield names
finally:
if delete:
for name in names:
os.unlink(name)
def get_quantizer(model, args, optimizer=None):
quantizer = None
if args.diffq:
quantizer = DiffQuantizer(
model, min_size=args.q_min_size, group_size=8)
if optimizer is not None:
quantizer.setup_optimizer(optimizer)
elif args.qat:
quantizer = UniformQuantizer(
model, bits=args.qat, min_size=args.q_min_size)
return quantizer
def load_model(path, strict=False):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
load_from = path
package = th.load(load_from, 'cpu')
klass = package["klass"]
args = package["args"]
kwargs = package["kwargs"]
if strict:
model = klass(*args, **kwargs)
else:
sig = inspect.signature(klass)
for key in list(kwargs):
if key not in sig.parameters:
warnings.warn("Dropping inexistant parameter " + key)
del kwargs[key]
model = klass(*args, **kwargs)
state = package["state"]
training_args = package["training_args"]
quantizer = get_quantizer(model, training_args)
set_state(model, quantizer, state)
return model
def get_state(model, quantizer):
if quantizer is None:
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
else:
state = quantizer.get_quantized_state()
buf = io.BytesIO()
th.save(state, buf)
state = {'compressed': zlib.compress(buf.getvalue())}
return state
def set_state(model, quantizer, state):
if quantizer is None:
model.load_state_dict(state)
else:
buf = io.BytesIO(zlib.decompress(state["compressed"]))
state = th.load(buf, "cpu")
quantizer.restore_quantized_state(state)
return state
def save_state(state, path):
buf = io.BytesIO()
th.save(state, buf)
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
path = path.parent / (path.stem + "-" + sig + path.suffix)
path.write_bytes(buf.getvalue())
def save_model(model, quantizer, training_args, path):
args, kwargs = model._init_args_kwargs
klass = model.__class__
state = get_state(model, quantizer)
save_to = path
package = {
'klass': klass,
'args': args,
'kwargs': kwargs,
'state': state,
'training_args': training_args,
}
th.save(package, save_to)
def capture_init(init):
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__

View File

@ -1,174 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import hashlib
import math
import json
from pathlib import Path
import julius
import torch as th
from torch import distributed
import torchaudio as ta
from torch.nn import functional as F
from .audio import convert_audio_channels
from .compressed import get_musdb_tracks
MIXTURE = "mixture"
EXT = ".wav"
def _track_metadata(track, sources):
track_length = None
track_samplerate = None
for source in sources + [MIXTURE]:
file = track / f"{source}{EXT}"
info = ta.info(str(file))
length = info.num_frames
if track_length is None:
track_length = length
track_samplerate = info.sample_rate
elif track_length != length:
raise ValueError(
f"Invalid length for file {file}: "
f"expecting {track_length} but got {length}.")
elif info.sample_rate != track_samplerate:
raise ValueError(
f"Invalid sample rate for file {file}: "
f"expecting {track_samplerate} but got {info.sample_rate}.")
if source == MIXTURE:
wav, _ = ta.load(str(file))
wav = wav.mean(0)
mean = wav.mean().item()
std = wav.std().item()
return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}
def _build_metadata(path, sources):
meta = {}
path = Path(path)
for file in path.iterdir():
meta[file.name] = _track_metadata(file, sources)
return meta
class Wavset:
def __init__(
self,
root, metadata, sources,
length=None, stride=None, normalize=True,
samplerate=44100, channels=2):
"""
Waveset (or mp3 set for that matter). Can be used to train
with arbitrary sources. Each track should be one folder inside of `path`.
The folder should contain files named `{source}.{ext}`.
Files will be grouped according to `sources` (each source is a list of
filenames).
Sample rate and channels will be converted on the fly.
`length` is the sample size to extract (in samples, not duration).
`stride` is how many samples to move by between each example.
"""
self.root = Path(root)
self.metadata = OrderedDict(metadata)
self.length = length
self.stride = stride or length
self.normalize = normalize
self.sources = sources
self.channels = channels
self.samplerate = samplerate
self.num_examples = []
for name, meta in self.metadata.items():
track_length = int(self.samplerate * meta['length'] / meta['samplerate'])
if length is None or track_length < length:
examples = 1
else:
examples = int(math.ceil((track_length - self.length) / self.stride) + 1)
self.num_examples.append(examples)
def __len__(self):
return sum(self.num_examples)
def get_file(self, name, source):
return self.root / name / f"{source}{EXT}"
def __getitem__(self, index):
for name, examples in zip(self.metadata, self.num_examples):
if index >= examples:
index -= examples
continue
meta = self.metadata[name]
num_frames = -1
offset = 0
if self.length is not None:
offset = int(math.ceil(
meta['samplerate'] * self.stride * index / self.samplerate))
num_frames = int(math.ceil(
meta['samplerate'] * self.length / self.samplerate))
wavs = []
for source in self.sources:
file = self.get_file(name, source)
wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
wav = convert_audio_channels(wav, self.channels)
wavs.append(wav)
example = th.stack(wavs)
example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
if self.normalize:
example = (example - meta['mean']) / meta['std']
if self.length:
example = example[..., :self.length]
example = F.pad(example, (0, self.length - example.shape[-1]))
return example
def get_wav_datasets(args, samples, sources):
sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8]
metadata_file = args.metadata / (sig + ".json")
train_path = args.wav / "train"
valid_path = args.wav / "valid"
if not metadata_file.is_file() and args.rank == 0:
train = _build_metadata(train_path, sources)
valid = _build_metadata(valid_path, sources)
json.dump([train, valid], open(metadata_file, "w"))
if args.world_size > 1:
distributed.barrier()
train, valid = json.load(open(metadata_file))
train_set = Wavset(train_path, train, sources,
length=samples, stride=args.data_stride,
samplerate=args.samplerate, channels=args.audio_channels,
normalize=args.norm_wav)
valid_set = Wavset(valid_path, valid, [MIXTURE] + sources,
samplerate=args.samplerate, channels=args.audio_channels,
normalize=args.norm_wav)
return train_set, valid_set
def get_musdb_wav_datasets(args, samples, sources):
metadata_file = args.metadata / "musdb_wav.json"
root = args.musdb / "train"
if not metadata_file.is_file() and args.rank == 0:
metadata = _build_metadata(root, sources)
json.dump(metadata, open(metadata_file, "w"))
if args.world_size > 1:
distributed.barrier()
metadata = json.load(open(metadata_file))
train_tracks = get_musdb_tracks(args.musdb, is_wav=True, subsets=["train"], split="train")
metadata_train = {name: meta for name, meta in metadata.items() if name in train_tracks}
metadata_valid = {name: meta for name, meta in metadata.items() if name not in train_tracks}
train_set = Wavset(root, metadata_train, sources,
length=samples, stride=args.data_stride,
samplerate=args.samplerate, channels=args.audio_channels,
normalize=args.norm_wav)
valid_set = Wavset(root, metadata_valid, [MIXTURE] + sources,
samplerate=args.samplerate, channels=args.audio_channels,
normalize=args.norm_wav)
return train_set, valid_set