mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-12-01 02:27:21 +01:00
Add files via upload
This commit is contained in:
parent
3e3d1628f5
commit
e1eac27754
@ -1,16 +1,19 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
import torch_directml
|
||||
|
||||
class STFT:
|
||||
def __init__(self, n_fft, hop_length, dim_f):
|
||||
def __init__(self, n_fft, hop_length, dim_f, device):
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
||||
self.dim_f = dim_f
|
||||
self.device = device
|
||||
|
||||
def __call__(self, x):
|
||||
x_is_mps = x.device.type == "mps"
|
||||
|
||||
x_is_mps = not x.device.type in ["cuda", "cpu"]
|
||||
if x_is_mps:
|
||||
x = x.cpu()
|
||||
|
||||
@ -23,12 +26,13 @@ class STFT:
|
||||
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
|
||||
|
||||
if x_is_mps:
|
||||
x = x.to('mps')
|
||||
x = x.to(self.device)
|
||||
|
||||
return x[..., :self.dim_f, :]
|
||||
|
||||
def inverse(self, x):
|
||||
x_is_mps = x.device.type == "mps"
|
||||
|
||||
x_is_mps = not x.device.type in ["cuda", "cpu"]
|
||||
if x_is_mps:
|
||||
x = x.cpu()
|
||||
|
||||
@ -45,7 +49,7 @@ class STFT:
|
||||
x = x.reshape([*batch_dims, 2, -1])
|
||||
|
||||
if x_is_mps:
|
||||
x = x.to('mps')
|
||||
x = x.to(self.device)
|
||||
|
||||
return x
|
||||
|
||||
@ -145,9 +149,10 @@ class TFC_TDF(nn.Module):
|
||||
|
||||
|
||||
class TFC_TDF_net(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, device):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.device = device
|
||||
|
||||
norm = get_norm(norm_type=config.model.norm)
|
||||
act = get_act(act_type=config.model.act)
|
||||
@ -192,7 +197,7 @@ class TFC_TDF_net(nn.Module):
|
||||
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
|
||||
)
|
||||
|
||||
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f)
|
||||
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f, self.device)
|
||||
|
||||
def cac2cws(self, x):
|
||||
k = self.num_subbands
|
||||
|
Loading…
Reference in New Issue
Block a user