mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-27 17:00:59 +01:00
Add files via upload
This commit is contained in:
parent
932ab13040
commit
de71392d59
@ -10,6 +10,10 @@ class STFT:
|
||||
self.dim_f = dim_f
|
||||
|
||||
def __call__(self, x):
|
||||
x_is_mps = x.device.type == "mps"
|
||||
if x_is_mps:
|
||||
x = x.cpu()
|
||||
|
||||
window = self.window.to(x.device)
|
||||
batch_dims = x.shape[:-2]
|
||||
c, t = x.shape[-2:]
|
||||
@ -17,9 +21,17 @@ class STFT:
|
||||
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False)
|
||||
x = x.permute([0, 3, 1, 2])
|
||||
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
|
||||
|
||||
if x_is_mps:
|
||||
x = x.to('mps')
|
||||
|
||||
return x[..., :self.dim_f, :]
|
||||
|
||||
def inverse(self, x):
|
||||
x_is_mps = x.device.type == "mps"
|
||||
if x_is_mps:
|
||||
x = x.cpu()
|
||||
|
||||
window = self.window.to(x.device)
|
||||
batch_dims = x.shape[:-3]
|
||||
c, f, t = x.shape[-3:]
|
||||
@ -31,6 +43,10 @@ class STFT:
|
||||
x = x[..., 0] + x[..., 1] * 1.j
|
||||
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
|
||||
x = x.reshape([*batch_dims, 2, -1])
|
||||
|
||||
if x_is_mps:
|
||||
x = x.to('mps')
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user