mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-23 23:10:57 +01:00
Add files via upload
This commit is contained in:
parent
db741801d8
commit
ac8338e231
@ -713,7 +713,6 @@ def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray):
|
||||
return array_1
|
||||
|
||||
def change_pitch_semitones(y, sr, semitone_shift):
|
||||
|
||||
factor = 2 ** (semitone_shift / 12) # Convert semitone shift to factor for resampling
|
||||
y_pitch_tuned = []
|
||||
for y_channel in y:
|
||||
|
@ -3,14 +3,16 @@ import torch.nn as nn
|
||||
from functools import partial
|
||||
|
||||
class STFT:
|
||||
def __init__(self, n_fft, hop_length, dim_f):
|
||||
def __init__(self, n_fft, hop_length, dim_f, device):
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
||||
self.dim_f = dim_f
|
||||
self.device = device
|
||||
|
||||
def __call__(self, x):
|
||||
x_is_mps = x.device.type == "mps"
|
||||
|
||||
x_is_mps = not x.device.type in ["cuda", "cpu"]
|
||||
if x_is_mps:
|
||||
x = x.cpu()
|
||||
|
||||
@ -23,12 +25,13 @@ class STFT:
|
||||
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
|
||||
|
||||
if x_is_mps:
|
||||
x = x.to('mps')
|
||||
x = x.to(self.device)
|
||||
|
||||
return x[..., :self.dim_f, :]
|
||||
|
||||
def inverse(self, x):
|
||||
x_is_mps = x.device.type == "mps"
|
||||
|
||||
x_is_mps = not x.device.type in ["cuda", "cpu"]
|
||||
if x_is_mps:
|
||||
x = x.cpu()
|
||||
|
||||
@ -45,11 +48,10 @@ class STFT:
|
||||
x = x.reshape([*batch_dims, 2, -1])
|
||||
|
||||
if x_is_mps:
|
||||
x = x.to('mps')
|
||||
x = x.to(self.device)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def get_norm(norm_type):
|
||||
def norm(c, norm_type):
|
||||
if norm_type == 'BatchNorm':
|
||||
@ -145,9 +147,10 @@ class TFC_TDF(nn.Module):
|
||||
|
||||
|
||||
class TFC_TDF_net(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, device):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.device = device
|
||||
|
||||
norm = get_norm(norm_type=config.model.norm)
|
||||
act = get_act(act_type=config.model.act)
|
||||
@ -192,7 +195,7 @@ class TFC_TDF_net(nn.Module):
|
||||
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
|
||||
)
|
||||
|
||||
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f)
|
||||
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f, self.device)
|
||||
|
||||
def cac2cws(self, x):
|
||||
k = self.num_subbands
|
||||
|
Loading…
Reference in New Issue
Block a user