Add files via upload

This commit is contained in:
Anjok07 2023-10-06 23:48:34 -05:00 committed by GitHub
parent de71392d59
commit 9da7c77cf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 9 deletions

View File

@ -121,7 +121,18 @@ def tensor_chunk(tensor_or_chunk):
assert isinstance(tensor_or_chunk, th.Tensor)
return TensorChunk(tensor_or_chunk)
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
def apply_model(model,
mix,
shifts=1,
split=True,
overlap=0.25,
transition_power=1.,
static_shifts=1,
set_progress_bar=None,
device=None,
progress=False,
num_workers=0,
pool=None):
"""
Apply model to a given mixture.
@ -140,8 +151,6 @@ def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power
be on `device`, while the entire tracks will be stored on `mix.device`.
"""
#print("Progress Bar?: ", type(set_progress_bar))
global fut_length
global bag_num
global prog_bar

View File

@ -689,7 +689,7 @@ class HDemucs(nn.Module):
length = x.shape[-1]
z = self._spec(mix)
mag = self._magnitude(z)
mag = self._magnitude(z).to(mix.device)
x = mag
B, C, Fq, T = x.shape
@ -769,14 +769,24 @@ class HDemucs(nn.Module):
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
# to cpu as mps doesnt support complex numbers
# demucs issue #435 ##432
# NOTE: in this case z already is on cpu
# TODO: remove this when mps supports complex numbers
x_is_mps = x.device.type == "mps"
if x_is_mps:
x = x.cpu()
zout = self._mask(z, x)
x = self._ispec(zout, length)
# back to mps device
if x_is_mps:
x = x.to('mps')
if self.hybrid:
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
x = xt + x
return x
return x

View File

@ -536,7 +536,7 @@ class HTDemucs(nn.Module):
length_pre_pad = mix.shape[-1]
mix = F.pad(mix, (0, training_length - length_pre_pad))
z = self._spec(mix)
mag = self._magnitude(z)
mag = self._magnitude(z).to(mix.device)
x = mag
B, C, Fq, T = x.shape
@ -625,6 +625,14 @@ class HTDemucs(nn.Module):
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
# to cpu as mps doesnt support complex numbers
# demucs issue #435 ##432
# NOTE: in this case z already is on cpu
# TODO: remove this when mps supports complex numbers
x_is_mps = x.device.type == "mps"
if x_is_mps:
x = x.cpu()
zout = self._mask(z, x)
if self.use_train_segment:
if self.training:
@ -634,6 +642,10 @@ class HTDemucs(nn.Module):
else:
x = self._ispec(zout, length)
# back to mps device
if x_is_mps:
x = x.to("mps")
if self.use_train_segment:
if self.training:
xt = xt.view(B, S, -1, length)

View File

@ -1,4 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
@ -11,6 +11,9 @@ import torch as th
def spectro(x, n_fft=512, hop_length=None, pad=0):
*other, length = x.shape
x = x.reshape(-1, length)
is_mps = x.device.type == 'mps'
if is_mps:
x = x.cpu()
z = th.stft(x,
n_fft * (1 + pad),
hop_length or n_fft // 4,
@ -29,6 +32,9 @@ def ispectro(z, hop_length=None, length=None, pad=0):
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
is_mps = z.device.type == 'mps'
if is_mps:
z = z.cpu()
x = th.istft(z,
n_fft,
hop_length,