diff --git a/lib_v5/tfc_tdf_v3.py b/lib_v5/tfc_tdf_v3.py index 8f72595..c8ba535 100644 --- a/lib_v5/tfc_tdf_v3.py +++ b/lib_v5/tfc_tdf_v3.py @@ -10,6 +10,10 @@ class STFT: self.dim_f = dim_f def __call__(self, x): + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + window = self.window.to(x.device) batch_dims = x.shape[:-2] c, t = x.shape[-2:] @@ -17,9 +21,17 @@ class STFT: x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False) x = x.permute([0, 3, 1, 2]) x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + + if x_is_mps: + x = x.to('mps') + return x[..., :self.dim_f, :] def inverse(self, x): + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + window = self.window.to(x.device) batch_dims = x.shape[:-3] c, f, t = x.shape[-3:] @@ -31,6 +43,10 @@ class STFT: x = x[..., 0] + x[..., 1] * 1.j x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True) x = x.reshape([*batch_dims, 2, -1]) + + if x_is_mps: + x = x.to('mps') + return x