fix stft and istft in pyotrch 2.0.0

fix stft and istft in pyotrch 2.0.0
in pytorch 2.0.0 not support real output(stft)and real input(istft)
This commit is contained in:
233lol 2023-04-06 09:55:10 +08:00 committed by GitHub
parent 5cef5eebbe
commit 535225172e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -392,7 +392,8 @@ class SeperateMDX(SeperateAttributes):
def stft(self, x):
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True)
x=torch.view_as_real(x)
x = x.permute([0,3,1,2])
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t])
return x[:,:,:self.dim_f]
@ -402,6 +403,8 @@ class SeperateMDX(SeperateAttributes):
x = torch.cat([x, freq_pad], -2)
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
x = x.permute([0,2,3,1])
x=x.contiguous()
x=torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1,2,self.chunk_size])
@ -936,4 +939,4 @@ def save_format(audio_path, save_format, mp3_bit_set):
try:
os.remove(audio_path)
except Exception as e:
print(e)
print(e)