mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-28 01:10:56 +01:00
175 lines
6.6 KiB
Python
175 lines
6.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from collections import OrderedDict
|
|
import hashlib
|
|
import math
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import julius
|
|
import torch as th
|
|
from torch import distributed
|
|
import torchaudio as ta
|
|
from torch.nn import functional as F
|
|
|
|
from .audio import convert_audio_channels
|
|
from .compressed import get_musdb_tracks
|
|
|
|
MIXTURE = "mixture"
|
|
EXT = ".wav"
|
|
|
|
|
|
def _track_metadata(track, sources):
|
|
track_length = None
|
|
track_samplerate = None
|
|
for source in sources + [MIXTURE]:
|
|
file = track / f"{source}{EXT}"
|
|
info = ta.info(str(file))
|
|
length = info.num_frames
|
|
if track_length is None:
|
|
track_length = length
|
|
track_samplerate = info.sample_rate
|
|
elif track_length != length:
|
|
raise ValueError(
|
|
f"Invalid length for file {file}: "
|
|
f"expecting {track_length} but got {length}.")
|
|
elif info.sample_rate != track_samplerate:
|
|
raise ValueError(
|
|
f"Invalid sample rate for file {file}: "
|
|
f"expecting {track_samplerate} but got {info.sample_rate}.")
|
|
if source == MIXTURE:
|
|
wav, _ = ta.load(str(file))
|
|
wav = wav.mean(0)
|
|
mean = wav.mean().item()
|
|
std = wav.std().item()
|
|
|
|
return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}
|
|
|
|
|
|
def _build_metadata(path, sources):
|
|
meta = {}
|
|
path = Path(path)
|
|
for file in path.iterdir():
|
|
meta[file.name] = _track_metadata(file, sources)
|
|
return meta
|
|
|
|
|
|
class Wavset:
|
|
def __init__(
|
|
self,
|
|
root, metadata, sources,
|
|
length=None, stride=None, normalize=True,
|
|
samplerate=44100, channels=2):
|
|
"""
|
|
Waveset (or mp3 set for that matter). Can be used to train
|
|
with arbitrary sources. Each track should be one folder inside of `path`.
|
|
The folder should contain files named `{source}.{ext}`.
|
|
Files will be grouped according to `sources` (each source is a list of
|
|
filenames).
|
|
|
|
Sample rate and channels will be converted on the fly.
|
|
|
|
`length` is the sample size to extract (in samples, not duration).
|
|
`stride` is how many samples to move by between each example.
|
|
"""
|
|
self.root = Path(root)
|
|
self.metadata = OrderedDict(metadata)
|
|
self.length = length
|
|
self.stride = stride or length
|
|
self.normalize = normalize
|
|
self.sources = sources
|
|
self.channels = channels
|
|
self.samplerate = samplerate
|
|
self.num_examples = []
|
|
for name, meta in self.metadata.items():
|
|
track_length = int(self.samplerate * meta['length'] / meta['samplerate'])
|
|
if length is None or track_length < length:
|
|
examples = 1
|
|
else:
|
|
examples = int(math.ceil((track_length - self.length) / self.stride) + 1)
|
|
self.num_examples.append(examples)
|
|
|
|
def __len__(self):
|
|
return sum(self.num_examples)
|
|
|
|
def get_file(self, name, source):
|
|
return self.root / name / f"{source}{EXT}"
|
|
|
|
def __getitem__(self, index):
|
|
for name, examples in zip(self.metadata, self.num_examples):
|
|
if index >= examples:
|
|
index -= examples
|
|
continue
|
|
meta = self.metadata[name]
|
|
num_frames = -1
|
|
offset = 0
|
|
if self.length is not None:
|
|
offset = int(math.ceil(
|
|
meta['samplerate'] * self.stride * index / self.samplerate))
|
|
num_frames = int(math.ceil(
|
|
meta['samplerate'] * self.length / self.samplerate))
|
|
wavs = []
|
|
for source in self.sources:
|
|
file = self.get_file(name, source)
|
|
wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
|
|
wav = convert_audio_channels(wav, self.channels)
|
|
wavs.append(wav)
|
|
|
|
example = th.stack(wavs)
|
|
example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
|
|
if self.normalize:
|
|
example = (example - meta['mean']) / meta['std']
|
|
if self.length:
|
|
example = example[..., :self.length]
|
|
example = F.pad(example, (0, self.length - example.shape[-1]))
|
|
return example
|
|
|
|
|
|
def get_wav_datasets(args, samples, sources):
|
|
sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8]
|
|
metadata_file = args.metadata / (sig + ".json")
|
|
train_path = args.wav / "train"
|
|
valid_path = args.wav / "valid"
|
|
if not metadata_file.is_file() and args.rank == 0:
|
|
train = _build_metadata(train_path, sources)
|
|
valid = _build_metadata(valid_path, sources)
|
|
json.dump([train, valid], open(metadata_file, "w"))
|
|
if args.world_size > 1:
|
|
distributed.barrier()
|
|
train, valid = json.load(open(metadata_file))
|
|
train_set = Wavset(train_path, train, sources,
|
|
length=samples, stride=args.data_stride,
|
|
samplerate=args.samplerate, channels=args.audio_channels,
|
|
normalize=args.norm_wav)
|
|
valid_set = Wavset(valid_path, valid, [MIXTURE] + sources,
|
|
samplerate=args.samplerate, channels=args.audio_channels,
|
|
normalize=args.norm_wav)
|
|
return train_set, valid_set
|
|
|
|
|
|
def get_musdb_wav_datasets(args, samples, sources):
|
|
metadata_file = args.metadata / "musdb_wav.json"
|
|
root = args.musdb / "train"
|
|
if not metadata_file.is_file() and args.rank == 0:
|
|
metadata = _build_metadata(root, sources)
|
|
json.dump(metadata, open(metadata_file, "w"))
|
|
if args.world_size > 1:
|
|
distributed.barrier()
|
|
metadata = json.load(open(metadata_file))
|
|
|
|
train_tracks = get_musdb_tracks(args.musdb, is_wav=True, subsets=["train"], split="train")
|
|
metadata_train = {name: meta for name, meta in metadata.items() if name in train_tracks}
|
|
metadata_valid = {name: meta for name, meta in metadata.items() if name not in train_tracks}
|
|
train_set = Wavset(root, metadata_train, sources,
|
|
length=samples, stride=args.data_stride,
|
|
samplerate=args.samplerate, channels=args.audio_channels,
|
|
normalize=args.norm_wav)
|
|
valid_set = Wavset(root, metadata_valid, [MIXTURE] + sources,
|
|
samplerate=args.samplerate, channels=args.audio_channels,
|
|
normalize=args.norm_wav)
|
|
return train_set, valid_set
|