mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2025-01-19 17:18:39 +01:00
174 lines
6.1 KiB
Python
174 lines
6.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
import os
|
|
from collections import defaultdict, namedtuple
|
|
from pathlib import Path
|
|
|
|
import musdb
|
|
import numpy as np
|
|
import torch as th
|
|
import tqdm
|
|
from torch.utils.data import DataLoader
|
|
|
|
from .audio import AudioFile
|
|
|
|
ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"])
|
|
|
|
|
|
class Rawset:
|
|
"""
|
|
Dataset of raw, normalized, float32 audio files
|
|
"""
|
|
def __init__(self, path, samples=None, stride=None, channels=2, streams=None):
|
|
self.path = Path(path)
|
|
self.channels = channels
|
|
self.samples = samples
|
|
if stride is None:
|
|
stride = samples if samples is not None else 0
|
|
self.stride = stride
|
|
entries = defaultdict(list)
|
|
for root, folders, files in os.walk(self.path, followlinks=True):
|
|
folders.sort()
|
|
files.sort()
|
|
for file in files:
|
|
if file.endswith(".raw"):
|
|
path = Path(root) / file
|
|
name, stream = path.stem.rsplit('.', 1)
|
|
entries[(path.parent.relative_to(self.path), name)].append(int(stream))
|
|
|
|
self._entries = list(entries.keys())
|
|
|
|
sizes = []
|
|
self._lengths = []
|
|
ref_streams = sorted(entries[self._entries[0]])
|
|
assert ref_streams == list(range(len(ref_streams)))
|
|
if streams is None:
|
|
self.streams = ref_streams
|
|
else:
|
|
self.streams = streams
|
|
for entry in sorted(entries.keys()):
|
|
streams = entries[entry]
|
|
assert sorted(streams) == ref_streams
|
|
file = self._path(*entry)
|
|
length = file.stat().st_size // (4 * channels)
|
|
if samples is None:
|
|
sizes.append(1)
|
|
else:
|
|
if length < samples:
|
|
self._entries.remove(entry)
|
|
continue
|
|
sizes.append((length - samples) // stride + 1)
|
|
self._lengths.append(length)
|
|
if not sizes:
|
|
raise ValueError(f"Empty dataset {self.path}")
|
|
self._cumulative_sizes = np.cumsum(sizes)
|
|
self._sizes = sizes
|
|
|
|
def __len__(self):
|
|
return self._cumulative_sizes[-1]
|
|
|
|
@property
|
|
def total_length(self):
|
|
return sum(self._lengths)
|
|
|
|
def chunk_info(self, index):
|
|
file_index = np.searchsorted(self._cumulative_sizes, index, side='right')
|
|
if file_index == 0:
|
|
local_index = index
|
|
else:
|
|
local_index = index - self._cumulative_sizes[file_index - 1]
|
|
return ChunkInfo(offset=local_index * self.stride,
|
|
file_index=file_index,
|
|
local_index=local_index)
|
|
|
|
def _path(self, folder, name, stream=0):
|
|
return self.path / folder / (name + f'.{stream}.raw')
|
|
|
|
def __getitem__(self, index):
|
|
chunk = self.chunk_info(index)
|
|
entry = self._entries[chunk.file_index]
|
|
|
|
length = self.samples or self._lengths[chunk.file_index]
|
|
streams = []
|
|
to_read = length * self.channels * 4
|
|
for stream_index, stream in enumerate(self.streams):
|
|
offset = chunk.offset * 4 * self.channels
|
|
file = open(self._path(*entry, stream=stream), 'rb')
|
|
file.seek(offset)
|
|
content = file.read(to_read)
|
|
assert len(content) == to_read
|
|
content = np.frombuffer(content, dtype=np.float32)
|
|
content = content.copy() # make writable
|
|
streams.append(th.from_numpy(content).view(length, self.channels).t())
|
|
return th.stack(streams, dim=0)
|
|
|
|
def name(self, index):
|
|
chunk = self.chunk_info(index)
|
|
folder, name = self._entries[chunk.file_index]
|
|
return folder / name
|
|
|
|
|
|
class MusDBSet:
|
|
def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2):
|
|
self.mus = mus
|
|
self.streams = streams
|
|
self.samplerate = samplerate
|
|
self.channels = channels
|
|
|
|
def __len__(self):
|
|
return len(self.mus.tracks)
|
|
|
|
def __getitem__(self, index):
|
|
track = self.mus.tracks[index]
|
|
return (track.name, AudioFile(track.path).read(channels=self.channels,
|
|
seek_time=0,
|
|
streams=self.streams,
|
|
samplerate=self.samplerate))
|
|
|
|
|
|
def build_raw(mus, destination, normalize, workers, samplerate, channels):
|
|
destination.mkdir(parents=True, exist_ok=True)
|
|
loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate),
|
|
batch_size=1,
|
|
num_workers=workers,
|
|
collate_fn=lambda x: x[0])
|
|
for name, streams in tqdm.tqdm(loader):
|
|
if normalize:
|
|
ref = streams[0].mean(dim=0) # use mono mixture as reference
|
|
streams = (streams - ref.mean()) / ref.std()
|
|
for index, stream in enumerate(streams):
|
|
open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes())
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser('rawset')
|
|
parser.add_argument('--workers', type=int, default=10)
|
|
parser.add_argument('--samplerate', type=int, default=44100)
|
|
parser.add_argument('--channels', type=int, default=2)
|
|
parser.add_argument('musdb', type=Path)
|
|
parser.add_argument('destination', type=Path)
|
|
|
|
args = parser.parse_args()
|
|
|
|
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"),
|
|
args.destination / "train",
|
|
normalize=True,
|
|
channels=args.channels,
|
|
samplerate=args.samplerate,
|
|
workers=args.workers)
|
|
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"),
|
|
args.destination / "valid",
|
|
normalize=True,
|
|
samplerate=args.samplerate,
|
|
channels=args.channels,
|
|
workers=args.workers)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|