mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-24 07:20:10 +01:00
116 lines
4.2 KiB
Python
116 lines
4.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import json
|
|
from fractions import Fraction
|
|
from concurrent import futures
|
|
|
|
import musdb
|
|
from torch import distributed
|
|
|
|
from .audio import AudioFile
|
|
|
|
|
|
def get_musdb_tracks(root, *args, **kwargs):
|
|
mus = musdb.DB(root, *args, **kwargs)
|
|
return {track.name: track.path for track in mus}
|
|
|
|
|
|
class StemsSet:
|
|
def __init__(self, tracks, metadata, duration=None, stride=1,
|
|
samplerate=44100, channels=2, streams=slice(None)):
|
|
|
|
self.metadata = []
|
|
for name, path in tracks.items():
|
|
meta = dict(metadata[name])
|
|
meta["path"] = path
|
|
meta["name"] = name
|
|
self.metadata.append(meta)
|
|
if duration is not None and meta["duration"] < duration:
|
|
raise ValueError(f"Track {name} duration is too small {meta['duration']}")
|
|
self.metadata.sort(key=lambda x: x["name"])
|
|
self.duration = duration
|
|
self.stride = stride
|
|
self.channels = channels
|
|
self.samplerate = samplerate
|
|
self.streams = streams
|
|
|
|
def __len__(self):
|
|
return sum(self._examples_count(m) for m in self.metadata)
|
|
|
|
def _examples_count(self, meta):
|
|
if self.duration is None:
|
|
return 1
|
|
else:
|
|
return int((meta["duration"] - self.duration) // self.stride + 1)
|
|
|
|
def track_metadata(self, index):
|
|
for meta in self.metadata:
|
|
examples = self._examples_count(meta)
|
|
if index >= examples:
|
|
index -= examples
|
|
continue
|
|
return meta
|
|
|
|
def __getitem__(self, index):
|
|
for meta in self.metadata:
|
|
examples = self._examples_count(meta)
|
|
if index >= examples:
|
|
index -= examples
|
|
continue
|
|
streams = AudioFile(meta["path"]).read(seek_time=index * self.stride,
|
|
duration=self.duration,
|
|
channels=self.channels,
|
|
samplerate=self.samplerate,
|
|
streams=self.streams)
|
|
return (streams - meta["mean"]) / meta["std"]
|
|
|
|
|
|
def _get_track_metadata(path):
|
|
# use mono at 44kHz as reference. For any other settings data won't be perfectly
|
|
# normalized but it should be good enough.
|
|
audio = AudioFile(path)
|
|
mix = audio.read(streams=0, channels=1, samplerate=44100)
|
|
return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()}
|
|
|
|
|
|
def _build_metadata(tracks, workers=10):
|
|
pendings = []
|
|
with futures.ProcessPoolExecutor(workers) as pool:
|
|
for name, path in tracks.items():
|
|
pendings.append((name, pool.submit(_get_track_metadata, path)))
|
|
return {name: p.result() for name, p in pendings}
|
|
|
|
|
|
def _build_musdb_metadata(path, musdb, workers):
|
|
tracks = get_musdb_tracks(musdb)
|
|
metadata = _build_metadata(tracks, workers)
|
|
path.parent.mkdir(exist_ok=True, parents=True)
|
|
json.dump(metadata, open(path, "w"))
|
|
|
|
|
|
def get_compressed_datasets(args, samples):
|
|
metadata_file = args.metadata / "musdb.json"
|
|
if not metadata_file.is_file() and args.rank == 0:
|
|
_build_musdb_metadata(metadata_file, args.musdb, args.workers)
|
|
if args.world_size > 1:
|
|
distributed.barrier()
|
|
metadata = json.load(open(metadata_file))
|
|
duration = Fraction(samples, args.samplerate)
|
|
stride = Fraction(args.data_stride, args.samplerate)
|
|
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
|
|
metadata,
|
|
duration=duration,
|
|
stride=stride,
|
|
streams=slice(1, None),
|
|
samplerate=args.samplerate,
|
|
channels=args.audio_channels)
|
|
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
|
|
metadata,
|
|
samplerate=args.samplerate,
|
|
channels=args.audio_channels)
|
|
return train_set, valid_set
|