Update separate.py

This commit is contained in:
Anjok07 2023-10-17 21:32:35 -05:00 committed by GitHub
parent 7036cc2ab0
commit b20ab25080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -757,7 +757,6 @@ class SeperateMDXC(SeperateAttributes):
mix_shape = mix.shape[1] mix_shape = mix.shape[1]
pad_size = hop_size - (mix_shape - chunk_size) % hop_size pad_size = hop_size - (mix_shape - chunk_size) % hop_size
mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1) mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
mix = mix.to(self.device)
chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1) chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)] batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
@ -769,7 +768,7 @@ class SeperateMDXC(SeperateAttributes):
cnt = 0 cnt = 0
for batch in batches: for batch in batches:
self.running_inference_progress_bar(len(batches)) self.running_inference_progress_bar(len(batches))
x = model(batch) x = model(batch.to(self.device))
for w in x: for w in x:
X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w