mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2025-01-19 01:14:03 +01:00
188 lines
6.0 KiB
Python
188 lines
6.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import errno
|
|
import functools
|
|
import gzip
|
|
import os
|
|
import random
|
|
import socket
|
|
import tempfile
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
|
|
import torch as th
|
|
import tqdm
|
|
from torch import distributed
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def center_trim(tensor, reference):
|
|
"""
|
|
Center trim `tensor` with respect to `reference`, along the last dimension.
|
|
`reference` can also be a number, representing the length to trim to.
|
|
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
|
"""
|
|
if hasattr(reference, "size"):
|
|
reference = reference.size(-1)
|
|
delta = tensor.size(-1) - reference
|
|
if delta < 0:
|
|
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
|
if delta:
|
|
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
|
|
return tensor
|
|
|
|
|
|
def average_metric(metric, count=1.):
|
|
"""
|
|
Average `metric` which should be a float across all hosts. `count` should be
|
|
the weight for this particular host (i.e. number of examples).
|
|
"""
|
|
metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda')
|
|
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
|
|
return metric[1].item() / metric[0].item()
|
|
|
|
|
|
def free_port(host='', low=20000, high=40000):
|
|
"""
|
|
Return a port number that is most likely free.
|
|
This could suffer from a race condition although
|
|
it should be quite rare.
|
|
"""
|
|
sock = socket.socket()
|
|
while True:
|
|
port = random.randint(low, high)
|
|
try:
|
|
sock.bind((host, port))
|
|
except OSError as error:
|
|
if error.errno == errno.EADDRINUSE:
|
|
continue
|
|
raise
|
|
return port
|
|
|
|
|
|
def sizeof_fmt(num, suffix='B'):
|
|
"""
|
|
Given `num` bytes, return human readable size.
|
|
Taken from https://stackoverflow.com/a/1094933
|
|
"""
|
|
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
|
|
if abs(num) < 1024.0:
|
|
return "%3.1f%s%s" % (num, unit, suffix)
|
|
num /= 1024.0
|
|
return "%.1f%s%s" % (num, 'Yi', suffix)
|
|
|
|
|
|
def human_seconds(seconds, display='.2f'):
|
|
"""
|
|
Given `seconds` seconds, return human readable duration.
|
|
"""
|
|
value = seconds * 1e6
|
|
ratios = [1e3, 1e3, 60, 60, 24]
|
|
names = ['us', 'ms', 's', 'min', 'hrs', 'days']
|
|
last = names.pop(0)
|
|
for name, ratio in zip(names, ratios):
|
|
if value / ratio < 0.3:
|
|
break
|
|
value /= ratio
|
|
last = name
|
|
return f"{format(value, display)} {last}"
|
|
|
|
|
|
def apply_model_v1(model, mix, shifts=None, split=False, progress=False):
|
|
"""
|
|
Apply model to a given mixture.
|
|
|
|
Args:
|
|
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
|
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
|
all predictions are averaged. This effectively makes the model time equivariant
|
|
and improves SDR by up to 0.2 points.
|
|
split (bool): if True, the input will be broken down in 8 seconds extracts
|
|
and predictions will be performed individually on each and concatenated.
|
|
Useful for model with large memory footprint like Tasnet.
|
|
progress (bool): if True, show a progress bar (requires split=True)
|
|
"""
|
|
channels, length = mix.size()
|
|
device = mix.device
|
|
if split:
|
|
out = th.zeros(4, channels, length, device=device)
|
|
shift = model.samplerate * 10
|
|
offsets = range(0, length, shift)
|
|
scale = 10
|
|
if progress:
|
|
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
|
|
for offset in offsets:
|
|
chunk = mix[..., offset:offset + shift]
|
|
chunk_out = apply_model_v1(model, chunk, shifts=shifts)
|
|
out[..., offset:offset + shift] = chunk_out
|
|
offset += shift
|
|
return out
|
|
elif shifts:
|
|
max_shift = int(model.samplerate / 2)
|
|
mix = F.pad(mix, (max_shift, max_shift))
|
|
offsets = list(range(max_shift))
|
|
random.shuffle(offsets)
|
|
out = 0
|
|
for offset in offsets[:shifts]:
|
|
shifted = mix[..., offset:offset + length + max_shift]
|
|
shifted_out = apply_model_v1(model, shifted)
|
|
out += shifted_out[..., max_shift - offset:max_shift - offset + length]
|
|
out /= shifts
|
|
return out
|
|
else:
|
|
valid_length = model.valid_length(length)
|
|
print('valid_length: ', valid_length)
|
|
delta = valid_length - length
|
|
padded = F.pad(mix, (delta // 2, delta - delta // 2))
|
|
with th.no_grad():
|
|
out = model(padded.unsqueeze(0))[0]
|
|
return center_trim(out, mix)
|
|
|
|
|
|
@contextmanager
|
|
def temp_filenames(count, delete=True, **kwargs):
|
|
names = []
|
|
try:
|
|
for _ in range(count):
|
|
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
|
yield names
|
|
finally:
|
|
if delete:
|
|
for name in names:
|
|
os.unlink(name)
|
|
|
|
|
|
def load_model(path):
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
load_from = path
|
|
if str(path).endswith(".gz"):
|
|
load_from = gzip.open(path, "rb")
|
|
klass, args, kwargs, state = th.load(load_from, 'cpu')
|
|
model = klass(*args, **kwargs)
|
|
model.load_state_dict(state)
|
|
return model
|
|
|
|
|
|
def save_model(model, path):
|
|
args, kwargs = model._init_args_kwargs
|
|
klass = model.__class__
|
|
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
|
|
save_to = path
|
|
if str(path).endswith(".gz"):
|
|
save_to = gzip.open(path, "wb", compresslevel=5)
|
|
th.save((klass, args, kwargs, state), save_to)
|
|
|
|
|
|
def capture_init(init):
|
|
@functools.wraps(init)
|
|
def __init__(self, *args, **kwargs):
|
|
self._init_args_kwargs = (args, kwargs)
|
|
init(self, *args, **kwargs)
|
|
|
|
return __init__
|