mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-12-01 02:27:21 +01:00
110 lines
4.0 KiB
Python
110 lines
4.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import gzip
|
|
import sys
|
|
from concurrent import futures
|
|
|
|
import musdb
|
|
import museval
|
|
import torch as th
|
|
import tqdm
|
|
from scipy.io import wavfile
|
|
from torch import distributed
|
|
|
|
from .audio import convert_audio
|
|
from .utils import apply_model
|
|
|
|
|
|
def evaluate(model,
|
|
musdb_path,
|
|
eval_folder,
|
|
workers=2,
|
|
device="cpu",
|
|
rank=0,
|
|
save=False,
|
|
shifts=0,
|
|
split=False,
|
|
overlap=0.25,
|
|
is_wav=False,
|
|
world_size=1):
|
|
"""
|
|
Evaluate model using museval. Run the model
|
|
on a single GPU, the bottleneck being the call to museval.
|
|
"""
|
|
|
|
output_dir = eval_folder / "results"
|
|
output_dir.mkdir(exist_ok=True, parents=True)
|
|
json_folder = eval_folder / "results/test"
|
|
json_folder.mkdir(exist_ok=True, parents=True)
|
|
|
|
# we load tracks from the original musdb set
|
|
test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav)
|
|
src_rate = 44100 # hardcoded for now...
|
|
|
|
for p in model.parameters():
|
|
p.requires_grad = False
|
|
p.grad = None
|
|
|
|
pendings = []
|
|
with futures.ProcessPoolExecutor(workers or 1) as pool:
|
|
for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout):
|
|
track = test_set.tracks[index]
|
|
|
|
out = json_folder / f"{track.name}.json.gz"
|
|
if out.exists():
|
|
continue
|
|
|
|
mix = th.from_numpy(track.audio).t().float()
|
|
ref = mix.mean(dim=0) # mono mixture
|
|
mix = (mix - ref.mean()) / ref.std()
|
|
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
|
|
estimates = apply_model(model, mix.to(device),
|
|
shifts=shifts, split=split, overlap=overlap)
|
|
estimates = estimates * ref.std() + ref.mean()
|
|
|
|
estimates = estimates.transpose(1, 2)
|
|
references = th.stack(
|
|
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
|
|
references = convert_audio(references, src_rate,
|
|
model.samplerate, model.audio_channels)
|
|
references = references.transpose(1, 2).numpy()
|
|
estimates = estimates.cpu().numpy()
|
|
win = int(1. * model.samplerate)
|
|
hop = int(1. * model.samplerate)
|
|
if save:
|
|
folder = eval_folder / "wav/test" / track.name
|
|
folder.mkdir(exist_ok=True, parents=True)
|
|
for name, estimate in zip(model.sources, estimates):
|
|
wavfile.write(str(folder / (name + ".wav")), 44100, estimate)
|
|
|
|
if workers:
|
|
pendings.append((track.name, pool.submit(
|
|
museval.evaluate, references, estimates, win=win, hop=hop)))
|
|
else:
|
|
pendings.append((track.name, museval.evaluate(
|
|
references, estimates, win=win, hop=hop)))
|
|
del references, mix, estimates, track
|
|
|
|
for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout):
|
|
if workers:
|
|
pending = pending.result()
|
|
sdr, isr, sir, sar = pending
|
|
track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name)
|
|
for idx, target in enumerate(model.sources):
|
|
values = {
|
|
"SDR": sdr[idx].tolist(),
|
|
"SIR": sir[idx].tolist(),
|
|
"ISR": isr[idx].tolist(),
|
|
"SAR": sar[idx].tolist()
|
|
}
|
|
|
|
track_store.add_target(target_name=target, values=values)
|
|
json_path = json_folder / f"{track_name}.json.gz"
|
|
gzip.open(json_path, "w").write(track_store.json.encode('utf-8'))
|
|
if world_size > 1:
|
|
distributed.barrier()
|