Add files via upload

More efficient MDX23C memory usage.
This commit is contained in:
Anjok07 2023-10-15 17:01:03 -05:00 committed by GitHub
parent 916358e32c
commit 54ba75acd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -727,32 +727,20 @@ class SeperateMDXC(SeperateAttributes):
mix = torch.cat([torch.zeros(2, C - H), mix, torch.zeros(2, pad_size + C - H)], 1)
mix = mix.to(self.device)
chunks = []
i = 0
while i + C <= mix.shape[1]:
chunks.append(mix[:, i:i + C])
i += H
chunks = torch.stack(chunks)
chunks = mix.unfold(1, C, H).transpose(0, 1)
batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
batches = []
i = 0
while i < len(chunks):
batches.append(chunks[i:i + batch_size])
i = i + batch_size
X = torch.zeros(S, 2, C - H) if S > 1 else torch.zeros(2, C - H)
X = X.to(self.device)
X = torch.zeros(S, *mix.shape).to(self.device) if S > 1 else torch.zeros_like(mix)
#with torch.cuda.amp.autocast():
with torch.no_grad():
cnt = 0
for batch in batches:
self.running_inference_progress_bar(len(batches))
x = model(batch)
for w in x:
a = X[..., :-(C - H)]
b = X[..., -(C - H):] + w[..., :(C - H)]
c = w[..., (C - H):]
X = torch.cat([a, b, c], -1)
X[..., cnt * H : cnt * H + C] += w
cnt += 1
estimated_sources = X[..., C - H:-(pad_size + C - H)] / N