diff --git a/separate.py b/separate.py index 23e8a12..c1fcabc 100644 --- a/separate.py +++ b/separate.py @@ -392,7 +392,8 @@ class SeperateMDX(SeperateAttributes): def stft(self, x): x = x.reshape([-1, self.chunk_size]) - x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) + x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True) + x=torch.view_as_real(x) x = x.permute([0,3,1,2]) x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t]) return x[:,:,:self.dim_f] @@ -402,6 +403,8 @@ class SeperateMDX(SeperateAttributes): x = torch.cat([x, freq_pad], -2) x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) x = x.permute([0,2,3,1]) + x=x.contiguous() + x=torch.view_as_complex(x) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) return x.reshape([-1,2,self.chunk_size]) @@ -936,4 +939,4 @@ def save_format(audio_path, save_format, mp3_bit_set): try: os.remove(audio_path) except Exception as e: - print(e) \ No newline at end of file + print(e)