mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2025-01-31 04:13:45 +01:00
Delete diffq directory
This commit is contained in:
parent
8def1cb6c5
commit
cba7e1c4f6
@ -1,29 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
"""
|
||||
This package implements different quantization strategies:
|
||||
|
||||
- `diffq.uniform.UniformQuantizer`: classic uniform quantization over n bits.
|
||||
- `diffq.diffq.DiffQuantizer`: differentiable quantizer based on scaled noise injection.
|
||||
- `diffq.lsq.LSQ`: Learnt Step size Quantizer based on [Esser et al. 2019] https://arxiv.org/abs/1902.08153
|
||||
- `diffq.bitpack`: efficient CPU bit-packing for returning quantized states.
|
||||
- `diffq.torch_pack`: torch based bit-packing compatible with torchscript.
|
||||
- `diffq.ts_export`: support exporting DiffQ based models to torchscript.
|
||||
|
||||
|
||||
Also, do check `diffq.base.BaseQuantizer` for the common methods of all Quantizers.
|
||||
"""
|
||||
|
||||
from .uniform import UniformQuantizer
|
||||
from .diffq import DiffQuantizer
|
||||
from .lsq import LSQ
|
||||
from .base import restore_quantized_state
|
||||
from . import ts_export
|
||||
|
||||
|
||||
__version__ = "0.2.2"
|
343
diffq/base.py
343
diffq/base.py
@ -1,343 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Base class for all quantizers."""
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from concurrent import futures
|
||||
from fnmatch import fnmatch
|
||||
from functools import partial
|
||||
import io
|
||||
import math
|
||||
from multiprocessing import cpu_count
|
||||
import pickle
|
||||
import typing as tp
|
||||
import zlib
|
||||
|
||||
import torch
|
||||
|
||||
from . import bitpack
|
||||
from . import torch_pack as torch_pack_mod
|
||||
|
||||
|
||||
class BaseQuantizer:
|
||||
@dataclass
|
||||
class _QuantizedParam:
|
||||
name: str
|
||||
param: torch.nn.Parameter
|
||||
module: torch.nn.Module
|
||||
# If a Parameter is used multiple times, `other` can be used
|
||||
# to share state between the different Quantizers
|
||||
other: tp.Optional[tp.Any]
|
||||
|
||||
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
|
||||
exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True):
|
||||
self.model = model
|
||||
self.min_size = min_size
|
||||
self.float16 = float16
|
||||
self.exclude = exclude
|
||||
self.detect_bound = detect_bound
|
||||
self._quantized = False
|
||||
self._need_unquantize = None
|
||||
self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook)
|
||||
self._post_handle = self.model.register_forward_hook(self._forward_hook)
|
||||
|
||||
self._qparams = []
|
||||
self._float16 = []
|
||||
self._others = []
|
||||
self._rnns = []
|
||||
|
||||
self._saved = []
|
||||
|
||||
self._find_params()
|
||||
|
||||
def _find_params(self):
|
||||
min_params = self.min_size * 2**20 // 4
|
||||
previous = {}
|
||||
for module_name, module in self.model.named_modules():
|
||||
if isinstance(module, torch.nn.RNNBase):
|
||||
self._rnns.append(module)
|
||||
for name, param in list(module.named_parameters(recurse=False)):
|
||||
full_name = f"{module_name}.{name}"
|
||||
matched = False
|
||||
for pattern in self.exclude:
|
||||
if fnmatch(full_name, pattern) or fnmatch(name, pattern):
|
||||
matched = True
|
||||
break
|
||||
|
||||
if param.numel() <= min_params or matched:
|
||||
if id(param) in previous:
|
||||
continue
|
||||
if self.detect_bound:
|
||||
previous[id(param)] = None
|
||||
if self.float16:
|
||||
self._float16.append(param)
|
||||
else:
|
||||
self._others.append(param)
|
||||
else:
|
||||
qparam = self._register_param(name, param, module, previous.get(id(param)))
|
||||
if self.detect_bound:
|
||||
previous[id(param)] = qparam
|
||||
self._qparams.append(qparam)
|
||||
|
||||
def _register_param(self, name, param, module, other):
|
||||
return self.__class__._QuantizedParam(name, param, module, other)
|
||||
|
||||
def _forward_pre_hook(self, module, input):
|
||||
if self.model.training:
|
||||
self._quantized_state = None
|
||||
self.unquantize()
|
||||
if self._pre_forward_train():
|
||||
self._fix_rnns()
|
||||
else:
|
||||
assert self._need_unquantize is None
|
||||
self._need_unquantize = self.quantize()
|
||||
|
||||
def _forward_hook(self, module, input, output):
|
||||
if self.model.training:
|
||||
if self._post_forward_train():
|
||||
self._fix_rnns(flatten=False) # Hacky, next forward will flatten
|
||||
else:
|
||||
if self._need_unquantize:
|
||||
self._need_unquantize = None
|
||||
self.unquantize()
|
||||
|
||||
def quantize(self):
|
||||
"""
|
||||
Immediately apply quantization to the model parameters.
|
||||
Model parameters are saved to later allow restoring the unquantized state.
|
||||
|
||||
Note that you shouldn't need to call this for model evaluation, as long as
|
||||
you properly call `model.train()` and `model.eval()`, but this can be
|
||||
useful for weight inspection.
|
||||
"""
|
||||
if self._quantized:
|
||||
return False
|
||||
self._saved = [qp.param.data.to('cpu', copy=True)
|
||||
for qp in self._qparams if qp.other is None]
|
||||
self.restore_quantized_state(self.get_quantized_state(packed=False))
|
||||
self._quantized = True
|
||||
self._fix_rnns()
|
||||
return True
|
||||
|
||||
@contextmanager
|
||||
def enter_quantize(self):
|
||||
"""Context manager for entering quantized state."""
|
||||
self.quantize()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.unquantize()
|
||||
|
||||
def unquantize(self):
|
||||
"""
|
||||
Revert a previous call to `quantize()`.
|
||||
"""
|
||||
if not self._quantized:
|
||||
return
|
||||
if not self._saved:
|
||||
raise RuntimeError("Nothing to restore. This shouldn't happen")
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is None:
|
||||
qparam.param.data[:] = self._saved.pop(0)
|
||||
assert len(self._saved) == 0
|
||||
self._quantized = False
|
||||
self._fix_rnns()
|
||||
|
||||
def _pre_forward_train(self) -> bool:
|
||||
"""
|
||||
Called once before each forward for continuous quantization.
|
||||
Should return True if parameters were changed.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _post_forward_train(self) -> bool:
|
||||
"""
|
||||
Called once after each forward (to restore state for instance).
|
||||
Should return True if parameters were changed.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _fix_rnns(self, flatten=True):
|
||||
"""
|
||||
To be called after quantization happened to fix RNNs.
|
||||
"""
|
||||
for rnn in self._rnns:
|
||||
rnn._flat_weights = [
|
||||
(lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn)
|
||||
for wn in rnn._flat_weights_names]
|
||||
if flatten:
|
||||
rnn.flatten_parameters()
|
||||
|
||||
def _bit_pack_param(self, qparam: _QuantizedParam, quantized: tp.Any,
|
||||
pack_fn: tp.Any) -> tp.Any:
|
||||
"""Further bitpack the quantized representation.
|
||||
This is used to return the quantized state. Should be overriden.
|
||||
"""
|
||||
return quantized
|
||||
|
||||
def _bit_unpack_param(self, qparam: _QuantizedParam, packed: tp.Any,
|
||||
unpack_fn: tp.Any) -> tp.Any:
|
||||
"""Unpack bitpacked representation. Should be overriden
|
||||
"""
|
||||
return packed
|
||||
|
||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
||||
"""
|
||||
To be overriden.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
||||
"""
|
||||
To be overriden.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_quantized_state(self, packed=True, torch_pack=False):
|
||||
"""
|
||||
Return a quantized representation fo the weights. If `packed` is True,
|
||||
this will also perform bitpacking to ensure optimal store.
|
||||
If `torck_pack` is true, the bitpacking from `torch_pack` will be used.
|
||||
It is slower (except maybe on GPU), but is compatible with torchscript.
|
||||
|
||||
You can restore a model from a quantized state either using
|
||||
`BaseQuantizer.restore_quantized_state` or `diffq.restore_quantized_state`
|
||||
if you do not have the original quantizer around anymore.
|
||||
"""
|
||||
float16_params = []
|
||||
for p in self._float16:
|
||||
q = p.data.half()
|
||||
float16_params.append(q)
|
||||
|
||||
if torch_pack:
|
||||
pack_fn = torch_pack_mod.pack
|
||||
else:
|
||||
pack_fn = bitpack.pack
|
||||
|
||||
all_quantized = []
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
continue
|
||||
quantized = self._quantize_param(qparam)
|
||||
if packed:
|
||||
quantized = self._bit_pack_param(qparam, quantized, pack_fn=pack_fn)
|
||||
all_quantized.append(quantized)
|
||||
|
||||
state = {
|
||||
"quantized": all_quantized,
|
||||
"float16": float16_params,
|
||||
"others": [p.data.clone() for p in self._others],
|
||||
}
|
||||
|
||||
kwargs = dict(self._init_kwargs)
|
||||
kwargs.pop("model")
|
||||
state["meta"] = {
|
||||
"init_kwargs": kwargs,
|
||||
"klass": self.__class__,
|
||||
"packed": packed,
|
||||
"torch_pack": torch_pack
|
||||
}
|
||||
return state
|
||||
|
||||
def restore_quantized_state(self, state) -> None:
|
||||
"""
|
||||
Restore the state of the model from the quantized state.
|
||||
"""
|
||||
for p, q in zip(self._float16, state["float16"]):
|
||||
p.data[:] = q.to(p)
|
||||
|
||||
for p, q in zip(self._others, state["others"]):
|
||||
p.data[:] = q
|
||||
|
||||
meta = state.get("meta", {})
|
||||
packed = meta.get("packed", False)
|
||||
torch_pack = meta.get("torch_pack", False)
|
||||
|
||||
if torch_pack:
|
||||
unpack_fn = torch_pack_mod.unpack
|
||||
else:
|
||||
unpack_fn = bitpack.unpack
|
||||
|
||||
remaining = list(state["quantized"])
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
# Only unquantize first appearance of nn.Parameter.
|
||||
continue
|
||||
quantized = remaining.pop(0)
|
||||
if packed:
|
||||
quantized = self._bit_unpack_param(qparam, quantized, unpack_fn)
|
||||
qparam.param.data[:] = self._unquantize_param(qparam, quantized)
|
||||
assert not remaining
|
||||
self._fix_rnns()
|
||||
|
||||
def detach(self) -> None:
|
||||
"""
|
||||
Detach from the model, removes hooks and anything else.
|
||||
"""
|
||||
self._pre_handle.remove()
|
||||
self._post_handle.remove()
|
||||
|
||||
def model_size(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns an estimate of the quantized model size.
|
||||
"""
|
||||
total = torch.tensor(0.)
|
||||
for p in self._float16:
|
||||
total += 16 * p.numel()
|
||||
for p in self._others:
|
||||
total += 32 * p.numel()
|
||||
return total / 2**20 / 8 # bits to MegaBytes
|
||||
|
||||
def true_model_size(self) -> float:
|
||||
"""
|
||||
Return the true quantized model size, in MB, without extra
|
||||
compression.
|
||||
"""
|
||||
return self.model_size().item()
|
||||
|
||||
def packed_model_size(self) -> float:
|
||||
"""Return the packed model size, when stored with pickle.
|
||||
This should be mostly equivalent to `true_model_size` up to some
|
||||
slight overhead for storing metadata.
|
||||
"""
|
||||
state = self.get_quantized_state(packed=True)
|
||||
return len(pickle.dumps(state)) / 2 ** 20
|
||||
|
||||
def compressed_model_size(self, compress_level=-1, num_workers=8) -> float:
|
||||
"""
|
||||
Return the compressed quantized model size, in MB.
|
||||
|
||||
Args:
|
||||
compress_level (int): compression level used with zlib,
|
||||
see `zlib.compress` for details.
|
||||
num_workers (int): will split the final big byte representation in that
|
||||
many chunks processed in parallels.
|
||||
"""
|
||||
out = io.BytesIO()
|
||||
torch.save(self.get_quantized_state(packed=False), out)
|
||||
ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers)
|
||||
return ms / 2 ** 20
|
||||
|
||||
|
||||
def restore_quantized_state(model: torch.nn.Module, state: dict):
|
||||
assert "meta" in state
|
||||
quantizer = state["meta"]["klass"](model, **state["meta"]["init_kwargs"])
|
||||
quantizer.restore_quantized_state(state)
|
||||
quantizer.detach()
|
||||
|
||||
|
||||
def _compress_len(data, compress_level):
|
||||
return len(zlib.compress(data, level=compress_level))
|
||||
|
||||
|
||||
def _parallel_compress_len(data, compress_level, num_workers):
|
||||
num_workers = min(cpu_count(), num_workers)
|
||||
chunk_size = int(math.ceil(len(data) / num_workers))
|
||||
chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)]
|
||||
with futures.ThreadPoolExecutor(num_workers) as pool:
|
||||
# thread pool is okay here, zlib calls an external C lib and GIL is released
|
||||
# before the call.
|
||||
return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks))
|
Binary file not shown.
318
diffq/diffq.py
318
diffq/diffq.py
@ -1,318 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Differentiable quantizer based on scaled noise injection.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from .base import BaseQuantizer
|
||||
from .uniform import uniform_quantize, uniform_unquantize
|
||||
from .utils import capture_init, simple_repr
|
||||
|
||||
|
||||
class DiffQuantizer(BaseQuantizer):
|
||||
@dataclass
|
||||
class _QuantizedParam(BaseQuantizer._QuantizedParam):
|
||||
logit: torch.nn.Parameter
|
||||
|
||||
@capture_init
|
||||
def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
|
||||
group_size: int = 1, min_bits: float = 2, max_bits: float = 15,
|
||||
param="bits", noise="gaussian",
|
||||
init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq",
|
||||
exclude: tp.List[str] = [], detect_bound: bool = True):
|
||||
"""
|
||||
Differentiable quantizer based on scaled noise injection.
|
||||
For every parameter `p` in the model, this introduces a number of bits parameter
|
||||
`b` with the same dimensions (when group_size = 1).
|
||||
Before each forward, `p` is replaced by `p + U`
|
||||
with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization
|
||||
step for `b` bits.
|
||||
This noise approximates the quantization noise in a differentiable manner, both
|
||||
with respect to the unquantized parameter `p` and the number of bits `b`.
|
||||
|
||||
At eveluation (as detected with `model.eval()`), the model is replaced
|
||||
by its true quantized version, and restored when going back to training.
|
||||
|
||||
When doing actual quantization (for serialization, or evaluation),
|
||||
the number of bits is rounded to the nearest integer, and needs to be stored along.
|
||||
This will cost a few bits per dimension. To reduce this cost, one can use `group_size`,
|
||||
which will use a single noise level for multiple weight entries.
|
||||
|
||||
You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the
|
||||
model size in MB. You can then use this estimate as a penalty in your training loss.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to quantize
|
||||
min_size (float): minimum size in MB of a parameter to be quantized.
|
||||
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
||||
group_size (int): weight entries are groupped together to reduce the number
|
||||
of noise scales to store. This should divide the size of all parameters
|
||||
bigger than min_size.
|
||||
min_bits (float): minimal number of bits.
|
||||
max_bits (float): maximal number of bits.
|
||||
init_bits (float): initial number of bits.
|
||||
extra_bits (float): extra bits to add for actual quantization (before roundoff).
|
||||
suffix (str): suffix used for the name of the extra noise scale parameters.
|
||||
exclude (list[str]): list of patterns used to match parameters to exclude.
|
||||
For instance `['bias']` to exclude all bias terms.
|
||||
detect_bound (bool): if True, will detect bound parameters and reuse
|
||||
the same quantized tensor for both, as well as the same number of bits.
|
||||
|
||||
..Warning::
|
||||
You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly.
|
||||
|
||||
"""
|
||||
self.group_size = group_size
|
||||
self.min_bits = min_bits
|
||||
self.max_bits = max_bits
|
||||
self.init_bits = init_bits
|
||||
self.extra_bits = extra_bits
|
||||
self.suffix = suffix
|
||||
self.param = param
|
||||
self.noise = noise
|
||||
assert noise in ["gaussian", "uniform"]
|
||||
self._optimizer_setup = False
|
||||
|
||||
self._min_noise = 1 / (2 ** self.max_bits - 1)
|
||||
self._max_noise = 1 / (2 ** self.min_bits - 1)
|
||||
|
||||
assert group_size >= 0
|
||||
assert min_bits < init_bits < max_bits, \
|
||||
"init_bits must be between min_bits and max_bits excluded3"
|
||||
|
||||
for name, _ in model.named_parameters():
|
||||
if name.endswith(suffix):
|
||||
raise RuntimeError("The model already has some noise scales parameters, "
|
||||
"maybe you used twice a DiffQuantizer on the same model?.")
|
||||
|
||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
||||
|
||||
def _get_bits(self, logit: torch.Tensor):
|
||||
if self.param == "noise":
|
||||
return torch.log2(1 + 1 / self._get_noise_scale(logit))
|
||||
else:
|
||||
t = torch.sigmoid(logit)
|
||||
return self.max_bits * t + (1 - t) * self.min_bits
|
||||
|
||||
def _get_noise_scale(self, logit: torch.Tensor):
|
||||
if self.param == "noise":
|
||||
t = torch.sigmoid(logit)
|
||||
return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise))
|
||||
else:
|
||||
return 1 / (2 ** self._get_bits(logit) - 1)
|
||||
|
||||
def _register_param(self, name, param, module, other):
|
||||
if other is not None:
|
||||
return self.__class__._QuantizedParam(
|
||||
name=name, param=param, module=module, logit=other.logit, other=other)
|
||||
assert self.group_size == 0 or param.numel() % self.group_size == 0
|
||||
# we want the initial number of bits to be init_bits.
|
||||
if self.param == "noise":
|
||||
noise_scale = 1 / (2 ** self.init_bits - 1)
|
||||
t = (math.log(noise_scale) - math.log(self._max_noise)) / (
|
||||
math.log(self._min_noise) - math.log(self._max_noise))
|
||||
else:
|
||||
t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
|
||||
assert 0 < t < 1
|
||||
logit = torch.logit(torch.tensor(float(t)))
|
||||
assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
|
||||
if self.group_size > 0:
|
||||
nparam = param.numel() // self.group_size
|
||||
else:
|
||||
nparam = 1
|
||||
logit = torch.nn.Parameter(
|
||||
torch.full(
|
||||
(nparam,),
|
||||
logit,
|
||||
device=param.device))
|
||||
module.register_parameter(name + self.suffix, logit)
|
||||
return self.__class__._QuantizedParam(
|
||||
name=name, param=param, module=module, logit=logit, other=None)
|
||||
|
||||
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
|
||||
params = [qp.logit for qp in self._qparams]
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
new_params = []
|
||||
for q in list(group["params"]):
|
||||
matched = False
|
||||
for p in params:
|
||||
if p is q:
|
||||
matched = True
|
||||
if not matched:
|
||||
new_params.append(q)
|
||||
group["params"][:] = new_params
|
||||
|
||||
def setup_optimizer(self, optimizer: torch.optim.Optimizer,
|
||||
lr: float = 1e-3, **kwargs):
|
||||
"""
|
||||
Setup the optimizer to tune the number of bits. In particular, this will deactivate
|
||||
weight decay for the bits parameters.
|
||||
|
||||
Args:
|
||||
optimizer (torch.Optimizer): optimizer to use.
|
||||
lr (float): specific learning rate for the bits parameters. 1e-3
|
||||
is perfect for Adam.,w
|
||||
kwargs (dict): overrides for other optimization parameters for the bits.
|
||||
"""
|
||||
assert not self._optimizer_setup
|
||||
self._optimizer_setup = True
|
||||
|
||||
params = [qp.logit for qp in self._qparams]
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
for q in list(group["params"]):
|
||||
for p in params:
|
||||
if p is q:
|
||||
raise RuntimeError("You should create the optimizer "
|
||||
"before the quantizer!")
|
||||
|
||||
group = {"params": params, "lr": lr, "weight_decay": 0}
|
||||
group.update(kwargs)
|
||||
optimizer.add_param_group(group)
|
||||
|
||||
def no_optimizer(self):
|
||||
"""
|
||||
Call this if you do not want to use an optimizer.
|
||||
"""
|
||||
self._optimizer_setup = True
|
||||
|
||||
def check_unused(self):
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
continue
|
||||
grad = qparam.param.grad
|
||||
if grad is None or (grad == 0).all():
|
||||
if qparam.logit.grad is not None:
|
||||
qparam.logit.grad.data.zero_()
|
||||
|
||||
def model_size(self, exact=False):
|
||||
"""
|
||||
Differentiable estimate of the model size.
|
||||
The size is returned in MB.
|
||||
|
||||
If `exact` is True, then the output is no longer differentiable but
|
||||
reflect exactly an achievable size, even without compression,
|
||||
i.e.same as returned by `naive_model_size()`.
|
||||
"""
|
||||
total = super().model_size()
|
||||
subtotal = 0
|
||||
for qparam in self._qparams:
|
||||
# only count the first appearance of a Parameter
|
||||
if qparam.other is not None:
|
||||
continue
|
||||
bits = self.extra_bits + self._get_bits(qparam.logit)
|
||||
if exact:
|
||||
bits = bits.round().clamp(1, 15)
|
||||
if self.group_size == 0:
|
||||
group_size = qparam.param.numel()
|
||||
else:
|
||||
group_size = self.group_size
|
||||
subtotal += group_size * bits.sum()
|
||||
subtotal += 2 * 32 # param scale
|
||||
|
||||
# Number of bits to represent each number of bits
|
||||
bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits)))
|
||||
subtotal += 8 # 8 bits for bits_bits
|
||||
subtotal += bits_bits * bits.numel()
|
||||
|
||||
subtotal /= 2 ** 20 * 8 # bits -> MegaBytes
|
||||
return total + subtotal
|
||||
|
||||
def true_model_size(self):
|
||||
"""
|
||||
Naive model size without zlib compression.
|
||||
"""
|
||||
return self.model_size(exact=True).item()
|
||||
|
||||
def _pre_forward_train(self):
|
||||
if not self._optimizer_setup:
|
||||
raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
|
||||
"before starting training.")
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
noisy = qparam.other.module._parameters[qparam.other.name]
|
||||
else:
|
||||
bits = self._get_bits(qparam.logit)[:, None]
|
||||
if self.group_size == 0:
|
||||
p_flat = qparam.param.view(-1)
|
||||
else:
|
||||
p_flat = qparam.param.view(-1, self.group_size)
|
||||
scale = p_flat.max() - p_flat.min()
|
||||
unit = 1 / (2**bits - 1)
|
||||
if self.noise == "uniform":
|
||||
noise_source = (torch.rand_like(p_flat) - 0.5)
|
||||
elif self.noise == "gaussian":
|
||||
noise_source = torch.randn_like(p_flat) / 2
|
||||
noise = scale * unit * noise_source
|
||||
noisy = p_flat + noise
|
||||
# We bypass the checks by PyTorch on parameters being leafs
|
||||
qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param)
|
||||
return True
|
||||
|
||||
def _post_forward_train(self):
|
||||
for qparam in self._qparams:
|
||||
qparam.module._parameters[qparam.name] = qparam.param
|
||||
return True
|
||||
|
||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
||||
bits = self.extra_bits + self._get_bits(qparam.logit)
|
||||
bits = bits.round().clamp(1, 15)[:, None].byte()
|
||||
if self.group_size == 0:
|
||||
p = qparam.param.data.view(1, -1)
|
||||
else:
|
||||
p = qparam.param.data.view(-1, self.group_size)
|
||||
levels, scales = uniform_quantize(p, bits)
|
||||
return levels, scales, bits[:, 0]
|
||||
|
||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
||||
levels, param_scale, bits = quantized
|
||||
bits = bits[:, None]
|
||||
return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data)
|
||||
|
||||
def _bit_pack_param(self, qparam, quantized, pack_fn):
|
||||
levels, scales, bits = quantized
|
||||
all_packed = []
|
||||
for bit in range(1, 15):
|
||||
sub_levels = levels[bits == bit]
|
||||
if not sub_levels.numel():
|
||||
all_packed.append(None)
|
||||
else:
|
||||
packed = pack_fn(sub_levels, bit)
|
||||
all_packed.append(packed)
|
||||
packed_bits = pack_fn(bits - self.min_bits)
|
||||
return (all_packed, scales, packed_bits)
|
||||
|
||||
def _bit_unpack_param(self, qparam, packed, unpack_fn):
|
||||
"""Unpack bitpacked representation. Should be overriden.
|
||||
"""
|
||||
packed_all_levels, scales, packed_bits = packed
|
||||
bits = unpack_fn(packed_bits, qparam.logit.numel()) + self.min_bits
|
||||
bits = bits.to(qparam.param.device)
|
||||
levels = torch.empty(qparam.logit.numel(), self.group_size,
|
||||
dtype=torch.short, device=qparam.param.device)
|
||||
for idx, packed_levels in enumerate(packed_all_levels):
|
||||
bit = idx + 1
|
||||
if packed_levels is None:
|
||||
continue
|
||||
sub_levels = levels[bits == bit]
|
||||
levels[bits == bit] = unpack_fn(
|
||||
packed_levels, sub_levels.numel()).view_as(sub_levels).to(sub_levels)
|
||||
return (levels, scales, bits)
|
||||
|
||||
def detach(self):
|
||||
super().detach()
|
||||
for qparam in self._qparams:
|
||||
delattr(qparam.module, qparam.name + self.suffix)
|
||||
|
||||
def __repr__(self):
|
||||
return simple_repr(self)
|
192
diffq/lsq.py
192
diffq/lsq.py
@ -1,192 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Learnt-Stepsize quantizer from [Esser et al. 2019] https://arxiv.org/abs/1902.08153.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from .base import BaseQuantizer
|
||||
from .utils import capture_init, simple_repr
|
||||
|
||||
|
||||
class LSQ(BaseQuantizer):
|
||||
"""Implements weight only quantization based on [Esser et al. 2019].
|
||||
https://arxiv.org/abs/1902.08153
|
||||
"""
|
||||
@dataclass
|
||||
class _QuantizedParam(BaseQuantizer._QuantizedParam):
|
||||
scale: torch.nn.Parameter
|
||||
|
||||
@capture_init
|
||||
def __init__(self, model: torch.nn.Module, bits: int = 8, min_size: float = 0.01,
|
||||
float16: bool = False, suffix: str = "_lsq", exclude=[], detect_bound=True):
|
||||
assert 0 < bits <= 15
|
||||
self.suffix = suffix
|
||||
self._optimizer_setup = False
|
||||
self.bits = bits
|
||||
|
||||
for name, _ in model.named_parameters():
|
||||
if name.endswith(suffix):
|
||||
raise RuntimeError("The model already has some noise scales parameters, "
|
||||
"maybe you used twice a LSQ on the same model?.")
|
||||
|
||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
||||
|
||||
def _register_param(self, name, param, module, other):
|
||||
if other is not None:
|
||||
return self.__class__._QuantizedParam(
|
||||
name=name, param=param, module=module, scale=other.scale, other=other)
|
||||
# we want the initial number of bits to be init_bits.
|
||||
scale = 2 * param.data.abs().mean() / (2 ** (self.bits - 1))**0.5
|
||||
scale = torch.nn.Parameter(scale)
|
||||
module.register_parameter(name + self.suffix, scale)
|
||||
return self.__class__._QuantizedParam(
|
||||
name=name, param=param, module=module, scale=scale, other=None)
|
||||
|
||||
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
|
||||
params = [qp.scale for qp in self._qparams]
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
new_params = []
|
||||
for q in list(group["params"]):
|
||||
matched = False
|
||||
for p in params:
|
||||
if p is q:
|
||||
matched = True
|
||||
if not matched:
|
||||
new_params.append(q)
|
||||
group["params"][:] = new_params
|
||||
|
||||
def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs):
|
||||
"""
|
||||
Setup the optimizer to tune the scale parameter.
|
||||
Following [Esser et al. 2019], we use the same LR and weight decay
|
||||
as the base optimizer, unless specified otherwise.
|
||||
|
||||
Args:
|
||||
optimizer (torch.Optimizer): optimizer to use.
|
||||
kwargs (dict): overrides for optimization parameters
|
||||
"""
|
||||
assert not self._optimizer_setup
|
||||
self._optimizer_setup = True
|
||||
|
||||
params = [qp.scale for qp in self._qparams]
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
for q in list(group["params"]):
|
||||
for p in params:
|
||||
if p is q:
|
||||
raise RuntimeError("You should create the optimizer "
|
||||
"before the quantizer!")
|
||||
|
||||
group = {"params": params}
|
||||
group.update(kwargs)
|
||||
optimizer.add_param_group(group)
|
||||
|
||||
def no_optimizer(self):
|
||||
"""
|
||||
Call this if you do not want to use an optimizer.
|
||||
"""
|
||||
self._optimizer_setup = True
|
||||
|
||||
def model_size(self, exact=False):
|
||||
"""
|
||||
Differentiable estimate of the model size.
|
||||
The size is returned in MB.
|
||||
|
||||
If `exact` is True, then the output is no longer differentiable but
|
||||
reflect exactly an achievable size, even without compression,
|
||||
i.e.same as returned by `naive_model_size()`.
|
||||
"""
|
||||
total = super().model_size()
|
||||
subtotal = 0
|
||||
for qparam in self._qparams:
|
||||
# only count the first appearance of a Parameter
|
||||
if qparam.other is not None:
|
||||
continue
|
||||
bits = qparam.param.numel() * self.bits
|
||||
subtotal += bits
|
||||
subtotal += 1 * 32 # param scale
|
||||
|
||||
subtotal /= 2 ** 20 * 8 # bits -> MegaBytes
|
||||
return total + subtotal
|
||||
|
||||
def true_model_size(self):
|
||||
"""
|
||||
Naive model size without zlib compression.
|
||||
"""
|
||||
return self.model_size(exact=True).item()
|
||||
|
||||
def _pre_forward_train(self):
|
||||
if not self._optimizer_setup:
|
||||
raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
|
||||
"before starting training.")
|
||||
for qparam in self._qparams:
|
||||
scale = qparam.scale
|
||||
quant, _ = quantize(qparam.param, scale, self.bits)
|
||||
# We bypass the checks by PyTorch on parameters being leafs
|
||||
qparam.module._parameters[qparam.name] = quant
|
||||
return True
|
||||
|
||||
def _post_forward_train(self):
|
||||
for qparam in self._qparams:
|
||||
qparam.module._parameters[qparam.name] = qparam.param
|
||||
return True
|
||||
|
||||
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
|
||||
_, index = quantize(qparam.param, qparam.scale, self.bits)
|
||||
assert (index <= (2 ** (self.bits - 1) - 1)).all(), index.max()
|
||||
assert (index >= (-2 ** (self.bits - 1))).all(), index.min()
|
||||
return index.detach().short(), qparam.scale.detach()
|
||||
|
||||
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
|
||||
index, scale = quantized
|
||||
return index.float() * scale
|
||||
|
||||
def _bit_pack_param(self, qparam, quantized, pack_fn):
|
||||
levels, scale = quantized
|
||||
packed = pack_fn(levels + 2 ** (self.bits - 1))
|
||||
return (packed, scale)
|
||||
|
||||
def _bit_unpack_param(self, qparam, packed, unpack_fn):
|
||||
"""Unpack bitpacked representation. Should be overriden
|
||||
"""
|
||||
packed_levels, scale = packed
|
||||
levels = unpack_fn(
|
||||
packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
|
||||
levels -= 2 ** (self.bits - 1)
|
||||
return (levels, scale)
|
||||
|
||||
def detach(self):
|
||||
super().detach()
|
||||
for qparam in self._qparams:
|
||||
delattr(qparam.module, qparam.name + self.suffix)
|
||||
|
||||
def __repr__(self):
|
||||
return simple_repr(self)
|
||||
|
||||
|
||||
def roundpass(x):
|
||||
return (x.round() - x).detach() + x
|
||||
|
||||
|
||||
def gradscale(x, scale):
|
||||
return (x - x * scale).detach() + x * scale
|
||||
|
||||
|
||||
def quantize(tensor, scale, bits):
|
||||
low = - 2 ** (bits - 1)
|
||||
high = 2 ** (bits - 1) - 1
|
||||
scale = gradscale(scale, 1 / (tensor.numel() * high)**0.5)
|
||||
|
||||
index = tensor / scale
|
||||
index = index.clamp(low, high)
|
||||
index = roundpass(index)
|
||||
return index * scale, index
|
@ -1,80 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Bit packing in pure PyTorch.
|
||||
Slower than bitpack.pyx but compatible with torchscript.
|
||||
"""
|
||||
import math
|
||||
import typing as tp
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def as_rectangle(p: torch.Tensor, side: int):
|
||||
"""Reshape as rectangle, using padding when necessary so that out shape is [*, side]"""
|
||||
p_flat = p.view(-1)
|
||||
ideal_length = int(math.ceil(len(p_flat) / side) * side)
|
||||
p_flat_pad = F.pad(p_flat, (0, ideal_length - len(p_flat)))
|
||||
return p_flat_pad.view(side, -1)
|
||||
|
||||
|
||||
def _storage_size(dtype: torch.dtype):
|
||||
if dtype == torch.int64:
|
||||
return 64
|
||||
elif dtype == torch.int32:
|
||||
return 32
|
||||
elif dtype == torch.int16:
|
||||
return 16
|
||||
elif dtype == torch.uint8:
|
||||
return 8
|
||||
else:
|
||||
raise ValueError("Invalid bitpacking storage type")
|
||||
|
||||
|
||||
def pack(indexes, nbits: int = 0, storage_dtype: torch.dtype = torch.int16):
|
||||
"""You can think of indexes as a "Tensor" of bits of shape [L, nbits].
|
||||
Instead of concatenating naively as [L * nbits], we instead look at it transposed as
|
||||
[nbits, L]. For L = 16 * G, we get [nbits, G, 16] which is trivial to store
|
||||
efficiently on int16 integers.
|
||||
There will be overhead if L is far from a multiple of 16 (e.g. 1) but for large
|
||||
model layers this is acceptable. Storage type can be changed.
|
||||
|
||||
`nbits` should be the number of bits on which the indexes are coded, and will
|
||||
actually be determined automatically if set to 0.
|
||||
"""
|
||||
assert not indexes.dtype.is_floating_point
|
||||
if indexes.numel() > 0:
|
||||
assert indexes.max().item() < 2 ** 15
|
||||
assert indexes.min().item() >= 0
|
||||
if nbits == 0:
|
||||
nbits = int(math.ceil(math.log2(1 + (indexes.max()))))
|
||||
else:
|
||||
assert indexes.max().item() < 2 ** nbits
|
||||
|
||||
indexes = indexes.reshape(-1)
|
||||
storage_size = _storage_size(storage_dtype)
|
||||
rect = as_rectangle(indexes, storage_size)
|
||||
out = torch.zeros(nbits, rect.shape[1], dtype=storage_dtype, device=indexes.device)
|
||||
for in_bit in range(nbits):
|
||||
for out_bit in range(storage_size):
|
||||
d = ((rect[out_bit] >> in_bit) & 1).to(out.dtype) << out_bit
|
||||
out[in_bit, :] |= d
|
||||
return out
|
||||
|
||||
|
||||
def unpack(packed: torch.Tensor, length: tp.Optional[int] = None):
|
||||
"""Opposite of `pack`. You might need to specify the original length."""
|
||||
storage_size = _storage_size(packed.dtype)
|
||||
nbits, groups = packed.shape
|
||||
out = torch.zeros(storage_size, groups, dtype=torch.int16, device=packed.device)
|
||||
for in_bit in range(storage_size):
|
||||
for out_bit in range(nbits):
|
||||
bit_value = (packed[out_bit, :] >> in_bit) & 1
|
||||
out[in_bit, :] = out[in_bit, :] | (bit_value.to(out) << out_bit)
|
||||
out = out.view(-1)
|
||||
if length is not None:
|
||||
out = out[:length]
|
||||
return out
|
@ -1,209 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""TorchScript export support.
|
||||
We have to do a lot of black magic for TorchScript to be happy
|
||||
because we cannot dynamically allocate new weights when loading the model.
|
||||
|
||||
Here is how it works:
|
||||
- we generate code in a temporary python file for the given model that explicitely
|
||||
override all the weights on the first forward from their packed version.
|
||||
This is because TorchScript does not let us iterate over parameters in a generic manner.
|
||||
- we zero out all the original weights. We cannot simply remove those weights
|
||||
because TorchScript won't let us recreate them.
|
||||
- A TorchScript file is just a zip file, but stored without compression.
|
||||
In order to remove the cost of storing the zeroed out weights, we unzip the file,
|
||||
and zip it again with compression.
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
import sys
|
||||
import typing as tp
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from torch import jit
|
||||
|
||||
from .diffq import DiffQuantizer
|
||||
from .uniform import uniform_unquantize
|
||||
from .torch_pack import unpack
|
||||
|
||||
_DiffQPacked = tp.Tuple[
|
||||
tp.List[tp.Optional[torch.Tensor]], tp.Tuple[float, float],
|
||||
torch.Tensor, tp.List[int]]
|
||||
|
||||
# This is the template for the generated class.
|
||||
TEMPLATE = '''
|
||||
import typing as tp
|
||||
import torch
|
||||
from torch import jit
|
||||
|
||||
from diffq.ts_export import _unpack_param, _DiffQPacked
|
||||
|
||||
from {module} import {klass}
|
||||
|
||||
|
||||
class DiffQTSModel(torch.nn.Module):
|
||||
def __init__(self, model: {klass}, group_size: int, min_bits: int,
|
||||
packed: tp.List[_DiffQPacked]):
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
self.min_bits = min_bits
|
||||
self.model = model
|
||||
self._unpacked = False
|
||||
self._packed = packed
|
||||
|
||||
@jit.export
|
||||
def unpack(self):
|
||||
"""
|
||||
Unpack the weights, automatically called on the first forward,
|
||||
or explicitely."""
|
||||
if self._unpacked:
|
||||
return
|
||||
{unpack_assigns}
|
||||
self._unpacked = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.unpack()
|
||||
return self.model.forward(x)
|
||||
'''
|
||||
|
||||
# those are the assignments for each quantized weight.
|
||||
UNPACK_ASSIGN = (' ' * 8) + ('self.model{full_name}.data[:] = '
|
||||
'_unpack_param(self._packed[{index}], '
|
||||
'group_size=self.group_size, min_bits=self.min_bits)')
|
||||
UNPACK_ASSIGN_SAME = (' ' * 8) + 'self.model{full_name} = self.model{other_name}'
|
||||
|
||||
|
||||
def export(quantizer: DiffQuantizer, path: tp.Union[str, Path]):
|
||||
"""Export the given quantized model to the given path.
|
||||
We must save the quantized model ourselves, as we need to recompress
|
||||
the zip archive afterwards.
|
||||
"""
|
||||
packed: tp.List[_DiffQPacked] = []
|
||||
uniq_name = ''.join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(12)])
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
sys.path.insert(0, tmpdir)
|
||||
try:
|
||||
code = _codegen(quantizer)
|
||||
with open(Path(tmpdir) / f'{uniq_name}.py', 'w') as f:
|
||||
f.write(code)
|
||||
module = importlib.import_module(uniq_name)
|
||||
ts_klass = module.DiffQTSModel
|
||||
state = quantizer.get_quantized_state(packed=True, torch_pack=True)
|
||||
quantized = state["quantized"]
|
||||
for qparam in quantizer._qparams:
|
||||
if qparam.other is None:
|
||||
levels, scales, bits = quantized.pop(0)
|
||||
size = qparam.param.size()
|
||||
packed.append((levels, scales, bits, list(size)))
|
||||
qparam.param.data.zero_()
|
||||
quantizer.detach()
|
||||
ts_premodel = ts_klass(quantizer.model, quantizer.group_size,
|
||||
quantizer.min_bits, packed)
|
||||
ts_model = jit.script(ts_premodel)
|
||||
if path is not None:
|
||||
jit.save(ts_model, path)
|
||||
recompress(path)
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
return ts_model
|
||||
|
||||
|
||||
def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor:
|
||||
"""Function called from TorchScript on the first forward to decode the
|
||||
packed weights to FP32.
|
||||
"""
|
||||
packed_all_levels, scales, packed_bits, shape = packed
|
||||
numel = 1
|
||||
for dim in shape:
|
||||
numel *= dim
|
||||
bits = unpack(packed_bits, numel // group_size) + min_bits
|
||||
levels = torch.empty(bits.numel(), group_size, dtype=torch.short)
|
||||
for idx, packed_levels in enumerate(packed_all_levels):
|
||||
bit = idx + 1
|
||||
if packed_levels is not None:
|
||||
sub_levels = levels[bits == bit]
|
||||
levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels)
|
||||
bits = bits[:, None]
|
||||
unquant = uniform_unquantize(levels, scales, bits)
|
||||
if len(shape) == 4:
|
||||
return unquant.view(shape[0], shape[1], shape[2], shape[3])
|
||||
elif len(shape) == 3:
|
||||
return unquant.view(shape[0], shape[1], shape[2])
|
||||
elif len(shape) == 2:
|
||||
return unquant.view(shape[0], shape[1])
|
||||
elif len(shape) == 1:
|
||||
return unquant.view(shape[0])
|
||||
else:
|
||||
raise RuntimeError("Invalid numbr of dim")
|
||||
|
||||
|
||||
def recompress(path: tp.Union[str, Path]):
|
||||
"""After having saved the torchscript file, this will recompress it
|
||||
to make sure all the zeroed out parameters don't actually take any space.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(path) as zipin:
|
||||
zipin.extractall(tmpdir)
|
||||
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED,
|
||||
compresslevel=1) as zipout:
|
||||
for root, folders, files in os.walk(tmpdir):
|
||||
for file in files:
|
||||
fp = Path(root) / file
|
||||
name = fp.relative_to(tmpdir)
|
||||
zipout.write(fp, name)
|
||||
|
||||
|
||||
def _get_full_name_access(full_name):
|
||||
# When generating code, we need to handle attributes vs. indexing.
|
||||
parts = []
|
||||
for part in full_name.split("."):
|
||||
try:
|
||||
index = int(part)
|
||||
except ValueError:
|
||||
parts.append("." + part)
|
||||
else:
|
||||
parts.append(f"[{index}]")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _codegen(quantizer: DiffQuantizer):
|
||||
# Generates the code for the given quantizer
|
||||
module = quantizer.model.__class__.__module__
|
||||
klass = quantizer.model.__class__.__name__
|
||||
model = quantizer.model
|
||||
|
||||
assert not quantizer.float16
|
||||
names = {}
|
||||
for mod_name, mod in model.named_modules():
|
||||
names[mod] = mod_name
|
||||
unpack_assigns = []
|
||||
|
||||
index = 0
|
||||
for qparam in quantizer._qparams:
|
||||
mod_name = names[qparam.module]
|
||||
if mod_name == '':
|
||||
full_name = qparam.name
|
||||
else:
|
||||
full_name = mod_name + '.' + qparam.name
|
||||
full_name = _get_full_name_access(full_name)
|
||||
if qparam.other is None:
|
||||
unpack_assigns.append(UNPACK_ASSIGN.format(full_name=full_name, index=index))
|
||||
index += 1
|
||||
else:
|
||||
other_name = names[(qparam.other.module, qparam.other.name)]
|
||||
other_name = _get_full_name_access(other_name)
|
||||
unpack_assigns.append(
|
||||
UNPACK_ASSIGN_SAME.format(full_name=full_name, other_name=other_name))
|
||||
|
||||
return TEMPLATE.format(
|
||||
module=module,
|
||||
klass=klass,
|
||||
unpack_assigns='\n'.join(unpack_assigns))
|
135
diffq/uniform.py
135
diffq/uniform.py
@ -1,135 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Classic uniform quantization over n bits.
|
||||
"""
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
from .base import BaseQuantizer
|
||||
from .utils import capture_init, simple_repr
|
||||
|
||||
|
||||
def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
|
||||
"""
|
||||
Quantize the given weights over `bits` bits.
|
||||
|
||||
Returns:
|
||||
- quantized levels
|
||||
- (min, max) range.
|
||||
|
||||
"""
|
||||
assert (bits >= 1).all() and (bits <= 15).all()
|
||||
num_levels = (2 ** bits.float()).long()
|
||||
mn = p.min().item()
|
||||
mx = p.max().item()
|
||||
p = (p - mn) / (mx - mn) # put p in [0, 1]
|
||||
unit = 1 / (num_levels - 1) # quantization unit
|
||||
levels = (p / unit).round()
|
||||
if (bits <= 8).all():
|
||||
levels = levels.byte()
|
||||
else:
|
||||
levels = levels.short()
|
||||
return levels, (mn, mx)
|
||||
|
||||
|
||||
def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float],
|
||||
bits: torch.Tensor = torch.tensor(8.)):
|
||||
"""
|
||||
Unquantize the weights from the levels and scale. Return a float32 tensor.
|
||||
"""
|
||||
mn, mx = scales
|
||||
num_levels = 2 ** bits.float()
|
||||
unit = 1 / (num_levels - 1)
|
||||
levels = levels.float()
|
||||
p = levels * unit # in [0, 1]
|
||||
return p * (mx - mn) + mn
|
||||
|
||||
|
||||
class UniformQuantizer(BaseQuantizer):
|
||||
@capture_init
|
||||
def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01,
|
||||
float16: bool = False, qat: bool = False, exclude=[], detect_bound=True):
|
||||
"""
|
||||
Args:
|
||||
model (torch.nn.Module): model to quantize
|
||||
bits (float): number of bits to quantize over.
|
||||
min_size (float): minimum size in MB of a parameter to be quantized.
|
||||
float16 (bool): if a layer is smaller than min_size, should we still do float16?
|
||||
qat (bool): perform quantized aware training.
|
||||
exclude (list[str]): list of patterns used to match parameters to exclude.
|
||||
For instance `['bias']` to exclude all bias terms.
|
||||
detect_bound (bool): if True, will detect bound parameters and reuse
|
||||
the same quantized tensor for both.
|
||||
"""
|
||||
self.bits = float(bits)
|
||||
self.qat = qat
|
||||
|
||||
super().__init__(model, min_size, float16, exclude, detect_bound)
|
||||
|
||||
def __repr__(self):
|
||||
return simple_repr(self, )
|
||||
|
||||
def _pre_forward_train(self):
|
||||
if self.qat:
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is not None:
|
||||
new_param = qparam.other.module._parameters[qparam.other.name]
|
||||
else:
|
||||
quantized = self._quantize_param(qparam)
|
||||
qvalue = self._unquantize_param(qparam, quantized)
|
||||
new_param = qparam.param + (qvalue - qparam.param).detach()
|
||||
qparam.module._parameters[qparam.name] = new_param
|
||||
return True
|
||||
return False
|
||||
|
||||
def _post_forward_train(self):
|
||||
if self.qat:
|
||||
for qparam in self._qparams:
|
||||
qparam.module._parameters[qparam.name] = qparam.param
|
||||
return True
|
||||
return False
|
||||
|
||||
def _quantize_param(self, qparam):
|
||||
levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits))
|
||||
return (levels, scales)
|
||||
|
||||
def _unquantize_param(self, qparam, quantized):
|
||||
levels, scales = quantized
|
||||
return uniform_unquantize(levels, scales, torch.tensor(self.bits))
|
||||
|
||||
def _bit_pack_param(self, qparam, quantized, pack_fn):
|
||||
levels, scales = quantized
|
||||
packed = pack_fn(levels, self.bits)
|
||||
return (packed, scales)
|
||||
|
||||
def _bit_unpack_param(self, qparam, packed, unpack_fn):
|
||||
"""Unpack bitpacked representation. Should be overriden
|
||||
"""
|
||||
packed_levels, scales = packed
|
||||
levels = unpack_fn(
|
||||
packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
|
||||
return (levels, scales)
|
||||
|
||||
def model_size(self):
|
||||
"""
|
||||
Non differentiable model size in MB.
|
||||
"""
|
||||
total = super().model_size()
|
||||
subtotal = 0
|
||||
for qparam in self._qparams:
|
||||
if qparam.other is None: # if parameter is bound, count only one copy.
|
||||
subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales
|
||||
subtotal /= 2**20 * 8 # bits to MegaBytes
|
||||
return total + subtotal
|
||||
|
||||
def true_model_size(self):
|
||||
"""
|
||||
Return the true quantized model size, in MB, without extra
|
||||
compression.
|
||||
"""
|
||||
return self.model_size().item()
|
@ -1,58 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
def simple_repr(obj, attrs: Optional[List[str]] = None, overrides={}):
|
||||
"""
|
||||
Return a simple representation string for `obj`.
|
||||
If `attrs` is not None, it should be a list of attributes to include.
|
||||
"""
|
||||
params = inspect.signature(obj.__class__).parameters
|
||||
attrs_repr = []
|
||||
if attrs is None:
|
||||
attrs = params.keys()
|
||||
for attr in attrs:
|
||||
display = False
|
||||
if attr in overrides:
|
||||
value = overrides[attr]
|
||||
elif hasattr(obj, attr):
|
||||
value = getattr(obj, attr)
|
||||
else:
|
||||
continue
|
||||
if attr in params:
|
||||
param = params[attr]
|
||||
if param.default is inspect._empty or value != param.default:
|
||||
display = True
|
||||
else:
|
||||
display = True
|
||||
|
||||
if display:
|
||||
attrs_repr.append(f"{attr}={value}")
|
||||
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
|
||||
|
||||
|
||||
def capture_init(init):
|
||||
"""capture_init.
|
||||
|
||||
Decorate `__init__` with this, and you can then
|
||||
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
|
||||
"""
|
||||
signature = inspect.signature(init)
|
||||
|
||||
@functools.wraps(init)
|
||||
def __init__(self, *args, **kwargs):
|
||||
bound = signature.bind(self, *args, **kwargs)
|
||||
actual_kwargs = dict(bound.arguments)
|
||||
del actual_kwargs['self']
|
||||
actual_kwargs.update(bound.kwargs)
|
||||
self._init_kwargs = actual_kwargs
|
||||
init(self, *args, **kwargs)
|
||||
|
||||
return __init__
|
Loading…
x
Reference in New Issue
Block a user