mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-23 23:10:57 +01:00
Add files via upload
This commit is contained in:
parent
bee41d7636
commit
db741801d8
@ -770,20 +770,24 @@ class HDemucs(nn.Module):
|
||||
x = x.view(B, S, -1, Fq, T)
|
||||
x = x * std[:, None] + mean[:, None]
|
||||
|
||||
# to cpu as mps doesnt support complex numbers
|
||||
# to cpu as non-cuda GPUs don't support complex numbers
|
||||
# demucs issue #435 ##432
|
||||
# NOTE: in this case z already is on cpu
|
||||
# TODO: remove this when mps supports complex numbers
|
||||
x_is_mps = x.device.type == "mps"
|
||||
if x_is_mps:
|
||||
|
||||
device_type = x.device.type
|
||||
device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type
|
||||
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
||||
|
||||
if x_is_other_gpu:
|
||||
x = x.cpu()
|
||||
|
||||
zout = self._mask(z, x)
|
||||
x = self._ispec(zout, length)
|
||||
|
||||
# back to mps device
|
||||
if x_is_mps:
|
||||
x = x.to('mps')
|
||||
# back to other device
|
||||
if x_is_other_gpu:
|
||||
x = x.to(device_load)
|
||||
|
||||
if self.hybrid:
|
||||
xt = xt.view(B, S, -1, length)
|
||||
|
@ -625,12 +625,16 @@ class HTDemucs(nn.Module):
|
||||
x = x.view(B, S, -1, Fq, T)
|
||||
x = x * std[:, None] + mean[:, None]
|
||||
|
||||
# to cpu as mps doesnt support complex numbers
|
||||
# to cpu as non-cuda GPUs don't support complex numbers
|
||||
# demucs issue #435 ##432
|
||||
# NOTE: in this case z already is on cpu
|
||||
# TODO: remove this when mps supports complex numbers
|
||||
x_is_mps = x.device.type == "mps"
|
||||
if x_is_mps:
|
||||
|
||||
device_type = x.device.type
|
||||
device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type
|
||||
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
||||
|
||||
if x_is_other_gpu:
|
||||
x = x.cpu()
|
||||
|
||||
zout = self._mask(z, x)
|
||||
@ -642,9 +646,9 @@ class HTDemucs(nn.Module):
|
||||
else:
|
||||
x = self._ispec(zout, length)
|
||||
|
||||
# back to mps device
|
||||
if x_is_mps:
|
||||
x = x.to("mps")
|
||||
# back to other device
|
||||
if x_is_other_gpu:
|
||||
x = x.to(device_load)
|
||||
|
||||
if self.use_train_segment:
|
||||
if self.training:
|
||||
|
@ -11,8 +11,11 @@ import torch as th
|
||||
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
||||
*other, length = x.shape
|
||||
x = x.reshape(-1, length)
|
||||
is_mps = x.device.type == 'mps'
|
||||
if is_mps:
|
||||
|
||||
device_type = x.device.type
|
||||
is_other_gpu = not device_type in ["cuda", "cpu"]
|
||||
|
||||
if is_other_gpu:
|
||||
x = x.cpu()
|
||||
z = th.stft(x,
|
||||
n_fft * (1 + pad),
|
||||
@ -32,8 +35,11 @@ def ispectro(z, hop_length=None, length=None, pad=0):
|
||||
n_fft = 2 * freqs - 2
|
||||
z = z.view(-1, freqs, frames)
|
||||
win_length = n_fft // (1 + pad)
|
||||
is_mps = z.device.type == 'mps'
|
||||
if is_mps:
|
||||
|
||||
device_type = z.device.type
|
||||
is_other_gpu = not device_type in ["cuda", "cpu"]
|
||||
|
||||
if is_other_gpu:
|
||||
z = z.cpu()
|
||||
x = th.istft(z,
|
||||
n_fft,
|
||||
|
Loading…
Reference in New Issue
Block a user