Merge pull request #480 from 233lol/patch-1

fix stft and istft in pyotrch 2.0.0
This commit is contained in:
Anjok07 2023-04-11 00:48:35 -05:00 committed by GitHub
commit 18d32660db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -392,7 +392,8 @@ class SeperateMDX(SeperateAttributes):
def stft(self, x): def stft(self, x):
x = x.reshape([-1, self.chunk_size]) x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True)
x=torch.view_as_real(x)
x = x.permute([0,3,1,2]) x = x.permute([0,3,1,2])
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t]) x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t])
return x[:,:,:self.dim_f] return x[:,:,:self.dim_f]
@ -402,6 +403,8 @@ class SeperateMDX(SeperateAttributes):
x = torch.cat([x, freq_pad], -2) x = torch.cat([x, freq_pad], -2)
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
x = x.permute([0,2,3,1]) x = x.permute([0,2,3,1])
x=x.contiguous()
x=torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1,2,self.chunk_size]) return x.reshape([-1,2,self.chunk_size])