mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-28 01:10:56 +01:00
136 lines
4.8 KiB
Python
136 lines
4.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
Classic uniform quantization over n bits.
|
|
"""
|
|
from typing import Tuple
|
|
import torch
|
|
|
|
from .base import BaseQuantizer
|
|
from .utils import capture_init, simple_repr
|
|
|
|
|
|
def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
|
|
"""
|
|
Quantize the given weights over `bits` bits.
|
|
|
|
Returns:
|
|
- quantized levels
|
|
- (min, max) range.
|
|
|
|
"""
|
|
assert (bits >= 1).all() and (bits <= 15).all()
|
|
num_levels = (2 ** bits.float()).long()
|
|
mn = p.min().item()
|
|
mx = p.max().item()
|
|
p = (p - mn) / (mx - mn) # put p in [0, 1]
|
|
unit = 1 / (num_levels - 1) # quantization unit
|
|
levels = (p / unit).round()
|
|
if (bits <= 8).all():
|
|
levels = levels.byte()
|
|
else:
|
|
levels = levels.short()
|
|
return levels, (mn, mx)
|
|
|
|
|
|
def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float],
|
|
bits: torch.Tensor = torch.tensor(8.)):
|
|
"""
|
|
Unquantize the weights from the levels and scale. Return a float32 tensor.
|
|
"""
|
|
mn, mx = scales
|
|
num_levels = 2 ** bits.float()
|
|
unit = 1 / (num_levels - 1)
|
|
levels = levels.float()
|
|
p = levels * unit # in [0, 1]
|
|
return p * (mx - mn) + mn
|
|
|
|
|
|
class UniformQuantizer(BaseQuantizer):
|
|
@capture_init
|
|
def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01,
|
|
float16: bool = False, qat: bool = False, exclude=[], detect_bound=True):
|
|
"""
|
|
Args:
|
|
model (torch.nn.Module): model to quantize
|
|
bits (float): number of bits to quantize over.
|
|
min_size (float): minimum size in MB of a parameter to be quantized.
|
|
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
|
qat (bool): perform quantized aware training.
|
|
exclude (list[str]): list of patterns used to match parameters to exclude.
|
|
For instance `['bias']` to exclude all bias terms.
|
|
detect_bound (bool): if True, will detect bound parameters and reuse
|
|
the same quantized tensor for both.
|
|
"""
|
|
self.bits = float(bits)
|
|
self.qat = qat
|
|
|
|
super().__init__(model, min_size, float16, exclude, detect_bound)
|
|
|
|
def __repr__(self):
|
|
return simple_repr(self, )
|
|
|
|
def _pre_forward_train(self):
|
|
if self.qat:
|
|
for qparam in self._qparams:
|
|
if qparam.other is not None:
|
|
new_param = qparam.other.module._parameters[qparam.other.name]
|
|
else:
|
|
quantized = self._quantize_param(qparam)
|
|
qvalue = self._unquantize_param(qparam, quantized)
|
|
new_param = qparam.param + (qvalue - qparam.param).detach()
|
|
qparam.module._parameters[qparam.name] = new_param
|
|
return True
|
|
return False
|
|
|
|
def _post_forward_train(self):
|
|
if self.qat:
|
|
for qparam in self._qparams:
|
|
qparam.module._parameters[qparam.name] = qparam.param
|
|
return True
|
|
return False
|
|
|
|
def _quantize_param(self, qparam):
|
|
levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits))
|
|
return (levels, scales)
|
|
|
|
def _unquantize_param(self, qparam, quantized):
|
|
levels, scales = quantized
|
|
return uniform_unquantize(levels, scales, torch.tensor(self.bits))
|
|
|
|
def _bit_pack_param(self, qparam, quantized, pack_fn):
|
|
levels, scales = quantized
|
|
packed = pack_fn(levels, self.bits)
|
|
return (packed, scales)
|
|
|
|
def _bit_unpack_param(self, qparam, packed, unpack_fn):
|
|
"""Unpack bitpacked representation. Should be overriden
|
|
"""
|
|
packed_levels, scales = packed
|
|
levels = unpack_fn(
|
|
packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
|
|
return (levels, scales)
|
|
|
|
def model_size(self):
|
|
"""
|
|
Non differentiable model size in MB.
|
|
"""
|
|
total = super().model_size()
|
|
subtotal = 0
|
|
for qparam in self._qparams:
|
|
if qparam.other is None: # if parameter is bound, count only one copy.
|
|
subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales
|
|
subtotal /= 2**20 * 8 # bits to MegaBytes
|
|
return total + subtotal
|
|
|
|
def true_model_size(self):
|
|
"""
|
|
Return the true quantized model size, in MB, without extra
|
|
compression.
|
|
"""
|
|
return self.model_size().item()
|