from functools import partial import torch from torch import nn, einsum, Tensor from torch.nn import Module, ModuleList import torch.nn.functional as F from .attend import Attend from beartype.typing import Tuple, Optional, List, Callable from beartype import beartype from rotary_embedding_torch import RotaryEmbedding from einops import rearrange, pack, unpack, reduce, repeat from librosa import filters def exists(val): return val is not None def default(v, d): return v if exists(v) else d def pack_one(t, pattern): return pack([t], pattern) def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] def pad_at_dim(t, pad, dim=-1, value=0.): dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = ((0, 0) * dims_from_right) return F.pad(t, (*zeros, *pad), value=value) class RMSNorm(Module): def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x): x = x.to(self.gamma.device) return F.normalize(x, dim=-1) * self.scale * self.gamma class FeedForward(Module): def __init__( self, dim, mult=4, dropout=0. ): super().__init__() dim_inner = int(dim * mult) self.net = nn.Sequential( RMSNorm(dim), nn.Linear(dim, dim_inner), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(Module): def __init__( self, dim, heads=8, dim_head=64, dropout=0., rotary_embed=None, flash=True ): super().__init__() self.heads = heads self.scale = dim_head ** -0.5 dim_inner = heads * dim_head self.rotary_embed = rotary_embed self.attend = Attend(flash=flash, dropout=dropout) self.norm = RMSNorm(dim) self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) self.to_gates = nn.Linear(dim, heads) self.to_out = nn.Sequential( nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout) ) def forward(self, x): x = self.norm(x) q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) if exists(self.rotary_embed): q = self.rotary_embed.rotate_queries_or_keys(q) k = self.rotary_embed.rotate_queries_or_keys(k) out = self.attend(q, k, v) gates = self.to_gates(x) out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class Transformer(Module): def __init__( self, *, dim, depth, dim_head=64, heads=8, attn_dropout=0., ff_dropout=0., ff_mult=4, norm_output=True, rotary_embed=None, flash_attn=True ): super().__init__() self.layers = ModuleList([]) for _ in range(depth): self.layers.append(ModuleList([ Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed, flash=flash_attn), FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) ])) self.norm = RMSNorm(dim) if norm_output else nn.Identity() def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return self.norm(x) class BandSplit(Module): @beartype def __init__( self, dim, dim_inputs: Tuple[int, ...] ): super().__init__() self.dim_inputs = dim_inputs self.to_features = ModuleList([]) for dim_in in dim_inputs: net = nn.Sequential( RMSNorm(dim_in), nn.Linear(dim_in, dim) ) self.to_features.append(net) def forward(self, x): x = x.split(self.dim_inputs, dim=-1) outs = [] for split_input, to_feature in zip(x, self.to_features): split_output = to_feature(split_input) outs.append(split_output) return torch.stack(outs, dim=-2) def MLP( dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh ): dim_hidden = default(dim_hidden, dim_in) net = [] dims = (dim_in, *((dim_hidden,) * depth), dim_out) for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): is_last = ind == (len(dims) - 2) net.append(nn.Linear(layer_dim_in, layer_dim_out)) if is_last: continue net.append(activation()) return nn.Sequential(*net) class MaskEstimator(Module): @beartype def __init__( self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4 ): super().__init__() self.dim_inputs = dim_inputs self.to_freqs = ModuleList([]) dim_hidden = dim * mlp_expansion_factor for dim_in in dim_inputs: net = [] mlp = nn.Sequential( MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1) ) self.to_freqs.append(mlp) def forward(self, x): x = x.unbind(dim=-2) outs = [] for band_features, mlp in zip(x, self.to_freqs): freq_out = mlp(band_features) outs.append(freq_out) return torch.cat(outs, dim=-1) class MelBandRoformer(Module): @beartype def __init__( self, dim, *, depth, stereo=False, num_stems=1, time_transformer_depth=2, freq_transformer_depth=2, num_bands=60, dim_head=64, heads=8, attn_dropout=0.1, ff_dropout=0.1, flash_attn=True, dim_freqs_in=1025, sample_rate=44100, stft_n_fft=2048, stft_hop_length=512, stft_win_length=2048, stft_normalized=False, stft_window_fn: Optional[Callable] = None, mask_estimator_depth=1, multi_stft_resolution_loss_weight=1., multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), multi_stft_hop_size=147, multi_stft_normalized=False, multi_stft_window_fn: Callable = torch.hann_window, match_input_audio_length=False, ): super().__init__() self.stereo = stereo self.audio_channels = 2 if stereo else 1 self.num_stems = num_stems self.layers = ModuleList([]) transformer_kwargs = dict( dim=dim, heads=heads, dim_head=dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, flash_attn=flash_attn ) time_rotary_embed = RotaryEmbedding(dim=dim_head) freq_rotary_embed = RotaryEmbedding(dim=dim_head) for _ in range(depth): self.layers.append(nn.ModuleList([ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs), Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) ])) self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) self.stft_kwargs = dict( n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized ) freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1] mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands) mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) mel_filter_bank[0][0] = 1. mel_filter_bank[-1, -1] = 1. freqs_per_band = mel_filter_bank > 0 assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now' repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands) freq_indices = repeated_freq_indices[freqs_per_band] if stereo: freq_indices = repeat(freq_indices, 'f -> f s', s=2) freq_indices = freq_indices * 2 + torch.arange(2) freq_indices = rearrange(freq_indices, 'f s -> (f s)') self.register_buffer('freq_indices', freq_indices, persistent=False) self.register_buffer('freqs_per_band', freqs_per_band, persistent=False) num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum') num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum') self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False) self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False) freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist()) self.band_split = BandSplit( dim=dim, dim_inputs=freqs_per_bands_with_complex ) self.mask_estimators = nn.ModuleList([]) for _ in range(num_stems): mask_estimator = MaskEstimator( dim=dim, dim_inputs=freqs_per_bands_with_complex, depth=mask_estimator_depth ) self.mask_estimators.append(mask_estimator) self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes self.multi_stft_n_fft = stft_n_fft self.multi_stft_window_fn = multi_stft_window_fn self.multi_stft_kwargs = dict( hop_length=multi_stft_hop_size, normalized=multi_stft_normalized ) self.match_input_audio_length = match_input_audio_length def forward( self, raw_audio, target=None, return_loss_breakdown=False ): """ einops b - batch f - freq t - time s - audio channel (1 for mono, 2 for stereo) n - number of 'stems' c - complex (2) d - feature dimension """ original_device = raw_audio.device x_is_mps = True if original_device.type == 'mps' else False if x_is_mps: raw_audio = raw_audio.cpu() device = raw_audio.device if raw_audio.ndim == 2: raw_audio = rearrange(raw_audio, 'b t -> b 1 t') batch, channels, raw_audio_length = raw_audio.shape istft_length = raw_audio_length if self.match_input_audio_length else None assert (not self.stereo and channels == 1) or ( self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') stft_window = self.stft_window_fn(device=device) stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) stft_repr = torch.view_as_real(stft_repr) stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting batch_arange = torch.arange(batch, device=device)[..., None] x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices] x = rearrange(x, 'b f t c -> b t (f c)') x = self.band_split(x) for time_transformer, freq_transformer in self.layers: x = rearrange(x, 'b t f d -> b f t d') x, ps = pack([x], '* t d') x = time_transformer(x) x, = unpack(x, ps, '* t d') x = rearrange(x, 'b f t d -> b t f d') x, ps = pack([x], '* f d') x = freq_transformer(x) x, = unpack(x, ps, '* f d') num_stems = len(self.mask_estimators) masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2) if x_is_mps: masks = masks.cpu() stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') stft_repr = torch.view_as_complex(stft_repr) masks = torch.view_as_complex(masks) masks = masks.type(stft_repr.dtype) if x_is_mps: scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) else: scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems) masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels) if x_is_mps: denom = denom.cpu() masks_averaged = masks_summed / denom.clamp(min=1e-8) stft_repr = stft_repr * masks_averaged stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=istft_length) recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems) if num_stems == 1: recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') if not exists(target): return recon_audio if self.num_stems > 1: assert target.ndim == 4 and target.shape[1] == self.num_stems if target.ndim == 2: target = rearrange(target, '... t -> ... 1 t') target = target[..., :recon_audio.shape[-1]] loss = F.l1_loss(recon_audio, target) multi_stft_resolution_loss = 0. for window_size in self.multi_stft_resolutions_window_sizes: res_stft_kwargs = dict( n_fft=max(window_size, self.multi_stft_n_fft), win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs, ) recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight total_loss = loss + weighted_multi_resolution_loss # Move the total loss back to the original device if necessary if x_is_mps: total_loss = total_loss.to(original_device) if not return_loss_breakdown: return total_loss # If detailed loss breakdown is requested, ensure all components are on the original device return total_loss, (loss.to(original_device) if x_is_mps else loss, multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss) # if not return_loss_breakdown: # return total_loss # return total_loss, (loss, multi_stft_resolution_loss)