mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-24 15:30:11 +01:00
132 lines
3.9 KiB
Python
132 lines
3.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
import math
|
|
import os
|
|
import tempfile
|
|
import typing as tp
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def unfold(a, kernel_size, stride):
|
|
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
|
with K the kernel size, by extracting frames with the given stride.
|
|
|
|
This will pad the input so that `F = ceil(T / K)`.
|
|
|
|
see https://github.com/pytorch/pytorch/issues/60466
|
|
"""
|
|
*shape, length = a.shape
|
|
n_frames = math.ceil(length / stride)
|
|
tgt_length = (n_frames - 1) * stride + kernel_size
|
|
a = F.pad(a, (0, tgt_length - length))
|
|
strides = list(a.stride())
|
|
assert strides[-1] == 1, 'data should be contiguous'
|
|
strides = strides[:-1] + [stride, 1]
|
|
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
|
|
|
|
|
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
|
|
"""
|
|
Center trim `tensor` with respect to `reference`, along the last dimension.
|
|
`reference` can also be a number, representing the length to trim to.
|
|
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
|
"""
|
|
ref_size: int
|
|
if isinstance(reference, torch.Tensor):
|
|
ref_size = reference.size(-1)
|
|
else:
|
|
ref_size = reference
|
|
delta = tensor.size(-1) - ref_size
|
|
if delta < 0:
|
|
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
|
if delta:
|
|
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
|
|
return tensor
|
|
|
|
|
|
def pull_metric(history: tp.List[dict], name: str):
|
|
out = []
|
|
for metrics in history:
|
|
metric = metrics
|
|
for part in name.split("."):
|
|
metric = metric[part]
|
|
out.append(metric)
|
|
return out
|
|
|
|
|
|
def EMA(beta: float = 1):
|
|
"""
|
|
Exponential Moving Average callback.
|
|
Returns a single function that can be called to repeatidly update the EMA
|
|
with a dict of metrics. The callback will return
|
|
the new averaged dict of metrics.
|
|
|
|
Note that for `beta=1`, this is just plain averaging.
|
|
"""
|
|
fix: tp.Dict[str, float] = defaultdict(float)
|
|
total: tp.Dict[str, float] = defaultdict(float)
|
|
|
|
def _update(metrics: dict, weight: float = 1) -> dict:
|
|
nonlocal total, fix
|
|
for key, value in metrics.items():
|
|
total[key] = total[key] * beta + weight * float(value)
|
|
fix[key] = fix[key] * beta + weight
|
|
return {key: tot / fix[key] for key, tot in total.items()}
|
|
return _update
|
|
|
|
|
|
def sizeof_fmt(num: float, suffix: str = 'B'):
|
|
"""
|
|
Given `num` bytes, return human readable size.
|
|
Taken from https://stackoverflow.com/a/1094933
|
|
"""
|
|
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
|
|
if abs(num) < 1024.0:
|
|
return "%3.1f%s%s" % (num, unit, suffix)
|
|
num /= 1024.0
|
|
return "%.1f%s%s" % (num, 'Yi', suffix)
|
|
|
|
|
|
@contextmanager
|
|
def temp_filenames(count: int, delete=True):
|
|
names = []
|
|
try:
|
|
for _ in range(count):
|
|
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
|
yield names
|
|
finally:
|
|
if delete:
|
|
for name in names:
|
|
os.unlink(name)
|
|
|
|
|
|
class DummyPoolExecutor:
|
|
class DummyResult:
|
|
def __init__(self, func, *args, **kwargs):
|
|
self.func = func
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
def result(self):
|
|
return self.func(*self.args, **self.kwargs)
|
|
|
|
def __init__(self, workers=0):
|
|
pass
|
|
|
|
def submit(self, func, *args, **kwargs):
|
|
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
return
|