diff --git a/separate.py b/separate.py index 33a51ed..abb3ac5 100644 --- a/separate.py +++ b/separate.py @@ -761,8 +761,9 @@ class SeperateMDXC(SeperateAttributes): chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1) batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)] - X = torch.zeros(S, *mix.shape).to(self.device) if S > 1 else torch.zeros_like(mix).to(self.device) - + X = torch.zeros(S, *mix.shape) if S > 1 else torch.zeros_like(mix) + X = X.to(self.device) + with torch.no_grad(): cnt = 0 for batch in batches: