2022-06-13 02:10:39 -05:00

84 lines
3.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Ways to make the model stronger."""
import random
import torch
def power_iteration(m, niters=1, bs=1):
"""This is the power method. batch size is used to try multiple starting point in parallel."""
assert m.dim() == 2
assert m.shape[0] == m.shape[1]
dim = m.shape[0]
b = torch.randn(dim, bs, device=m.device, dtype=m.dtype)
for _ in range(niters):
n = m.mm(b)
norm = n.norm(dim=0, keepdim=True)
b = n / (1e-10 + norm)
return norm.mean()
# We need a shared RNG to make sure all the distributed worker will skip the penalty together,
# as otherwise we wouldn't get any speed up.
penalty_rng = random.Random(1234)
def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True,
proba=1, conv_only=False, exact=False, bs=1):
"""
Penalty on the largest singular value for a layer.
Args:
- model: model to penalize
- min_size: minimum size in MB of a layer to penalize.
- dim: projection dimension for the svd_lowrank. Higher is better but slower.
- niters: number of iterations in the algorithm used by svd_lowrank.
- powm: use power method instead of lowrank SVD, my own experience
is that it is both slower and less stable.
- convtr: when True, differentiate between Conv and Transposed Conv.
this is kept for compatibility with older experiments.
- proba: probability to apply the penalty.
- conv_only: only apply to conv and conv transposed, not LSTM
(might not be reliable for other models than Demucs).
- exact: use exact SVD (slow but useful at validation).
- bs: batch_size for power method.
"""
total = 0
if penalty_rng.random() > proba:
return 0.
for m in model.modules():
for name, p in m.named_parameters(recurse=False):
if p.numel() / 2**18 < min_size:
continue
if convtr:
if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)):
if p.dim() in [3, 4]:
p = p.transpose(0, 1).contiguous()
if p.dim() == 3:
p = p.view(len(p), -1)
elif p.dim() == 4:
p = p.view(len(p), -1)
elif p.dim() == 1:
continue
elif conv_only:
continue
assert p.dim() == 2, (name, p.shape)
if exact:
estimate = torch.svd(p, compute_uv=False)[1].pow(2).max()
elif powm:
a, b = p.shape
if a < b:
n = p.mm(p.t())
else:
n = p.t().mm(p)
estimate = power_iteration(n, niters, bs)
else:
estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2)
total += estimate
return total / proba