mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-12 01:50:48 +01:00
Delete diffq directory
This commit is contained in:
parent
b16a422b39
commit
0e6f64ef20
@ -1,18 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
"""
|
||||
This package implements different quantization strategies:
|
||||
|
||||
- `diffq.uniform.UniformQuantizer`: classic uniform quantization over n bits.
|
||||
- `diffq.diffq.DiffQuantizer`: differentiable quantizer based on scaled noise injection.
|
||||
|
||||
Also, do check `diffq.base.BaseQuantizer` for the common methods of all Quantizers.
|
||||
"""
|
||||
|
||||
from .uniform import UniformQuantizer
|
||||
from .diffq import DiffQuantizer
|
262
diffq/base.py
262
diffq/base.py
@ -1,262 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from concurrent import futures
|
||||
from fnmatch import fnmatch
|
||||
from functools import partial
|
||||
import io
|
||||
import math
|
||||
from multiprocessing import cpu_count
|
||||
import typing as tp
|
||||
import zlib
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseQuantizer:
|
||||
@dataclass
|
||||
class _QuantizedParam:
|
||||
name: str
|
||||
param: torch.nn.Parameter
|
||||
module: torch.nn.Module
|
||||
# If a Parameter is used multiple times, `other` can be used
|
||||
# to share state between the different Quantizers
|
||||
other: tp.Optional[tp.Any]
|
||||
|
||||
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
|
||||
exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True):
|
||||
self.model = model
|
||||
self.min_size = min_size
|
||||
self.float16 = float16
|
||||
self.exclude = exclude
|
||||
self.detect_bound = detect_bound
|
||||
self._quantized = False
|
||||
self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook)
|
||||
self._post_handle = self.model.register_forward_hook(self._forward_hook)
|
||||
|
||||
self._quantized_state = None
|
||||
self._qparams = []
|
||||
self._float16 = []
|
||||
self._others = []
|
||||
self._rnns = []
|
||||
|
||||
self._saved = []
|
||||
|
||||
self._find_params()
|
||||
|
||||
def _find_params(self):
|
||||
min_params = self.min_size * 2**20 // 4
|
||||
previous = {}
|
||||
for module_name, module in self.model.named_modules():
|
||||
if isinstance(module, torch.nn.RNNBase):
|
||||
self._rnns.append(module)
|
||||
for name, param in list(module.named_parameters(recurse=False)):
|
||||
full_name = f"{module_name}.{name}"
|
||||
matched = False
|
||||
for pattern in self.exclude:
|
||||
if fnmatch(full_name, pattern) or fnmatch(name, pattern):
|
||||
matched = True
|
||||
break
|
||||
|
||||
if param.numel() <= min_params or matched:
|
||||
if id(param) in previous:
|
||||
continue
|
||||
if self.detect_bound:
|
||||
previous[id(param)] = None
|
||||
if self.float16:
|
||||
self._float16.append(param)
|
||||
else:
|
||||
self._others.append(param)
|
||||
else:
|
||||
qparam = self._register_param(name, param, module, previous.get(id(param)))
|
||||
if self.detect_bound:
|
||||
previous[id(param)] = qparam
|
||||
self._qparams.append(qparam)
|
||||
|
||||
def _register_param(self, name, param, module, other):
|
||||
return self.__class__._QuantizedParam(name, param, module, other)
|
||||
|
||||
def _forward_pre_hook(self, module, input):
|
||||
if self.model.training:
|
||||
self._quantized_state = None
|
||||
if self._quantized:
|
||||
self.unquantize()
|
||||
if self._pre_forward_train():
|
||||
self._fix_rnns()
|
||||
else:
|
||||
self.quantize()
|
||||
|
||||
def _forward_hook(self, module, input, output):
|
||||
if self.model.training:
|
||||
if self._post_forward_train():
|
||||
self._fix_rnns(flatten=False) # Hacky, next forward will flatten
|
||||
|
||||
def quantize(self, save=True):
|
||||
"""
|
||||
Immediately apply quantization to the model parameters.
|
||||
If `save` is True, save a copy of the unquantized parameters, that can be
|
||||
restored with `unquantize()`.
|
||||
"""
|
||||
if self._quantized:
|
||||
return
|
||||
if save:
|
||||
self._saved = [qp.param.data.to('cpu', copy=True)
|
||||
for qp in self._qparams if qp.other is None]
|
||||
self.restore_quantized_state(self.get_quantized_state())
|
||||
self._quantized = True
|
||||
self._fix_rnns()
|
||||
|
||||
def unquantize(self):
|
||||
"""
|
||||
Revert a previous call to `quantize()`.
|
||||
"""
|
||||
if not self._quantized:
|
||||
raise RuntimeError("Can only be called on a quantized model.")
|
||||
if not self._saved:
|
||||
raise RuntimeError("Nothing to restore.")
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is None:
|
||||
qparam.param.data[:] = self._saved.pop(0)
|
||||
assert len(self._saved) == 0
|
||||
self._quantized = False
|
||||
self._fix_rnns()
|
||||
|
||||
def _pre_forward_train(self) -> bool:
|
||||
"""
|
||||
Called once before each forward for continuous quantization.
|
||||
Should return True if parameters were changed.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _post_forward_train(self) -> bool:
|
||||
"""
|
||||
Called once after each forward (to restore state for instance).
|
||||
Should return True if parameters were changed.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _fix_rnns(self, flatten=True):
|
||||
"""
|
||||
To be called after quantization happened to fix RNNs.
|
||||
"""
|
||||
for rnn in self._rnns:
|
||||
rnn._flat_weights = [
|
||||
(lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn)
|
||||
for wn in rnn._flat_weights_names]
|
||||
if flatten:
|
||||
rnn.flatten_parameters()
|
||||
|
||||
def get_quantized_state(self):
|
||||
"""
|
||||
Returns sufficient quantized information to rebuild the model state.
|
||||
|
||||
..Note::
|
||||
To achieve maximum compression, you should compress this with
|
||||
gzip or other, as quantized weights are not optimally coded!
|
||||
"""
|
||||
if self._quantized_state is None:
|
||||
self._quantized_state = self._get_quantized_state()
|
||||
return self._quantized_state
|
||||
|
||||
def _get_quantized_state(self):
|
||||
"""
|
||||
Actual implementation for `get_quantized_state`.
|
||||
"""
|
||||
float16_params = []
|
||||
for p in self._float16:
|
||||
q = p.data.half()
|
||||
float16_params.append(q)
|
||||
|
||||
return {
|
||||
"quantized": [self._quantize_param(qparam) for qparam in self._qparams
|
||||
if qparam.other is None],
|
||||
"float16": float16_params,
|
||||
"others": [p.data.clone() for p in self._others],
|
||||
}
|
||||
|
||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
||||
"""
|
||||
To be overriden.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
||||
"""
|
||||
To be overriden.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def restore_quantized_state(self, state) -> None:
|
||||
"""
|
||||
Restore the state of the model from the quantized state.
|
||||
"""
|
||||
for p, q in zip(self._float16, state["float16"]):
|
||||
p.data[:] = q.to(p)
|
||||
|
||||
for p, q in zip(self._others, state["others"]):
|
||||
p.data[:] = q
|
||||
|
||||
remaining = list(state["quantized"])
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
# Only unquantize first appearance of nn.Parameter.
|
||||
continue
|
||||
quantized = remaining.pop(0)
|
||||
qparam.param.data[:] = self._unquantize_param(qparam, quantized)
|
||||
self._fix_rnns()
|
||||
|
||||
def detach(self) -> None:
|
||||
"""
|
||||
Detach from the model, removes hooks and anything else.
|
||||
"""
|
||||
self._pre_handle.remove()
|
||||
self._post_handle.remove()
|
||||
|
||||
def model_size(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns an estimate of the quantized model size.
|
||||
"""
|
||||
total = torch.tensor(0.)
|
||||
for p in self._float16:
|
||||
total += 16 * p.numel()
|
||||
for p in self._others:
|
||||
total += 32 * p.numel()
|
||||
return total / 2**20 / 8 # bits to MegaBytes
|
||||
|
||||
def true_model_size(self) -> float:
|
||||
"""
|
||||
Return the true quantized model size, in MB, without extra
|
||||
compression.
|
||||
"""
|
||||
return self.model_size().item()
|
||||
|
||||
def compressed_model_size(self, compress_level=-1, num_workers=8) -> float:
|
||||
"""
|
||||
Return the compressed quantized model size, in MB.
|
||||
|
||||
Args:
|
||||
compress_level (int): compression level used with zlib,
|
||||
see `zlib.compress` for details.
|
||||
num_workers (int): will split the final big byte representation in that
|
||||
many chunks processed in parallels.
|
||||
"""
|
||||
out = io.BytesIO()
|
||||
torch.save(self.get_quantized_state(), out)
|
||||
ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers)
|
||||
return ms / 2 ** 20
|
||||
|
||||
|
||||
def _compress_len(data, compress_level):
|
||||
return len(zlib.compress(data, level=compress_level))
|
||||
|
||||
|
||||
def _parallel_compress_len(data, compress_level, num_workers):
|
||||
num_workers = min(cpu_count(), num_workers)
|
||||
chunk_size = int(math.ceil(len(data) / num_workers))
|
||||
chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)]
|
||||
with futures.ProcessPoolExecutor(num_workers) as pool:
|
||||
return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks))
|
286
diffq/diffq.py
286
diffq/diffq.py
@ -1,286 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Differentiable quantizer based on scaled noise injection.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from .base import BaseQuantizer
|
||||
from .uniform import uniform_quantize, uniform_unquantize
|
||||
from .utils import simple_repr
|
||||
|
||||
|
||||
class DiffQuantizer(BaseQuantizer):
|
||||
@dataclass
|
||||
class _QuantizedParam(BaseQuantizer._QuantizedParam):
|
||||
logit: torch.nn.Parameter
|
||||
|
||||
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
|
||||
group_size: int = 1, min_bits: float = 2, max_bits: float = 15,
|
||||
param="bits", noise="gaussian",
|
||||
init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq",
|
||||
exclude: tp.List[str] = [], detect_bound: bool = True):
|
||||
"""
|
||||
Differentiable quantizer based on scaled noise injection.
|
||||
For every parameter `p` in the model, this introduces a number of bits parameter
|
||||
`b` with the same dimensions (when group_size = 1).
|
||||
Before each forward, `p` is replaced by `p + U`
|
||||
with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization
|
||||
step for `b` bits.
|
||||
This noise approximates the quantization noise in a differentiable manner, both
|
||||
with respect to the unquantized parameter `p` and the number of bits `b`.
|
||||
|
||||
At eveluation (as detected with `model.eval()`), the model is replaced
|
||||
by its true quantized version, and restored when going back to training.
|
||||
|
||||
When doing actual quantization (for serialization, or evaluation),
|
||||
the number of bits is rounded to the nearest integer, and needs to be stored along.
|
||||
This will cost a few bits per dimension. To reduce this cost, one can use `group_size`,
|
||||
which will use a single noise level for multiple weight entries.
|
||||
|
||||
You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the
|
||||
model size in MB. You can then use this estimate as a penalty in your training loss.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to quantize
|
||||
min_size (float): minimum size in MB of a parameter to be quantized.
|
||||
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
||||
group_size (int): weight entries are groupped together to reduce the number
|
||||
of noise scales to store. This should divide the size of all parameters
|
||||
bigger than min_size.
|
||||
min_bits (float): minimal number of bits.
|
||||
max_bits (float): maximal number of bits.
|
||||
init_bits (float): initial number of bits.
|
||||
extra_bits (float): extra bits to add for actual quantization (before roundoff).
|
||||
suffix (str): suffix used for the name of the extra noise scale parameters.
|
||||
exclude (list[str]): list of patterns used to match parameters to exclude.
|
||||
For instance `['bias']` to exclude all bias terms.
|
||||
detect_bound (bool): if True, will detect bound parameters and reuse
|
||||
the same quantized tensor for both, as well as the same number of bits.
|
||||
|
||||
..Warning::
|
||||
You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly.
|
||||
|
||||
"""
|
||||
self.group_size = group_size
|
||||
self.min_bits = min_bits
|
||||
self.max_bits = max_bits
|
||||
self.init_bits = init_bits
|
||||
self.extra_bits = extra_bits
|
||||
self.suffix = suffix
|
||||
self.param = param
|
||||
self.noise = noise
|
||||
assert noise in ["gaussian", "uniform"]
|
||||
self._optimizer_setup = False
|
||||
|
||||
self._min_noise = 1 / (2 ** self.max_bits - 1)
|
||||
self._max_noise = 1 / (2 ** self.min_bits - 1)
|
||||
|
||||
assert group_size >= 0
|
||||
assert min_bits < init_bits < max_bits, \
|
||||
"init_bits must be between min_bits and max_bits excluded3"
|
||||
|
||||
for name, _ in model.named_parameters():
|
||||
if name.endswith(suffix):
|
||||
raise RuntimeError("The model already has some noise scales parameters, "
|
||||
"maybe you used twice a DiffQuantizer on the same model?.")
|
||||
|
||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
||||
|
||||
def _get_bits(self, logit: torch.Tensor):
|
||||
if self.param == "noise":
|
||||
return torch.log2(1 + 1 / self._get_noise_scale(logit))
|
||||
else:
|
||||
t = torch.sigmoid(logit)
|
||||
return self.max_bits * t + (1 - t) * self.min_bits
|
||||
|
||||
def _get_noise_scale(self, logit: torch.Tensor):
|
||||
if self.param == "noise":
|
||||
t = torch.sigmoid(logit)
|
||||
return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise))
|
||||
else:
|
||||
return 1 / (2 ** self._get_bits(logit) - 1)
|
||||
|
||||
def _register_param(self, name, param, module, other):
|
||||
if other is not None:
|
||||
return self.__class__._QuantizedParam(
|
||||
name=name, param=param, module=module, logit=other.logit, other=other)
|
||||
assert self.group_size == 0 or param.numel() % self.group_size == 0
|
||||
# we want the initial number of bits to be init_bits.
|
||||
if self.param == "noise":
|
||||
noise_scale = 1 / (2 ** self.init_bits - 1)
|
||||
t = (math.log(noise_scale) - math.log(self._max_noise)) / (
|
||||
math.log(self._min_noise) - math.log(self._max_noise))
|
||||
else:
|
||||
t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
|
||||
assert 0 < t < 1
|
||||
logit = torch.logit(torch.tensor(float(t)))
|
||||
assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
|
||||
if self.group_size > 0:
|
||||
nparam = param.numel() // self.group_size
|
||||
else:
|
||||
nparam = 1
|
||||
logit = torch.nn.Parameter(
|
||||
torch.full(
|
||||
(nparam,),
|
||||
logit,
|
||||
device=param.device))
|
||||
module.register_parameter(name + self.suffix, logit)
|
||||
return self.__class__._QuantizedParam(
|
||||
name=name, param=param, module=module, logit=logit, other=None)
|
||||
|
||||
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
|
||||
params = [qp.logit for qp in self._qparams]
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
new_params = []
|
||||
for q in list(group["params"]):
|
||||
matched = False
|
||||
for p in params:
|
||||
if p is q:
|
||||
matched = True
|
||||
if not matched:
|
||||
new_params.append(q)
|
||||
group["params"][:] = new_params
|
||||
|
||||
def setup_optimizer(self, optimizer: torch.optim.Optimizer,
|
||||
lr: float = 1e-3, **kwargs):
|
||||
"""
|
||||
Setup the optimizer to tune the number of bits. In particular, this will deactivate
|
||||
weight decay for the bits parameters.
|
||||
|
||||
Args:
|
||||
optimizer (torch.Optimizer): optimizer to use.
|
||||
lr (float): specific learning rate for the bits parameters. 1e-3
|
||||
is perfect for Adam.,w
|
||||
kwargs (dict): overrides for other optimization parameters for the bits.
|
||||
"""
|
||||
assert not self._optimizer_setup
|
||||
self._optimizer_setup = True
|
||||
|
||||
params = [qp.logit for qp in self._qparams]
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
for q in list(group["params"]):
|
||||
for p in params:
|
||||
if p is q:
|
||||
raise RuntimeError("You should create the optimizer "
|
||||
"before the quantizer!")
|
||||
|
||||
group = {"params": params, "lr": lr, "weight_decay": 0}
|
||||
group.update(kwargs)
|
||||
optimizer.add_param_group(group)
|
||||
|
||||
def no_optimizer(self):
|
||||
"""
|
||||
Call this if you do not want to use an optimizer.
|
||||
"""
|
||||
self._optimizer_setup = True
|
||||
|
||||
def check_unused(self):
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
continue
|
||||
grad = qparam.param.grad
|
||||
if grad is None or (grad == 0).all():
|
||||
if qparam.logit.grad is not None:
|
||||
qparam.logit.grad.data.zero_()
|
||||
|
||||
def model_size(self, exact=False):
|
||||
"""
|
||||
Differentiable estimate of the model size.
|
||||
The size is returned in MB.
|
||||
|
||||
If `exact` is True, then the output is no longer differentiable but
|
||||
reflect exactly an achievable size, even without compression,
|
||||
i.e.same as returned by `naive_model_size()`.
|
||||
"""
|
||||
total = super().model_size()
|
||||
subtotal = 0
|
||||
for qparam in self._qparams:
|
||||
# only count the first appearance of a Parameter
|
||||
if qparam.other is not None:
|
||||
continue
|
||||
bits = self.extra_bits + self._get_bits(qparam.logit)
|
||||
if exact:
|
||||
bits = bits.round().clamp(1, 15)
|
||||
if self.group_size == 0:
|
||||
group_size = qparam.param.numel()
|
||||
else:
|
||||
group_size = self.group_size
|
||||
subtotal += group_size * bits.sum()
|
||||
subtotal += 2 * 32 # param scale
|
||||
|
||||
# Number of bits to represent each number of bits
|
||||
bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits)))
|
||||
subtotal += 8 # 8 bits for bits_bits
|
||||
subtotal += bits_bits * bits.numel()
|
||||
|
||||
subtotal /= 2 ** 20 * 8 # bits -> MegaBytes
|
||||
return total + subtotal
|
||||
|
||||
def true_model_size(self):
|
||||
"""
|
||||
Naive model size without zlib compression.
|
||||
"""
|
||||
return self.model_size(exact=True).item()
|
||||
|
||||
def _pre_forward_train(self):
|
||||
if not self._optimizer_setup:
|
||||
raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
|
||||
"before starting training.")
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
noisy = qparam.other.module._parameters[qparam.other.name]
|
||||
else:
|
||||
bits = self._get_bits(qparam.logit)[:, None]
|
||||
if self.group_size == 0:
|
||||
p_flat = qparam.param.view(-1)
|
||||
else:
|
||||
p_flat = qparam.param.view(-1, self.group_size)
|
||||
scale = p_flat.max() - p_flat.min()
|
||||
unit = 1 / (2**bits - 1)
|
||||
if self.noise == "uniform":
|
||||
noise_source = (torch.rand_like(p_flat) - 0.5)
|
||||
elif self.noise == "gaussian":
|
||||
noise_source = torch.randn_like(p_flat) / 2
|
||||
noise = scale * unit * noise_source
|
||||
noisy = p_flat + noise
|
||||
# We bypass the checks by PyTorch on parameters being leafs
|
||||
qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param)
|
||||
return True
|
||||
|
||||
def _post_forward_train(self):
|
||||
for qparam in self._qparams:
|
||||
qparam.module._parameters[qparam.name] = qparam.param
|
||||
return True
|
||||
|
||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
||||
bits = self.extra_bits + self._get_bits(qparam.logit)
|
||||
bits = bits.round().clamp(1, 15)[:, None].byte()
|
||||
if self.group_size == 0:
|
||||
p = qparam.param.data.view(-1)
|
||||
else:
|
||||
p = qparam.param.data.view(-1, self.group_size)
|
||||
levels, scales = uniform_quantize(p, bits)
|
||||
return levels, scales, bits
|
||||
|
||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
||||
levels, param_scale, bits = quantized
|
||||
return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data)
|
||||
|
||||
def detach(self):
|
||||
super().detach()
|
||||
for qparam in self._qparams:
|
||||
delattr(qparam.module, qparam.name + self.suffix)
|
||||
|
||||
def __repr__(self):
|
||||
return simple_repr(self)
|
121
diffq/uniform.py
121
diffq/uniform.py
@ -1,121 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Classic uniform quantization over n bits.
|
||||
"""
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
from .base import BaseQuantizer
|
||||
from .utils import simple_repr
|
||||
|
||||
|
||||
def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
|
||||
"""
|
||||
Quantize the given weights over `bits` bits.
|
||||
|
||||
Returns:
|
||||
- quantized levels
|
||||
- (min, max) range.
|
||||
|
||||
"""
|
||||
assert (bits >= 1).all() and (bits <= 15).all()
|
||||
num_levels = (2 ** bits.float()).long()
|
||||
mn = p.min().item()
|
||||
mx = p.max().item()
|
||||
p = (p - mn) / (mx - mn) # put p in [0, 1]
|
||||
unit = 1 / (num_levels - 1) # quantization unit
|
||||
levels = (p / unit).round()
|
||||
if (bits <= 8).all():
|
||||
levels = levels.byte()
|
||||
else:
|
||||
levels = levels.short()
|
||||
return levels, (mn, mx)
|
||||
|
||||
|
||||
def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float],
|
||||
bits: torch.Tensor = torch.tensor(8.)):
|
||||
"""
|
||||
Unquantize the weights from the levels and scale. Return a float32 tensor.
|
||||
"""
|
||||
mn, mx = scales
|
||||
num_levels = 2 ** bits.float()
|
||||
unit = 1 / (num_levels - 1)
|
||||
levels = levels.float()
|
||||
p = levels * unit # in [0, 1]
|
||||
return p * (mx - mn) + mn
|
||||
|
||||
|
||||
class UniformQuantizer(BaseQuantizer):
|
||||
def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01,
|
||||
float16: bool = False, qat: bool = False, exclude=[], detect_bound=True):
|
||||
"""
|
||||
Args:
|
||||
model (torch.nn.Module): model to quantize
|
||||
bits (float): number of bits to quantize over.
|
||||
min_size (float): minimum size in MB of a parameter to be quantized.
|
||||
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
||||
qat (bool): perform quantized aware training.
|
||||
exclude (list[str]): list of patterns used to match parameters to exclude.
|
||||
For instance `['bias']` to exclude all bias terms.
|
||||
detect_bound (bool): if True, will detect bound parameters and reuse
|
||||
the same quantized tensor for both.
|
||||
"""
|
||||
self.bits = float(bits)
|
||||
self.qat = qat
|
||||
|
||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
||||
|
||||
def __repr__(self):
|
||||
return simple_repr(self, )
|
||||
|
||||
def _pre_forward_train(self):
|
||||
if self.qat:
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
new_param = qparam.other.module._parameters[qparam.other.name]
|
||||
else:
|
||||
quantized = self._quantize_param(qparam)
|
||||
qvalue = self._unquantize_param(qparam, quantized)
|
||||
new_param = qparam.param + (qvalue - qparam.param).detach()
|
||||
qparam.module._parameters[qparam.name] = new_param
|
||||
return True
|
||||
return False
|
||||
|
||||
def _post_forward_train(self):
|
||||
if self.qat:
|
||||
for qparam in self._qparams:
|
||||
qparam.module._parameters[qparam.name] = qparam.param
|
||||
return True
|
||||
return False
|
||||
|
||||
def _quantize_param(self, qparam):
|
||||
levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits))
|
||||
return (levels, scales)
|
||||
|
||||
def _unquantize_param(self, qparam, quantized):
|
||||
levels, scales = quantized
|
||||
return uniform_unquantize(levels, scales, torch.tensor(self.bits))
|
||||
|
||||
def model_size(self):
|
||||
"""
|
||||
Non differentiable model size in MB.
|
||||
"""
|
||||
total = super().model_size()
|
||||
subtotal = 0
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is None: # if parameter is bound, count only one copy.
|
||||
subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales
|
||||
subtotal /= 2**20 * 8 # bits to MegaBytes
|
||||
return total + subtotal
|
||||
|
||||
def true_model_size(self):
|
||||
"""
|
||||
Return the true quantized model size, in MB, without extra
|
||||
compression.
|
||||
"""
|
||||
return self.model_size().item()
|
@ -1,37 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
def simple_repr(obj, attrs: Optional[List[str]] = None, overrides={}):
|
||||
"""
|
||||
Return a simple representation string for `obj`.
|
||||
If `attrs` is not None, it should be a list of attributes to include.
|
||||
"""
|
||||
params = inspect.signature(obj.__class__).parameters
|
||||
attrs_repr = []
|
||||
if attrs is None:
|
||||
attrs = params.keys()
|
||||
for attr in attrs:
|
||||
display = False
|
||||
if attr in overrides:
|
||||
value = overrides[attr]
|
||||
elif hasattr(obj, attr):
|
||||
value = getattr(obj, attr)
|
||||
else:
|
||||
continue
|
||||
if attr in params:
|
||||
param = params[attr]
|
||||
if param.default is inspect._empty or value != param.default:
|
||||
display = True
|
||||
else:
|
||||
display = True
|
||||
|
||||
if display:
|
||||
attrs_repr.append(f"{attr}={value}")
|
||||
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
|
Loading…
Reference in New Issue
Block a user