Add files via upload

This commit is contained in:
Anjok07 2023-10-17 01:02:43 -05:00 committed by GitHub
parent db741801d8
commit ac8338e231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 9 deletions

View File

@ -713,7 +713,6 @@ def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray):
return array_1 return array_1
def change_pitch_semitones(y, sr, semitone_shift): def change_pitch_semitones(y, sr, semitone_shift):
factor = 2 ** (semitone_shift / 12) # Convert semitone shift to factor for resampling factor = 2 ** (semitone_shift / 12) # Convert semitone shift to factor for resampling
y_pitch_tuned = [] y_pitch_tuned = []
for y_channel in y: for y_channel in y:

View File

@ -3,14 +3,16 @@ import torch.nn as nn
from functools import partial from functools import partial
class STFT: class STFT:
def __init__(self, n_fft, hop_length, dim_f): def __init__(self, n_fft, hop_length, dim_f, device):
self.n_fft = n_fft self.n_fft = n_fft
self.hop_length = hop_length self.hop_length = hop_length
self.window = torch.hann_window(window_length=self.n_fft, periodic=True) self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
self.dim_f = dim_f self.dim_f = dim_f
self.device = device
def __call__(self, x): def __call__(self, x):
x_is_mps = x.device.type == "mps"
x_is_mps = not x.device.type in ["cuda", "cpu"]
if x_is_mps: if x_is_mps:
x = x.cpu() x = x.cpu()
@ -23,12 +25,13 @@ class STFT:
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
if x_is_mps: if x_is_mps:
x = x.to('mps') x = x.to(self.device)
return x[..., :self.dim_f, :] return x[..., :self.dim_f, :]
def inverse(self, x): def inverse(self, x):
x_is_mps = x.device.type == "mps"
x_is_mps = not x.device.type in ["cuda", "cpu"]
if x_is_mps: if x_is_mps:
x = x.cpu() x = x.cpu()
@ -45,11 +48,10 @@ class STFT:
x = x.reshape([*batch_dims, 2, -1]) x = x.reshape([*batch_dims, 2, -1])
if x_is_mps: if x_is_mps:
x = x.to('mps') x = x.to(self.device)
return x return x
def get_norm(norm_type): def get_norm(norm_type):
def norm(c, norm_type): def norm(c, norm_type):
if norm_type == 'BatchNorm': if norm_type == 'BatchNorm':
@ -145,9 +147,10 @@ class TFC_TDF(nn.Module):
class TFC_TDF_net(nn.Module): class TFC_TDF_net(nn.Module):
def __init__(self, config): def __init__(self, config, device):
super().__init__() super().__init__()
self.config = config self.config = config
self.device = device
norm = get_norm(norm_type=config.model.norm) norm = get_norm(norm_type=config.model.norm)
act = get_act(act_type=config.model.act) act = get_act(act_type=config.model.act)
@ -192,7 +195,7 @@ class TFC_TDF_net(nn.Module):
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
) )
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f) self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f, self.device)
def cac2cws(self, x): def cac2cws(self, x):
k = self.num_subbands k = self.num_subbands