mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-23 23:10:57 +01:00
Add files via upload
This commit is contained in:
parent
de71392d59
commit
9da7c77cf3
@ -121,7 +121,18 @@ def tensor_chunk(tensor_or_chunk):
|
||||
assert isinstance(tensor_or_chunk, th.Tensor)
|
||||
return TensorChunk(tensor_or_chunk)
|
||||
|
||||
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
|
||||
def apply_model(model,
|
||||
mix,
|
||||
shifts=1,
|
||||
split=True,
|
||||
overlap=0.25,
|
||||
transition_power=1.,
|
||||
static_shifts=1,
|
||||
set_progress_bar=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
num_workers=0,
|
||||
pool=None):
|
||||
"""
|
||||
Apply model to a given mixture.
|
||||
|
||||
@ -140,8 +151,6 @@ def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power
|
||||
be on `device`, while the entire tracks will be stored on `mix.device`.
|
||||
"""
|
||||
|
||||
#print("Progress Bar?: ", type(set_progress_bar))
|
||||
|
||||
global fut_length
|
||||
global bag_num
|
||||
global prog_bar
|
||||
|
@ -689,7 +689,7 @@ class HDemucs(nn.Module):
|
||||
length = x.shape[-1]
|
||||
|
||||
z = self._spec(mix)
|
||||
mag = self._magnitude(z)
|
||||
mag = self._magnitude(z).to(mix.device)
|
||||
x = mag
|
||||
|
||||
B, C, Fq, T = x.shape
|
||||
@ -769,14 +769,24 @@ class HDemucs(nn.Module):
|
||||
S = len(self.sources)
|
||||
x = x.view(B, S, -1, Fq, T)
|
||||
x = x * std[:, None] + mean[:, None]
|
||||
|
||||
# to cpu as mps doesnt support complex numbers
|
||||
# demucs issue #435 ##432
|
||||
# NOTE: in this case z already is on cpu
|
||||
# TODO: remove this when mps supports complex numbers
|
||||
x_is_mps = x.device.type == "mps"
|
||||
if x_is_mps:
|
||||
x = x.cpu()
|
||||
|
||||
zout = self._mask(z, x)
|
||||
x = self._ispec(zout, length)
|
||||
|
||||
# back to mps device
|
||||
if x_is_mps:
|
||||
x = x.to('mps')
|
||||
|
||||
if self.hybrid:
|
||||
xt = xt.view(B, S, -1, length)
|
||||
xt = xt * stdt[:, None] + meant[:, None]
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
return x
|
@ -536,7 +536,7 @@ class HTDemucs(nn.Module):
|
||||
length_pre_pad = mix.shape[-1]
|
||||
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
||||
z = self._spec(mix)
|
||||
mag = self._magnitude(z)
|
||||
mag = self._magnitude(z).to(mix.device)
|
||||
x = mag
|
||||
|
||||
B, C, Fq, T = x.shape
|
||||
@ -625,6 +625,14 @@ class HTDemucs(nn.Module):
|
||||
x = x.view(B, S, -1, Fq, T)
|
||||
x = x * std[:, None] + mean[:, None]
|
||||
|
||||
# to cpu as mps doesnt support complex numbers
|
||||
# demucs issue #435 ##432
|
||||
# NOTE: in this case z already is on cpu
|
||||
# TODO: remove this when mps supports complex numbers
|
||||
x_is_mps = x.device.type == "mps"
|
||||
if x_is_mps:
|
||||
x = x.cpu()
|
||||
|
||||
zout = self._mask(z, x)
|
||||
if self.use_train_segment:
|
||||
if self.training:
|
||||
@ -634,6 +642,10 @@ class HTDemucs(nn.Module):
|
||||
else:
|
||||
x = self._ispec(zout, length)
|
||||
|
||||
# back to mps device
|
||||
if x_is_mps:
|
||||
x = x.to("mps")
|
||||
|
||||
if self.use_train_segment:
|
||||
if self.training:
|
||||
xt = xt.view(B, S, -1, length)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
@ -11,6 +11,9 @@ import torch as th
|
||||
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
||||
*other, length = x.shape
|
||||
x = x.reshape(-1, length)
|
||||
is_mps = x.device.type == 'mps'
|
||||
if is_mps:
|
||||
x = x.cpu()
|
||||
z = th.stft(x,
|
||||
n_fft * (1 + pad),
|
||||
hop_length or n_fft // 4,
|
||||
@ -29,6 +32,9 @@ def ispectro(z, hop_length=None, length=None, pad=0):
|
||||
n_fft = 2 * freqs - 2
|
||||
z = z.view(-1, freqs, frames)
|
||||
win_length = n_fft // (1 + pad)
|
||||
is_mps = z.device.type == 'mps'
|
||||
if is_mps:
|
||||
z = z.cpu()
|
||||
x = th.istft(z,
|
||||
n_fft,
|
||||
hop_length,
|
||||
|
Loading…
Reference in New Issue
Block a user