From 9da7c77cf3c620954e3c8af2ea57d532a04dd55c Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Fri, 6 Oct 2023 23:48:34 -0500 Subject: [PATCH] Add files via upload --- demucs/apply.py | 15 ++++++++++++--- demucs/hdemucs.py | 18 ++++++++++++++---- demucs/htdemucs.py | 14 +++++++++++++- demucs/spec.py | 8 +++++++- 4 files changed, 46 insertions(+), 9 deletions(-) diff --git a/demucs/apply.py b/demucs/apply.py index ac4029b..be6d930 100644 --- a/demucs/apply.py +++ b/demucs/apply.py @@ -121,7 +121,18 @@ def tensor_chunk(tensor_or_chunk): assert isinstance(tensor_or_chunk, th.Tensor) return TensorChunk(tensor_or_chunk) -def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None): +def apply_model(model, + mix, + shifts=1, + split=True, + overlap=0.25, + transition_power=1., + static_shifts=1, + set_progress_bar=None, + device=None, + progress=False, + num_workers=0, + pool=None): """ Apply model to a given mixture. @@ -140,8 +151,6 @@ def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power be on `device`, while the entire tracks will be stored on `mix.device`. """ - #print("Progress Bar?: ", type(set_progress_bar)) - global fut_length global bag_num global prog_bar diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py index d776d55..eea7b4d 100644 --- a/demucs/hdemucs.py +++ b/demucs/hdemucs.py @@ -689,7 +689,7 @@ class HDemucs(nn.Module): length = x.shape[-1] z = self._spec(mix) - mag = self._magnitude(z) + mag = self._magnitude(z).to(mix.device) x = mag B, C, Fq, T = x.shape @@ -769,14 +769,24 @@ class HDemucs(nn.Module): S = len(self.sources) x = x.view(B, S, -1, Fq, T) x = x * std[:, None] + mean[:, None] + + # to cpu as mps doesnt support complex numbers + # demucs issue #435 ##432 + # NOTE: in this case z already is on cpu + # TODO: remove this when mps supports complex numbers + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() zout = self._mask(z, x) x = self._ispec(zout, length) + # back to mps device + if x_is_mps: + x = x.to('mps') + if self.hybrid: xt = xt.view(B, S, -1, length) xt = xt * stdt[:, None] + meant[:, None] x = xt + x - return x - - + return x \ No newline at end of file diff --git a/demucs/htdemucs.py b/demucs/htdemucs.py index ffa466b..b9686c3 100644 --- a/demucs/htdemucs.py +++ b/demucs/htdemucs.py @@ -536,7 +536,7 @@ class HTDemucs(nn.Module): length_pre_pad = mix.shape[-1] mix = F.pad(mix, (0, training_length - length_pre_pad)) z = self._spec(mix) - mag = self._magnitude(z) + mag = self._magnitude(z).to(mix.device) x = mag B, C, Fq, T = x.shape @@ -625,6 +625,14 @@ class HTDemucs(nn.Module): x = x.view(B, S, -1, Fq, T) x = x * std[:, None] + mean[:, None] + # to cpu as mps doesnt support complex numbers + # demucs issue #435 ##432 + # NOTE: in this case z already is on cpu + # TODO: remove this when mps supports complex numbers + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + zout = self._mask(z, x) if self.use_train_segment: if self.training: @@ -634,6 +642,10 @@ class HTDemucs(nn.Module): else: x = self._ispec(zout, length) + # back to mps device + if x_is_mps: + x = x.to("mps") + if self.use_train_segment: if self.training: xt = xt.view(B, S, -1, length) diff --git a/demucs/spec.py b/demucs/spec.py index 85e5dc9..2925045 100644 --- a/demucs/spec.py +++ b/demucs/spec.py @@ -1,4 +1,4 @@ -# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -11,6 +11,9 @@ import torch as th def spectro(x, n_fft=512, hop_length=None, pad=0): *other, length = x.shape x = x.reshape(-1, length) + is_mps = x.device.type == 'mps' + if is_mps: + x = x.cpu() z = th.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, @@ -29,6 +32,9 @@ def ispectro(z, hop_length=None, length=None, pad=0): n_fft = 2 * freqs - 2 z = z.view(-1, freqs, frames) win_length = n_fft // (1 + pad) + is_mps = z.device.type == 'mps' + if is_mps: + z = z.cpu() x = th.istft(z, n_fft, hop_length,