mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2025-01-18 17:04:04 +01:00
Add files via upload
This commit is contained in:
parent
bc5195e68c
commit
58fa38822a
7
demucs/__init__.py
Normal file
7
demucs/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
__version__ = "2.0.3"
|
317
demucs/__main__.py
Normal file
317
demucs/__main__.py
Normal file
@ -0,0 +1,317 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch as th
|
||||
from torch import distributed, nn
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from .augment import FlipChannels, FlipSign, Remix, Scale, Shift
|
||||
from .compressed import get_compressed_datasets
|
||||
from .model import Demucs
|
||||
from .parser import get_name, get_parser
|
||||
from .raw import Rawset
|
||||
from .repitch import RepitchedWrapper
|
||||
from .pretrained import load_pretrained, SOURCES
|
||||
from .tasnet import ConvTasNet
|
||||
from .test import evaluate
|
||||
from .train import train_model, validate_model
|
||||
from .utils import (human_seconds, load_model, save_model, get_state,
|
||||
save_state, sizeof_fmt, get_quantizer)
|
||||
from .wav import get_wav_datasets, get_musdb_wav_datasets
|
||||
|
||||
|
||||
@dataclass
|
||||
class SavedState:
|
||||
metrics: list = field(default_factory=list)
|
||||
last_state: dict = None
|
||||
best_state: dict = None
|
||||
optimizer: dict = None
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
name = get_name(parser, args)
|
||||
print(f"Experiment {name}")
|
||||
|
||||
if args.musdb is None and args.rank == 0:
|
||||
print(
|
||||
"You must provide the path to the MusDB dataset with the --musdb flag. "
|
||||
"To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
eval_folder = args.evals / name
|
||||
eval_folder.mkdir(exist_ok=True, parents=True)
|
||||
args.logs.mkdir(exist_ok=True)
|
||||
metrics_path = args.logs / f"{name}.json"
|
||||
eval_folder.mkdir(exist_ok=True, parents=True)
|
||||
args.checkpoints.mkdir(exist_ok=True, parents=True)
|
||||
args.models.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if args.device is None:
|
||||
device = "cpu"
|
||||
if th.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = args.device
|
||||
|
||||
th.manual_seed(args.seed)
|
||||
# Prevents too many threads to be started when running `museval` as it can be quite
|
||||
# inefficient on NUMA architectures.
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
if args.world_size > 1:
|
||||
if device != "cuda" and args.rank == 0:
|
||||
print("Error: distributed training is only available with cuda device", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
th.cuda.set_device(args.rank % th.cuda.device_count())
|
||||
distributed.init_process_group(backend="nccl",
|
||||
init_method="tcp://" + args.master,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size)
|
||||
|
||||
checkpoint = args.checkpoints / f"{name}.th"
|
||||
checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
|
||||
if args.restart and checkpoint.exists() and args.rank == 0:
|
||||
checkpoint.unlink()
|
||||
|
||||
if args.test or args.test_pretrained:
|
||||
args.epochs = 1
|
||||
args.repeat = 0
|
||||
if args.test:
|
||||
model = load_model(args.models / args.test)
|
||||
else:
|
||||
model = load_pretrained(args.test_pretrained)
|
||||
elif args.tasnet:
|
||||
model = ConvTasNet(audio_channels=args.audio_channels,
|
||||
samplerate=args.samplerate, X=args.X,
|
||||
segment_length=4 * args.samples,
|
||||
sources=SOURCES)
|
||||
else:
|
||||
model = Demucs(
|
||||
audio_channels=args.audio_channels,
|
||||
channels=args.channels,
|
||||
context=args.context,
|
||||
depth=args.depth,
|
||||
glu=args.glu,
|
||||
growth=args.growth,
|
||||
kernel_size=args.kernel_size,
|
||||
lstm_layers=args.lstm_layers,
|
||||
rescale=args.rescale,
|
||||
rewrite=args.rewrite,
|
||||
stride=args.conv_stride,
|
||||
resample=args.resample,
|
||||
normalize=args.normalize,
|
||||
samplerate=args.samplerate,
|
||||
segment_length=4 * args.samples,
|
||||
sources=SOURCES,
|
||||
)
|
||||
model.to(device)
|
||||
if args.init:
|
||||
model.load_state_dict(load_pretrained(args.init).state_dict())
|
||||
|
||||
if args.show:
|
||||
print(model)
|
||||
size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
|
||||
print(f"Model size {size}")
|
||||
return
|
||||
|
||||
try:
|
||||
saved = th.load(checkpoint, map_location='cpu')
|
||||
except IOError:
|
||||
saved = SavedState()
|
||||
|
||||
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
quantizer = None
|
||||
quantizer = get_quantizer(model, args, optimizer)
|
||||
|
||||
if saved.last_state is not None:
|
||||
model.load_state_dict(saved.last_state, strict=False)
|
||||
if saved.optimizer is not None:
|
||||
optimizer.load_state_dict(saved.optimizer)
|
||||
|
||||
model_name = f"{name}.th"
|
||||
if args.save_model:
|
||||
if args.rank == 0:
|
||||
model.to("cpu")
|
||||
model.load_state_dict(saved.best_state)
|
||||
save_model(model, quantizer, args, args.models / model_name)
|
||||
return
|
||||
elif args.save_state:
|
||||
model_name = f"{args.save_state}.th"
|
||||
if args.rank == 0:
|
||||
model.to("cpu")
|
||||
model.load_state_dict(saved.best_state)
|
||||
state = get_state(model, quantizer)
|
||||
save_state(state, args.models / model_name)
|
||||
return
|
||||
|
||||
if args.rank == 0:
|
||||
done = args.logs / f"{name}.done"
|
||||
if done.exists():
|
||||
done.unlink()
|
||||
|
||||
augment = [Shift(args.data_stride)]
|
||||
if args.augment:
|
||||
augment += [FlipSign(), FlipChannels(), Scale(),
|
||||
Remix(group_size=args.remix_group_size)]
|
||||
augment = nn.Sequential(*augment).to(device)
|
||||
print("Agumentation pipeline:", augment)
|
||||
|
||||
if args.mse:
|
||||
criterion = nn.MSELoss()
|
||||
else:
|
||||
criterion = nn.L1Loss()
|
||||
|
||||
# Setting number of samples so that all convolution windows are full.
|
||||
# Prevents hard to debug mistake with the prediction being shifted compared
|
||||
# to the input mixture.
|
||||
samples = model.valid_length(args.samples)
|
||||
print(f"Number of training samples adjusted to {samples}")
|
||||
samples = samples + args.data_stride
|
||||
if args.repitch:
|
||||
# We need a bit more audio samples, to account for potential
|
||||
# tempo change.
|
||||
samples = math.ceil(samples / (1 - 0.01 * args.max_tempo))
|
||||
|
||||
args.metadata.mkdir(exist_ok=True, parents=True)
|
||||
if args.raw:
|
||||
train_set = Rawset(args.raw / "train",
|
||||
samples=samples,
|
||||
channels=args.audio_channels,
|
||||
streams=range(1, len(model.sources) + 1),
|
||||
stride=args.data_stride)
|
||||
|
||||
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
|
||||
elif args.wav:
|
||||
train_set, valid_set = get_wav_datasets(args, samples, model.sources)
|
||||
elif args.is_wav:
|
||||
train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources)
|
||||
else:
|
||||
train_set, valid_set = get_compressed_datasets(args, samples)
|
||||
|
||||
if args.repitch:
|
||||
train_set = RepitchedWrapper(
|
||||
train_set,
|
||||
proba=args.repitch,
|
||||
max_tempo=args.max_tempo)
|
||||
|
||||
best_loss = float("inf")
|
||||
for epoch, metrics in enumerate(saved.metrics):
|
||||
print(f"Epoch {epoch:03d}: "
|
||||
f"train={metrics['train']:.8f} "
|
||||
f"valid={metrics['valid']:.8f} "
|
||||
f"best={metrics['best']:.4f} "
|
||||
f"ms={metrics.get('true_model_size', 0):.2f}MB "
|
||||
f"cms={metrics.get('compressed_model_size', 0):.2f}MB "
|
||||
f"duration={human_seconds(metrics['duration'])}")
|
||||
best_loss = metrics['best']
|
||||
|
||||
if args.world_size > 1:
|
||||
dmodel = DistributedDataParallel(model,
|
||||
device_ids=[th.cuda.current_device()],
|
||||
output_device=th.cuda.current_device())
|
||||
else:
|
||||
dmodel = model
|
||||
|
||||
for epoch in range(len(saved.metrics), args.epochs):
|
||||
begin = time.time()
|
||||
model.train()
|
||||
train_loss, model_size = train_model(
|
||||
epoch, train_set, dmodel, criterion, optimizer, augment,
|
||||
quantizer=quantizer,
|
||||
batch_size=args.batch_size,
|
||||
device=device,
|
||||
repeat=args.repeat,
|
||||
seed=args.seed,
|
||||
diffq=args.diffq,
|
||||
workers=args.workers,
|
||||
world_size=args.world_size)
|
||||
model.eval()
|
||||
valid_loss = validate_model(
|
||||
epoch, valid_set, model, criterion,
|
||||
device=device,
|
||||
rank=args.rank,
|
||||
split=args.split_valid,
|
||||
overlap=args.overlap,
|
||||
world_size=args.world_size)
|
||||
|
||||
ms = 0
|
||||
cms = 0
|
||||
if quantizer and args.rank == 0:
|
||||
ms = quantizer.true_model_size()
|
||||
cms = quantizer.compressed_model_size(num_workers=min(40, args.world_size * 10))
|
||||
|
||||
duration = time.time() - begin
|
||||
if valid_loss < best_loss and ms <= args.ms_target:
|
||||
best_loss = valid_loss
|
||||
saved.best_state = {
|
||||
key: value.to("cpu").clone()
|
||||
for key, value in model.state_dict().items()
|
||||
}
|
||||
|
||||
saved.metrics.append({
|
||||
"train": train_loss,
|
||||
"valid": valid_loss,
|
||||
"best": best_loss,
|
||||
"duration": duration,
|
||||
"model_size": model_size,
|
||||
"true_model_size": ms,
|
||||
"compressed_model_size": cms,
|
||||
})
|
||||
if args.rank == 0:
|
||||
json.dump(saved.metrics, open(metrics_path, "w"))
|
||||
|
||||
saved.last_state = model.state_dict()
|
||||
saved.optimizer = optimizer.state_dict()
|
||||
if args.rank == 0 and not args.test:
|
||||
th.save(saved, checkpoint_tmp)
|
||||
checkpoint_tmp.rename(checkpoint)
|
||||
|
||||
print(f"Epoch {epoch:03d}: "
|
||||
f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB "
|
||||
f"cms={cms:.2f}MB "
|
||||
f"duration={human_seconds(duration)}")
|
||||
|
||||
if args.world_size > 1:
|
||||
distributed.barrier()
|
||||
|
||||
del dmodel
|
||||
model.load_state_dict(saved.best_state)
|
||||
if args.eval_cpu:
|
||||
device = "cpu"
|
||||
model.to(device)
|
||||
model.eval()
|
||||
evaluate(model, args.musdb, eval_folder,
|
||||
is_wav=args.is_wav,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
device=device,
|
||||
save=args.save,
|
||||
split=args.split_valid,
|
||||
shifts=args.shifts,
|
||||
overlap=args.overlap,
|
||||
workers=args.eval_workers)
|
||||
model.to("cpu")
|
||||
if args.rank == 0:
|
||||
if not (args.test or args.test_pretrained):
|
||||
save_model(model, quantizer, args, args.models / model_name)
|
||||
print("done")
|
||||
done.write_text("done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
172
demucs/audio.py
Normal file
172
demucs/audio.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import json
|
||||
import subprocess as sp
|
||||
from pathlib import Path
|
||||
|
||||
import julius
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils import temp_filenames
|
||||
|
||||
|
||||
def _read_info(path):
|
||||
stdout_data = sp.check_output([
|
||||
'ffprobe', "-loglevel", "panic",
|
||||
str(path), '-print_format', 'json', '-show_format', '-show_streams'
|
||||
])
|
||||
return json.loads(stdout_data.decode('utf-8'))
|
||||
|
||||
|
||||
class AudioFile:
|
||||
"""
|
||||
Allows to read audio from any format supported by ffmpeg, as well as resampling or
|
||||
converting to mono on the fly. See :method:`read` for more details.
|
||||
"""
|
||||
def __init__(self, path: Path):
|
||||
self.path = Path(path)
|
||||
self._info = None
|
||||
|
||||
def __repr__(self):
|
||||
features = [("path", self.path)]
|
||||
features.append(("samplerate", self.samplerate()))
|
||||
features.append(("channels", self.channels()))
|
||||
features.append(("streams", len(self)))
|
||||
features_str = ", ".join(f"{name}={value}" for name, value in features)
|
||||
return f"AudioFile({features_str})"
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
if self._info is None:
|
||||
self._info = _read_info(self.path)
|
||||
return self._info
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
return float(self.info['format']['duration'])
|
||||
|
||||
@property
|
||||
def _audio_streams(self):
|
||||
return [
|
||||
index for index, stream in enumerate(self.info["streams"])
|
||||
if stream["codec_type"] == "audio"
|
||||
]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._audio_streams)
|
||||
|
||||
def channels(self, stream=0):
|
||||
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
|
||||
|
||||
def samplerate(self, stream=0):
|
||||
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
|
||||
|
||||
def read(self,
|
||||
seek_time=None,
|
||||
duration=None,
|
||||
streams=slice(None),
|
||||
samplerate=None,
|
||||
channels=None,
|
||||
temp_folder=None):
|
||||
"""
|
||||
Slightly more efficient implementation than stempeg,
|
||||
in particular, this will extract all stems at once
|
||||
rather than having to loop over one file multiple times
|
||||
for each stream.
|
||||
|
||||
Args:
|
||||
seek_time (float): seek time in seconds or None if no seeking is needed.
|
||||
duration (float): duration in seconds to extract or None to extract until the end.
|
||||
streams (slice, int or list): streams to extract, can be a single int, a list or
|
||||
a slice. If it is a slice or list, the output will be of size [S, C, T]
|
||||
with S the number of streams, C the number of channels and T the number of samples.
|
||||
If it is an int, the output will be [C, T].
|
||||
samplerate (int): if provided, will resample on the fly. If None, no resampling will
|
||||
be done. Original sampling rate can be obtained with :method:`samplerate`.
|
||||
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
|
||||
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
|
||||
See https://sound.stackexchange.com/a/42710.
|
||||
Our definition of mono is simply the average of the two channels. Any other
|
||||
value will be ignored.
|
||||
temp_folder (str or Path or None): temporary folder to use for decoding.
|
||||
|
||||
|
||||
"""
|
||||
streams = np.array(range(len(self)))[streams]
|
||||
single = not isinstance(streams, np.ndarray)
|
||||
if single:
|
||||
streams = [streams]
|
||||
|
||||
if duration is None:
|
||||
target_size = None
|
||||
query_duration = None
|
||||
else:
|
||||
target_size = int((samplerate or self.samplerate()) * duration)
|
||||
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
|
||||
|
||||
with temp_filenames(len(streams)) as filenames:
|
||||
command = ['ffmpeg', '-y']
|
||||
command += ['-loglevel', 'panic']
|
||||
if seek_time:
|
||||
command += ['-ss', str(seek_time)]
|
||||
command += ['-i', str(self.path)]
|
||||
for stream, filename in zip(streams, filenames):
|
||||
command += ['-map', f'0:{self._audio_streams[stream]}']
|
||||
if query_duration is not None:
|
||||
command += ['-t', str(query_duration)]
|
||||
command += ['-threads', '1']
|
||||
command += ['-f', 'f32le']
|
||||
if samplerate is not None:
|
||||
command += ['-ar', str(samplerate)]
|
||||
command += [filename]
|
||||
|
||||
sp.run(command, check=True)
|
||||
wavs = []
|
||||
for filename in filenames:
|
||||
wav = np.fromfile(filename, dtype=np.float32)
|
||||
wav = torch.from_numpy(wav)
|
||||
wav = wav.view(-1, self.channels()).t()
|
||||
if channels is not None:
|
||||
wav = convert_audio_channels(wav, channels)
|
||||
if target_size is not None:
|
||||
wav = wav[..., :target_size]
|
||||
wavs.append(wav)
|
||||
wav = torch.stack(wavs, dim=0)
|
||||
if single:
|
||||
wav = wav[0]
|
||||
return wav
|
||||
|
||||
|
||||
def convert_audio_channels(wav, channels=2):
|
||||
"""Convert audio to the given number of channels."""
|
||||
*shape, src_channels, length = wav.shape
|
||||
if src_channels == channels:
|
||||
pass
|
||||
elif channels == 1:
|
||||
# Case 1:
|
||||
# The caller asked 1-channel audio, but the stream have multiple
|
||||
# channels, downmix all channels.
|
||||
wav = wav.mean(dim=-2, keepdim=True)
|
||||
elif src_channels == 1:
|
||||
# Case 2:
|
||||
# The caller asked for multiple channels, but the input file have
|
||||
# one single channel, replicate the audio over all channels.
|
||||
wav = wav.expand(*shape, channels, length)
|
||||
elif src_channels >= channels:
|
||||
# Case 3:
|
||||
# The caller asked for multiple channels, and the input file have
|
||||
# more channels than requested. In that case return the first channels.
|
||||
wav = wav[..., :channels, :]
|
||||
else:
|
||||
# Case 4: What is a reasonable choice here?
|
||||
raise ValueError('The audio file has less channels than requested but is not mono.')
|
||||
return wav
|
||||
|
||||
|
||||
def convert_audio(wav, from_samplerate, to_samplerate, channels):
|
||||
wav = convert_audio_channels(wav, channels)
|
||||
return julius.resample_frac(wav, from_samplerate, to_samplerate)
|
115
demucs/compressed.py
Normal file
115
demucs/compressed.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from fractions import Fraction
|
||||
from concurrent import futures
|
||||
|
||||
import musdb
|
||||
from torch import distributed
|
||||
|
||||
from .audio import AudioFile
|
||||
|
||||
|
||||
def get_musdb_tracks(root, *args, **kwargs):
|
||||
mus = musdb.DB(root, *args, **kwargs)
|
||||
return {track.name: track.path for track in mus}
|
||||
|
||||
|
||||
class StemsSet:
|
||||
def __init__(self, tracks, metadata, duration=None, stride=1,
|
||||
samplerate=44100, channels=2, streams=slice(None)):
|
||||
|
||||
self.metadata = []
|
||||
for name, path in tracks.items():
|
||||
meta = dict(metadata[name])
|
||||
meta["path"] = path
|
||||
meta["name"] = name
|
||||
self.metadata.append(meta)
|
||||
if duration is not None and meta["duration"] < duration:
|
||||
raise ValueError(f"Track {name} duration is too small {meta['duration']}")
|
||||
self.metadata.sort(key=lambda x: x["name"])
|
||||
self.duration = duration
|
||||
self.stride = stride
|
||||
self.channels = channels
|
||||
self.samplerate = samplerate
|
||||
self.streams = streams
|
||||
|
||||
def __len__(self):
|
||||
return sum(self._examples_count(m) for m in self.metadata)
|
||||
|
||||
def _examples_count(self, meta):
|
||||
if self.duration is None:
|
||||
return 1
|
||||
else:
|
||||
return int((meta["duration"] - self.duration) // self.stride + 1)
|
||||
|
||||
def track_metadata(self, index):
|
||||
for meta in self.metadata:
|
||||
examples = self._examples_count(meta)
|
||||
if index >= examples:
|
||||
index -= examples
|
||||
continue
|
||||
return meta
|
||||
|
||||
def __getitem__(self, index):
|
||||
for meta in self.metadata:
|
||||
examples = self._examples_count(meta)
|
||||
if index >= examples:
|
||||
index -= examples
|
||||
continue
|
||||
streams = AudioFile(meta["path"]).read(seek_time=index * self.stride,
|
||||
duration=self.duration,
|
||||
channels=self.channels,
|
||||
samplerate=self.samplerate,
|
||||
streams=self.streams)
|
||||
return (streams - meta["mean"]) / meta["std"]
|
||||
|
||||
|
||||
def _get_track_metadata(path):
|
||||
# use mono at 44kHz as reference. For any other settings data won't be perfectly
|
||||
# normalized but it should be good enough.
|
||||
audio = AudioFile(path)
|
||||
mix = audio.read(streams=0, channels=1, samplerate=44100)
|
||||
return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()}
|
||||
|
||||
|
||||
def _build_metadata(tracks, workers=10):
|
||||
pendings = []
|
||||
with futures.ProcessPoolExecutor(workers) as pool:
|
||||
for name, path in tracks.items():
|
||||
pendings.append((name, pool.submit(_get_track_metadata, path)))
|
||||
return {name: p.result() for name, p in pendings}
|
||||
|
||||
|
||||
def _build_musdb_metadata(path, musdb, workers):
|
||||
tracks = get_musdb_tracks(musdb)
|
||||
metadata = _build_metadata(tracks, workers)
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
json.dump(metadata, open(path, "w"))
|
||||
|
||||
|
||||
def get_compressed_datasets(args, samples):
|
||||
metadata_file = args.metadata / "musdb.json"
|
||||
if not metadata_file.is_file() and args.rank == 0:
|
||||
_build_musdb_metadata(metadata_file, args.musdb, args.workers)
|
||||
if args.world_size > 1:
|
||||
distributed.barrier()
|
||||
metadata = json.load(open(metadata_file))
|
||||
duration = Fraction(samples, args.samplerate)
|
||||
stride = Fraction(args.data_stride, args.samplerate)
|
||||
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
|
||||
metadata,
|
||||
duration=duration,
|
||||
stride=stride,
|
||||
streams=slice(1, None),
|
||||
samplerate=args.samplerate,
|
||||
channels=args.audio_channels)
|
||||
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
|
||||
metadata,
|
||||
samplerate=args.samplerate,
|
||||
channels=args.audio_channels)
|
||||
return train_set, valid_set
|
202
demucs/model.py
Normal file
202
demucs/model.py
Normal file
@ -0,0 +1,202 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import julius
|
||||
from torch import nn
|
||||
|
||||
from .utils import capture_init, center_trim
|
||||
|
||||
|
||||
class BLSTM(nn.Module):
|
||||
def __init__(self, dim, layers=1):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
||||
self.linear = nn.Linear(2 * dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(2, 0, 1)
|
||||
x = self.lstm(x)[0]
|
||||
x = self.linear(x)
|
||||
x = x.permute(1, 2, 0)
|
||||
return x
|
||||
|
||||
|
||||
def rescale_conv(conv, reference):
|
||||
std = conv.weight.std().detach()
|
||||
scale = (std / reference)**0.5
|
||||
conv.weight.data /= scale
|
||||
if conv.bias is not None:
|
||||
conv.bias.data /= scale
|
||||
|
||||
|
||||
def rescale_module(module, reference):
|
||||
for sub in module.modules():
|
||||
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||
rescale_conv(sub, reference)
|
||||
|
||||
|
||||
class Demucs(nn.Module):
|
||||
@capture_init
|
||||
def __init__(self,
|
||||
sources,
|
||||
audio_channels=2,
|
||||
channels=64,
|
||||
depth=6,
|
||||
rewrite=True,
|
||||
glu=True,
|
||||
rescale=0.1,
|
||||
resample=True,
|
||||
kernel_size=8,
|
||||
stride=4,
|
||||
growth=2.,
|
||||
lstm_layers=2,
|
||||
context=3,
|
||||
normalize=False,
|
||||
samplerate=44100,
|
||||
segment_length=4 * 10 * 44100):
|
||||
"""
|
||||
Args:
|
||||
sources (list[str]): list of source names
|
||||
audio_channels (int): stereo or mono
|
||||
channels (int): first convolution channels
|
||||
depth (int): number of encoder/decoder layers
|
||||
rewrite (bool): add 1x1 convolution to each encoder layer
|
||||
and a convolution to each decoder layer.
|
||||
For the decoder layer, `context` gives the kernel size.
|
||||
glu (bool): use glu instead of ReLU
|
||||
resample_input (bool): upsample x2 the input and downsample /2 the output.
|
||||
rescale (int): rescale initial weights of convolutions
|
||||
to get their standard deviation closer to `rescale`
|
||||
kernel_size (int): kernel size for convolutions
|
||||
stride (int): stride for convolutions
|
||||
growth (float): multiply (resp divide) number of channels by that
|
||||
for each layer of the encoder (resp decoder)
|
||||
lstm_layers (int): number of lstm layers, 0 = no lstm
|
||||
context (int): kernel size of the convolution in the
|
||||
decoder before the transposed convolution. If > 1,
|
||||
will provide some context from neighboring time
|
||||
steps.
|
||||
samplerate (int): stored as meta information for easing
|
||||
future evaluations of the model.
|
||||
segment_length (int): stored as meta information for easing
|
||||
future evaluations of the model. Length of the segments on which
|
||||
the model was trained.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.audio_channels = audio_channels
|
||||
self.sources = sources
|
||||
self.kernel_size = kernel_size
|
||||
self.context = context
|
||||
self.stride = stride
|
||||
self.depth = depth
|
||||
self.resample = resample
|
||||
self.channels = channels
|
||||
self.normalize = normalize
|
||||
self.samplerate = samplerate
|
||||
self.segment_length = segment_length
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
if glu:
|
||||
activation = nn.GLU(dim=1)
|
||||
ch_scale = 2
|
||||
else:
|
||||
activation = nn.ReLU()
|
||||
ch_scale = 1
|
||||
in_channels = audio_channels
|
||||
for index in range(depth):
|
||||
encode = []
|
||||
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
||||
if rewrite:
|
||||
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
||||
self.encoder.append(nn.Sequential(*encode))
|
||||
|
||||
decode = []
|
||||
if index > 0:
|
||||
out_channels = in_channels
|
||||
else:
|
||||
out_channels = len(self.sources) * audio_channels
|
||||
if rewrite:
|
||||
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
||||
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
||||
if index > 0:
|
||||
decode.append(nn.ReLU())
|
||||
self.decoder.insert(0, nn.Sequential(*decode))
|
||||
in_channels = channels
|
||||
channels = int(growth * channels)
|
||||
|
||||
channels = in_channels
|
||||
|
||||
if lstm_layers:
|
||||
self.lstm = BLSTM(channels, lstm_layers)
|
||||
else:
|
||||
self.lstm = None
|
||||
|
||||
if rescale:
|
||||
rescale_module(self, reference=rescale)
|
||||
|
||||
def valid_length(self, length):
|
||||
"""
|
||||
Return the nearest valid length to use with the model so that
|
||||
there is no time steps left over in a convolutions, e.g. for all
|
||||
layers, size of the input - kernel_size % stride = 0.
|
||||
|
||||
If the mixture has a valid length, the estimated sources
|
||||
will have exactly the same length when context = 1. If context > 1,
|
||||
the two signals can be center trimmed to match.
|
||||
|
||||
For training, extracts should have a valid length.For evaluation
|
||||
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
||||
"""
|
||||
if self.resample:
|
||||
length *= 2
|
||||
for _ in range(self.depth):
|
||||
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
||||
length = max(1, length)
|
||||
length += self.context - 1
|
||||
for _ in range(self.depth):
|
||||
length = (length - 1) * self.stride + self.kernel_size
|
||||
|
||||
if self.resample:
|
||||
length = math.ceil(length / 2)
|
||||
return int(length)
|
||||
|
||||
def forward(self, mix):
|
||||
x = mix
|
||||
|
||||
if self.normalize:
|
||||
mono = mix.mean(dim=1, keepdim=True)
|
||||
mean = mono.mean(dim=-1, keepdim=True)
|
||||
std = mono.std(dim=-1, keepdim=True)
|
||||
else:
|
||||
mean = 0
|
||||
std = 1
|
||||
|
||||
x = (x - mean) / (1e-5 + std)
|
||||
|
||||
if self.resample:
|
||||
x = julius.resample_frac(x, 1, 2)
|
||||
|
||||
saved = []
|
||||
for encode in self.encoder:
|
||||
x = encode(x)
|
||||
saved.append(x)
|
||||
if self.lstm:
|
||||
x = self.lstm(x)
|
||||
for decode in self.decoder:
|
||||
skip = center_trim(saved.pop(-1), x)
|
||||
x = x + skip
|
||||
x = decode(x)
|
||||
|
||||
if self.resample:
|
||||
x = julius.resample_frac(x, 2, 1)
|
||||
x = x * std + mean
|
||||
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
||||
return x
|
244
demucs/parser.py
Normal file
244
demucs/parser.py
Normal file
@ -0,0 +1,244 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.")
|
||||
default_raw = None
|
||||
default_musdb = None
|
||||
if 'DEMUCS_RAW' in os.environ:
|
||||
default_raw = Path(os.environ['DEMUCS_RAW'])
|
||||
if 'DEMUCS_MUSDB' in os.environ:
|
||||
default_musdb = Path(os.environ['DEMUCS_MUSDB'])
|
||||
parser.add_argument(
|
||||
"--raw",
|
||||
type=Path,
|
||||
default=default_raw,
|
||||
help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.")
|
||||
parser.add_argument("--no_raw", action="store_const", const=None, dest="raw")
|
||||
parser.add_argument("-m",
|
||||
"--musdb",
|
||||
type=Path,
|
||||
default=default_musdb,
|
||||
help="Path to musdb root")
|
||||
parser.add_argument("--is_wav", action="store_true",
|
||||
help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).")
|
||||
parser.add_argument("--metadata", type=Path, default=Path("metadata/"),
|
||||
help="Folder where metadata information is stored.")
|
||||
parser.add_argument("--wav", type=Path,
|
||||
help="Path to a wav dataset. This should contain a 'train' and a 'valid' "
|
||||
"subfolder.")
|
||||
parser.add_argument("--samplerate", type=int, default=44100)
|
||||
parser.add_argument("--audio_channels", type=int, default=2)
|
||||
parser.add_argument("--samples",
|
||||
default=44100 * 10,
|
||||
type=int,
|
||||
help="number of samples to feed in")
|
||||
parser.add_argument("--data_stride",
|
||||
default=44100,
|
||||
type=int,
|
||||
help="Stride for chunks, shorter = longer epochs")
|
||||
parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers")
|
||||
parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers")
|
||||
parser.add_argument("-d",
|
||||
"--device",
|
||||
help="Device to train on, default is cuda if available else cpu")
|
||||
parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.")
|
||||
parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file")
|
||||
parser.add_argument("--test", help="Just run the test pipeline + one validation. "
|
||||
"This should be a filename relative to the models/ folder.")
|
||||
parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, "
|
||||
"on a pretrained model. ")
|
||||
|
||||
parser.add_argument("--rank", default=0, type=int)
|
||||
parser.add_argument("--world_size", default=1, type=int)
|
||||
parser.add_argument("--master")
|
||||
|
||||
parser.add_argument("--checkpoints",
|
||||
type=Path,
|
||||
default=Path("checkpoints"),
|
||||
help="Folder where to store checkpoints etc")
|
||||
parser.add_argument("--evals",
|
||||
type=Path,
|
||||
default=Path("evals"),
|
||||
help="Folder where to store evals and waveforms")
|
||||
parser.add_argument("--save",
|
||||
action="store_true",
|
||||
help="Save estimated for the test set waveforms")
|
||||
parser.add_argument("--logs",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Folder where to store logs")
|
||||
parser.add_argument("--models",
|
||||
type=Path,
|
||||
default=Path("models"),
|
||||
help="Folder where to store trained models")
|
||||
parser.add_argument("-R",
|
||||
"--restart",
|
||||
action='store_true',
|
||||
help='Restart training, ignoring previous run')
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs")
|
||||
parser.add_argument("-r",
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Repeat the train set, longer epochs")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=64)
|
||||
parser.add_argument("--lr", type=float, default=3e-4)
|
||||
parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1")
|
||||
parser.add_argument("--init", help="Initialize from a pre-trained model.")
|
||||
|
||||
# Augmentation options
|
||||
parser.add_argument("--no_augment",
|
||||
action="store_false",
|
||||
dest="augment",
|
||||
default=True,
|
||||
help="No basic data augmentation.")
|
||||
parser.add_argument("--repitch", type=float, default=0.2,
|
||||
help="Probability to do tempo/pitch change")
|
||||
parser.add_argument("--max_tempo", type=float, default=12,
|
||||
help="Maximum relative tempo change in %% when using repitch.")
|
||||
|
||||
parser.add_argument("--remix_group_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Shuffle sources using group of this size. Useful to somewhat "
|
||||
"replicate multi-gpu training "
|
||||
"on less GPUs.")
|
||||
parser.add_argument("--shifts",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of random shifts used for the shift trick.")
|
||||
parser.add_argument("--overlap",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="Overlap when --split_valid is passed.")
|
||||
|
||||
# See model.py for doc
|
||||
parser.add_argument("--growth",
|
||||
type=float,
|
||||
default=2.,
|
||||
help="Number of channels between two layers will increase by this factor")
|
||||
parser.add_argument("--depth",
|
||||
type=int,
|
||||
default=6,
|
||||
help="Number of layers for the encoder and decoder")
|
||||
parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM")
|
||||
parser.add_argument("--channels",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of channels for the first encoder layer")
|
||||
parser.add_argument("--kernel_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Kernel size for the (transposed) convolutions")
|
||||
parser.add_argument("--conv_stride",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Stride for the (transposed) convolutions")
|
||||
parser.add_argument("--context",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Context size for the decoder convolutions "
|
||||
"before the transposed convolutions")
|
||||
parser.add_argument("--rescale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Initial weight rescale reference")
|
||||
parser.add_argument("--no_resample", action="store_false",
|
||||
default=True, dest="resample",
|
||||
help="No Resampling of the input/output x2")
|
||||
parser.add_argument("--no_glu",
|
||||
action="store_false",
|
||||
default=True,
|
||||
dest="glu",
|
||||
help="Replace all GLUs by ReLUs")
|
||||
parser.add_argument("--no_rewrite",
|
||||
action="store_false",
|
||||
default=True,
|
||||
dest="rewrite",
|
||||
help="No 1x1 rewrite convolutions")
|
||||
parser.add_argument("--normalize", action="store_true")
|
||||
parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True)
|
||||
|
||||
# Tasnet options
|
||||
parser.add_argument("--tasnet", action="store_true")
|
||||
parser.add_argument("--split_valid",
|
||||
action="store_true",
|
||||
help="Predict chunks by chunks for valid and test. Required for tasnet")
|
||||
parser.add_argument("--X", type=int, default=8)
|
||||
|
||||
# Other options
|
||||
parser.add_argument("--show",
|
||||
action="store_true",
|
||||
help="Show model architecture, size and exit")
|
||||
parser.add_argument("--save_model", action="store_true",
|
||||
help="Skip traning, just save final model "
|
||||
"for the current checkpoint value.")
|
||||
parser.add_argument("--save_state",
|
||||
help="Skip training, just save state "
|
||||
"for the current checkpoint value. You should "
|
||||
"provide a model name as argument.")
|
||||
|
||||
# Quantization options
|
||||
parser.add_argument("--q-min-size", type=float, default=1,
|
||||
help="Only quantize layers over this size (in MB)")
|
||||
parser.add_argument(
|
||||
"--qat", type=int, help="If provided, use QAT training with that many bits.")
|
||||
|
||||
parser.add_argument("--diffq", type=float, default=0)
|
||||
parser.add_argument(
|
||||
"--ms-target", type=float, default=162,
|
||||
help="Model size target in MB, when using DiffQ. Best model will be kept "
|
||||
"only if it is smaller than this target.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_name(parser, args):
|
||||
"""
|
||||
Return the name of an experiment given the args. Some parameters are ignored,
|
||||
for instance --workers, as they do not impact the final result.
|
||||
"""
|
||||
ignore_args = set([
|
||||
"checkpoints",
|
||||
"deterministic",
|
||||
"eval",
|
||||
"evals",
|
||||
"eval_cpu",
|
||||
"eval_workers",
|
||||
"logs",
|
||||
"master",
|
||||
"rank",
|
||||
"restart",
|
||||
"save",
|
||||
"save_model",
|
||||
"save_state",
|
||||
"show",
|
||||
"workers",
|
||||
"world_size",
|
||||
])
|
||||
parts = []
|
||||
name_args = dict(args.__dict__)
|
||||
for name, value in name_args.items():
|
||||
if name in ignore_args:
|
||||
continue
|
||||
if value != parser.get_default(name):
|
||||
if isinstance(value, Path):
|
||||
parts.append(f"{name}={value.name}")
|
||||
else:
|
||||
parts.append(f"{name}={value}")
|
||||
if parts:
|
||||
name = " ".join(parts)
|
||||
else:
|
||||
name = "default"
|
||||
return name
|
107
demucs/pretrained.py
Normal file
107
demucs/pretrained.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# author: adefossez
|
||||
|
||||
import logging
|
||||
|
||||
from diffq import DiffQuantizer
|
||||
import torch.hub
|
||||
|
||||
from .model import Demucs
|
||||
from .tasnet import ConvTasNet
|
||||
from .utils import set_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
|
||||
|
||||
PRETRAINED_MODELS = {
|
||||
'demucs': 'e07c671f',
|
||||
'demucs48_hq': '28a1282c',
|
||||
'demucs_extra': '3646af93',
|
||||
'demucs_quantized': '07afea75',
|
||||
'tasnet': 'beb46fac',
|
||||
'tasnet_extra': 'df3777b2',
|
||||
'demucs_unittest': '09ebc15f',
|
||||
}
|
||||
|
||||
SOURCES = ["drums", "bass", "other", "vocals"]
|
||||
|
||||
|
||||
def get_url(name):
|
||||
sig = PRETRAINED_MODELS[name]
|
||||
return ROOT + name + "-" + sig[:8] + ".th"
|
||||
|
||||
|
||||
def is_pretrained(name):
|
||||
return name in PRETRAINED_MODELS
|
||||
|
||||
|
||||
def load_pretrained(name):
|
||||
if name == "demucs":
|
||||
return demucs(pretrained=True)
|
||||
elif name == "demucs48_hq":
|
||||
return demucs(pretrained=True, hq=True, channels=48)
|
||||
elif name == "demucs_extra":
|
||||
return demucs(pretrained=True, extra=True)
|
||||
elif name == "demucs_quantized":
|
||||
return demucs(pretrained=True, quantized=True)
|
||||
elif name == "demucs_unittest":
|
||||
return demucs_unittest(pretrained=True)
|
||||
elif name == "tasnet":
|
||||
return tasnet(pretrained=True)
|
||||
elif name == "tasnet_extra":
|
||||
return tasnet(pretrained=True, extra=True)
|
||||
else:
|
||||
raise ValueError(f"Invalid pretrained name {name}")
|
||||
|
||||
|
||||
def _load_state(name, model, quantizer=None):
|
||||
url = get_url(name)
|
||||
state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
|
||||
set_state(model, quantizer, state)
|
||||
if quantizer:
|
||||
quantizer.detach()
|
||||
|
||||
|
||||
def demucs_unittest(pretrained=True):
|
||||
model = Demucs(channels=4, sources=SOURCES)
|
||||
if pretrained:
|
||||
_load_state('demucs_unittest', model)
|
||||
return model
|
||||
|
||||
|
||||
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
|
||||
if not pretrained and (extra or quantized or hq):
|
||||
raise ValueError("if extra or quantized is True, pretrained must be True.")
|
||||
model = Demucs(sources=SOURCES, channels=channels)
|
||||
if pretrained:
|
||||
name = 'demucs'
|
||||
if channels != 64:
|
||||
name += str(channels)
|
||||
quantizer = None
|
||||
if sum([extra, quantized, hq]) > 1:
|
||||
raise ValueError("Only one of extra, quantized, hq, can be True.")
|
||||
if quantized:
|
||||
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
|
||||
name += '_quantized'
|
||||
if extra:
|
||||
name += '_extra'
|
||||
if hq:
|
||||
name += '_hq'
|
||||
_load_state(name, model, quantizer)
|
||||
return model
|
||||
|
||||
|
||||
def tasnet(pretrained=True, extra=False):
|
||||
if not pretrained and extra:
|
||||
raise ValueError("if extra is True, pretrained must be True.")
|
||||
model = ConvTasNet(X=10, sources=SOURCES)
|
||||
if pretrained:
|
||||
name = 'tasnet'
|
||||
if extra:
|
||||
name = 'tasnet_extra'
|
||||
_load_state(name, model)
|
||||
return model
|
173
demucs/raw.py
Normal file
173
demucs/raw.py
Normal file
@ -0,0 +1,173 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from collections import defaultdict, namedtuple
|
||||
from pathlib import Path
|
||||
|
||||
import musdb
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .audio import AudioFile
|
||||
|
||||
ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"])
|
||||
|
||||
|
||||
class Rawset:
|
||||
"""
|
||||
Dataset of raw, normalized, float32 audio files
|
||||
"""
|
||||
def __init__(self, path, samples=None, stride=None, channels=2, streams=None):
|
||||
self.path = Path(path)
|
||||
self.channels = channels
|
||||
self.samples = samples
|
||||
if stride is None:
|
||||
stride = samples if samples is not None else 0
|
||||
self.stride = stride
|
||||
entries = defaultdict(list)
|
||||
for root, folders, files in os.walk(self.path, followlinks=True):
|
||||
folders.sort()
|
||||
files.sort()
|
||||
for file in files:
|
||||
if file.endswith(".raw"):
|
||||
path = Path(root) / file
|
||||
name, stream = path.stem.rsplit('.', 1)
|
||||
entries[(path.parent.relative_to(self.path), name)].append(int(stream))
|
||||
|
||||
self._entries = list(entries.keys())
|
||||
|
||||
sizes = []
|
||||
self._lengths = []
|
||||
ref_streams = sorted(entries[self._entries[0]])
|
||||
assert ref_streams == list(range(len(ref_streams)))
|
||||
if streams is None:
|
||||
self.streams = ref_streams
|
||||
else:
|
||||
self.streams = streams
|
||||
for entry in sorted(entries.keys()):
|
||||
streams = entries[entry]
|
||||
assert sorted(streams) == ref_streams
|
||||
file = self._path(*entry)
|
||||
length = file.stat().st_size // (4 * channels)
|
||||
if samples is None:
|
||||
sizes.append(1)
|
||||
else:
|
||||
if length < samples:
|
||||
self._entries.remove(entry)
|
||||
continue
|
||||
sizes.append((length - samples) // stride + 1)
|
||||
self._lengths.append(length)
|
||||
if not sizes:
|
||||
raise ValueError(f"Empty dataset {self.path}")
|
||||
self._cumulative_sizes = np.cumsum(sizes)
|
||||
self._sizes = sizes
|
||||
|
||||
def __len__(self):
|
||||
return self._cumulative_sizes[-1]
|
||||
|
||||
@property
|
||||
def total_length(self):
|
||||
return sum(self._lengths)
|
||||
|
||||
def chunk_info(self, index):
|
||||
file_index = np.searchsorted(self._cumulative_sizes, index, side='right')
|
||||
if file_index == 0:
|
||||
local_index = index
|
||||
else:
|
||||
local_index = index - self._cumulative_sizes[file_index - 1]
|
||||
return ChunkInfo(offset=local_index * self.stride,
|
||||
file_index=file_index,
|
||||
local_index=local_index)
|
||||
|
||||
def _path(self, folder, name, stream=0):
|
||||
return self.path / folder / (name + f'.{stream}.raw')
|
||||
|
||||
def __getitem__(self, index):
|
||||
chunk = self.chunk_info(index)
|
||||
entry = self._entries[chunk.file_index]
|
||||
|
||||
length = self.samples or self._lengths[chunk.file_index]
|
||||
streams = []
|
||||
to_read = length * self.channels * 4
|
||||
for stream_index, stream in enumerate(self.streams):
|
||||
offset = chunk.offset * 4 * self.channels
|
||||
file = open(self._path(*entry, stream=stream), 'rb')
|
||||
file.seek(offset)
|
||||
content = file.read(to_read)
|
||||
assert len(content) == to_read
|
||||
content = np.frombuffer(content, dtype=np.float32)
|
||||
content = content.copy() # make writable
|
||||
streams.append(th.from_numpy(content).view(length, self.channels).t())
|
||||
return th.stack(streams, dim=0)
|
||||
|
||||
def name(self, index):
|
||||
chunk = self.chunk_info(index)
|
||||
folder, name = self._entries[chunk.file_index]
|
||||
return folder / name
|
||||
|
||||
|
||||
class MusDBSet:
|
||||
def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2):
|
||||
self.mus = mus
|
||||
self.streams = streams
|
||||
self.samplerate = samplerate
|
||||
self.channels = channels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mus.tracks)
|
||||
|
||||
def __getitem__(self, index):
|
||||
track = self.mus.tracks[index]
|
||||
return (track.name, AudioFile(track.path).read(channels=self.channels,
|
||||
seek_time=0,
|
||||
streams=self.streams,
|
||||
samplerate=self.samplerate))
|
||||
|
||||
|
||||
def build_raw(mus, destination, normalize, workers, samplerate, channels):
|
||||
destination.mkdir(parents=True, exist_ok=True)
|
||||
loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate),
|
||||
batch_size=1,
|
||||
num_workers=workers,
|
||||
collate_fn=lambda x: x[0])
|
||||
for name, streams in tqdm.tqdm(loader):
|
||||
if normalize:
|
||||
ref = streams[0].mean(dim=0) # use mono mixture as reference
|
||||
streams = (streams - ref.mean()) / ref.std()
|
||||
for index, stream in enumerate(streams):
|
||||
open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser('rawset')
|
||||
parser.add_argument('--workers', type=int, default=10)
|
||||
parser.add_argument('--samplerate', type=int, default=44100)
|
||||
parser.add_argument('--channels', type=int, default=2)
|
||||
parser.add_argument('musdb', type=Path)
|
||||
parser.add_argument('destination', type=Path)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"),
|
||||
args.destination / "train",
|
||||
normalize=True,
|
||||
channels=args.channels,
|
||||
samplerate=args.samplerate,
|
||||
workers=args.workers)
|
||||
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"),
|
||||
args.destination / "valid",
|
||||
normalize=True,
|
||||
samplerate=args.samplerate,
|
||||
channels=args.channels,
|
||||
workers=args.workers)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
185
demucs/separate.py
Normal file
185
demucs/separate.py
Normal file
@ -0,0 +1,185 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
import julius
|
||||
import torch as th
|
||||
import torchaudio as ta
|
||||
|
||||
from .audio import AudioFile, convert_audio_channels
|
||||
from .pretrained import is_pretrained, load_pretrained
|
||||
from .utils import apply_model, load_model
|
||||
|
||||
|
||||
def load_track(track, device, audio_channels, samplerate):
|
||||
errors = {}
|
||||
wav = None
|
||||
|
||||
try:
|
||||
wav = AudioFile(track).read(
|
||||
streams=0,
|
||||
samplerate=samplerate,
|
||||
channels=audio_channels).to(device)
|
||||
except FileNotFoundError:
|
||||
errors['ffmpeg'] = 'Ffmpeg is not installed.'
|
||||
except subprocess.CalledProcessError:
|
||||
errors['ffmpeg'] = 'FFmpeg could not read the file.'
|
||||
|
||||
if wav is None:
|
||||
try:
|
||||
wav, sr = ta.load(str(track))
|
||||
except RuntimeError as err:
|
||||
errors['torchaudio'] = err.args[0]
|
||||
else:
|
||||
wav = convert_audio_channels(wav, audio_channels)
|
||||
wav = wav.to(device)
|
||||
wav = julius.resample_frac(wav, sr, samplerate)
|
||||
|
||||
if wav is None:
|
||||
print(f"Could not load file {track}. "
|
||||
"Maybe it is not a supported file format? ")
|
||||
for backend, error in errors.items():
|
||||
print(f"When trying to load using {backend}, got the following error: {error}")
|
||||
sys.exit(1)
|
||||
return wav
|
||||
|
||||
|
||||
def encode_mp3(wav, path, bitrate=320, samplerate=44100, channels=2, verbose=False):
|
||||
try:
|
||||
import lameenc
|
||||
except ImportError:
|
||||
print("Failed to call lame encoder. Maybe it is not installed? "
|
||||
"On windows, run `python.exe -m pip install -U lameenc`, "
|
||||
"on OSX/Linux, run `python3 -m pip install -U lameenc`, "
|
||||
"then try again.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
encoder = lameenc.Encoder()
|
||||
encoder.set_bit_rate(bitrate)
|
||||
encoder.set_in_sample_rate(samplerate)
|
||||
encoder.set_channels(channels)
|
||||
encoder.set_quality(2) # 2-highest, 7-fastest
|
||||
if not verbose:
|
||||
encoder.silence()
|
||||
wav = wav.transpose(0, 1).numpy()
|
||||
mp3_data = encoder.encode(wav.tobytes())
|
||||
mp3_data += encoder.flush()
|
||||
with open(path, "wb") as f:
|
||||
f.write(mp3_data)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("demucs.separate",
|
||||
description="Separate the sources for the given tracks")
|
||||
parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
|
||||
parser.add_argument("-n",
|
||||
"--name",
|
||||
default="demucs_quantized",
|
||||
help="Model name. See README.md for the list of pretrained models. "
|
||||
"Default is demucs_quantized.")
|
||||
parser.add_argument("-v", "--verbose", action="store_true")
|
||||
parser.add_argument("-o",
|
||||
"--out",
|
||||
type=Path,
|
||||
default=Path("separated"),
|
||||
help="Folder where to put extracted tracks. A subfolder "
|
||||
"with the model name will be created.")
|
||||
parser.add_argument("--models",
|
||||
type=Path,
|
||||
default=Path("models"),
|
||||
help="Path to trained models. "
|
||||
"Also used to store downloaded pretrained models")
|
||||
parser.add_argument("-d",
|
||||
"--device",
|
||||
default="cuda" if th.cuda.is_available() else "cpu",
|
||||
help="Device to use, default is cuda if available else cpu")
|
||||
parser.add_argument("--shifts",
|
||||
default=0,
|
||||
type=int,
|
||||
help="Number of random shifts for equivariant stabilization."
|
||||
"Increase separation time but improves quality for Demucs. 10 was used "
|
||||
"in the original paper.")
|
||||
parser.add_argument("--overlap",
|
||||
default=0.25,
|
||||
type=float,
|
||||
help="Overlap between the splits.")
|
||||
parser.add_argument("--no-split",
|
||||
action="store_false",
|
||||
dest="split",
|
||||
default=True,
|
||||
help="Doesn't split audio in chunks. This can use large amounts of memory.")
|
||||
parser.add_argument("--float32",
|
||||
action="store_true",
|
||||
help="Convert the output wavefile to use pcm f32 format instead of s16. "
|
||||
"This should not make a difference if you just plan on listening to the "
|
||||
"audio but might be needed to compute exactly metrics like SDR etc.")
|
||||
parser.add_argument("--int16",
|
||||
action="store_false",
|
||||
dest="float32",
|
||||
help="Opposite of --float32, here for compatibility.")
|
||||
parser.add_argument("--mp3", action="store_true",
|
||||
help="Convert the output wavs to mp3.")
|
||||
parser.add_argument("--mp3-bitrate",
|
||||
default=320,
|
||||
type=int,
|
||||
help="Bitrate of converted mp3.")
|
||||
|
||||
args = parser.parse_args()
|
||||
name = args.name + ".th"
|
||||
model_path = args.models / name
|
||||
if model_path.is_file():
|
||||
model = load_model(model_path)
|
||||
else:
|
||||
if is_pretrained(args.name):
|
||||
model = load_pretrained(args.name)
|
||||
else:
|
||||
print(f"No pre-trained model {args.name}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
model.to(args.device)
|
||||
|
||||
out = args.out / args.name
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Separated tracks will be stored in {out.resolve()}")
|
||||
for track in args.tracks:
|
||||
if not track.exists():
|
||||
print(
|
||||
f"File {track} does not exist. If the path contains spaces, "
|
||||
"please try again after surrounding the entire path with quotes \"\".",
|
||||
file=sys.stderr)
|
||||
continue
|
||||
print(f"Separating track {track}")
|
||||
wav = load_track(track, args.device, model.audio_channels, model.samplerate)
|
||||
|
||||
ref = wav.mean(0)
|
||||
wav = (wav - ref.mean()) / ref.std()
|
||||
sources = apply_model(model, wav, shifts=args.shifts, split=args.split,
|
||||
overlap=args.overlap, progress=True)
|
||||
sources = sources * ref.std() + ref.mean()
|
||||
|
||||
track_folder = out / track.name.rsplit(".", 1)[0]
|
||||
track_folder.mkdir(exist_ok=True)
|
||||
for source, name in zip(sources, model.sources):
|
||||
source = source / max(1.01 * source.abs().max(), 1)
|
||||
if args.mp3 or not args.float32:
|
||||
source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short()
|
||||
source = source.cpu()
|
||||
stem = str(track_folder / name)
|
||||
if args.mp3:
|
||||
encode_mp3(source, stem + ".mp3",
|
||||
bitrate=args.mp3_bitrate,
|
||||
samplerate=model.samplerate,
|
||||
channels=model.audio_channels,
|
||||
verbose=args.verbose)
|
||||
else:
|
||||
wavname = str(track_folder / f"{name}.wav")
|
||||
ta.save(wavname, source, sample_rate=model.samplerate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
452
demucs/tasnet.py
Normal file
452
demucs/tasnet.py
Normal file
@ -0,0 +1,452 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Created on 2018/12
|
||||
# Author: Kaituo XU
|
||||
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
||||
# Here is the original license:
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Copyright (c) 2018 Kaituo XU
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import capture_init
|
||||
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
def overlap_and_add(signal, frame_step):
|
||||
outer_dimensions = signal.size()[:-2]
|
||||
frames, frame_length = signal.size()[-2:]
|
||||
|
||||
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
||||
subframe_step = frame_step // subframe_length
|
||||
subframes_per_frame = frame_length // subframe_length
|
||||
output_size = frame_step * (frames - 1) + frame_length
|
||||
output_subframes = output_size // subframe_length
|
||||
|
||||
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
||||
|
||||
frame = torch.arange(0, output_subframes,
|
||||
device=signal.device).unfold(0, subframes_per_frame, subframe_step)
|
||||
frame = frame.long() # signal may in GPU or CPU
|
||||
frame = frame.contiguous().view(-1)
|
||||
|
||||
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
||||
result.index_add_(-2, frame, subframe_signal)
|
||||
result = result.view(*outer_dimensions, -1)
|
||||
return result
|
||||
|
||||
|
||||
class ConvTasNet(nn.Module):
|
||||
@capture_init
|
||||
def __init__(self,
|
||||
sources,
|
||||
N=256,
|
||||
L=20,
|
||||
B=256,
|
||||
H=512,
|
||||
P=3,
|
||||
X=8,
|
||||
R=4,
|
||||
audio_channels=2,
|
||||
norm_type="gLN",
|
||||
causal=False,
|
||||
mask_nonlinear='relu',
|
||||
samplerate=44100,
|
||||
segment_length=44100 * 2 * 4):
|
||||
"""
|
||||
Args:
|
||||
sources: list of sources
|
||||
N: Number of filters in autoencoder
|
||||
L: Length of the filters (in samples)
|
||||
B: Number of channels in bottleneck 1 × 1-conv block
|
||||
H: Number of channels in convolutional blocks
|
||||
P: Kernel size in convolutional blocks
|
||||
X: Number of convolutional blocks in each repeat
|
||||
R: Number of repeats
|
||||
norm_type: BN, gLN, cLN
|
||||
causal: causal or non-causal
|
||||
mask_nonlinear: use which non-linear function to generate mask
|
||||
"""
|
||||
super(ConvTasNet, self).__init__()
|
||||
# Hyper-parameter
|
||||
self.sources = sources
|
||||
self.C = len(sources)
|
||||
self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
|
||||
self.norm_type = norm_type
|
||||
self.causal = causal
|
||||
self.mask_nonlinear = mask_nonlinear
|
||||
self.audio_channels = audio_channels
|
||||
self.samplerate = samplerate
|
||||
self.segment_length = segment_length
|
||||
# Components
|
||||
self.encoder = Encoder(L, N, audio_channels)
|
||||
self.separator = TemporalConvNet(
|
||||
N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
|
||||
self.decoder = Decoder(N, L, audio_channels)
|
||||
# init
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_normal_(p)
|
||||
|
||||
def valid_length(self, length):
|
||||
return length
|
||||
|
||||
def forward(self, mixture):
|
||||
"""
|
||||
Args:
|
||||
mixture: [M, T], M is batch size, T is #samples
|
||||
Returns:
|
||||
est_source: [M, C, T]
|
||||
"""
|
||||
mixture_w = self.encoder(mixture)
|
||||
est_mask = self.separator(mixture_w)
|
||||
est_source = self.decoder(mixture_w, est_mask)
|
||||
|
||||
# T changed after conv1d in encoder, fix it here
|
||||
T_origin = mixture.size(-1)
|
||||
T_conv = est_source.size(-1)
|
||||
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
||||
return est_source
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Estimation of the nonnegative mixture weight by a 1-D conv layer.
|
||||
"""
|
||||
def __init__(self, L, N, audio_channels):
|
||||
super(Encoder, self).__init__()
|
||||
# Hyper-parameter
|
||||
self.L, self.N = L, N
|
||||
# Components
|
||||
# 50% overlap
|
||||
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
|
||||
|
||||
def forward(self, mixture):
|
||||
"""
|
||||
Args:
|
||||
mixture: [M, T], M is batch size, T is #samples
|
||||
Returns:
|
||||
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
||||
"""
|
||||
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
||||
return mixture_w
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, N, L, audio_channels):
|
||||
super(Decoder, self).__init__()
|
||||
# Hyper-parameter
|
||||
self.N, self.L = N, L
|
||||
self.audio_channels = audio_channels
|
||||
# Components
|
||||
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
||||
|
||||
def forward(self, mixture_w, est_mask):
|
||||
"""
|
||||
Args:
|
||||
mixture_w: [M, N, K]
|
||||
est_mask: [M, C, N, K]
|
||||
Returns:
|
||||
est_source: [M, C, T]
|
||||
"""
|
||||
# D = W * M
|
||||
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
||||
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
||||
# S = DV
|
||||
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
||||
m, c, k, _ = est_source.size()
|
||||
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
|
||||
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
||||
return est_source
|
||||
|
||||
|
||||
class TemporalConvNet(nn.Module):
|
||||
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
|
||||
"""
|
||||
Args:
|
||||
N: Number of filters in autoencoder
|
||||
B: Number of channels in bottleneck 1 × 1-conv block
|
||||
H: Number of channels in convolutional blocks
|
||||
P: Kernel size in convolutional blocks
|
||||
X: Number of convolutional blocks in each repeat
|
||||
R: Number of repeats
|
||||
C: Number of speakers
|
||||
norm_type: BN, gLN, cLN
|
||||
causal: causal or non-causal
|
||||
mask_nonlinear: use which non-linear function to generate mask
|
||||
"""
|
||||
super(TemporalConvNet, self).__init__()
|
||||
# Hyper-parameter
|
||||
self.C = C
|
||||
self.mask_nonlinear = mask_nonlinear
|
||||
# Components
|
||||
# [M, N, K] -> [M, N, K]
|
||||
layer_norm = ChannelwiseLayerNorm(N)
|
||||
# [M, N, K] -> [M, B, K]
|
||||
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
||||
# [M, B, K] -> [M, B, K]
|
||||
repeats = []
|
||||
for r in range(R):
|
||||
blocks = []
|
||||
for x in range(X):
|
||||
dilation = 2**x
|
||||
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
||||
blocks += [
|
||||
TemporalBlock(B,
|
||||
H,
|
||||
P,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
norm_type=norm_type,
|
||||
causal=causal)
|
||||
]
|
||||
repeats += [nn.Sequential(*blocks)]
|
||||
temporal_conv_net = nn.Sequential(*repeats)
|
||||
# [M, B, K] -> [M, C*N, K]
|
||||
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
||||
# Put together
|
||||
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
|
||||
mask_conv1x1)
|
||||
|
||||
def forward(self, mixture_w):
|
||||
"""
|
||||
Keep this API same with TasNet
|
||||
Args:
|
||||
mixture_w: [M, N, K], M is batch size
|
||||
returns:
|
||||
est_mask: [M, C, N, K]
|
||||
"""
|
||||
M, N, K = mixture_w.size()
|
||||
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
||||
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
||||
if self.mask_nonlinear == 'softmax':
|
||||
est_mask = F.softmax(score, dim=1)
|
||||
elif self.mask_nonlinear == 'relu':
|
||||
est_mask = F.relu(score)
|
||||
else:
|
||||
raise ValueError("Unsupported mask non-linear function")
|
||||
return est_mask
|
||||
|
||||
|
||||
class TemporalBlock(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
norm_type="gLN",
|
||||
causal=False):
|
||||
super(TemporalBlock, self).__init__()
|
||||
# [M, B, K] -> [M, H, K]
|
||||
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
prelu = nn.PReLU()
|
||||
norm = chose_norm(norm_type, out_channels)
|
||||
# [M, H, K] -> [M, B, K]
|
||||
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
|
||||
dilation, norm_type, causal)
|
||||
# Put together
|
||||
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [M, B, K]
|
||||
Returns:
|
||||
[M, B, K]
|
||||
"""
|
||||
residual = x
|
||||
out = self.net(x)
|
||||
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
||||
return out + residual # look like w/o F.relu is better than w/ F.relu
|
||||
# return F.relu(out + residual)
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
norm_type="gLN",
|
||||
causal=False):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
# Use `groups` option to implement depthwise convolution
|
||||
# [M, H, K] -> [M, H, K]
|
||||
depthwise_conv = nn.Conv1d(in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
bias=False)
|
||||
if causal:
|
||||
chomp = Chomp1d(padding)
|
||||
prelu = nn.PReLU()
|
||||
norm = chose_norm(norm_type, in_channels)
|
||||
# [M, H, K] -> [M, B, K]
|
||||
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
# Put together
|
||||
if causal:
|
||||
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
||||
else:
|
||||
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [M, H, K]
|
||||
Returns:
|
||||
result: [M, B, K]
|
||||
"""
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Chomp1d(nn.Module):
|
||||
"""To ensure the output length is the same as the input.
|
||||
"""
|
||||
def __init__(self, chomp_size):
|
||||
super(Chomp1d, self).__init__()
|
||||
self.chomp_size = chomp_size
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [M, H, Kpad]
|
||||
Returns:
|
||||
[M, H, K]
|
||||
"""
|
||||
return x[:, :, :-self.chomp_size].contiguous()
|
||||
|
||||
|
||||
def chose_norm(norm_type, channel_size):
|
||||
"""The input of normlization will be (M, C, K), where M is batch size,
|
||||
C is channel size and K is sequence length.
|
||||
"""
|
||||
if norm_type == "gLN":
|
||||
return GlobalLayerNorm(channel_size)
|
||||
elif norm_type == "cLN":
|
||||
return ChannelwiseLayerNorm(channel_size)
|
||||
elif norm_type == "id":
|
||||
return nn.Identity()
|
||||
else: # norm_type == "BN":
|
||||
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
||||
# along M and K, so this BN usage is right.
|
||||
return nn.BatchNorm1d(channel_size)
|
||||
|
||||
|
||||
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
||||
class ChannelwiseLayerNorm(nn.Module):
|
||||
"""Channel-wise Layer Normalization (cLN)"""
|
||||
def __init__(self, channel_size):
|
||||
super(ChannelwiseLayerNorm, self).__init__()
|
||||
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
||||
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.gamma.data.fill_(1)
|
||||
self.beta.data.zero_()
|
||||
|
||||
def forward(self, y):
|
||||
"""
|
||||
Args:
|
||||
y: [M, N, K], M is batch size, N is channel size, K is length
|
||||
Returns:
|
||||
cLN_y: [M, N, K]
|
||||
"""
|
||||
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
||||
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
||||
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
||||
return cLN_y
|
||||
|
||||
|
||||
class GlobalLayerNorm(nn.Module):
|
||||
"""Global Layer Normalization (gLN)"""
|
||||
def __init__(self, channel_size):
|
||||
super(GlobalLayerNorm, self).__init__()
|
||||
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
||||
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.gamma.data.fill_(1)
|
||||
self.beta.data.zero_()
|
||||
|
||||
def forward(self, y):
|
||||
"""
|
||||
Args:
|
||||
y: [M, N, K], M is batch size, N is channel size, K is length
|
||||
Returns:
|
||||
gLN_y: [M, N, K]
|
||||
"""
|
||||
# TODO: in torch 1.0, torch.mean() support dim list
|
||||
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
||||
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
||||
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
||||
return gLN_y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(123)
|
||||
M, N, L, T = 2, 3, 4, 12
|
||||
K = 2 * T // L - 1
|
||||
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
||||
mixture = torch.randint(3, (M, T))
|
||||
# test Encoder
|
||||
encoder = Encoder(L, N)
|
||||
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
||||
mixture_w = encoder(mixture)
|
||||
print('mixture', mixture)
|
||||
print('U', encoder.conv1d_U.weight)
|
||||
print('mixture_w', mixture_w)
|
||||
print('mixture_w size', mixture_w.size())
|
||||
|
||||
# test TemporalConvNet
|
||||
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
||||
est_mask = separator(mixture_w)
|
||||
print('est_mask', est_mask)
|
||||
|
||||
# test Decoder
|
||||
decoder = Decoder(N, L)
|
||||
est_mask = torch.randint(2, (B, K, C, N))
|
||||
est_source = decoder(mixture_w, est_mask)
|
||||
print('est_source', est_source)
|
||||
|
||||
# test Conv-TasNet
|
||||
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
||||
est_source = conv_tasnet(mixture)
|
||||
print('est_source', est_source)
|
||||
print('est_source size', est_source.size())
|
109
demucs/test.py
Normal file
109
demucs/test.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import gzip
|
||||
import sys
|
||||
from concurrent import futures
|
||||
|
||||
import musdb
|
||||
import museval
|
||||
import torch as th
|
||||
import tqdm
|
||||
from scipy.io import wavfile
|
||||
from torch import distributed
|
||||
|
||||
from .audio import convert_audio
|
||||
from .utils import apply_model
|
||||
|
||||
|
||||
def evaluate(model,
|
||||
musdb_path,
|
||||
eval_folder,
|
||||
workers=2,
|
||||
device="cpu",
|
||||
rank=0,
|
||||
save=False,
|
||||
shifts=0,
|
||||
split=False,
|
||||
overlap=0.25,
|
||||
is_wav=False,
|
||||
world_size=1):
|
||||
"""
|
||||
Evaluate model using museval. Run the model
|
||||
on a single GPU, the bottleneck being the call to museval.
|
||||
"""
|
||||
|
||||
output_dir = eval_folder / "results"
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
json_folder = eval_folder / "results/test"
|
||||
json_folder.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# we load tracks from the original musdb set
|
||||
test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav)
|
||||
src_rate = 44100 # hardcoded for now...
|
||||
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
p.grad = None
|
||||
|
||||
pendings = []
|
||||
with futures.ProcessPoolExecutor(workers or 1) as pool:
|
||||
for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout):
|
||||
track = test_set.tracks[index]
|
||||
|
||||
out = json_folder / f"{track.name}.json.gz"
|
||||
if out.exists():
|
||||
continue
|
||||
|
||||
mix = th.from_numpy(track.audio).t().float()
|
||||
ref = mix.mean(dim=0) # mono mixture
|
||||
mix = (mix - ref.mean()) / ref.std()
|
||||
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
|
||||
estimates = apply_model(model, mix.to(device),
|
||||
shifts=shifts, split=split, overlap=overlap)
|
||||
estimates = estimates * ref.std() + ref.mean()
|
||||
|
||||
estimates = estimates.transpose(1, 2)
|
||||
references = th.stack(
|
||||
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
|
||||
references = convert_audio(references, src_rate,
|
||||
model.samplerate, model.audio_channels)
|
||||
references = references.transpose(1, 2).numpy()
|
||||
estimates = estimates.cpu().numpy()
|
||||
win = int(1. * model.samplerate)
|
||||
hop = int(1. * model.samplerate)
|
||||
if save:
|
||||
folder = eval_folder / "wav/test" / track.name
|
||||
folder.mkdir(exist_ok=True, parents=True)
|
||||
for name, estimate in zip(model.sources, estimates):
|
||||
wavfile.write(str(folder / (name + ".wav")), 44100, estimate)
|
||||
|
||||
if workers:
|
||||
pendings.append((track.name, pool.submit(
|
||||
museval.evaluate, references, estimates, win=win, hop=hop)))
|
||||
else:
|
||||
pendings.append((track.name, museval.evaluate(
|
||||
references, estimates, win=win, hop=hop)))
|
||||
del references, mix, estimates, track
|
||||
|
||||
for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout):
|
||||
if workers:
|
||||
pending = pending.result()
|
||||
sdr, isr, sir, sar = pending
|
||||
track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name)
|
||||
for idx, target in enumerate(model.sources):
|
||||
values = {
|
||||
"SDR": sdr[idx].tolist(),
|
||||
"SIR": sir[idx].tolist(),
|
||||
"ISR": isr[idx].tolist(),
|
||||
"SAR": sar[idx].tolist()
|
||||
}
|
||||
|
||||
track_store.add_target(target_name=target, values=values)
|
||||
json_path = json_folder / f"{track_name}.json.gz"
|
||||
gzip.open(json_path, "w").write(track_store.json.encode('utf-8'))
|
||||
if world_size > 1:
|
||||
distributed.barrier()
|
323
demucs/utils.py
Normal file
323
demucs/utils.py
Normal file
@ -0,0 +1,323 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import tempfile
|
||||
import warnings
|
||||
import zlib
|
||||
from contextlib import contextmanager
|
||||
|
||||
from diffq import UniformQuantizer, DiffQuantizer
|
||||
import torch as th
|
||||
import tqdm
|
||||
from torch import distributed
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def center_trim(tensor, reference):
|
||||
"""
|
||||
Center trim `tensor` with respect to `reference`, along the last dimension.
|
||||
`reference` can also be a number, representing the length to trim to.
|
||||
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
||||
"""
|
||||
if hasattr(reference, "size"):
|
||||
reference = reference.size(-1)
|
||||
delta = tensor.size(-1) - reference
|
||||
if delta < 0:
|
||||
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
||||
if delta:
|
||||
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
|
||||
return tensor
|
||||
|
||||
|
||||
def average_metric(metric, count=1.):
|
||||
"""
|
||||
Average `metric` which should be a float across all hosts. `count` should be
|
||||
the weight for this particular host (i.e. number of examples).
|
||||
"""
|
||||
metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda')
|
||||
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
|
||||
return metric[1].item() / metric[0].item()
|
||||
|
||||
|
||||
def free_port(host='', low=20000, high=40000):
|
||||
"""
|
||||
Return a port number that is most likely free.
|
||||
This could suffer from a race condition although
|
||||
it should be quite rare.
|
||||
"""
|
||||
sock = socket.socket()
|
||||
while True:
|
||||
port = random.randint(low, high)
|
||||
try:
|
||||
sock.bind((host, port))
|
||||
except OSError as error:
|
||||
if error.errno == errno.EADDRINUSE:
|
||||
continue
|
||||
raise
|
||||
return port
|
||||
|
||||
|
||||
def sizeof_fmt(num, suffix='B'):
|
||||
"""
|
||||
Given `num` bytes, return human readable size.
|
||||
Taken from https://stackoverflow.com/a/1094933
|
||||
"""
|
||||
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
|
||||
if abs(num) < 1024.0:
|
||||
return "%3.1f%s%s" % (num, unit, suffix)
|
||||
num /= 1024.0
|
||||
return "%.1f%s%s" % (num, 'Yi', suffix)
|
||||
|
||||
|
||||
def human_seconds(seconds, display='.2f'):
|
||||
"""
|
||||
Given `seconds` seconds, return human readable duration.
|
||||
"""
|
||||
value = seconds * 1e6
|
||||
ratios = [1e3, 1e3, 60, 60, 24]
|
||||
names = ['us', 'ms', 's', 'min', 'hrs', 'days']
|
||||
last = names.pop(0)
|
||||
for name, ratio in zip(names, ratios):
|
||||
if value / ratio < 0.3:
|
||||
break
|
||||
value /= ratio
|
||||
last = name
|
||||
return f"{format(value, display)} {last}"
|
||||
|
||||
|
||||
class TensorChunk:
|
||||
def __init__(self, tensor, offset=0, length=None):
|
||||
total_length = tensor.shape[-1]
|
||||
assert offset >= 0
|
||||
assert offset < total_length
|
||||
|
||||
if length is None:
|
||||
length = total_length - offset
|
||||
else:
|
||||
length = min(total_length - offset, length)
|
||||
|
||||
self.tensor = tensor
|
||||
self.offset = offset
|
||||
self.length = length
|
||||
self.device = tensor.device
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
shape = list(self.tensor.shape)
|
||||
shape[-1] = self.length
|
||||
return shape
|
||||
|
||||
def padded(self, target_length):
|
||||
delta = target_length - self.length
|
||||
total_length = self.tensor.shape[-1]
|
||||
assert delta >= 0
|
||||
|
||||
start = self.offset - delta // 2
|
||||
end = start + target_length
|
||||
|
||||
correct_start = max(0, start)
|
||||
correct_end = min(total_length, end)
|
||||
|
||||
pad_left = correct_start - start
|
||||
pad_right = end - correct_end
|
||||
|
||||
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
||||
assert out.shape[-1] == target_length
|
||||
return out
|
||||
|
||||
|
||||
def tensor_chunk(tensor_or_chunk):
|
||||
if isinstance(tensor_or_chunk, TensorChunk):
|
||||
return tensor_or_chunk
|
||||
else:
|
||||
assert isinstance(tensor_or_chunk, th.Tensor)
|
||||
return TensorChunk(tensor_or_chunk)
|
||||
|
||||
|
||||
def apply_model(model, mix, shifts=None, split=False,
|
||||
overlap=0.25, transition_power=1., progress=False):
|
||||
"""
|
||||
Apply model to a given mixture.
|
||||
|
||||
Args:
|
||||
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
||||
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
||||
all predictions are averaged. This effectively makes the model time equivariant
|
||||
and improves SDR by up to 0.2 points.
|
||||
split (bool): if True, the input will be broken down in 8 seconds extracts
|
||||
and predictions will be performed individually on each and concatenated.
|
||||
Useful for model with large memory footprint like Tasnet.
|
||||
progress (bool): if True, show a progress bar (requires split=True)
|
||||
"""
|
||||
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
||||
device = mix.device
|
||||
channels, length = mix.shape
|
||||
if split:
|
||||
out = th.zeros(len(model.sources), channels, length, device=device)
|
||||
sum_weight = th.zeros(length, device=device)
|
||||
segment = model.segment_length
|
||||
stride = int((1 - overlap) * segment)
|
||||
offsets = range(0, length, stride)
|
||||
scale = stride / model.samplerate
|
||||
if progress:
|
||||
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
|
||||
# We start from a triangle shaped weight, with maximal weight in the middle
|
||||
# of the segment. Then we normalize and take to the power `transition_power`.
|
||||
# Large values of transition power will lead to sharper transitions.
|
||||
weight = th.cat([th.arange(1, segment // 2 + 1),
|
||||
th.arange(segment - segment // 2, 0, -1)]).to(device)
|
||||
assert len(weight) == segment
|
||||
# If the overlap < 50%, this will translate to linear transition when
|
||||
# transition_power is 1.
|
||||
weight = (weight / weight.max())**transition_power
|
||||
for offset in offsets:
|
||||
chunk = TensorChunk(mix, offset, segment)
|
||||
chunk_out = apply_model(model, chunk, shifts=shifts)
|
||||
chunk_length = chunk_out.shape[-1]
|
||||
out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out
|
||||
sum_weight[offset:offset + segment] += weight[:chunk_length]
|
||||
offset += segment
|
||||
assert sum_weight.min() > 0
|
||||
out /= sum_weight
|
||||
return out
|
||||
elif shifts:
|
||||
max_shift = int(0.5 * model.samplerate)
|
||||
mix = tensor_chunk(mix)
|
||||
padded_mix = mix.padded(length + 2 * max_shift)
|
||||
out = 0
|
||||
for _ in range(shifts):
|
||||
offset = random.randint(0, max_shift)
|
||||
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
||||
shifted_out = apply_model(model, shifted)
|
||||
out += shifted_out[..., max_shift - offset:]
|
||||
out /= shifts
|
||||
return out
|
||||
else:
|
||||
valid_length = model.valid_length(length)
|
||||
mix = tensor_chunk(mix)
|
||||
padded_mix = mix.padded(valid_length)
|
||||
with th.no_grad():
|
||||
out = model(padded_mix.unsqueeze(0))[0]
|
||||
return center_trim(out, length)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_filenames(count, delete=True):
|
||||
names = []
|
||||
try:
|
||||
for _ in range(count):
|
||||
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
||||
yield names
|
||||
finally:
|
||||
if delete:
|
||||
for name in names:
|
||||
os.unlink(name)
|
||||
|
||||
|
||||
def get_quantizer(model, args, optimizer=None):
|
||||
quantizer = None
|
||||
if args.diffq:
|
||||
quantizer = DiffQuantizer(
|
||||
model, min_size=args.q_min_size, group_size=8)
|
||||
if optimizer is not None:
|
||||
quantizer.setup_optimizer(optimizer)
|
||||
elif args.qat:
|
||||
quantizer = UniformQuantizer(
|
||||
model, bits=args.qat, min_size=args.q_min_size)
|
||||
return quantizer
|
||||
|
||||
|
||||
def load_model(path, strict=False):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
load_from = path
|
||||
package = th.load(load_from, 'cpu')
|
||||
|
||||
klass = package["klass"]
|
||||
args = package["args"]
|
||||
kwargs = package["kwargs"]
|
||||
|
||||
if strict:
|
||||
model = klass(*args, **kwargs)
|
||||
else:
|
||||
sig = inspect.signature(klass)
|
||||
for key in list(kwargs):
|
||||
if key not in sig.parameters:
|
||||
warnings.warn("Dropping inexistant parameter " + key)
|
||||
del kwargs[key]
|
||||
model = klass(*args, **kwargs)
|
||||
|
||||
state = package["state"]
|
||||
training_args = package["training_args"]
|
||||
quantizer = get_quantizer(model, training_args)
|
||||
|
||||
set_state(model, quantizer, state)
|
||||
return model
|
||||
|
||||
|
||||
def get_state(model, quantizer):
|
||||
if quantizer is None:
|
||||
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
|
||||
else:
|
||||
state = quantizer.get_quantized_state()
|
||||
buf = io.BytesIO()
|
||||
th.save(state, buf)
|
||||
state = {'compressed': zlib.compress(buf.getvalue())}
|
||||
return state
|
||||
|
||||
|
||||
def set_state(model, quantizer, state):
|
||||
if quantizer is None:
|
||||
model.load_state_dict(state)
|
||||
else:
|
||||
buf = io.BytesIO(zlib.decompress(state["compressed"]))
|
||||
state = th.load(buf, "cpu")
|
||||
quantizer.restore_quantized_state(state)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def save_state(state, path):
|
||||
buf = io.BytesIO()
|
||||
th.save(state, buf)
|
||||
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
|
||||
|
||||
path = path.parent / (path.stem + "-" + sig + path.suffix)
|
||||
path.write_bytes(buf.getvalue())
|
||||
|
||||
|
||||
def save_model(model, quantizer, training_args, path):
|
||||
args, kwargs = model._init_args_kwargs
|
||||
klass = model.__class__
|
||||
|
||||
state = get_state(model, quantizer)
|
||||
|
||||
save_to = path
|
||||
package = {
|
||||
'klass': klass,
|
||||
'args': args,
|
||||
'kwargs': kwargs,
|
||||
'state': state,
|
||||
'training_args': training_args,
|
||||
}
|
||||
th.save(package, save_to)
|
||||
|
||||
|
||||
def capture_init(init):
|
||||
@functools.wraps(init)
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._init_args_kwargs = (args, kwargs)
|
||||
init(self, *args, **kwargs)
|
||||
|
||||
return __init__
|
174
demucs/wav.py
Normal file
174
demucs/wav.py
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import OrderedDict
|
||||
import hashlib
|
||||
import math
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import julius
|
||||
import torch as th
|
||||
from torch import distributed
|
||||
import torchaudio as ta
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .audio import convert_audio_channels
|
||||
from .compressed import get_musdb_tracks
|
||||
|
||||
MIXTURE = "mixture"
|
||||
EXT = ".wav"
|
||||
|
||||
|
||||
def _track_metadata(track, sources):
|
||||
track_length = None
|
||||
track_samplerate = None
|
||||
for source in sources + [MIXTURE]:
|
||||
file = track / f"{source}{EXT}"
|
||||
info = ta.info(str(file))
|
||||
length = info.num_frames
|
||||
if track_length is None:
|
||||
track_length = length
|
||||
track_samplerate = info.sample_rate
|
||||
elif track_length != length:
|
||||
raise ValueError(
|
||||
f"Invalid length for file {file}: "
|
||||
f"expecting {track_length} but got {length}.")
|
||||
elif info.sample_rate != track_samplerate:
|
||||
raise ValueError(
|
||||
f"Invalid sample rate for file {file}: "
|
||||
f"expecting {track_samplerate} but got {info.sample_rate}.")
|
||||
if source == MIXTURE:
|
||||
wav, _ = ta.load(str(file))
|
||||
wav = wav.mean(0)
|
||||
mean = wav.mean().item()
|
||||
std = wav.std().item()
|
||||
|
||||
return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}
|
||||
|
||||
|
||||
def _build_metadata(path, sources):
|
||||
meta = {}
|
||||
path = Path(path)
|
||||
for file in path.iterdir():
|
||||
meta[file.name] = _track_metadata(file, sources)
|
||||
return meta
|
||||
|
||||
|
||||
class Wavset:
|
||||
def __init__(
|
||||
self,
|
||||
root, metadata, sources,
|
||||
length=None, stride=None, normalize=True,
|
||||
samplerate=44100, channels=2):
|
||||
"""
|
||||
Waveset (or mp3 set for that matter). Can be used to train
|
||||
with arbitrary sources. Each track should be one folder inside of `path`.
|
||||
The folder should contain files named `{source}.{ext}`.
|
||||
Files will be grouped according to `sources` (each source is a list of
|
||||
filenames).
|
||||
|
||||
Sample rate and channels will be converted on the fly.
|
||||
|
||||
`length` is the sample size to extract (in samples, not duration).
|
||||
`stride` is how many samples to move by between each example.
|
||||
"""
|
||||
self.root = Path(root)
|
||||
self.metadata = OrderedDict(metadata)
|
||||
self.length = length
|
||||
self.stride = stride or length
|
||||
self.normalize = normalize
|
||||
self.sources = sources
|
||||
self.channels = channels
|
||||
self.samplerate = samplerate
|
||||
self.num_examples = []
|
||||
for name, meta in self.metadata.items():
|
||||
track_length = int(self.samplerate * meta['length'] / meta['samplerate'])
|
||||
if length is None or track_length < length:
|
||||
examples = 1
|
||||
else:
|
||||
examples = int(math.ceil((track_length - self.length) / self.stride) + 1)
|
||||
self.num_examples.append(examples)
|
||||
|
||||
def __len__(self):
|
||||
return sum(self.num_examples)
|
||||
|
||||
def get_file(self, name, source):
|
||||
return self.root / name / f"{source}{EXT}"
|
||||
|
||||
def __getitem__(self, index):
|
||||
for name, examples in zip(self.metadata, self.num_examples):
|
||||
if index >= examples:
|
||||
index -= examples
|
||||
continue
|
||||
meta = self.metadata[name]
|
||||
num_frames = -1
|
||||
offset = 0
|
||||
if self.length is not None:
|
||||
offset = int(math.ceil(
|
||||
meta['samplerate'] * self.stride * index / self.samplerate))
|
||||
num_frames = int(math.ceil(
|
||||
meta['samplerate'] * self.length / self.samplerate))
|
||||
wavs = []
|
||||
for source in self.sources:
|
||||
file = self.get_file(name, source)
|
||||
wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
|
||||
wav = convert_audio_channels(wav, self.channels)
|
||||
wavs.append(wav)
|
||||
|
||||
example = th.stack(wavs)
|
||||
example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
|
||||
if self.normalize:
|
||||
example = (example - meta['mean']) / meta['std']
|
||||
if self.length:
|
||||
example = example[..., :self.length]
|
||||
example = F.pad(example, (0, self.length - example.shape[-1]))
|
||||
return example
|
||||
|
||||
|
||||
def get_wav_datasets(args, samples, sources):
|
||||
sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8]
|
||||
metadata_file = args.metadata / (sig + ".json")
|
||||
train_path = args.wav / "train"
|
||||
valid_path = args.wav / "valid"
|
||||
if not metadata_file.is_file() and args.rank == 0:
|
||||
train = _build_metadata(train_path, sources)
|
||||
valid = _build_metadata(valid_path, sources)
|
||||
json.dump([train, valid], open(metadata_file, "w"))
|
||||
if args.world_size > 1:
|
||||
distributed.barrier()
|
||||
train, valid = json.load(open(metadata_file))
|
||||
train_set = Wavset(train_path, train, sources,
|
||||
length=samples, stride=args.data_stride,
|
||||
samplerate=args.samplerate, channels=args.audio_channels,
|
||||
normalize=args.norm_wav)
|
||||
valid_set = Wavset(valid_path, valid, [MIXTURE] + sources,
|
||||
samplerate=args.samplerate, channels=args.audio_channels,
|
||||
normalize=args.norm_wav)
|
||||
return train_set, valid_set
|
||||
|
||||
|
||||
def get_musdb_wav_datasets(args, samples, sources):
|
||||
metadata_file = args.metadata / "musdb_wav.json"
|
||||
root = args.musdb / "train"
|
||||
if not metadata_file.is_file() and args.rank == 0:
|
||||
metadata = _build_metadata(root, sources)
|
||||
json.dump(metadata, open(metadata_file, "w"))
|
||||
if args.world_size > 1:
|
||||
distributed.barrier()
|
||||
metadata = json.load(open(metadata_file))
|
||||
|
||||
train_tracks = get_musdb_tracks(args.musdb, is_wav=True, subsets=["train"], split="train")
|
||||
metadata_train = {name: meta for name, meta in metadata.items() if name in train_tracks}
|
||||
metadata_valid = {name: meta for name, meta in metadata.items() if name not in train_tracks}
|
||||
train_set = Wavset(root, metadata_train, sources,
|
||||
length=samples, stride=args.data_stride,
|
||||
samplerate=args.samplerate, channels=args.audio_channels,
|
||||
normalize=args.norm_wav)
|
||||
valid_set = Wavset(root, metadata_valid, [MIXTURE] + sources,
|
||||
samplerate=args.samplerate, channels=args.audio_channels,
|
||||
normalize=args.norm_wav)
|
||||
return train_set, valid_set
|
18
diffq/__init__.py
Normal file
18
diffq/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
"""
|
||||
This package implements different quantization strategies:
|
||||
|
||||
- `diffq.uniform.UniformQuantizer`: classic uniform quantization over n bits.
|
||||
- `diffq.diffq.DiffQuantizer`: differentiable quantizer based on scaled noise injection.
|
||||
|
||||
Also, do check `diffq.base.BaseQuantizer` for the common methods of all Quantizers.
|
||||
"""
|
||||
|
||||
from .uniform import UniformQuantizer
|
||||
from .diffq import DiffQuantizer
|
262
diffq/base.py
Normal file
262
diffq/base.py
Normal file
@ -0,0 +1,262 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from concurrent import futures
|
||||
from fnmatch import fnmatch
|
||||
from functools import partial
|
||||
import io
|
||||
import math
|
||||
from multiprocessing import cpu_count
|
||||
import typing as tp
|
||||
import zlib
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseQuantizer:
|
||||
@dataclass
|
||||
class _QuantizedParam:
|
||||
name: str
|
||||
param: torch.nn.Parameter
|
||||
module: torch.nn.Module
|
||||
# If a Parameter is used multiple times, `other` can be used
|
||||
# to share state between the different Quantizers
|
||||
other: tp.Optional[tp.Any]
|
||||
|
||||
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
|
||||
exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True):
|
||||
self.model = model
|
||||
self.min_size = min_size
|
||||
self.float16 = float16
|
||||
self.exclude = exclude
|
||||
self.detect_bound = detect_bound
|
||||
self._quantized = False
|
||||
self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook)
|
||||
self._post_handle = self.model.register_forward_hook(self._forward_hook)
|
||||
|
||||
self._quantized_state = None
|
||||
self._qparams = []
|
||||
self._float16 = []
|
||||
self._others = []
|
||||
self._rnns = []
|
||||
|
||||
self._saved = []
|
||||
|
||||
self._find_params()
|
||||
|
||||
def _find_params(self):
|
||||
min_params = self.min_size * 2**20 // 4
|
||||
previous = {}
|
||||
for module_name, module in self.model.named_modules():
|
||||
if isinstance(module, torch.nn.RNNBase):
|
||||
self._rnns.append(module)
|
||||
for name, param in list(module.named_parameters(recurse=False)):
|
||||
full_name = f"{module_name}.{name}"
|
||||
matched = False
|
||||
for pattern in self.exclude:
|
||||
if fnmatch(full_name, pattern) or fnmatch(name, pattern):
|
||||
matched = True
|
||||
break
|
||||
|
||||
if param.numel() <= min_params or matched:
|
||||
if id(param) in previous:
|
||||
continue
|
||||
if self.detect_bound:
|
||||
previous[id(param)] = None
|
||||
if self.float16:
|
||||
self._float16.append(param)
|
||||
else:
|
||||
self._others.append(param)
|
||||
else:
|
||||
qparam = self._register_param(name, param, module, previous.get(id(param)))
|
||||
if self.detect_bound:
|
||||
previous[id(param)] = qparam
|
||||
self._qparams.append(qparam)
|
||||
|
||||
def _register_param(self, name, param, module, other):
|
||||
return self.__class__._QuantizedParam(name, param, module, other)
|
||||
|
||||
def _forward_pre_hook(self, module, input):
|
||||
if self.model.training:
|
||||
self._quantized_state = None
|
||||
if self._quantized:
|
||||
self.unquantize()
|
||||
if self._pre_forward_train():
|
||||
self._fix_rnns()
|
||||
else:
|
||||
self.quantize()
|
||||
|
||||
def _forward_hook(self, module, input, output):
|
||||
if self.model.training:
|
||||
if self._post_forward_train():
|
||||
self._fix_rnns(flatten=False) # Hacky, next forward will flatten
|
||||
|
||||
def quantize(self, save=True):
|
||||
"""
|
||||
Immediately apply quantization to the model parameters.
|
||||
If `save` is True, save a copy of the unquantized parameters, that can be
|
||||
restored with `unquantize()`.
|
||||
"""
|
||||
if self._quantized:
|
||||
return
|
||||
if save:
|
||||
self._saved = [qp.param.data.to('cpu', copy=True)
|
||||
for qp in self._qparams if qp.other is None]
|
||||
self.restore_quantized_state(self.get_quantized_state())
|
||||
self._quantized = True
|
||||
self._fix_rnns()
|
||||
|
||||
def unquantize(self):
|
||||
"""
|
||||
Revert a previous call to `quantize()`.
|
||||
"""
|
||||
if not self._quantized:
|
||||
raise RuntimeError("Can only be called on a quantized model.")
|
||||
if not self._saved:
|
||||
raise RuntimeError("Nothing to restore.")
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is None:
|
||||
qparam.param.data[:] = self._saved.pop(0)
|
||||
assert len(self._saved) == 0
|
||||
self._quantized = False
|
||||
self._fix_rnns()
|
||||
|
||||
def _pre_forward_train(self) -> bool:
|
||||
"""
|
||||
Called once before each forward for continuous quantization.
|
||||
Should return True if parameters were changed.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _post_forward_train(self) -> bool:
|
||||
"""
|
||||
Called once after each forward (to restore state for instance).
|
||||
Should return True if parameters were changed.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _fix_rnns(self, flatten=True):
|
||||
"""
|
||||
To be called after quantization happened to fix RNNs.
|
||||
"""
|
||||
for rnn in self._rnns:
|
||||
rnn._flat_weights = [
|
||||
(lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn)
|
||||
for wn in rnn._flat_weights_names]
|
||||
if flatten:
|
||||
rnn.flatten_parameters()
|
||||
|
||||
def get_quantized_state(self):
|
||||
"""
|
||||
Returns sufficient quantized information to rebuild the model state.
|
||||
|
||||
..Note::
|
||||
To achieve maximum compression, you should compress this with
|
||||
gzip or other, as quantized weights are not optimally coded!
|
||||
"""
|
||||
if self._quantized_state is None:
|
||||
self._quantized_state = self._get_quantized_state()
|
||||
return self._quantized_state
|
||||
|
||||
def _get_quantized_state(self):
|
||||
"""
|
||||
Actual implementation for `get_quantized_state`.
|
||||
"""
|
||||
float16_params = []
|
||||
for p in self._float16:
|
||||
q = p.data.half()
|
||||
float16_params.append(q)
|
||||
|
||||
return {
|
||||
"quantized": [self._quantize_param(qparam) for qparam in self._qparams
|
||||
if qparam.other is None],
|
||||
"float16": float16_params,
|
||||
"others": [p.data.clone() for p in self._others],
|
||||
}
|
||||
|
||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
||||
"""
|
||||
To be overriden.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
||||
"""
|
||||
To be overriden.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def restore_quantized_state(self, state) -> None:
|
||||
"""
|
||||
Restore the state of the model from the quantized state.
|
||||
"""
|
||||
for p, q in zip(self._float16, state["float16"]):
|
||||
p.data[:] = q.to(p)
|
||||
|
||||
for p, q in zip(self._others, state["others"]):
|
||||
p.data[:] = q
|
||||
|
||||
remaining = list(state["quantized"])
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
# Only unquantize first appearance of nn.Parameter.
|
||||
continue
|
||||
quantized = remaining.pop(0)
|
||||
qparam.param.data[:] = self._unquantize_param(qparam, quantized)
|
||||
self._fix_rnns()
|
||||
|
||||
def detach(self) -> None:
|
||||
"""
|
||||
Detach from the model, removes hooks and anything else.
|
||||
"""
|
||||
self._pre_handle.remove()
|
||||
self._post_handle.remove()
|
||||
|
||||
def model_size(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns an estimate of the quantized model size.
|
||||
"""
|
||||
total = torch.tensor(0.)
|
||||
for p in self._float16:
|
||||
total += 16 * p.numel()
|
||||
for p in self._others:
|
||||
total += 32 * p.numel()
|
||||
return total / 2**20 / 8 # bits to MegaBytes
|
||||
|
||||
def true_model_size(self) -> float:
|
||||
"""
|
||||
Return the true quantized model size, in MB, without extra
|
||||
compression.
|
||||
"""
|
||||
return self.model_size().item()
|
||||
|
||||
def compressed_model_size(self, compress_level=-1, num_workers=8) -> float:
|
||||
"""
|
||||
Return the compressed quantized model size, in MB.
|
||||
|
||||
Args:
|
||||
compress_level (int): compression level used with zlib,
|
||||
see `zlib.compress` for details.
|
||||
num_workers (int): will split the final big byte representation in that
|
||||
many chunks processed in parallels.
|
||||
"""
|
||||
out = io.BytesIO()
|
||||
torch.save(self.get_quantized_state(), out)
|
||||
ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers)
|
||||
return ms / 2 ** 20
|
||||
|
||||
|
||||
def _compress_len(data, compress_level):
|
||||
return len(zlib.compress(data, level=compress_level))
|
||||
|
||||
|
||||
def _parallel_compress_len(data, compress_level, num_workers):
|
||||
num_workers = min(cpu_count(), num_workers)
|
||||
chunk_size = int(math.ceil(len(data) / num_workers))
|
||||
chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)]
|
||||
with futures.ProcessPoolExecutor(num_workers) as pool:
|
||||
return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks))
|
286
diffq/diffq.py
Normal file
286
diffq/diffq.py
Normal file
@ -0,0 +1,286 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Differentiable quantizer based on scaled noise injection.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from .base import BaseQuantizer
|
||||
from .uniform import uniform_quantize, uniform_unquantize
|
||||
from .utils import simple_repr
|
||||
|
||||
|
||||
class DiffQuantizer(BaseQuantizer):
|
||||
@dataclass
|
||||
class _QuantizedParam(BaseQuantizer._QuantizedParam):
|
||||
logit: torch.nn.Parameter
|
||||
|
||||
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
|
||||
group_size: int = 1, min_bits: float = 2, max_bits: float = 15,
|
||||
param="bits", noise="gaussian",
|
||||
init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq",
|
||||
exclude: tp.List[str] = [], detect_bound: bool = True):
|
||||
"""
|
||||
Differentiable quantizer based on scaled noise injection.
|
||||
For every parameter `p` in the model, this introduces a number of bits parameter
|
||||
`b` with the same dimensions (when group_size = 1).
|
||||
Before each forward, `p` is replaced by `p + U`
|
||||
with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization
|
||||
step for `b` bits.
|
||||
This noise approximates the quantization noise in a differentiable manner, both
|
||||
with respect to the unquantized parameter `p` and the number of bits `b`.
|
||||
|
||||
At eveluation (as detected with `model.eval()`), the model is replaced
|
||||
by its true quantized version, and restored when going back to training.
|
||||
|
||||
When doing actual quantization (for serialization, or evaluation),
|
||||
the number of bits is rounded to the nearest integer, and needs to be stored along.
|
||||
This will cost a few bits per dimension. To reduce this cost, one can use `group_size`,
|
||||
which will use a single noise level for multiple weight entries.
|
||||
|
||||
You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the
|
||||
model size in MB. You can then use this estimate as a penalty in your training loss.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to quantize
|
||||
min_size (float): minimum size in MB of a parameter to be quantized.
|
||||
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
||||
group_size (int): weight entries are groupped together to reduce the number
|
||||
of noise scales to store. This should divide the size of all parameters
|
||||
bigger than min_size.
|
||||
min_bits (float): minimal number of bits.
|
||||
max_bits (float): maximal number of bits.
|
||||
init_bits (float): initial number of bits.
|
||||
extra_bits (float): extra bits to add for actual quantization (before roundoff).
|
||||
suffix (str): suffix used for the name of the extra noise scale parameters.
|
||||
exclude (list[str]): list of patterns used to match parameters to exclude.
|
||||
For instance `['bias']` to exclude all bias terms.
|
||||
detect_bound (bool): if True, will detect bound parameters and reuse
|
||||
the same quantized tensor for both, as well as the same number of bits.
|
||||
|
||||
..Warning::
|
||||
You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly.
|
||||
|
||||
"""
|
||||
self.group_size = group_size
|
||||
self.min_bits = min_bits
|
||||
self.max_bits = max_bits
|
||||
self.init_bits = init_bits
|
||||
self.extra_bits = extra_bits
|
||||
self.suffix = suffix
|
||||
self.param = param
|
||||
self.noise = noise
|
||||
assert noise in ["gaussian", "uniform"]
|
||||
self._optimizer_setup = False
|
||||
|
||||
self._min_noise = 1 / (2 ** self.max_bits - 1)
|
||||
self._max_noise = 1 / (2 ** self.min_bits - 1)
|
||||
|
||||
assert group_size >= 0
|
||||
assert min_bits < init_bits < max_bits, \
|
||||
"init_bits must be between min_bits and max_bits excluded3"
|
||||
|
||||
for name, _ in model.named_parameters():
|
||||
if name.endswith(suffix):
|
||||
raise RuntimeError("The model already has some noise scales parameters, "
|
||||
"maybe you used twice a DiffQuantizer on the same model?.")
|
||||
|
||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
||||
|
||||
def _get_bits(self, logit: torch.Tensor):
|
||||
if self.param == "noise":
|
||||
return torch.log2(1 + 1 / self._get_noise_scale(logit))
|
||||
else:
|
||||
t = torch.sigmoid(logit)
|
||||
return self.max_bits * t + (1 - t) * self.min_bits
|
||||
|
||||
def _get_noise_scale(self, logit: torch.Tensor):
|
||||
if self.param == "noise":
|
||||
t = torch.sigmoid(logit)
|
||||
return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise))
|
||||
else:
|
||||
return 1 / (2 ** self._get_bits(logit) - 1)
|
||||
|
||||
def _register_param(self, name, param, module, other):
|
||||
if other is not None:
|
||||
return self.__class__._QuantizedParam(
|
||||
name=name, param=param, module=module, logit=other.logit, other=other)
|
||||
assert self.group_size == 0 or param.numel() % self.group_size == 0
|
||||
# we want the initial number of bits to be init_bits.
|
||||
if self.param == "noise":
|
||||
noise_scale = 1 / (2 ** self.init_bits - 1)
|
||||
t = (math.log(noise_scale) - math.log(self._max_noise)) / (
|
||||
math.log(self._min_noise) - math.log(self._max_noise))
|
||||
else:
|
||||
t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
|
||||
assert 0 < t < 1
|
||||
logit = torch.logit(torch.tensor(float(t)))
|
||||
assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
|
||||
if self.group_size > 0:
|
||||
nparam = param.numel() // self.group_size
|
||||
else:
|
||||
nparam = 1
|
||||
logit = torch.nn.Parameter(
|
||||
torch.full(
|
||||
(nparam,),
|
||||
logit,
|
||||
device=param.device))
|
||||
module.register_parameter(name + self.suffix, logit)
|
||||
return self.__class__._QuantizedParam(
|
||||
name=name, param=param, module=module, logit=logit, other=None)
|
||||
|
||||
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
|
||||
params = [qp.logit for qp in self._qparams]
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
new_params = []
|
||||
for q in list(group["params"]):
|
||||
matched = False
|
||||
for p in params:
|
||||
if p is q:
|
||||
matched = True
|
||||
if not matched:
|
||||
new_params.append(q)
|
||||
group["params"][:] = new_params
|
||||
|
||||
def setup_optimizer(self, optimizer: torch.optim.Optimizer,
|
||||
lr: float = 1e-3, **kwargs):
|
||||
"""
|
||||
Setup the optimizer to tune the number of bits. In particular, this will deactivate
|
||||
weight decay for the bits parameters.
|
||||
|
||||
Args:
|
||||
optimizer (torch.Optimizer): optimizer to use.
|
||||
lr (float): specific learning rate for the bits parameters. 1e-3
|
||||
is perfect for Adam.,w
|
||||
kwargs (dict): overrides for other optimization parameters for the bits.
|
||||
"""
|
||||
assert not self._optimizer_setup
|
||||
self._optimizer_setup = True
|
||||
|
||||
params = [qp.logit for qp in self._qparams]
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
for q in list(group["params"]):
|
||||
for p in params:
|
||||
if p is q:
|
||||
raise RuntimeError("You should create the optimizer "
|
||||
"before the quantizer!")
|
||||
|
||||
group = {"params": params, "lr": lr, "weight_decay": 0}
|
||||
group.update(kwargs)
|
||||
optimizer.add_param_group(group)
|
||||
|
||||
def no_optimizer(self):
|
||||
"""
|
||||
Call this if you do not want to use an optimizer.
|
||||
"""
|
||||
self._optimizer_setup = True
|
||||
|
||||
def check_unused(self):
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
continue
|
||||
grad = qparam.param.grad
|
||||
if grad is None or (grad == 0).all():
|
||||
if qparam.logit.grad is not None:
|
||||
qparam.logit.grad.data.zero_()
|
||||
|
||||
def model_size(self, exact=False):
|
||||
"""
|
||||
Differentiable estimate of the model size.
|
||||
The size is returned in MB.
|
||||
|
||||
If `exact` is True, then the output is no longer differentiable but
|
||||
reflect exactly an achievable size, even without compression,
|
||||
i.e.same as returned by `naive_model_size()`.
|
||||
"""
|
||||
total = super().model_size()
|
||||
subtotal = 0
|
||||
for qparam in self._qparams:
|
||||
# only count the first appearance of a Parameter
|
||||
if qparam.other is not None:
|
||||
continue
|
||||
bits = self.extra_bits + self._get_bits(qparam.logit)
|
||||
if exact:
|
||||
bits = bits.round().clamp(1, 15)
|
||||
if self.group_size == 0:
|
||||
group_size = qparam.param.numel()
|
||||
else:
|
||||
group_size = self.group_size
|
||||
subtotal += group_size * bits.sum()
|
||||
subtotal += 2 * 32 # param scale
|
||||
|
||||
# Number of bits to represent each number of bits
|
||||
bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits)))
|
||||
subtotal += 8 # 8 bits for bits_bits
|
||||
subtotal += bits_bits * bits.numel()
|
||||
|
||||
subtotal /= 2 ** 20 * 8 # bits -> MegaBytes
|
||||
return total + subtotal
|
||||
|
||||
def true_model_size(self):
|
||||
"""
|
||||
Naive model size without zlib compression.
|
||||
"""
|
||||
return self.model_size(exact=True).item()
|
||||
|
||||
def _pre_forward_train(self):
|
||||
if not self._optimizer_setup:
|
||||
raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
|
||||
"before starting training.")
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
noisy = qparam.other.module._parameters[qparam.other.name]
|
||||
else:
|
||||
bits = self._get_bits(qparam.logit)[:, None]
|
||||
if self.group_size == 0:
|
||||
p_flat = qparam.param.view(-1)
|
||||
else:
|
||||
p_flat = qparam.param.view(-1, self.group_size)
|
||||
scale = p_flat.max() - p_flat.min()
|
||||
unit = 1 / (2**bits - 1)
|
||||
if self.noise == "uniform":
|
||||
noise_source = (torch.rand_like(p_flat) - 0.5)
|
||||
elif self.noise == "gaussian":
|
||||
noise_source = torch.randn_like(p_flat) / 2
|
||||
noise = scale * unit * noise_source
|
||||
noisy = p_flat + noise
|
||||
# We bypass the checks by PyTorch on parameters being leafs
|
||||
qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param)
|
||||
return True
|
||||
|
||||
def _post_forward_train(self):
|
||||
for qparam in self._qparams:
|
||||
qparam.module._parameters[qparam.name] = qparam.param
|
||||
return True
|
||||
|
||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
||||
bits = self.extra_bits + self._get_bits(qparam.logit)
|
||||
bits = bits.round().clamp(1, 15)[:, None].byte()
|
||||
if self.group_size == 0:
|
||||
p = qparam.param.data.view(-1)
|
||||
else:
|
||||
p = qparam.param.data.view(-1, self.group_size)
|
||||
levels, scales = uniform_quantize(p, bits)
|
||||
return levels, scales, bits
|
||||
|
||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
||||
levels, param_scale, bits = quantized
|
||||
return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data)
|
||||
|
||||
def detach(self):
|
||||
super().detach()
|
||||
for qparam in self._qparams:
|
||||
delattr(qparam.module, qparam.name + self.suffix)
|
||||
|
||||
def __repr__(self):
|
||||
return simple_repr(self)
|
121
diffq/uniform.py
Normal file
121
diffq/uniform.py
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Classic uniform quantization over n bits.
|
||||
"""
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
from .base import BaseQuantizer
|
||||
from .utils import simple_repr
|
||||
|
||||
|
||||
def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
|
||||
"""
|
||||
Quantize the given weights over `bits` bits.
|
||||
|
||||
Returns:
|
||||
- quantized levels
|
||||
- (min, max) range.
|
||||
|
||||
"""
|
||||
assert (bits >= 1).all() and (bits <= 15).all()
|
||||
num_levels = (2 ** bits.float()).long()
|
||||
mn = p.min().item()
|
||||
mx = p.max().item()
|
||||
p = (p - mn) / (mx - mn) # put p in [0, 1]
|
||||
unit = 1 / (num_levels - 1) # quantization unit
|
||||
levels = (p / unit).round()
|
||||
if (bits <= 8).all():
|
||||
levels = levels.byte()
|
||||
else:
|
||||
levels = levels.short()
|
||||
return levels, (mn, mx)
|
||||
|
||||
|
||||
def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float],
|
||||
bits: torch.Tensor = torch.tensor(8.)):
|
||||
"""
|
||||
Unquantize the weights from the levels and scale. Return a float32 tensor.
|
||||
"""
|
||||
mn, mx = scales
|
||||
num_levels = 2 ** bits.float()
|
||||
unit = 1 / (num_levels - 1)
|
||||
levels = levels.float()
|
||||
p = levels * unit # in [0, 1]
|
||||
return p * (mx - mn) + mn
|
||||
|
||||
|
||||
class UniformQuantizer(BaseQuantizer):
|
||||
def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01,
|
||||
float16: bool = False, qat: bool = False, exclude=[], detect_bound=True):
|
||||
"""
|
||||
Args:
|
||||
model (torch.nn.Module): model to quantize
|
||||
bits (float): number of bits to quantize over.
|
||||
min_size (float): minimum size in MB of a parameter to be quantized.
|
||||
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
||||
qat (bool): perform quantized aware training.
|
||||
exclude (list[str]): list of patterns used to match parameters to exclude.
|
||||
For instance `['bias']` to exclude all bias terms.
|
||||
detect_bound (bool): if True, will detect bound parameters and reuse
|
||||
the same quantized tensor for both.
|
||||
"""
|
||||
self.bits = float(bits)
|
||||
self.qat = qat
|
||||
|
||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
||||
|
||||
def __repr__(self):
|
||||
return simple_repr(self, )
|
||||
|
||||
def _pre_forward_train(self):
|
||||
if self.qat:
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
new_param = qparam.other.module._parameters[qparam.other.name]
|
||||
else:
|
||||
quantized = self._quantize_param(qparam)
|
||||
qvalue = self._unquantize_param(qparam, quantized)
|
||||
new_param = qparam.param + (qvalue - qparam.param).detach()
|
||||
qparam.module._parameters[qparam.name] = new_param
|
||||
return True
|
||||
return False
|
||||
|
||||
def _post_forward_train(self):
|
||||
if self.qat:
|
||||
for qparam in self._qparams:
|
||||
qparam.module._parameters[qparam.name] = qparam.param
|
||||
return True
|
||||
return False
|
||||
|
||||
def _quantize_param(self, qparam):
|
||||
levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits))
|
||||
return (levels, scales)
|
||||
|
||||
def _unquantize_param(self, qparam, quantized):
|
||||
levels, scales = quantized
|
||||
return uniform_unquantize(levels, scales, torch.tensor(self.bits))
|
||||
|
||||
def model_size(self):
|
||||
"""
|
||||
Non differentiable model size in MB.
|
||||
"""
|
||||
total = super().model_size()
|
||||
subtotal = 0
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is None: # if parameter is bound, count only one copy.
|
||||
subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales
|
||||
subtotal /= 2**20 * 8 # bits to MegaBytes
|
||||
return total + subtotal
|
||||
|
||||
def true_model_size(self):
|
||||
"""
|
||||
Return the true quantized model size, in MB, without extra
|
||||
compression.
|
||||
"""
|
||||
return self.model_size().item()
|
37
diffq/utils.py
Normal file
37
diffq/utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
def simple_repr(obj, attrs: Optional[List[str]] = None, overrides={}):
|
||||
"""
|
||||
Return a simple representation string for `obj`.
|
||||
If `attrs` is not None, it should be a list of attributes to include.
|
||||
"""
|
||||
params = inspect.signature(obj.__class__).parameters
|
||||
attrs_repr = []
|
||||
if attrs is None:
|
||||
attrs = params.keys()
|
||||
for attr in attrs:
|
||||
display = False
|
||||
if attr in overrides:
|
||||
value = overrides[attr]
|
||||
elif hasattr(obj, attr):
|
||||
value = getattr(obj, attr)
|
||||
else:
|
||||
continue
|
||||
if attr in params:
|
||||
param = params[attr]
|
||||
if param.default is inspect._empty or value != param.default:
|
||||
display = True
|
||||
else:
|
||||
display = True
|
||||
|
||||
if display:
|
||||
attrs_repr.append(f"{attr}={value}")
|
||||
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
|
244
models.py
Normal file
244
models.py
Normal file
@ -0,0 +1,244 @@
|
||||
import torch
|
||||
from torch._C import has_mkl
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
dim_c = 4
|
||||
k = 3
|
||||
model_path = 'model'
|
||||
n_fft_scale = {'bass': 8, 'drums':2, 'other':4, 'vocals':3, '*':2}
|
||||
|
||||
|
||||
class Conv_TDF(nn.Module):
|
||||
def __init__(self, c, l, f, k, bn, bias=True):
|
||||
|
||||
super(Conv_TDF, self).__init__()
|
||||
|
||||
self.use_tdf = bn is not None
|
||||
|
||||
self.H = nn.ModuleList()
|
||||
for i in range(l):
|
||||
self.H.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k//2),
|
||||
nn.BatchNorm2d(c),
|
||||
nn.ReLU(),
|
||||
)
|
||||
)
|
||||
|
||||
if self.use_tdf:
|
||||
if bn==0:
|
||||
self.tdf = nn.Sequential(
|
||||
nn.Linear(f,f, bias=bias),
|
||||
nn.BatchNorm2d(c),
|
||||
nn.ReLU()
|
||||
)
|
||||
else:
|
||||
self.tdf = nn.Sequential(
|
||||
nn.Linear(f,f//bn, bias=bias),
|
||||
nn.BatchNorm2d(c),
|
||||
nn.ReLU(),
|
||||
nn.Linear(f//bn,f, bias=bias),
|
||||
nn.BatchNorm2d(c),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for h in self.H:
|
||||
x = h(x)
|
||||
|
||||
return x + self.tdf(x) if self.use_tdf else x
|
||||
|
||||
|
||||
class Conv_TDF_net_trim(nn.Module):
|
||||
def __init__(self, device, load, model_name, target_name, lr, epoch,
|
||||
L, l, g, dim_f, dim_t, k=3, hop=1024, bn=None, bias=True):
|
||||
|
||||
super(Conv_TDF_net_trim, self).__init__()
|
||||
|
||||
self.dim_f, self.dim_t = 2**dim_f, 2**dim_t
|
||||
self.n_fft = self.dim_f * n_fft_scale[target_name]
|
||||
self.hop = hop
|
||||
self.n_bins = self.n_fft//2+1
|
||||
self.chunk_size = hop * (self.dim_t-1)
|
||||
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
|
||||
self.target_name = target_name
|
||||
self.blender = 'blender' in model_name
|
||||
|
||||
out_c = dim_c*4 if target_name=='*' else dim_c
|
||||
in_c = dim_c*2 if self.blender else dim_c
|
||||
#out_c = dim_c*2 if self.blender else dim_c
|
||||
self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(device)
|
||||
|
||||
self.n = L//2
|
||||
if load:
|
||||
|
||||
self.first_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels=in_c, out_channels=g, kernel_size=1, stride=1),
|
||||
nn.BatchNorm2d(g),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
f = self.dim_f
|
||||
c = g
|
||||
self.ds_dense = nn.ModuleList()
|
||||
self.ds = nn.ModuleList()
|
||||
for i in range(self.n):
|
||||
self.ds_dense.append(Conv_TDF(c, l, f, k, bn, bias=bias))
|
||||
|
||||
scale = (2,2)
|
||||
self.ds.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels=c, out_channels=c+g, kernel_size=scale, stride=scale),
|
||||
nn.BatchNorm2d(c+g),
|
||||
nn.ReLU()
|
||||
)
|
||||
)
|
||||
f = f//2
|
||||
c += g
|
||||
|
||||
self.mid_dense = Conv_TDF(c, l, f, k, bn, bias=bias)
|
||||
#if bn is None and mid_tdf:
|
||||
# self.mid_dense = Conv_TDF(c, l, f, k, bn=0, bias=False)
|
||||
|
||||
self.us_dense = nn.ModuleList()
|
||||
self.us = nn.ModuleList()
|
||||
for i in range(self.n):
|
||||
scale = (2,2)
|
||||
self.us.append(
|
||||
nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels=c, out_channels=c-g, kernel_size=scale, stride=scale),
|
||||
nn.BatchNorm2d(c-g),
|
||||
nn.ReLU()
|
||||
)
|
||||
)
|
||||
f = f*2
|
||||
c -= g
|
||||
|
||||
self.us_dense.append(Conv_TDF(c, l, f, k, bn, bias=bias))
|
||||
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels=c, out_channels=out_c, kernel_size=1, stride=1),
|
||||
)
|
||||
|
||||
|
||||
model_cfg = f'L{L}l{l}g{g}'
|
||||
model_cfg += ', ' if (bn is None or bn==0) else f'bn{bn}, '
|
||||
|
||||
stft_cfg = f'f{dim_f}t{dim_t}, '
|
||||
|
||||
model_name = model_name[:model_name.index('(')+1] + model_cfg + stft_cfg + model_name[model_name.index('(')+1:]
|
||||
try:
|
||||
self.load_state_dict(
|
||||
torch.load('{0}/{1}/{2}_lr{3}_e{4:05}.ckpt'.format(model_path, model_name, target_name, lr, epoch), map_location=device)
|
||||
)
|
||||
print(f'Loading model ({target_name})')
|
||||
except FileNotFoundError:
|
||||
print(f'Random init ({target_name})')
|
||||
|
||||
|
||||
def stft(self, x):
|
||||
x = x.reshape([-1, self.chunk_size])
|
||||
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
|
||||
x = x.permute([0,3,1,2])
|
||||
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,dim_c,self.n_bins,self.dim_t])
|
||||
return x[:,:,:self.dim_f]
|
||||
|
||||
def istft(self, x, freq_pad=None):
|
||||
freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
|
||||
x = torch.cat([x, freq_pad], -2)
|
||||
c = 4*2 if self.target_name=='*' else 2
|
||||
x = x.reshape([-1,c,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
|
||||
x = x.permute([0,2,3,1])
|
||||
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
|
||||
return x.reshape([-1,c,self.chunk_size])
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.first_conv(x)
|
||||
|
||||
x = x.transpose(-1,-2)
|
||||
|
||||
ds_outputs = []
|
||||
for i in range(self.n):
|
||||
x = self.ds_dense[i](x)
|
||||
ds_outputs.append(x)
|
||||
x = self.ds[i](x)
|
||||
|
||||
x = self.mid_dense(x)
|
||||
|
||||
for i in range(self.n):
|
||||
x = self.us[i](x)
|
||||
x *= ds_outputs[-i-1]
|
||||
x = self.us_dense[i](x)
|
||||
|
||||
x = x.transpose(-1,-2)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
def stft(wave, nfft, hl):
|
||||
wave_left = np.asfortranarray(wave[0])
|
||||
wave_right = np.asfortranarray(wave[1])
|
||||
spec_left = librosa.stft(wave_left, nfft, hop_length=hl)
|
||||
spec_right = librosa.stft(wave_right, nfft, hop_length=hl)
|
||||
spec = np.asfortranarray([spec_left, spec_right])
|
||||
|
||||
return spec
|
||||
|
||||
def istft(spec, hl):
|
||||
spec_left = np.asfortranarray(spec[0])
|
||||
spec_right = np.asfortranarray(spec[1])
|
||||
|
||||
wave_left = librosa.istft(spec_left, hop_length=hl)
|
||||
wave_right = librosa.istft(spec_right, hop_length=hl)
|
||||
wave = np.asfortranarray([wave_left, wave_right])
|
||||
|
||||
return wave
|
||||
|
||||
def spec_effects(wave, algorithm='default', value=None):
|
||||
spec = [stft(wave[0],2048,1024),stft(wave[1],2048,1024)]
|
||||
if algorithm == 'min_mag':
|
||||
v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
|
||||
wave = istft(v_spec_m,1024)
|
||||
elif algorithm == 'max_mag':
|
||||
v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
|
||||
wave = istft(v_spec_m,1024)
|
||||
elif algorithm == 'default':
|
||||
#wave = [istft(spec[0],1024),istft(spec[1],1024)]
|
||||
wave = (wave[1] * value) + (wave[0] * (1-value))
|
||||
elif algorithm == 'invert_p':
|
||||
X_mag = np.abs(spec[0])
|
||||
y_mag = np.abs(spec[1])
|
||||
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
|
||||
v_spec = spec[1] - max_mag * np.exp(1.j * np.angle(spec[0]))
|
||||
wave = istft(v_spec,1024)
|
||||
return wave
|
||||
|
||||
|
||||
def get_models(name, device, load=True, stems='vocals'):
|
||||
|
||||
if name=='tdf_extra':
|
||||
models = []
|
||||
if 'vocals' in stems:
|
||||
models.append(
|
||||
Conv_TDF_net_trim(
|
||||
device=device, load=load,
|
||||
model_name='Conv-TDF', target_name='vocals',
|
||||
lr=0.0001, epoch=0,
|
||||
L=11, l=3, g=32, bn=8, bias=False,
|
||||
dim_f=11, dim_t=8
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
else:
|
||||
print('Model undefined')
|
||||
return None
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user