Add files via upload

This commit is contained in:
Anjok07 2023-10-17 01:02:16 -05:00 committed by GitHub
parent bee41d7636
commit db741801d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 16 deletions

View File

@ -770,20 +770,24 @@ class HDemucs(nn.Module):
x = x.view(B, S, -1, Fq, T) x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None] x = x * std[:, None] + mean[:, None]
# to cpu as mps doesnt support complex numbers # to cpu as non-cuda GPUs don't support complex numbers
# demucs issue #435 ##432 # demucs issue #435 ##432
# NOTE: in this case z already is on cpu # NOTE: in this case z already is on cpu
# TODO: remove this when mps supports complex numbers # TODO: remove this when mps supports complex numbers
x_is_mps = x.device.type == "mps"
if x_is_mps: device_type = x.device.type
device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type
x_is_other_gpu = not device_type in ["cuda", "cpu"]
if x_is_other_gpu:
x = x.cpu() x = x.cpu()
zout = self._mask(z, x) zout = self._mask(z, x)
x = self._ispec(zout, length) x = self._ispec(zout, length)
# back to mps device # back to other device
if x_is_mps: if x_is_other_gpu:
x = x.to('mps') x = x.to(device_load)
if self.hybrid: if self.hybrid:
xt = xt.view(B, S, -1, length) xt = xt.view(B, S, -1, length)

View File

@ -625,12 +625,16 @@ class HTDemucs(nn.Module):
x = x.view(B, S, -1, Fq, T) x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None] x = x * std[:, None] + mean[:, None]
# to cpu as mps doesnt support complex numbers # to cpu as non-cuda GPUs don't support complex numbers
# demucs issue #435 ##432 # demucs issue #435 ##432
# NOTE: in this case z already is on cpu # NOTE: in this case z already is on cpu
# TODO: remove this when mps supports complex numbers # TODO: remove this when mps supports complex numbers
x_is_mps = x.device.type == "mps"
if x_is_mps: device_type = x.device.type
device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type
x_is_other_gpu = not device_type in ["cuda", "cpu"]
if x_is_other_gpu:
x = x.cpu() x = x.cpu()
zout = self._mask(z, x) zout = self._mask(z, x)
@ -642,9 +646,9 @@ class HTDemucs(nn.Module):
else: else:
x = self._ispec(zout, length) x = self._ispec(zout, length)
# back to mps device # back to other device
if x_is_mps: if x_is_other_gpu:
x = x.to("mps") x = x.to(device_load)
if self.use_train_segment: if self.use_train_segment:
if self.training: if self.training:

View File

@ -11,8 +11,11 @@ import torch as th
def spectro(x, n_fft=512, hop_length=None, pad=0): def spectro(x, n_fft=512, hop_length=None, pad=0):
*other, length = x.shape *other, length = x.shape
x = x.reshape(-1, length) x = x.reshape(-1, length)
is_mps = x.device.type == 'mps'
if is_mps: device_type = x.device.type
is_other_gpu = not device_type in ["cuda", "cpu"]
if is_other_gpu:
x = x.cpu() x = x.cpu()
z = th.stft(x, z = th.stft(x,
n_fft * (1 + pad), n_fft * (1 + pad),
@ -32,8 +35,11 @@ def ispectro(z, hop_length=None, length=None, pad=0):
n_fft = 2 * freqs - 2 n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames) z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad) win_length = n_fft // (1 + pad)
is_mps = z.device.type == 'mps'
if is_mps: device_type = z.device.type
is_other_gpu = not device_type in ["cuda", "cpu"]
if is_other_gpu:
z = z.cpu() z = z.cpu()
x = th.istft(z, x = th.istft(z,
n_fft, n_fft,