mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-12 01:50:48 +01:00
Delete diffq directory
This commit is contained in:
parent
b16a422b39
commit
0e6f64ef20
@ -1,18 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
# flake8: noqa
|
|
||||||
"""
|
|
||||||
This package implements different quantization strategies:
|
|
||||||
|
|
||||||
- `diffq.uniform.UniformQuantizer`: classic uniform quantization over n bits.
|
|
||||||
- `diffq.diffq.DiffQuantizer`: differentiable quantizer based on scaled noise injection.
|
|
||||||
|
|
||||||
Also, do check `diffq.base.BaseQuantizer` for the common methods of all Quantizers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .uniform import UniformQuantizer
|
|
||||||
from .diffq import DiffQuantizer
|
|
262
diffq/base.py
262
diffq/base.py
@ -1,262 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from concurrent import futures
|
|
||||||
from fnmatch import fnmatch
|
|
||||||
from functools import partial
|
|
||||||
import io
|
|
||||||
import math
|
|
||||||
from multiprocessing import cpu_count
|
|
||||||
import typing as tp
|
|
||||||
import zlib
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class BaseQuantizer:
|
|
||||||
@dataclass
|
|
||||||
class _QuantizedParam:
|
|
||||||
name: str
|
|
||||||
param: torch.nn.Parameter
|
|
||||||
module: torch.nn.Module
|
|
||||||
# If a Parameter is used multiple times, `other` can be used
|
|
||||||
# to share state between the different Quantizers
|
|
||||||
other: tp.Optional[tp.Any]
|
|
||||||
|
|
||||||
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
|
|
||||||
exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True):
|
|
||||||
self.model = model
|
|
||||||
self.min_size = min_size
|
|
||||||
self.float16 = float16
|
|
||||||
self.exclude = exclude
|
|
||||||
self.detect_bound = detect_bound
|
|
||||||
self._quantized = False
|
|
||||||
self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook)
|
|
||||||
self._post_handle = self.model.register_forward_hook(self._forward_hook)
|
|
||||||
|
|
||||||
self._quantized_state = None
|
|
||||||
self._qparams = []
|
|
||||||
self._float16 = []
|
|
||||||
self._others = []
|
|
||||||
self._rnns = []
|
|
||||||
|
|
||||||
self._saved = []
|
|
||||||
|
|
||||||
self._find_params()
|
|
||||||
|
|
||||||
def _find_params(self):
|
|
||||||
min_params = self.min_size * 2**20 // 4
|
|
||||||
previous = {}
|
|
||||||
for module_name, module in self.model.named_modules():
|
|
||||||
if isinstance(module, torch.nn.RNNBase):
|
|
||||||
self._rnns.append(module)
|
|
||||||
for name, param in list(module.named_parameters(recurse=False)):
|
|
||||||
full_name = f"{module_name}.{name}"
|
|
||||||
matched = False
|
|
||||||
for pattern in self.exclude:
|
|
||||||
if fnmatch(full_name, pattern) or fnmatch(name, pattern):
|
|
||||||
matched = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if param.numel() <= min_params or matched:
|
|
||||||
if id(param) in previous:
|
|
||||||
continue
|
|
||||||
if self.detect_bound:
|
|
||||||
previous[id(param)] = None
|
|
||||||
if self.float16:
|
|
||||||
self._float16.append(param)
|
|
||||||
else:
|
|
||||||
self._others.append(param)
|
|
||||||
else:
|
|
||||||
qparam = self._register_param(name, param, module, previous.get(id(param)))
|
|
||||||
if self.detect_bound:
|
|
||||||
previous[id(param)] = qparam
|
|
||||||
self._qparams.append(qparam)
|
|
||||||
|
|
||||||
def _register_param(self, name, param, module, other):
|
|
||||||
return self.__class__._QuantizedParam(name, param, module, other)
|
|
||||||
|
|
||||||
def _forward_pre_hook(self, module, input):
|
|
||||||
if self.model.training:
|
|
||||||
self._quantized_state = None
|
|
||||||
if self._quantized:
|
|
||||||
self.unquantize()
|
|
||||||
if self._pre_forward_train():
|
|
||||||
self._fix_rnns()
|
|
||||||
else:
|
|
||||||
self.quantize()
|
|
||||||
|
|
||||||
def _forward_hook(self, module, input, output):
|
|
||||||
if self.model.training:
|
|
||||||
if self._post_forward_train():
|
|
||||||
self._fix_rnns(flatten=False) # Hacky, next forward will flatten
|
|
||||||
|
|
||||||
def quantize(self, save=True):
|
|
||||||
"""
|
|
||||||
Immediately apply quantization to the model parameters.
|
|
||||||
If `save` is True, save a copy of the unquantized parameters, that can be
|
|
||||||
restored with `unquantize()`.
|
|
||||||
"""
|
|
||||||
if self._quantized:
|
|
||||||
return
|
|
||||||
if save:
|
|
||||||
self._saved = [qp.param.data.to('cpu', copy=True)
|
|
||||||
for qp in self._qparams if qp.other is None]
|
|
||||||
self.restore_quantized_state(self.get_quantized_state())
|
|
||||||
self._quantized = True
|
|
||||||
self._fix_rnns()
|
|
||||||
|
|
||||||
def unquantize(self):
|
|
||||||
"""
|
|
||||||
Revert a previous call to `quantize()`.
|
|
||||||
"""
|
|
||||||
if not self._quantized:
|
|
||||||
raise RuntimeError("Can only be called on a quantized model.")
|
|
||||||
if not self._saved:
|
|
||||||
raise RuntimeError("Nothing to restore.")
|
|
||||||
for qparam in self._qparams:
|
|
||||||
if qparam.other is None:
|
|
||||||
qparam.param.data[:] = self._saved.pop(0)
|
|
||||||
assert len(self._saved) == 0
|
|
||||||
self._quantized = False
|
|
||||||
self._fix_rnns()
|
|
||||||
|
|
||||||
def _pre_forward_train(self) -> bool:
|
|
||||||
"""
|
|
||||||
Called once before each forward for continuous quantization.
|
|
||||||
Should return True if parameters were changed.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _post_forward_train(self) -> bool:
|
|
||||||
"""
|
|
||||||
Called once after each forward (to restore state for instance).
|
|
||||||
Should return True if parameters were changed.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _fix_rnns(self, flatten=True):
|
|
||||||
"""
|
|
||||||
To be called after quantization happened to fix RNNs.
|
|
||||||
"""
|
|
||||||
for rnn in self._rnns:
|
|
||||||
rnn._flat_weights = [
|
|
||||||
(lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn)
|
|
||||||
for wn in rnn._flat_weights_names]
|
|
||||||
if flatten:
|
|
||||||
rnn.flatten_parameters()
|
|
||||||
|
|
||||||
def get_quantized_state(self):
|
|
||||||
"""
|
|
||||||
Returns sufficient quantized information to rebuild the model state.
|
|
||||||
|
|
||||||
..Note::
|
|
||||||
To achieve maximum compression, you should compress this with
|
|
||||||
gzip or other, as quantized weights are not optimally coded!
|
|
||||||
"""
|
|
||||||
if self._quantized_state is None:
|
|
||||||
self._quantized_state = self._get_quantized_state()
|
|
||||||
return self._quantized_state
|
|
||||||
|
|
||||||
def _get_quantized_state(self):
|
|
||||||
"""
|
|
||||||
Actual implementation for `get_quantized_state`.
|
|
||||||
"""
|
|
||||||
float16_params = []
|
|
||||||
for p in self._float16:
|
|
||||||
q = p.data.half()
|
|
||||||
float16_params.append(q)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"quantized": [self._quantize_param(qparam) for qparam in self._qparams
|
|
||||||
if qparam.other is None],
|
|
||||||
"float16": float16_params,
|
|
||||||
"others": [p.data.clone() for p in self._others],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
|
||||||
"""
|
|
||||||
To be overriden.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
To be overriden.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def restore_quantized_state(self, state) -> None:
|
|
||||||
"""
|
|
||||||
Restore the state of the model from the quantized state.
|
|
||||||
"""
|
|
||||||
for p, q in zip(self._float16, state["float16"]):
|
|
||||||
p.data[:] = q.to(p)
|
|
||||||
|
|
||||||
for p, q in zip(self._others, state["others"]):
|
|
||||||
p.data[:] = q
|
|
||||||
|
|
||||||
remaining = list(state["quantized"])
|
|
||||||
for qparam in self._qparams:
|
|
||||||
if qparam.other is not None:
|
|
||||||
# Only unquantize first appearance of nn.Parameter.
|
|
||||||
continue
|
|
||||||
quantized = remaining.pop(0)
|
|
||||||
qparam.param.data[:] = self._unquantize_param(qparam, quantized)
|
|
||||||
self._fix_rnns()
|
|
||||||
|
|
||||||
def detach(self) -> None:
|
|
||||||
"""
|
|
||||||
Detach from the model, removes hooks and anything else.
|
|
||||||
"""
|
|
||||||
self._pre_handle.remove()
|
|
||||||
self._post_handle.remove()
|
|
||||||
|
|
||||||
def model_size(self) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Returns an estimate of the quantized model size.
|
|
||||||
"""
|
|
||||||
total = torch.tensor(0.)
|
|
||||||
for p in self._float16:
|
|
||||||
total += 16 * p.numel()
|
|
||||||
for p in self._others:
|
|
||||||
total += 32 * p.numel()
|
|
||||||
return total / 2**20 / 8 # bits to MegaBytes
|
|
||||||
|
|
||||||
def true_model_size(self) -> float:
|
|
||||||
"""
|
|
||||||
Return the true quantized model size, in MB, without extra
|
|
||||||
compression.
|
|
||||||
"""
|
|
||||||
return self.model_size().item()
|
|
||||||
|
|
||||||
def compressed_model_size(self, compress_level=-1, num_workers=8) -> float:
|
|
||||||
"""
|
|
||||||
Return the compressed quantized model size, in MB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
compress_level (int): compression level used with zlib,
|
|
||||||
see `zlib.compress` for details.
|
|
||||||
num_workers (int): will split the final big byte representation in that
|
|
||||||
many chunks processed in parallels.
|
|
||||||
"""
|
|
||||||
out = io.BytesIO()
|
|
||||||
torch.save(self.get_quantized_state(), out)
|
|
||||||
ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers)
|
|
||||||
return ms / 2 ** 20
|
|
||||||
|
|
||||||
|
|
||||||
def _compress_len(data, compress_level):
|
|
||||||
return len(zlib.compress(data, level=compress_level))
|
|
||||||
|
|
||||||
|
|
||||||
def _parallel_compress_len(data, compress_level, num_workers):
|
|
||||||
num_workers = min(cpu_count(), num_workers)
|
|
||||||
chunk_size = int(math.ceil(len(data) / num_workers))
|
|
||||||
chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)]
|
|
||||||
with futures.ProcessPoolExecutor(num_workers) as pool:
|
|
||||||
return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks))
|
|
286
diffq/diffq.py
286
diffq/diffq.py
@ -1,286 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Differentiable quantizer based on scaled noise injection.
|
|
||||||
"""
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import math
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .base import BaseQuantizer
|
|
||||||
from .uniform import uniform_quantize, uniform_unquantize
|
|
||||||
from .utils import simple_repr
|
|
||||||
|
|
||||||
|
|
||||||
class DiffQuantizer(BaseQuantizer):
|
|
||||||
@dataclass
|
|
||||||
class _QuantizedParam(BaseQuantizer._QuantizedParam):
|
|
||||||
logit: torch.nn.Parameter
|
|
||||||
|
|
||||||
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
|
|
||||||
group_size: int = 1, min_bits: float = 2, max_bits: float = 15,
|
|
||||||
param="bits", noise="gaussian",
|
|
||||||
init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq",
|
|
||||||
exclude: tp.List[str] = [], detect_bound: bool = True):
|
|
||||||
"""
|
|
||||||
Differentiable quantizer based on scaled noise injection.
|
|
||||||
For every parameter `p` in the model, this introduces a number of bits parameter
|
|
||||||
`b` with the same dimensions (when group_size = 1).
|
|
||||||
Before each forward, `p` is replaced by `p + U`
|
|
||||||
with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization
|
|
||||||
step for `b` bits.
|
|
||||||
This noise approximates the quantization noise in a differentiable manner, both
|
|
||||||
with respect to the unquantized parameter `p` and the number of bits `b`.
|
|
||||||
|
|
||||||
At eveluation (as detected with `model.eval()`), the model is replaced
|
|
||||||
by its true quantized version, and restored when going back to training.
|
|
||||||
|
|
||||||
When doing actual quantization (for serialization, or evaluation),
|
|
||||||
the number of bits is rounded to the nearest integer, and needs to be stored along.
|
|
||||||
This will cost a few bits per dimension. To reduce this cost, one can use `group_size`,
|
|
||||||
which will use a single noise level for multiple weight entries.
|
|
||||||
|
|
||||||
You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the
|
|
||||||
model size in MB. You can then use this estimate as a penalty in your training loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (torch.nn.Module): model to quantize
|
|
||||||
min_size (float): minimum size in MB of a parameter to be quantized.
|
|
||||||
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
|
||||||
group_size (int): weight entries are groupped together to reduce the number
|
|
||||||
of noise scales to store. This should divide the size of all parameters
|
|
||||||
bigger than min_size.
|
|
||||||
min_bits (float): minimal number of bits.
|
|
||||||
max_bits (float): maximal number of bits.
|
|
||||||
init_bits (float): initial number of bits.
|
|
||||||
extra_bits (float): extra bits to add for actual quantization (before roundoff).
|
|
||||||
suffix (str): suffix used for the name of the extra noise scale parameters.
|
|
||||||
exclude (list[str]): list of patterns used to match parameters to exclude.
|
|
||||||
For instance `['bias']` to exclude all bias terms.
|
|
||||||
detect_bound (bool): if True, will detect bound parameters and reuse
|
|
||||||
the same quantized tensor for both, as well as the same number of bits.
|
|
||||||
|
|
||||||
..Warning::
|
|
||||||
You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly.
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.group_size = group_size
|
|
||||||
self.min_bits = min_bits
|
|
||||||
self.max_bits = max_bits
|
|
||||||
self.init_bits = init_bits
|
|
||||||
self.extra_bits = extra_bits
|
|
||||||
self.suffix = suffix
|
|
||||||
self.param = param
|
|
||||||
self.noise = noise
|
|
||||||
assert noise in ["gaussian", "uniform"]
|
|
||||||
self._optimizer_setup = False
|
|
||||||
|
|
||||||
self._min_noise = 1 / (2 ** self.max_bits - 1)
|
|
||||||
self._max_noise = 1 / (2 ** self.min_bits - 1)
|
|
||||||
|
|
||||||
assert group_size >= 0
|
|
||||||
assert min_bits < init_bits < max_bits, \
|
|
||||||
"init_bits must be between min_bits and max_bits excluded3"
|
|
||||||
|
|
||||||
for name, _ in model.named_parameters():
|
|
||||||
if name.endswith(suffix):
|
|
||||||
raise RuntimeError("The model already has some noise scales parameters, "
|
|
||||||
"maybe you used twice a DiffQuantizer on the same model?.")
|
|
||||||
|
|
||||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
|
||||||
|
|
||||||
def _get_bits(self, logit: torch.Tensor):
|
|
||||||
if self.param == "noise":
|
|
||||||
return torch.log2(1 + 1 / self._get_noise_scale(logit))
|
|
||||||
else:
|
|
||||||
t = torch.sigmoid(logit)
|
|
||||||
return self.max_bits * t + (1 - t) * self.min_bits
|
|
||||||
|
|
||||||
def _get_noise_scale(self, logit: torch.Tensor):
|
|
||||||
if self.param == "noise":
|
|
||||||
t = torch.sigmoid(logit)
|
|
||||||
return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise))
|
|
||||||
else:
|
|
||||||
return 1 / (2 ** self._get_bits(logit) - 1)
|
|
||||||
|
|
||||||
def _register_param(self, name, param, module, other):
|
|
||||||
if other is not None:
|
|
||||||
return self.__class__._QuantizedParam(
|
|
||||||
name=name, param=param, module=module, logit=other.logit, other=other)
|
|
||||||
assert self.group_size == 0 or param.numel() % self.group_size == 0
|
|
||||||
# we want the initial number of bits to be init_bits.
|
|
||||||
if self.param == "noise":
|
|
||||||
noise_scale = 1 / (2 ** self.init_bits - 1)
|
|
||||||
t = (math.log(noise_scale) - math.log(self._max_noise)) / (
|
|
||||||
math.log(self._min_noise) - math.log(self._max_noise))
|
|
||||||
else:
|
|
||||||
t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
|
|
||||||
assert 0 < t < 1
|
|
||||||
logit = torch.logit(torch.tensor(float(t)))
|
|
||||||
assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
|
|
||||||
if self.group_size > 0:
|
|
||||||
nparam = param.numel() // self.group_size
|
|
||||||
else:
|
|
||||||
nparam = 1
|
|
||||||
logit = torch.nn.Parameter(
|
|
||||||
torch.full(
|
|
||||||
(nparam,),
|
|
||||||
logit,
|
|
||||||
device=param.device))
|
|
||||||
module.register_parameter(name + self.suffix, logit)
|
|
||||||
return self.__class__._QuantizedParam(
|
|
||||||
name=name, param=param, module=module, logit=logit, other=None)
|
|
||||||
|
|
||||||
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
|
|
||||||
params = [qp.logit for qp in self._qparams]
|
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
|
||||||
new_params = []
|
|
||||||
for q in list(group["params"]):
|
|
||||||
matched = False
|
|
||||||
for p in params:
|
|
||||||
if p is q:
|
|
||||||
matched = True
|
|
||||||
if not matched:
|
|
||||||
new_params.append(q)
|
|
||||||
group["params"][:] = new_params
|
|
||||||
|
|
||||||
def setup_optimizer(self, optimizer: torch.optim.Optimizer,
|
|
||||||
lr: float = 1e-3, **kwargs):
|
|
||||||
"""
|
|
||||||
Setup the optimizer to tune the number of bits. In particular, this will deactivate
|
|
||||||
weight decay for the bits parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
optimizer (torch.Optimizer): optimizer to use.
|
|
||||||
lr (float): specific learning rate for the bits parameters. 1e-3
|
|
||||||
is perfect for Adam.,w
|
|
||||||
kwargs (dict): overrides for other optimization parameters for the bits.
|
|
||||||
"""
|
|
||||||
assert not self._optimizer_setup
|
|
||||||
self._optimizer_setup = True
|
|
||||||
|
|
||||||
params = [qp.logit for qp in self._qparams]
|
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
|
||||||
for q in list(group["params"]):
|
|
||||||
for p in params:
|
|
||||||
if p is q:
|
|
||||||
raise RuntimeError("You should create the optimizer "
|
|
||||||
"before the quantizer!")
|
|
||||||
|
|
||||||
group = {"params": params, "lr": lr, "weight_decay": 0}
|
|
||||||
group.update(kwargs)
|
|
||||||
optimizer.add_param_group(group)
|
|
||||||
|
|
||||||
def no_optimizer(self):
|
|
||||||
"""
|
|
||||||
Call this if you do not want to use an optimizer.
|
|
||||||
"""
|
|
||||||
self._optimizer_setup = True
|
|
||||||
|
|
||||||
def check_unused(self):
|
|
||||||
for qparam in self._qparams:
|
|
||||||
if qparam.other is not None:
|
|
||||||
continue
|
|
||||||
grad = qparam.param.grad
|
|
||||||
if grad is None or (grad == 0).all():
|
|
||||||
if qparam.logit.grad is not None:
|
|
||||||
qparam.logit.grad.data.zero_()
|
|
||||||
|
|
||||||
def model_size(self, exact=False):
|
|
||||||
"""
|
|
||||||
Differentiable estimate of the model size.
|
|
||||||
The size is returned in MB.
|
|
||||||
|
|
||||||
If `exact` is True, then the output is no longer differentiable but
|
|
||||||
reflect exactly an achievable size, even without compression,
|
|
||||||
i.e.same as returned by `naive_model_size()`.
|
|
||||||
"""
|
|
||||||
total = super().model_size()
|
|
||||||
subtotal = 0
|
|
||||||
for qparam in self._qparams:
|
|
||||||
# only count the first appearance of a Parameter
|
|
||||||
if qparam.other is not None:
|
|
||||||
continue
|
|
||||||
bits = self.extra_bits + self._get_bits(qparam.logit)
|
|
||||||
if exact:
|
|
||||||
bits = bits.round().clamp(1, 15)
|
|
||||||
if self.group_size == 0:
|
|
||||||
group_size = qparam.param.numel()
|
|
||||||
else:
|
|
||||||
group_size = self.group_size
|
|
||||||
subtotal += group_size * bits.sum()
|
|
||||||
subtotal += 2 * 32 # param scale
|
|
||||||
|
|
||||||
# Number of bits to represent each number of bits
|
|
||||||
bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits)))
|
|
||||||
subtotal += 8 # 8 bits for bits_bits
|
|
||||||
subtotal += bits_bits * bits.numel()
|
|
||||||
|
|
||||||
subtotal /= 2 ** 20 * 8 # bits -> MegaBytes
|
|
||||||
return total + subtotal
|
|
||||||
|
|
||||||
def true_model_size(self):
|
|
||||||
"""
|
|
||||||
Naive model size without zlib compression.
|
|
||||||
"""
|
|
||||||
return self.model_size(exact=True).item()
|
|
||||||
|
|
||||||
def _pre_forward_train(self):
|
|
||||||
if not self._optimizer_setup:
|
|
||||||
raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
|
|
||||||
"before starting training.")
|
|
||||||
for qparam in self._qparams:
|
|
||||||
if qparam.other is not None:
|
|
||||||
noisy = qparam.other.module._parameters[qparam.other.name]
|
|
||||||
else:
|
|
||||||
bits = self._get_bits(qparam.logit)[:, None]
|
|
||||||
if self.group_size == 0:
|
|
||||||
p_flat = qparam.param.view(-1)
|
|
||||||
else:
|
|
||||||
p_flat = qparam.param.view(-1, self.group_size)
|
|
||||||
scale = p_flat.max() - p_flat.min()
|
|
||||||
unit = 1 / (2**bits - 1)
|
|
||||||
if self.noise == "uniform":
|
|
||||||
noise_source = (torch.rand_like(p_flat) - 0.5)
|
|
||||||
elif self.noise == "gaussian":
|
|
||||||
noise_source = torch.randn_like(p_flat) / 2
|
|
||||||
noise = scale * unit * noise_source
|
|
||||||
noisy = p_flat + noise
|
|
||||||
# We bypass the checks by PyTorch on parameters being leafs
|
|
||||||
qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _post_forward_train(self):
|
|
||||||
for qparam in self._qparams:
|
|
||||||
qparam.module._parameters[qparam.name] = qparam.param
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
|
||||||
bits = self.extra_bits + self._get_bits(qparam.logit)
|
|
||||||
bits = bits.round().clamp(1, 15)[:, None].byte()
|
|
||||||
if self.group_size == 0:
|
|
||||||
p = qparam.param.data.view(-1)
|
|
||||||
else:
|
|
||||||
p = qparam.param.data.view(-1, self.group_size)
|
|
||||||
levels, scales = uniform_quantize(p, bits)
|
|
||||||
return levels, scales, bits
|
|
||||||
|
|
||||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
|
||||||
levels, param_scale, bits = quantized
|
|
||||||
return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data)
|
|
||||||
|
|
||||||
def detach(self):
|
|
||||||
super().detach()
|
|
||||||
for qparam in self._qparams:
|
|
||||||
delattr(qparam.module, qparam.name + self.suffix)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return simple_repr(self)
|
|
121
diffq/uniform.py
121
diffq/uniform.py
@ -1,121 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Classic uniform quantization over n bits.
|
|
||||||
"""
|
|
||||||
from typing import Tuple
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .base import BaseQuantizer
|
|
||||||
from .utils import simple_repr
|
|
||||||
|
|
||||||
|
|
||||||
def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
|
|
||||||
"""
|
|
||||||
Quantize the given weights over `bits` bits.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- quantized levels
|
|
||||||
- (min, max) range.
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert (bits >= 1).all() and (bits <= 15).all()
|
|
||||||
num_levels = (2 ** bits.float()).long()
|
|
||||||
mn = p.min().item()
|
|
||||||
mx = p.max().item()
|
|
||||||
p = (p - mn) / (mx - mn) # put p in [0, 1]
|
|
||||||
unit = 1 / (num_levels - 1) # quantization unit
|
|
||||||
levels = (p / unit).round()
|
|
||||||
if (bits <= 8).all():
|
|
||||||
levels = levels.byte()
|
|
||||||
else:
|
|
||||||
levels = levels.short()
|
|
||||||
return levels, (mn, mx)
|
|
||||||
|
|
||||||
|
|
||||||
def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float],
|
|
||||||
bits: torch.Tensor = torch.tensor(8.)):
|
|
||||||
"""
|
|
||||||
Unquantize the weights from the levels and scale. Return a float32 tensor.
|
|
||||||
"""
|
|
||||||
mn, mx = scales
|
|
||||||
num_levels = 2 ** bits.float()
|
|
||||||
unit = 1 / (num_levels - 1)
|
|
||||||
levels = levels.float()
|
|
||||||
p = levels * unit # in [0, 1]
|
|
||||||
return p * (mx - mn) + mn
|
|
||||||
|
|
||||||
|
|
||||||
class UniformQuantizer(BaseQuantizer):
|
|
||||||
def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01,
|
|
||||||
float16: bool = False, qat: bool = False, exclude=[], detect_bound=True):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
model (torch.nn.Module): model to quantize
|
|
||||||
bits (float): number of bits to quantize over.
|
|
||||||
min_size (float): minimum size in MB of a parameter to be quantized.
|
|
||||||
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
|
||||||
qat (bool): perform quantized aware training.
|
|
||||||
exclude (list[str]): list of patterns used to match parameters to exclude.
|
|
||||||
For instance `['bias']` to exclude all bias terms.
|
|
||||||
detect_bound (bool): if True, will detect bound parameters and reuse
|
|
||||||
the same quantized tensor for both.
|
|
||||||
"""
|
|
||||||
self.bits = float(bits)
|
|
||||||
self.qat = qat
|
|
||||||
|
|
||||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return simple_repr(self, )
|
|
||||||
|
|
||||||
def _pre_forward_train(self):
|
|
||||||
if self.qat:
|
|
||||||
for qparam in self._qparams:
|
|
||||||
if qparam.other is not None:
|
|
||||||
new_param = qparam.other.module._parameters[qparam.other.name]
|
|
||||||
else:
|
|
||||||
quantized = self._quantize_param(qparam)
|
|
||||||
qvalue = self._unquantize_param(qparam, quantized)
|
|
||||||
new_param = qparam.param + (qvalue - qparam.param).detach()
|
|
||||||
qparam.module._parameters[qparam.name] = new_param
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _post_forward_train(self):
|
|
||||||
if self.qat:
|
|
||||||
for qparam in self._qparams:
|
|
||||||
qparam.module._parameters[qparam.name] = qparam.param
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _quantize_param(self, qparam):
|
|
||||||
levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits))
|
|
||||||
return (levels, scales)
|
|
||||||
|
|
||||||
def _unquantize_param(self, qparam, quantized):
|
|
||||||
levels, scales = quantized
|
|
||||||
return uniform_unquantize(levels, scales, torch.tensor(self.bits))
|
|
||||||
|
|
||||||
def model_size(self):
|
|
||||||
"""
|
|
||||||
Non differentiable model size in MB.
|
|
||||||
"""
|
|
||||||
total = super().model_size()
|
|
||||||
subtotal = 0
|
|
||||||
for qparam in self._qparams:
|
|
||||||
if qparam.other is None: # if parameter is bound, count only one copy.
|
|
||||||
subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales
|
|
||||||
subtotal /= 2**20 * 8 # bits to MegaBytes
|
|
||||||
return total + subtotal
|
|
||||||
|
|
||||||
def true_model_size(self):
|
|
||||||
"""
|
|
||||||
Return the true quantized model size, in MB, without extra
|
|
||||||
compression.
|
|
||||||
"""
|
|
||||||
return self.model_size().item()
|
|
@ -1,37 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
|
|
||||||
def simple_repr(obj, attrs: Optional[List[str]] = None, overrides={}):
|
|
||||||
"""
|
|
||||||
Return a simple representation string for `obj`.
|
|
||||||
If `attrs` is not None, it should be a list of attributes to include.
|
|
||||||
"""
|
|
||||||
params = inspect.signature(obj.__class__).parameters
|
|
||||||
attrs_repr = []
|
|
||||||
if attrs is None:
|
|
||||||
attrs = params.keys()
|
|
||||||
for attr in attrs:
|
|
||||||
display = False
|
|
||||||
if attr in overrides:
|
|
||||||
value = overrides[attr]
|
|
||||||
elif hasattr(obj, attr):
|
|
||||||
value = getattr(obj, attr)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
if attr in params:
|
|
||||||
param = params[attr]
|
|
||||||
if param.default is inspect._empty or value != param.default:
|
|
||||||
display = True
|
|
||||||
else:
|
|
||||||
display = True
|
|
||||||
|
|
||||||
if display:
|
|
||||||
attrs_repr.append(f"{attr}={value}")
|
|
||||||
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
|
|
Loading…
Reference in New Issue
Block a user