Add files via upload

This commit is contained in:
Anjok07 2023-10-06 19:06:09 -05:00 committed by GitHub
parent 932ab13040
commit de71392d59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,6 +10,10 @@ class STFT:
self.dim_f = dim_f self.dim_f = dim_f
def __call__(self, x): def __call__(self, x):
x_is_mps = x.device.type == "mps"
if x_is_mps:
x = x.cpu()
window = self.window.to(x.device) window = self.window.to(x.device)
batch_dims = x.shape[:-2] batch_dims = x.shape[:-2]
c, t = x.shape[-2:] c, t = x.shape[-2:]
@ -17,9 +21,17 @@ class STFT:
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False) x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False)
x = x.permute([0, 3, 1, 2]) x = x.permute([0, 3, 1, 2])
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
if x_is_mps:
x = x.to('mps')
return x[..., :self.dim_f, :] return x[..., :self.dim_f, :]
def inverse(self, x): def inverse(self, x):
x_is_mps = x.device.type == "mps"
if x_is_mps:
x = x.cpu()
window = self.window.to(x.device) window = self.window.to(x.device)
batch_dims = x.shape[:-3] batch_dims = x.shape[:-3]
c, f, t = x.shape[-3:] c, f, t = x.shape[-3:]
@ -31,6 +43,10 @@ class STFT:
x = x[..., 0] + x[..., 1] * 1.j x = x[..., 0] + x[..., 1] * 1.j
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
x = x.reshape([*batch_dims, 2, -1]) x = x.reshape([*batch_dims, 2, -1])
if x_is_mps:
x = x.to('mps')
return x return x