Add files via upload

This commit is contained in:
Anjok07 2023-10-10 21:49:34 -05:00 committed by GitHub
parent 3e3d1628f5
commit e1eac27754
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,16 +1,19 @@
import torch
import torch.nn as nn
from functools import partial
import torch_directml
class STFT:
def __init__(self, n_fft, hop_length, dim_f):
def __init__(self, n_fft, hop_length, dim_f, device):
self.n_fft = n_fft
self.hop_length = hop_length
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
self.dim_f = dim_f
self.device = device
def __call__(self, x):
x_is_mps = x.device.type == "mps"
x_is_mps = not x.device.type in ["cuda", "cpu"]
if x_is_mps:
x = x.cpu()
@ -23,12 +26,13 @@ class STFT:
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
if x_is_mps:
x = x.to('mps')
x = x.to(self.device)
return x[..., :self.dim_f, :]
def inverse(self, x):
x_is_mps = x.device.type == "mps"
x_is_mps = not x.device.type in ["cuda", "cpu"]
if x_is_mps:
x = x.cpu()
@ -45,7 +49,7 @@ class STFT:
x = x.reshape([*batch_dims, 2, -1])
if x_is_mps:
x = x.to('mps')
x = x.to(self.device)
return x
@ -145,9 +149,10 @@ class TFC_TDF(nn.Module):
class TFC_TDF_net(nn.Module):
def __init__(self, config):
def __init__(self, config, device):
super().__init__()
self.config = config
self.device = device
norm = get_norm(norm_type=config.model.norm)
act = get_act(act_type=config.model.act)
@ -192,7 +197,7 @@ class TFC_TDF_net(nn.Module):
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
)
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f)
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f, self.device)
def cac2cws(self, x):
k = self.num_subbands