Add files via upload

This commit is contained in:
Anjok07 2024-04-02 01:02:07 -05:00 committed by GitHub
parent 72531ac3a0
commit 82769c7acd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1248 additions and 1 deletions

110
lib_v5/attend.py Normal file
View File

@ -0,0 +1,110 @@
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# constants
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
def exists(val):
return val is not None
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
def forward(self, q, k, v):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = q.shape[-1] ** -0.5
if self.flash:
return self.flash_attn(q, k, v)
# similarity
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
return out

607
lib_v5/bs_roformer.py Normal file
View File

@ -0,0 +1,607 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from .attend import Attend
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
from rotary_embedding_torch import RotaryEmbedding
from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange
# helper functions
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# norm
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x.to(self.gamma.device)
return F.normalize(x, dim=-1) * self.scale * self.gamma
# attention
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
self.attend = Attend(flash=flash, dropout=dropout)
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearAttention(Module):
"""
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
"""
@beartype
def __init__(
self,
*,
dim,
dim_head=32,
heads=8,
scale=8,
flash=False,
dropout=0.
):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale=scale,
dropout=dropout,
flash=flash
)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False)
)
def forward(
self,
x
):
x = self.norm(x)
q, k, v = self.to_qkv(x)
q, k = map(l2norm, (q, k))
q = q * self.temperature.exp()
out = self.attend(q, k, v)
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
if linear_attn:
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
else:
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
rotary_embed=rotary_embed, flash=flash_attn)
self.layers.append(ModuleList([
attn,
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# bandsplit module
class BandSplit(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
self.to_features.append(net)
def forward(self, x):
x = x.split(self.dim_inputs, dim=-1)
outs = []
for split_input, to_feature in zip(x, self.to_features):
split_output = to_feature(split_input)
outs.append(split_output)
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
dim_hidden = default(dim_hidden, dim_in)
net = []
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)
net.append(nn.Linear(layer_dim_in, layer_dim_out))
if is_last:
continue
net.append(activation())
return nn.Sequential(*net)
class MaskEstimator(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
dim_hidden = dim * mlp_expansion_factor
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
self.to_freqs.append(mlp)
def forward(self, x):
x = x.unbind(dim=-2)
outs = []
for band_features, mlp in zip(x, self.to_freqs):
freq_out = mlp(band_features)
outs.append(freq_out)
return torch.cat(outs, dim=-1)
# main class
DEFAULT_FREQS_PER_BANDS = (
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
12, 12, 12, 12, 12, 12, 12, 12,
24, 24, 24, 24, 24, 24, 24, 24,
48, 48, 48, 48, 48, 48, 48, 48,
128, 129,
)
class BSRoformer(Module):
@beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
# in the paper, they divide into ~60 bands, test with 1 for starters
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
flash_attn=True,
dim_freqs_in=1025,
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=2,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window
):
super().__init__()
self.stereo = stereo
self.audio_channels = 2 if stereo else 1
self.num_stems = num_stems
self.layers = ModuleList([])
transformer_kwargs = dict(
dim=dim,
heads=heads,
dim_head=dim_head,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn,
norm_output=False
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
for _ in range(depth):
tran_modules = []
if linear_transformer_depth > 0:
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
tran_modules.append(
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
)
tran_modules.append(
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
)
self.layers.append(nn.ModuleList(tran_modules))
self.final_norm = RMSNorm(dim)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
)
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
assert len(freqs_per_bands) > 1
assert sum(
freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.mask_estimators = nn.ModuleList([])
for _ in range(num_stems):
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
depth=mask_estimator_depth
)
self.mask_estimators.append(mask_estimator)
# for the multi-resolution stft loss
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
"""
einops
b - batch
f - freq
t - time
s - audio channel (1 for mono, 2 for stereo)
n - number of 'stems'
c - complex (2)
d - feature dimension
"""
original_device = raw_audio.device
x_is_mps = True if original_device.type == 'mps' else False
if x_is_mps:
raw_audio = raw_audio.cpu()
device = raw_audio.device
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
channels = raw_audio.shape[1]
assert (not self.stereo and channels == 1) or (
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
# to stft
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
stft_window = self.stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = rearrange(stft_repr,
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
x = self.band_split(x)
# axial / hierarchical attention
for transformer_block in self.layers:
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
else:
time_transformer, freq_transformer = transformer_block
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
x = self.final_norm(x)
num_stems = len(self.mask_estimators)
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
if x_is_mps:
mask = mask.to('cpu')
# modulate frequency representation
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
# complex number multiplication
stft_repr = torch.view_as_complex(stft_repr)
mask = torch.view_as_complex(mask)
stft_repr = stft_repr * mask
# istft
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
# if a target is passed in, calculate loss for learning
if not exists(target):
return recon_audio
if self.num_stems > 1:
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = target[..., :recon_audio.shape[-1]]
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
n_fft=max(window_size, self.multi_stft_n_fft),
win_length=window_size,
return_complex=True,
window=self.multi_stft_window_fn(window_size, device=device),
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
total_loss = loss + weighted_multi_resolution_loss
if not return_loss_breakdown:
# Move the result back to the original device if it was moved to CPU for MPS compatibility
if x_is_mps:
total_loss = total_loss.to(original_device)
return total_loss
# For detailed loss breakdown, ensure all components are moved back to the original device for MPS
if x_is_mps:
loss = loss.to(original_device)
multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device)
weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device)
return total_loss, (loss, multi_stft_resolution_loss)
# if not return_loss_breakdown:
# return total_loss
# return total_loss, (loss, multi_stft_resolution_loss)

528
lib_v5/mel_band_roformer.py Normal file
View File

@ -0,0 +1,528 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from .attend import Attend
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
from rotary_embedding_torch import RotaryEmbedding
from einops import rearrange, pack, unpack, reduce, repeat
from librosa import filters
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def pad_at_dim(t, pad, dim=-1, value=0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value=value)
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x.to(self.gamma.device)
return F.normalize(x, dim=-1) * self.scale * self.gamma
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
self.attend = Attend(flash=flash, dropout=dropout)
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed,
flash=flash_attn),
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class BandSplit(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
self.to_features.append(net)
def forward(self, x):
x = x.split(self.dim_inputs, dim=-1)
outs = []
for split_input, to_feature in zip(x, self.to_features):
split_output = to_feature(split_input)
outs.append(split_output)
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
dim_hidden = default(dim_hidden, dim_in)
net = []
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)
net.append(nn.Linear(layer_dim_in, layer_dim_out))
if is_last:
continue
net.append(activation())
return nn.Sequential(*net)
class MaskEstimator(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
dim_hidden = dim * mlp_expansion_factor
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
self.to_freqs.append(mlp)
def forward(self, x):
x = x.unbind(dim=-2)
outs = []
for band_features, mlp in zip(x, self.to_freqs):
freq_out = mlp(band_features)
outs.append(freq_out)
return torch.cat(outs, dim=-1)
class MelBandRoformer(Module):
@beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
num_bands=60,
dim_head=64,
heads=8,
attn_dropout=0.1,
ff_dropout=0.1,
flash_attn=True,
dim_freqs_in=1025,
sample_rate=44100,
stft_n_fft=2048,
stft_hop_length=512,
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=1,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
match_input_audio_length=False,
):
super().__init__()
self.stereo = stereo
self.audio_channels = 2 if stereo else 1
self.num_stems = num_stems
self.layers = ModuleList([])
transformer_kwargs = dict(
dim=dim,
heads=heads,
dim_head=dim_head,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
for _ in range(depth):
self.layers.append(nn.ModuleList([
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs),
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
]))
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
mel_filter_bank[0][0] = 1.
mel_filter_bank[-1, -1] = 1.
freqs_per_band = mel_filter_bank > 0
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
freq_indices = repeated_freq_indices[freqs_per_band]
if stereo:
freq_indices = repeat(freq_indices, 'f -> f s', s=2)
freq_indices = freq_indices * 2 + torch.arange(2)
freq_indices = rearrange(freq_indices, 'f s -> (f s)')
self.register_buffer('freq_indices', freq_indices, persistent=False)
self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.mask_estimators = nn.ModuleList([])
for _ in range(num_stems):
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
depth=mask_estimator_depth
)
self.mask_estimators.append(mask_estimator)
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
self.match_input_audio_length = match_input_audio_length
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
"""
einops
b - batch
f - freq
t - time
s - audio channel (1 for mono, 2 for stereo)
n - number of 'stems'
c - complex (2)
d - feature dimension
"""
original_device = raw_audio.device
x_is_mps = True if original_device.type == 'mps' else False
if x_is_mps:
raw_audio = raw_audio.cpu()
device = raw_audio.device
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
batch, channels, raw_audio_length = raw_audio.shape
istft_length = raw_audio_length if self.match_input_audio_length else None
assert (not self.stereo and channels == 1) or (
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
stft_window = self.stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = rearrange(stft_repr,
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
batch_arange = torch.arange(batch, device=device)[..., None]
x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
x = rearrange(x, 'b f t c -> b t (f c)')
x = self.band_split(x)
for time_transformer, freq_transformer in self.layers:
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
num_stems = len(self.mask_estimators)
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
if x_is_mps:
masks = masks.cpu()
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
stft_repr = torch.view_as_complex(stft_repr)
masks = torch.view_as_complex(masks)
masks = masks.type(stft_repr.dtype)
if x_is_mps:
scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
else:
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
if x_is_mps:
denom = denom.cpu()
masks_averaged = masks_summed / denom.clamp(min=1e-8)
stft_repr = stft_repr * masks_averaged
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
length=istft_length)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
if not exists(target):
return recon_audio
if self.num_stems > 1:
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = target[..., :recon_audio.shape[-1]]
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
n_fft=max(window_size, self.multi_stft_n_fft),
win_length=window_size,
return_complex=True,
window=self.multi_stft_window_fn(window_size, device=device),
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
total_loss = loss + weighted_multi_resolution_loss
# Move the total loss back to the original device if necessary
if x_is_mps:
total_loss = total_loss.to(original_device)
if not return_loss_breakdown:
return total_loss
# If detailed loss breakdown is requested, ensure all components are on the original device
return total_loss, (loss.to(original_device) if x_is_mps else loss,
multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)
# if not return_loss_breakdown:
# return total_loss
# return total_loss, (loss, multi_stft_resolution_loss)

View File

@ -541,7 +541,9 @@ def ensembling(a, inputs, is_wavs=False):
if MIN_SPEC == a:
input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
if MAX_SPEC == a:
input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
#input = np.array(np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), inputs[i], input), dtype=object)
input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
#max_spec = np.array([np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), s, specs[0]) for s in specs[1:]], dtype=object)[-1]
#linear_ensemble
#input = ensemble_wav(inputs, split_size=1)