mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-28 01:10:56 +01:00
Add files via upload
This commit is contained in:
parent
de71392d59
commit
9da7c77cf3
@ -121,7 +121,18 @@ def tensor_chunk(tensor_or_chunk):
|
|||||||
assert isinstance(tensor_or_chunk, th.Tensor)
|
assert isinstance(tensor_or_chunk, th.Tensor)
|
||||||
return TensorChunk(tensor_or_chunk)
|
return TensorChunk(tensor_or_chunk)
|
||||||
|
|
||||||
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
|
def apply_model(model,
|
||||||
|
mix,
|
||||||
|
shifts=1,
|
||||||
|
split=True,
|
||||||
|
overlap=0.25,
|
||||||
|
transition_power=1.,
|
||||||
|
static_shifts=1,
|
||||||
|
set_progress_bar=None,
|
||||||
|
device=None,
|
||||||
|
progress=False,
|
||||||
|
num_workers=0,
|
||||||
|
pool=None):
|
||||||
"""
|
"""
|
||||||
Apply model to a given mixture.
|
Apply model to a given mixture.
|
||||||
|
|
||||||
@ -140,8 +151,6 @@ def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power
|
|||||||
be on `device`, while the entire tracks will be stored on `mix.device`.
|
be on `device`, while the entire tracks will be stored on `mix.device`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#print("Progress Bar?: ", type(set_progress_bar))
|
|
||||||
|
|
||||||
global fut_length
|
global fut_length
|
||||||
global bag_num
|
global bag_num
|
||||||
global prog_bar
|
global prog_bar
|
||||||
|
@ -689,7 +689,7 @@ class HDemucs(nn.Module):
|
|||||||
length = x.shape[-1]
|
length = x.shape[-1]
|
||||||
|
|
||||||
z = self._spec(mix)
|
z = self._spec(mix)
|
||||||
mag = self._magnitude(z)
|
mag = self._magnitude(z).to(mix.device)
|
||||||
x = mag
|
x = mag
|
||||||
|
|
||||||
B, C, Fq, T = x.shape
|
B, C, Fq, T = x.shape
|
||||||
@ -770,13 +770,23 @@ class HDemucs(nn.Module):
|
|||||||
x = x.view(B, S, -1, Fq, T)
|
x = x.view(B, S, -1, Fq, T)
|
||||||
x = x * std[:, None] + mean[:, None]
|
x = x * std[:, None] + mean[:, None]
|
||||||
|
|
||||||
|
# to cpu as mps doesnt support complex numbers
|
||||||
|
# demucs issue #435 ##432
|
||||||
|
# NOTE: in this case z already is on cpu
|
||||||
|
# TODO: remove this when mps supports complex numbers
|
||||||
|
x_is_mps = x.device.type == "mps"
|
||||||
|
if x_is_mps:
|
||||||
|
x = x.cpu()
|
||||||
|
|
||||||
zout = self._mask(z, x)
|
zout = self._mask(z, x)
|
||||||
x = self._ispec(zout, length)
|
x = self._ispec(zout, length)
|
||||||
|
|
||||||
|
# back to mps device
|
||||||
|
if x_is_mps:
|
||||||
|
x = x.to('mps')
|
||||||
|
|
||||||
if self.hybrid:
|
if self.hybrid:
|
||||||
xt = xt.view(B, S, -1, length)
|
xt = xt.view(B, S, -1, length)
|
||||||
xt = xt * stdt[:, None] + meant[:, None]
|
xt = xt * stdt[:, None] + meant[:, None]
|
||||||
x = xt + x
|
x = xt + x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -536,7 +536,7 @@ class HTDemucs(nn.Module):
|
|||||||
length_pre_pad = mix.shape[-1]
|
length_pre_pad = mix.shape[-1]
|
||||||
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
||||||
z = self._spec(mix)
|
z = self._spec(mix)
|
||||||
mag = self._magnitude(z)
|
mag = self._magnitude(z).to(mix.device)
|
||||||
x = mag
|
x = mag
|
||||||
|
|
||||||
B, C, Fq, T = x.shape
|
B, C, Fq, T = x.shape
|
||||||
@ -625,6 +625,14 @@ class HTDemucs(nn.Module):
|
|||||||
x = x.view(B, S, -1, Fq, T)
|
x = x.view(B, S, -1, Fq, T)
|
||||||
x = x * std[:, None] + mean[:, None]
|
x = x * std[:, None] + mean[:, None]
|
||||||
|
|
||||||
|
# to cpu as mps doesnt support complex numbers
|
||||||
|
# demucs issue #435 ##432
|
||||||
|
# NOTE: in this case z already is on cpu
|
||||||
|
# TODO: remove this when mps supports complex numbers
|
||||||
|
x_is_mps = x.device.type == "mps"
|
||||||
|
if x_is_mps:
|
||||||
|
x = x.cpu()
|
||||||
|
|
||||||
zout = self._mask(z, x)
|
zout = self._mask(z, x)
|
||||||
if self.use_train_segment:
|
if self.use_train_segment:
|
||||||
if self.training:
|
if self.training:
|
||||||
@ -634,6 +642,10 @@ class HTDemucs(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = self._ispec(zout, length)
|
x = self._ispec(zout, length)
|
||||||
|
|
||||||
|
# back to mps device
|
||||||
|
if x_is_mps:
|
||||||
|
x = x.to("mps")
|
||||||
|
|
||||||
if self.use_train_segment:
|
if self.use_train_segment:
|
||||||
if self.training:
|
if self.training:
|
||||||
xt = xt.view(B, S, -1, length)
|
xt = xt.view(B, S, -1, length)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
@ -11,6 +11,9 @@ import torch as th
|
|||||||
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
||||||
*other, length = x.shape
|
*other, length = x.shape
|
||||||
x = x.reshape(-1, length)
|
x = x.reshape(-1, length)
|
||||||
|
is_mps = x.device.type == 'mps'
|
||||||
|
if is_mps:
|
||||||
|
x = x.cpu()
|
||||||
z = th.stft(x,
|
z = th.stft(x,
|
||||||
n_fft * (1 + pad),
|
n_fft * (1 + pad),
|
||||||
hop_length or n_fft // 4,
|
hop_length or n_fft // 4,
|
||||||
@ -29,6 +32,9 @@ def ispectro(z, hop_length=None, length=None, pad=0):
|
|||||||
n_fft = 2 * freqs - 2
|
n_fft = 2 * freqs - 2
|
||||||
z = z.view(-1, freqs, frames)
|
z = z.view(-1, freqs, frames)
|
||||||
win_length = n_fft // (1 + pad)
|
win_length = n_fft // (1 + pad)
|
||||||
|
is_mps = z.device.type == 'mps'
|
||||||
|
if is_mps:
|
||||||
|
z = z.cpu()
|
||||||
x = th.istft(z,
|
x = th.istft(z,
|
||||||
n_fft,
|
n_fft,
|
||||||
hop_length,
|
hop_length,
|
||||||
|
Loading…
Reference in New Issue
Block a user