From db741801d83121cccbadb75106c12e691d73c8be Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Tue, 17 Oct 2023 01:02:16 -0500 Subject: [PATCH] Add files via upload --- demucs/hdemucs.py | 16 ++++++++++------ demucs/htdemucs.py | 16 ++++++++++------ demucs/spec.py | 14 ++++++++++---- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py index eea7b4d..ea4e652 100644 --- a/demucs/hdemucs.py +++ b/demucs/hdemucs.py @@ -770,20 +770,24 @@ class HDemucs(nn.Module): x = x.view(B, S, -1, Fq, T) x = x * std[:, None] + mean[:, None] - # to cpu as mps doesnt support complex numbers + # to cpu as non-cuda GPUs don't support complex numbers # demucs issue #435 ##432 # NOTE: in this case z already is on cpu # TODO: remove this when mps supports complex numbers - x_is_mps = x.device.type == "mps" - if x_is_mps: + + device_type = x.device.type + device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type + x_is_other_gpu = not device_type in ["cuda", "cpu"] + + if x_is_other_gpu: x = x.cpu() zout = self._mask(z, x) x = self._ispec(zout, length) - # back to mps device - if x_is_mps: - x = x.to('mps') + # back to other device + if x_is_other_gpu: + x = x.to(device_load) if self.hybrid: xt = xt.view(B, S, -1, length) diff --git a/demucs/htdemucs.py b/demucs/htdemucs.py index b9686c3..bad6b11 100644 --- a/demucs/htdemucs.py +++ b/demucs/htdemucs.py @@ -625,12 +625,16 @@ class HTDemucs(nn.Module): x = x.view(B, S, -1, Fq, T) x = x * std[:, None] + mean[:, None] - # to cpu as mps doesnt support complex numbers + # to cpu as non-cuda GPUs don't support complex numbers # demucs issue #435 ##432 # NOTE: in this case z already is on cpu # TODO: remove this when mps supports complex numbers - x_is_mps = x.device.type == "mps" - if x_is_mps: + + device_type = x.device.type + device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type + x_is_other_gpu = not device_type in ["cuda", "cpu"] + + if x_is_other_gpu: x = x.cpu() zout = self._mask(z, x) @@ -642,9 +646,9 @@ class HTDemucs(nn.Module): else: x = self._ispec(zout, length) - # back to mps device - if x_is_mps: - x = x.to("mps") + # back to other device + if x_is_other_gpu: + x = x.to(device_load) if self.use_train_segment: if self.training: diff --git a/demucs/spec.py b/demucs/spec.py index 2925045..9c79cb6 100644 --- a/demucs/spec.py +++ b/demucs/spec.py @@ -11,8 +11,11 @@ import torch as th def spectro(x, n_fft=512, hop_length=None, pad=0): *other, length = x.shape x = x.reshape(-1, length) - is_mps = x.device.type == 'mps' - if is_mps: + + device_type = x.device.type + is_other_gpu = not device_type in ["cuda", "cpu"] + + if is_other_gpu: x = x.cpu() z = th.stft(x, n_fft * (1 + pad), @@ -32,8 +35,11 @@ def ispectro(z, hop_length=None, length=None, pad=0): n_fft = 2 * freqs - 2 z = z.view(-1, freqs, frames) win_length = n_fft // (1 + pad) - is_mps = z.device.type == 'mps' - if is_mps: + + device_type = z.device.type + is_other_gpu = not device_type in ["cuda", "cpu"] + + if is_other_gpu: z = z.cpu() x = th.istft(z, n_fft,