diff --git a/diffq/__init__.py b/diffq/__init__.py deleted file mode 100644 index b67e784..0000000 --- a/diffq/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# flake8: noqa -""" -This package implements different quantization strategies: - -- `diffq.uniform.UniformQuantizer`: classic uniform quantization over n bits. -- `diffq.diffq.DiffQuantizer`: differentiable quantizer based on scaled noise injection. -- `diffq.lsq.LSQ`: Learnt Step size Quantizer based on [Esser et al. 2019] https://arxiv.org/abs/1902.08153 -- `diffq.bitpack`: efficient CPU bit-packing for returning quantized states. -- `diffq.torch_pack`: torch based bit-packing compatible with torchscript. -- `diffq.ts_export`: support exporting DiffQ based models to torchscript. - - -Also, do check `diffq.base.BaseQuantizer` for the common methods of all Quantizers. -""" - -from .uniform import UniformQuantizer -from .diffq import DiffQuantizer -from .lsq import LSQ -from .base import restore_quantized_state -from . import ts_export - - -__version__ = "0.2.2" diff --git a/diffq/base.py b/diffq/base.py deleted file mode 100644 index acbded3..0000000 --- a/diffq/base.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Base class for all quantizers.""" -from contextlib import contextmanager -from dataclasses import dataclass -from concurrent import futures -from fnmatch import fnmatch -from functools import partial -import io -import math -from multiprocessing import cpu_count -import pickle -import typing as tp -import zlib - -import torch - -from . import bitpack -from . import torch_pack as torch_pack_mod - - -class BaseQuantizer: - @dataclass - class _QuantizedParam: - name: str - param: torch.nn.Parameter - module: torch.nn.Module - # If a Parameter is used multiple times, `other` can be used - # to share state between the different Quantizers - other: tp.Optional[tp.Any] - - def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False, - exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True): - self.model = model - self.min_size = min_size - self.float16 = float16 - self.exclude = exclude - self.detect_bound = detect_bound - self._quantized = False - self._need_unquantize = None - self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook) - self._post_handle = self.model.register_forward_hook(self._forward_hook) - - self._qparams = [] - self._float16 = [] - self._others = [] - self._rnns = [] - - self._saved = [] - - self._find_params() - - def _find_params(self): - min_params = self.min_size * 2**20 // 4 - previous = {} - for module_name, module in self.model.named_modules(): - if isinstance(module, torch.nn.RNNBase): - self._rnns.append(module) - for name, param in list(module.named_parameters(recurse=False)): - full_name = f"{module_name}.{name}" - matched = False - for pattern in self.exclude: - if fnmatch(full_name, pattern) or fnmatch(name, pattern): - matched = True - break - - if param.numel() <= min_params or matched: - if id(param) in previous: - continue - if self.detect_bound: - previous[id(param)] = None - if self.float16: - self._float16.append(param) - else: - self._others.append(param) - else: - qparam = self._register_param(name, param, module, previous.get(id(param))) - if self.detect_bound: - previous[id(param)] = qparam - self._qparams.append(qparam) - - def _register_param(self, name, param, module, other): - return self.__class__._QuantizedParam(name, param, module, other) - - def _forward_pre_hook(self, module, input): - if self.model.training: - self._quantized_state = None - self.unquantize() - if self._pre_forward_train(): - self._fix_rnns() - else: - assert self._need_unquantize is None - self._need_unquantize = self.quantize() - - def _forward_hook(self, module, input, output): - if self.model.training: - if self._post_forward_train(): - self._fix_rnns(flatten=False) # Hacky, next forward will flatten - else: - if self._need_unquantize: - self._need_unquantize = None - self.unquantize() - - def quantize(self): - """ - Immediately apply quantization to the model parameters. - Model parameters are saved to later allow restoring the unquantized state. - - Note that you shouldn't need to call this for model evaluation, as long as - you properly call `model.train()` and `model.eval()`, but this can be - useful for weight inspection. - """ - if self._quantized: - return False - self._saved = [qp.param.data.to('cpu', copy=True) - for qp in self._qparams if qp.other is None] - self.restore_quantized_state(self.get_quantized_state(packed=False)) - self._quantized = True - self._fix_rnns() - return True - - @contextmanager - def enter_quantize(self): - """Context manager for entering quantized state.""" - self.quantize() - try: - yield - finally: - self.unquantize() - - def unquantize(self): - """ - Revert a previous call to `quantize()`. - """ - if not self._quantized: - return - if not self._saved: - raise RuntimeError("Nothing to restore. This shouldn't happen") - for qparam in self._qparams: - if qparam.other is None: - qparam.param.data[:] = self._saved.pop(0) - assert len(self._saved) == 0 - self._quantized = False - self._fix_rnns() - - def _pre_forward_train(self) -> bool: - """ - Called once before each forward for continuous quantization. - Should return True if parameters were changed. - """ - return False - - def _post_forward_train(self) -> bool: - """ - Called once after each forward (to restore state for instance). - Should return True if parameters were changed. - """ - return False - - def _fix_rnns(self, flatten=True): - """ - To be called after quantization happened to fix RNNs. - """ - for rnn in self._rnns: - rnn._flat_weights = [ - (lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn) - for wn in rnn._flat_weights_names] - if flatten: - rnn.flatten_parameters() - - def _bit_pack_param(self, qparam: _QuantizedParam, quantized: tp.Any, - pack_fn: tp.Any) -> tp.Any: - """Further bitpack the quantized representation. - This is used to return the quantized state. Should be overriden. - """ - return quantized - - def _bit_unpack_param(self, qparam: _QuantizedParam, packed: tp.Any, - unpack_fn: tp.Any) -> tp.Any: - """Unpack bitpacked representation. Should be overriden - """ - return packed - - def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: - """ - To be overriden. - """ - raise NotImplementedError() - - def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: - """ - To be overriden. - """ - raise NotImplementedError() - - def get_quantized_state(self, packed=True, torch_pack=False): - """ - Return a quantized representation fo the weights. If `packed` is True, - this will also perform bitpacking to ensure optimal store. - If `torck_pack` is true, the bitpacking from `torch_pack` will be used. - It is slower (except maybe on GPU), but is compatible with torchscript. - - You can restore a model from a quantized state either using - `BaseQuantizer.restore_quantized_state` or `diffq.restore_quantized_state` - if you do not have the original quantizer around anymore. - """ - float16_params = [] - for p in self._float16: - q = p.data.half() - float16_params.append(q) - - if torch_pack: - pack_fn = torch_pack_mod.pack - else: - pack_fn = bitpack.pack - - all_quantized = [] - for qparam in self._qparams: - if qparam.other is not None: - continue - quantized = self._quantize_param(qparam) - if packed: - quantized = self._bit_pack_param(qparam, quantized, pack_fn=pack_fn) - all_quantized.append(quantized) - - state = { - "quantized": all_quantized, - "float16": float16_params, - "others": [p.data.clone() for p in self._others], - } - - kwargs = dict(self._init_kwargs) - kwargs.pop("model") - state["meta"] = { - "init_kwargs": kwargs, - "klass": self.__class__, - "packed": packed, - "torch_pack": torch_pack - } - return state - - def restore_quantized_state(self, state) -> None: - """ - Restore the state of the model from the quantized state. - """ - for p, q in zip(self._float16, state["float16"]): - p.data[:] = q.to(p) - - for p, q in zip(self._others, state["others"]): - p.data[:] = q - - meta = state.get("meta", {}) - packed = meta.get("packed", False) - torch_pack = meta.get("torch_pack", False) - - if torch_pack: - unpack_fn = torch_pack_mod.unpack - else: - unpack_fn = bitpack.unpack - - remaining = list(state["quantized"]) - for qparam in self._qparams: - if qparam.other is not None: - # Only unquantize first appearance of nn.Parameter. - continue - quantized = remaining.pop(0) - if packed: - quantized = self._bit_unpack_param(qparam, quantized, unpack_fn) - qparam.param.data[:] = self._unquantize_param(qparam, quantized) - assert not remaining - self._fix_rnns() - - def detach(self) -> None: - """ - Detach from the model, removes hooks and anything else. - """ - self._pre_handle.remove() - self._post_handle.remove() - - def model_size(self) -> torch.Tensor: - """ - Returns an estimate of the quantized model size. - """ - total = torch.tensor(0.) - for p in self._float16: - total += 16 * p.numel() - for p in self._others: - total += 32 * p.numel() - return total / 2**20 / 8 # bits to MegaBytes - - def true_model_size(self) -> float: - """ - Return the true quantized model size, in MB, without extra - compression. - """ - return self.model_size().item() - - def packed_model_size(self) -> float: - """Return the packed model size, when stored with pickle. - This should be mostly equivalent to `true_model_size` up to some - slight overhead for storing metadata. - """ - state = self.get_quantized_state(packed=True) - return len(pickle.dumps(state)) / 2 ** 20 - - def compressed_model_size(self, compress_level=-1, num_workers=8) -> float: - """ - Return the compressed quantized model size, in MB. - - Args: - compress_level (int): compression level used with zlib, - see `zlib.compress` for details. - num_workers (int): will split the final big byte representation in that - many chunks processed in parallels. - """ - out = io.BytesIO() - torch.save(self.get_quantized_state(packed=False), out) - ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers) - return ms / 2 ** 20 - - -def restore_quantized_state(model: torch.nn.Module, state: dict): - assert "meta" in state - quantizer = state["meta"]["klass"](model, **state["meta"]["init_kwargs"]) - quantizer.restore_quantized_state(state) - quantizer.detach() - - -def _compress_len(data, compress_level): - return len(zlib.compress(data, level=compress_level)) - - -def _parallel_compress_len(data, compress_level, num_workers): - num_workers = min(cpu_count(), num_workers) - chunk_size = int(math.ceil(len(data) / num_workers)) - chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)] - with futures.ThreadPoolExecutor(num_workers) as pool: - # thread pool is okay here, zlib calls an external C lib and GIL is released - # before the call. - return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks)) diff --git a/diffq/bitpack.cp39-win_amd64.pyd b/diffq/bitpack.cp39-win_amd64.pyd deleted file mode 100644 index a28d6e9..0000000 Binary files a/diffq/bitpack.cp39-win_amd64.pyd and /dev/null differ diff --git a/diffq/diffq.py b/diffq/diffq.py deleted file mode 100644 index 5fa5113..0000000 --- a/diffq/diffq.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Differentiable quantizer based on scaled noise injection. -""" -from dataclasses import dataclass -import math -import typing as tp - -import torch - -from .base import BaseQuantizer -from .uniform import uniform_quantize, uniform_unquantize -from .utils import capture_init, simple_repr - - -class DiffQuantizer(BaseQuantizer): - @dataclass - class _QuantizedParam(BaseQuantizer._QuantizedParam): - logit: torch.nn.Parameter - - @capture_init - def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False, - group_size: int = 1, min_bits: float = 2, max_bits: float = 15, - param="bits", noise="gaussian", - init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq", - exclude: tp.List[str] = [], detect_bound: bool = True): - """ - Differentiable quantizer based on scaled noise injection. - For every parameter `p` in the model, this introduces a number of bits parameter - `b` with the same dimensions (when group_size = 1). - Before each forward, `p` is replaced by `p + U` - with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization - step for `b` bits. - This noise approximates the quantization noise in a differentiable manner, both - with respect to the unquantized parameter `p` and the number of bits `b`. - - At eveluation (as detected with `model.eval()`), the model is replaced - by its true quantized version, and restored when going back to training. - - When doing actual quantization (for serialization, or evaluation), - the number of bits is rounded to the nearest integer, and needs to be stored along. - This will cost a few bits per dimension. To reduce this cost, one can use `group_size`, - which will use a single noise level for multiple weight entries. - - You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the - model size in MB. You can then use this estimate as a penalty in your training loss. - - Args: - model (torch.nn.Module): model to quantize - min_size (float): minimum size in MB of a parameter to be quantized. - float16 (bool): if a layer is smaller than min_size, should we still do float16? - group_size (int): weight entries are groupped together to reduce the number - of noise scales to store. This should divide the size of all parameters - bigger than min_size. - min_bits (float): minimal number of bits. - max_bits (float): maximal number of bits. - init_bits (float): initial number of bits. - extra_bits (float): extra bits to add for actual quantization (before roundoff). - suffix (str): suffix used for the name of the extra noise scale parameters. - exclude (list[str]): list of patterns used to match parameters to exclude. - For instance `['bias']` to exclude all bias terms. - detect_bound (bool): if True, will detect bound parameters and reuse - the same quantized tensor for both, as well as the same number of bits. - - ..Warning:: - You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly. - - """ - self.group_size = group_size - self.min_bits = min_bits - self.max_bits = max_bits - self.init_bits = init_bits - self.extra_bits = extra_bits - self.suffix = suffix - self.param = param - self.noise = noise - assert noise in ["gaussian", "uniform"] - self._optimizer_setup = False - - self._min_noise = 1 / (2 ** self.max_bits - 1) - self._max_noise = 1 / (2 ** self.min_bits - 1) - - assert group_size >= 0 - assert min_bits < init_bits < max_bits, \ - "init_bits must be between min_bits and max_bits excluded3" - - for name, _ in model.named_parameters(): - if name.endswith(suffix): - raise RuntimeError("The model already has some noise scales parameters, " - "maybe you used twice a DiffQuantizer on the same model?.") - - super().__init__(model, min_size, float16, exclude, detect_bound) - - def _get_bits(self, logit: torch.Tensor): - if self.param == "noise": - return torch.log2(1 + 1 / self._get_noise_scale(logit)) - else: - t = torch.sigmoid(logit) - return self.max_bits * t + (1 - t) * self.min_bits - - def _get_noise_scale(self, logit: torch.Tensor): - if self.param == "noise": - t = torch.sigmoid(logit) - return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise)) - else: - return 1 / (2 ** self._get_bits(logit) - 1) - - def _register_param(self, name, param, module, other): - if other is not None: - return self.__class__._QuantizedParam( - name=name, param=param, module=module, logit=other.logit, other=other) - assert self.group_size == 0 or param.numel() % self.group_size == 0 - # we want the initial number of bits to be init_bits. - if self.param == "noise": - noise_scale = 1 / (2 ** self.init_bits - 1) - t = (math.log(noise_scale) - math.log(self._max_noise)) / ( - math.log(self._min_noise) - math.log(self._max_noise)) - else: - t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits) - assert 0 < t < 1 - logit = torch.logit(torch.tensor(float(t))) - assert abs(self._get_bits(logit) - self.init_bits) < 1e-5 - if self.group_size > 0: - nparam = param.numel() // self.group_size - else: - nparam = 1 - logit = torch.nn.Parameter( - torch.full( - (nparam,), - logit, - device=param.device)) - module.register_parameter(name + self.suffix, logit) - return self.__class__._QuantizedParam( - name=name, param=param, module=module, logit=logit, other=None) - - def clear_optimizer(self, optimizer: torch.optim.Optimizer): - params = [qp.logit for qp in self._qparams] - - for group in optimizer.param_groups: - new_params = [] - for q in list(group["params"]): - matched = False - for p in params: - if p is q: - matched = True - if not matched: - new_params.append(q) - group["params"][:] = new_params - - def setup_optimizer(self, optimizer: torch.optim.Optimizer, - lr: float = 1e-3, **kwargs): - """ - Setup the optimizer to tune the number of bits. In particular, this will deactivate - weight decay for the bits parameters. - - Args: - optimizer (torch.Optimizer): optimizer to use. - lr (float): specific learning rate for the bits parameters. 1e-3 - is perfect for Adam.,w - kwargs (dict): overrides for other optimization parameters for the bits. - """ - assert not self._optimizer_setup - self._optimizer_setup = True - - params = [qp.logit for qp in self._qparams] - - for group in optimizer.param_groups: - for q in list(group["params"]): - for p in params: - if p is q: - raise RuntimeError("You should create the optimizer " - "before the quantizer!") - - group = {"params": params, "lr": lr, "weight_decay": 0} - group.update(kwargs) - optimizer.add_param_group(group) - - def no_optimizer(self): - """ - Call this if you do not want to use an optimizer. - """ - self._optimizer_setup = True - - def check_unused(self): - for qparam in self._qparams: - if qparam.other is not None: - continue - grad = qparam.param.grad - if grad is None or (grad == 0).all(): - if qparam.logit.grad is not None: - qparam.logit.grad.data.zero_() - - def model_size(self, exact=False): - """ - Differentiable estimate of the model size. - The size is returned in MB. - - If `exact` is True, then the output is no longer differentiable but - reflect exactly an achievable size, even without compression, - i.e.same as returned by `naive_model_size()`. - """ - total = super().model_size() - subtotal = 0 - for qparam in self._qparams: - # only count the first appearance of a Parameter - if qparam.other is not None: - continue - bits = self.extra_bits + self._get_bits(qparam.logit) - if exact: - bits = bits.round().clamp(1, 15) - if self.group_size == 0: - group_size = qparam.param.numel() - else: - group_size = self.group_size - subtotal += group_size * bits.sum() - subtotal += 2 * 32 # param scale - - # Number of bits to represent each number of bits - bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits))) - subtotal += 8 # 8 bits for bits_bits - subtotal += bits_bits * bits.numel() - - subtotal /= 2 ** 20 * 8 # bits -> MegaBytes - return total + subtotal - - def true_model_size(self): - """ - Naive model size without zlib compression. - """ - return self.model_size(exact=True).item() - - def _pre_forward_train(self): - if not self._optimizer_setup: - raise RuntimeError("You must call `setup_optimizer()` on your optimizer " - "before starting training.") - for qparam in self._qparams: - if qparam.other is not None: - noisy = qparam.other.module._parameters[qparam.other.name] - else: - bits = self._get_bits(qparam.logit)[:, None] - if self.group_size == 0: - p_flat = qparam.param.view(-1) - else: - p_flat = qparam.param.view(-1, self.group_size) - scale = p_flat.max() - p_flat.min() - unit = 1 / (2**bits - 1) - if self.noise == "uniform": - noise_source = (torch.rand_like(p_flat) - 0.5) - elif self.noise == "gaussian": - noise_source = torch.randn_like(p_flat) / 2 - noise = scale * unit * noise_source - noisy = p_flat + noise - # We bypass the checks by PyTorch on parameters being leafs - qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param) - return True - - def _post_forward_train(self): - for qparam in self._qparams: - qparam.module._parameters[qparam.name] = qparam.param - return True - - def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: - bits = self.extra_bits + self._get_bits(qparam.logit) - bits = bits.round().clamp(1, 15)[:, None].byte() - if self.group_size == 0: - p = qparam.param.data.view(1, -1) - else: - p = qparam.param.data.view(-1, self.group_size) - levels, scales = uniform_quantize(p, bits) - return levels, scales, bits[:, 0] - - def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: - levels, param_scale, bits = quantized - bits = bits[:, None] - return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data) - - def _bit_pack_param(self, qparam, quantized, pack_fn): - levels, scales, bits = quantized - all_packed = [] - for bit in range(1, 15): - sub_levels = levels[bits == bit] - if not sub_levels.numel(): - all_packed.append(None) - else: - packed = pack_fn(sub_levels, bit) - all_packed.append(packed) - packed_bits = pack_fn(bits - self.min_bits) - return (all_packed, scales, packed_bits) - - def _bit_unpack_param(self, qparam, packed, unpack_fn): - """Unpack bitpacked representation. Should be overriden. - """ - packed_all_levels, scales, packed_bits = packed - bits = unpack_fn(packed_bits, qparam.logit.numel()) + self.min_bits - bits = bits.to(qparam.param.device) - levels = torch.empty(qparam.logit.numel(), self.group_size, - dtype=torch.short, device=qparam.param.device) - for idx, packed_levels in enumerate(packed_all_levels): - bit = idx + 1 - if packed_levels is None: - continue - sub_levels = levels[bits == bit] - levels[bits == bit] = unpack_fn( - packed_levels, sub_levels.numel()).view_as(sub_levels).to(sub_levels) - return (levels, scales, bits) - - def detach(self): - super().detach() - for qparam in self._qparams: - delattr(qparam.module, qparam.name + self.suffix) - - def __repr__(self): - return simple_repr(self) diff --git a/diffq/lsq.py b/diffq/lsq.py deleted file mode 100644 index 95238b0..0000000 --- a/diffq/lsq.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Learnt-Stepsize quantizer from [Esser et al. 2019] https://arxiv.org/abs/1902.08153. -""" -from dataclasses import dataclass -import typing as tp - -import torch - -from .base import BaseQuantizer -from .utils import capture_init, simple_repr - - -class LSQ(BaseQuantizer): - """Implements weight only quantization based on [Esser et al. 2019]. - https://arxiv.org/abs/1902.08153 - """ - @dataclass - class _QuantizedParam(BaseQuantizer._QuantizedParam): - scale: torch.nn.Parameter - - @capture_init - def __init__(self, model: torch.nn.Module, bits: int = 8, min_size: float = 0.01, - float16: bool = False, suffix: str = "_lsq", exclude=[], detect_bound=True): - assert 0 < bits <= 15 - self.suffix = suffix - self._optimizer_setup = False - self.bits = bits - - for name, _ in model.named_parameters(): - if name.endswith(suffix): - raise RuntimeError("The model already has some noise scales parameters, " - "maybe you used twice a LSQ on the same model?.") - - super().__init__(model, min_size, float16, exclude, detect_bound) - - def _register_param(self, name, param, module, other): - if other is not None: - return self.__class__._QuantizedParam( - name=name, param=param, module=module, scale=other.scale, other=other) - # we want the initial number of bits to be init_bits. - scale = 2 * param.data.abs().mean() / (2 ** (self.bits - 1))**0.5 - scale = torch.nn.Parameter(scale) - module.register_parameter(name + self.suffix, scale) - return self.__class__._QuantizedParam( - name=name, param=param, module=module, scale=scale, other=None) - - def clear_optimizer(self, optimizer: torch.optim.Optimizer): - params = [qp.scale for qp in self._qparams] - - for group in optimizer.param_groups: - new_params = [] - for q in list(group["params"]): - matched = False - for p in params: - if p is q: - matched = True - if not matched: - new_params.append(q) - group["params"][:] = new_params - - def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs): - """ - Setup the optimizer to tune the scale parameter. - Following [Esser et al. 2019], we use the same LR and weight decay - as the base optimizer, unless specified otherwise. - - Args: - optimizer (torch.Optimizer): optimizer to use. - kwargs (dict): overrides for optimization parameters - """ - assert not self._optimizer_setup - self._optimizer_setup = True - - params = [qp.scale for qp in self._qparams] - - for group in optimizer.param_groups: - for q in list(group["params"]): - for p in params: - if p is q: - raise RuntimeError("You should create the optimizer " - "before the quantizer!") - - group = {"params": params} - group.update(kwargs) - optimizer.add_param_group(group) - - def no_optimizer(self): - """ - Call this if you do not want to use an optimizer. - """ - self._optimizer_setup = True - - def model_size(self, exact=False): - """ - Differentiable estimate of the model size. - The size is returned in MB. - - If `exact` is True, then the output is no longer differentiable but - reflect exactly an achievable size, even without compression, - i.e.same as returned by `naive_model_size()`. - """ - total = super().model_size() - subtotal = 0 - for qparam in self._qparams: - # only count the first appearance of a Parameter - if qparam.other is not None: - continue - bits = qparam.param.numel() * self.bits - subtotal += bits - subtotal += 1 * 32 # param scale - - subtotal /= 2 ** 20 * 8 # bits -> MegaBytes - return total + subtotal - - def true_model_size(self): - """ - Naive model size without zlib compression. - """ - return self.model_size(exact=True).item() - - def _pre_forward_train(self): - if not self._optimizer_setup: - raise RuntimeError("You must call `setup_optimizer()` on your optimizer " - "before starting training.") - for qparam in self._qparams: - scale = qparam.scale - quant, _ = quantize(qparam.param, scale, self.bits) - # We bypass the checks by PyTorch on parameters being leafs - qparam.module._parameters[qparam.name] = quant - return True - - def _post_forward_train(self): - for qparam in self._qparams: - qparam.module._parameters[qparam.name] = qparam.param - return True - - def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: - _, index = quantize(qparam.param, qparam.scale, self.bits) - assert (index <= (2 ** (self.bits - 1) - 1)).all(), index.max() - assert (index >= (-2 ** (self.bits - 1))).all(), index.min() - return index.detach().short(), qparam.scale.detach() - - def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: - index, scale = quantized - return index.float() * scale - - def _bit_pack_param(self, qparam, quantized, pack_fn): - levels, scale = quantized - packed = pack_fn(levels + 2 ** (self.bits - 1)) - return (packed, scale) - - def _bit_unpack_param(self, qparam, packed, unpack_fn): - """Unpack bitpacked representation. Should be overriden - """ - packed_levels, scale = packed - levels = unpack_fn( - packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param) - levels -= 2 ** (self.bits - 1) - return (levels, scale) - - def detach(self): - super().detach() - for qparam in self._qparams: - delattr(qparam.module, qparam.name + self.suffix) - - def __repr__(self): - return simple_repr(self) - - -def roundpass(x): - return (x.round() - x).detach() + x - - -def gradscale(x, scale): - return (x - x * scale).detach() + x * scale - - -def quantize(tensor, scale, bits): - low = - 2 ** (bits - 1) - high = 2 ** (bits - 1) - 1 - scale = gradscale(scale, 1 / (tensor.numel() * high)**0.5) - - index = tensor / scale - index = index.clamp(low, high) - index = roundpass(index) - return index * scale, index diff --git a/diffq/torch_pack.py b/diffq/torch_pack.py deleted file mode 100644 index 926af7d..0000000 --- a/diffq/torch_pack.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Bit packing in pure PyTorch. -Slower than bitpack.pyx but compatible with torchscript. -""" -import math -import typing as tp -import torch -from torch.nn import functional as F - - -def as_rectangle(p: torch.Tensor, side: int): - """Reshape as rectangle, using padding when necessary so that out shape is [*, side]""" - p_flat = p.view(-1) - ideal_length = int(math.ceil(len(p_flat) / side) * side) - p_flat_pad = F.pad(p_flat, (0, ideal_length - len(p_flat))) - return p_flat_pad.view(side, -1) - - -def _storage_size(dtype: torch.dtype): - if dtype == torch.int64: - return 64 - elif dtype == torch.int32: - return 32 - elif dtype == torch.int16: - return 16 - elif dtype == torch.uint8: - return 8 - else: - raise ValueError("Invalid bitpacking storage type") - - -def pack(indexes, nbits: int = 0, storage_dtype: torch.dtype = torch.int16): - """You can think of indexes as a "Tensor" of bits of shape [L, nbits]. - Instead of concatenating naively as [L * nbits], we instead look at it transposed as - [nbits, L]. For L = 16 * G, we get [nbits, G, 16] which is trivial to store - efficiently on int16 integers. - There will be overhead if L is far from a multiple of 16 (e.g. 1) but for large - model layers this is acceptable. Storage type can be changed. - - `nbits` should be the number of bits on which the indexes are coded, and will - actually be determined automatically if set to 0. - """ - assert not indexes.dtype.is_floating_point - if indexes.numel() > 0: - assert indexes.max().item() < 2 ** 15 - assert indexes.min().item() >= 0 - if nbits == 0: - nbits = int(math.ceil(math.log2(1 + (indexes.max())))) - else: - assert indexes.max().item() < 2 ** nbits - - indexes = indexes.reshape(-1) - storage_size = _storage_size(storage_dtype) - rect = as_rectangle(indexes, storage_size) - out = torch.zeros(nbits, rect.shape[1], dtype=storage_dtype, device=indexes.device) - for in_bit in range(nbits): - for out_bit in range(storage_size): - d = ((rect[out_bit] >> in_bit) & 1).to(out.dtype) << out_bit - out[in_bit, :] |= d - return out - - -def unpack(packed: torch.Tensor, length: tp.Optional[int] = None): - """Opposite of `pack`. You might need to specify the original length.""" - storage_size = _storage_size(packed.dtype) - nbits, groups = packed.shape - out = torch.zeros(storage_size, groups, dtype=torch.int16, device=packed.device) - for in_bit in range(storage_size): - for out_bit in range(nbits): - bit_value = (packed[out_bit, :] >> in_bit) & 1 - out[in_bit, :] = out[in_bit, :] | (bit_value.to(out) << out_bit) - out = out.view(-1) - if length is not None: - out = out[:length] - return out diff --git a/diffq/ts_export.py b/diffq/ts_export.py deleted file mode 100644 index d67ab2b..0000000 --- a/diffq/ts_export.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""TorchScript export support. -We have to do a lot of black magic for TorchScript to be happy -because we cannot dynamically allocate new weights when loading the model. - -Here is how it works: -- we generate code in a temporary python file for the given model that explicitely - override all the weights on the first forward from their packed version. - This is because TorchScript does not let us iterate over parameters in a generic manner. -- we zero out all the original weights. We cannot simply remove those weights - because TorchScript won't let us recreate them. -- A TorchScript file is just a zip file, but stored without compression. - In order to remove the cost of storing the zeroed out weights, we unzip the file, - and zip it again with compression. -""" -import importlib -import os -from pathlib import Path -import random -import sys -import typing as tp -import tempfile -import zipfile - -import torch -from torch import jit - -from .diffq import DiffQuantizer -from .uniform import uniform_unquantize -from .torch_pack import unpack - -_DiffQPacked = tp.Tuple[ - tp.List[tp.Optional[torch.Tensor]], tp.Tuple[float, float], - torch.Tensor, tp.List[int]] - -# This is the template for the generated class. -TEMPLATE = ''' -import typing as tp -import torch -from torch import jit - -from diffq.ts_export import _unpack_param, _DiffQPacked - -from {module} import {klass} - - -class DiffQTSModel(torch.nn.Module): - def __init__(self, model: {klass}, group_size: int, min_bits: int, - packed: tp.List[_DiffQPacked]): - super().__init__() - self.group_size = group_size - self.min_bits = min_bits - self.model = model - self._unpacked = False - self._packed = packed - - @jit.export - def unpack(self): - """ - Unpack the weights, automatically called on the first forward, - or explicitely.""" - if self._unpacked: - return -{unpack_assigns} - self._unpacked = True - - def forward(self, x: torch.Tensor): - self.unpack() - return self.model.forward(x) -''' - -# those are the assignments for each quantized weight. -UNPACK_ASSIGN = (' ' * 8) + ('self.model{full_name}.data[:] = ' - '_unpack_param(self._packed[{index}], ' - 'group_size=self.group_size, min_bits=self.min_bits)') -UNPACK_ASSIGN_SAME = (' ' * 8) + 'self.model{full_name} = self.model{other_name}' - - -def export(quantizer: DiffQuantizer, path: tp.Union[str, Path]): - """Export the given quantized model to the given path. - We must save the quantized model ourselves, as we need to recompress - the zip archive afterwards. - """ - packed: tp.List[_DiffQPacked] = [] - uniq_name = ''.join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(12)]) - with tempfile.TemporaryDirectory() as tmpdir: - sys.path.insert(0, tmpdir) - try: - code = _codegen(quantizer) - with open(Path(tmpdir) / f'{uniq_name}.py', 'w') as f: - f.write(code) - module = importlib.import_module(uniq_name) - ts_klass = module.DiffQTSModel - state = quantizer.get_quantized_state(packed=True, torch_pack=True) - quantized = state["quantized"] - for qparam in quantizer._qparams: - if qparam.other is None: - levels, scales, bits = quantized.pop(0) - size = qparam.param.size() - packed.append((levels, scales, bits, list(size))) - qparam.param.data.zero_() - quantizer.detach() - ts_premodel = ts_klass(quantizer.model, quantizer.group_size, - quantizer.min_bits, packed) - ts_model = jit.script(ts_premodel) - if path is not None: - jit.save(ts_model, path) - recompress(path) - finally: - sys.path.pop(0) - - return ts_model - - -def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor: - """Function called from TorchScript on the first forward to decode the - packed weights to FP32. - """ - packed_all_levels, scales, packed_bits, shape = packed - numel = 1 - for dim in shape: - numel *= dim - bits = unpack(packed_bits, numel // group_size) + min_bits - levels = torch.empty(bits.numel(), group_size, dtype=torch.short) - for idx, packed_levels in enumerate(packed_all_levels): - bit = idx + 1 - if packed_levels is not None: - sub_levels = levels[bits == bit] - levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels) - bits = bits[:, None] - unquant = uniform_unquantize(levels, scales, bits) - if len(shape) == 4: - return unquant.view(shape[0], shape[1], shape[2], shape[3]) - elif len(shape) == 3: - return unquant.view(shape[0], shape[1], shape[2]) - elif len(shape) == 2: - return unquant.view(shape[0], shape[1]) - elif len(shape) == 1: - return unquant.view(shape[0]) - else: - raise RuntimeError("Invalid numbr of dim") - - -def recompress(path: tp.Union[str, Path]): - """After having saved the torchscript file, this will recompress it - to make sure all the zeroed out parameters don't actually take any space. - """ - with tempfile.TemporaryDirectory() as tmpdir: - with zipfile.ZipFile(path) as zipin: - zipin.extractall(tmpdir) - with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED, - compresslevel=1) as zipout: - for root, folders, files in os.walk(tmpdir): - for file in files: - fp = Path(root) / file - name = fp.relative_to(tmpdir) - zipout.write(fp, name) - - -def _get_full_name_access(full_name): - # When generating code, we need to handle attributes vs. indexing. - parts = [] - for part in full_name.split("."): - try: - index = int(part) - except ValueError: - parts.append("." + part) - else: - parts.append(f"[{index}]") - return "".join(parts) - - -def _codegen(quantizer: DiffQuantizer): - # Generates the code for the given quantizer - module = quantizer.model.__class__.__module__ - klass = quantizer.model.__class__.__name__ - model = quantizer.model - - assert not quantizer.float16 - names = {} - for mod_name, mod in model.named_modules(): - names[mod] = mod_name - unpack_assigns = [] - - index = 0 - for qparam in quantizer._qparams: - mod_name = names[qparam.module] - if mod_name == '': - full_name = qparam.name - else: - full_name = mod_name + '.' + qparam.name - full_name = _get_full_name_access(full_name) - if qparam.other is None: - unpack_assigns.append(UNPACK_ASSIGN.format(full_name=full_name, index=index)) - index += 1 - else: - other_name = names[(qparam.other.module, qparam.other.name)] - other_name = _get_full_name_access(other_name) - unpack_assigns.append( - UNPACK_ASSIGN_SAME.format(full_name=full_name, other_name=other_name)) - - return TEMPLATE.format( - module=module, - klass=klass, - unpack_assigns='\n'.join(unpack_assigns)) diff --git a/diffq/uniform.py b/diffq/uniform.py deleted file mode 100644 index 7d02a4b..0000000 --- a/diffq/uniform.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Classic uniform quantization over n bits. -""" -from typing import Tuple -import torch - -from .base import BaseQuantizer -from .utils import capture_init, simple_repr - - -def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)): - """ - Quantize the given weights over `bits` bits. - - Returns: - - quantized levels - - (min, max) range. - - """ - assert (bits >= 1).all() and (bits <= 15).all() - num_levels = (2 ** bits.float()).long() - mn = p.min().item() - mx = p.max().item() - p = (p - mn) / (mx - mn) # put p in [0, 1] - unit = 1 / (num_levels - 1) # quantization unit - levels = (p / unit).round() - if (bits <= 8).all(): - levels = levels.byte() - else: - levels = levels.short() - return levels, (mn, mx) - - -def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float], - bits: torch.Tensor = torch.tensor(8.)): - """ - Unquantize the weights from the levels and scale. Return a float32 tensor. - """ - mn, mx = scales - num_levels = 2 ** bits.float() - unit = 1 / (num_levels - 1) - levels = levels.float() - p = levels * unit # in [0, 1] - return p * (mx - mn) + mn - - -class UniformQuantizer(BaseQuantizer): - @capture_init - def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01, - float16: bool = False, qat: bool = False, exclude=[], detect_bound=True): - """ - Args: - model (torch.nn.Module): model to quantize - bits (float): number of bits to quantize over. - min_size (float): minimum size in MB of a parameter to be quantized. - float16 (bool): if a layer is smaller than min_size, should we still do float16? - qat (bool): perform quantized aware training. - exclude (list[str]): list of patterns used to match parameters to exclude. - For instance `['bias']` to exclude all bias terms. - detect_bound (bool): if True, will detect bound parameters and reuse - the same quantized tensor for both. - """ - self.bits = float(bits) - self.qat = qat - - super().__init__(model, min_size, float16, exclude, detect_bound) - - def __repr__(self): - return simple_repr(self, ) - - def _pre_forward_train(self): - if self.qat: - for qparam in self._qparams: - if qparam.other is not None: - new_param = qparam.other.module._parameters[qparam.other.name] - else: - quantized = self._quantize_param(qparam) - qvalue = self._unquantize_param(qparam, quantized) - new_param = qparam.param + (qvalue - qparam.param).detach() - qparam.module._parameters[qparam.name] = new_param - return True - return False - - def _post_forward_train(self): - if self.qat: - for qparam in self._qparams: - qparam.module._parameters[qparam.name] = qparam.param - return True - return False - - def _quantize_param(self, qparam): - levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits)) - return (levels, scales) - - def _unquantize_param(self, qparam, quantized): - levels, scales = quantized - return uniform_unquantize(levels, scales, torch.tensor(self.bits)) - - def _bit_pack_param(self, qparam, quantized, pack_fn): - levels, scales = quantized - packed = pack_fn(levels, self.bits) - return (packed, scales) - - def _bit_unpack_param(self, qparam, packed, unpack_fn): - """Unpack bitpacked representation. Should be overriden - """ - packed_levels, scales = packed - levels = unpack_fn( - packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param) - return (levels, scales) - - def model_size(self): - """ - Non differentiable model size in MB. - """ - total = super().model_size() - subtotal = 0 - for qparam in self._qparams: - if qparam.other is None: # if parameter is bound, count only one copy. - subtotal += self.bits * qparam.param.numel() + 64 # 2 float for the overall scales - subtotal /= 2**20 * 8 # bits to MegaBytes - return total + subtotal - - def true_model_size(self): - """ - Return the true quantized model size, in MB, without extra - compression. - """ - return self.model_size().item() diff --git a/diffq/utils.py b/diffq/utils.py deleted file mode 100644 index 4cd810c..0000000 --- a/diffq/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import functools -import inspect -from typing import Optional, List - - -def simple_repr(obj, attrs: Optional[List[str]] = None, overrides={}): - """ - Return a simple representation string for `obj`. - If `attrs` is not None, it should be a list of attributes to include. - """ - params = inspect.signature(obj.__class__).parameters - attrs_repr = [] - if attrs is None: - attrs = params.keys() - for attr in attrs: - display = False - if attr in overrides: - value = overrides[attr] - elif hasattr(obj, attr): - value = getattr(obj, attr) - else: - continue - if attr in params: - param = params[attr] - if param.default is inspect._empty or value != param.default: - display = True - else: - display = True - - if display: - attrs_repr.append(f"{attr}={value}") - return f"{obj.__class__.__name__}({','.join(attrs_repr)})" - - -def capture_init(init): - """capture_init. - - Decorate `__init__` with this, and you can then - recover the *args and **kwargs passed to it in `self._init_args_kwargs` - """ - signature = inspect.signature(init) - - @functools.wraps(init) - def __init__(self, *args, **kwargs): - bound = signature.bind(self, *args, **kwargs) - actual_kwargs = dict(bound.arguments) - del actual_kwargs['self'] - actual_kwargs.update(bound.kwargs) - self._init_kwargs = actual_kwargs - init(self, *args, **kwargs) - - return __init__