mirror of
https://github.com/Anjok07/ultimatevocalremovergui.git
synced 2024-11-13 18:40:48 +01:00
210 lines
7.6 KiB
Python
210 lines
7.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""TorchScript export support.
|
|
We have to do a lot of black magic for TorchScript to be happy
|
|
because we cannot dynamically allocate new weights when loading the model.
|
|
|
|
Here is how it works:
|
|
- we generate code in a temporary python file for the given model that explicitely
|
|
override all the weights on the first forward from their packed version.
|
|
This is because TorchScript does not let us iterate over parameters in a generic manner.
|
|
- we zero out all the original weights. We cannot simply remove those weights
|
|
because TorchScript won't let us recreate them.
|
|
- A TorchScript file is just a zip file, but stored without compression.
|
|
In order to remove the cost of storing the zeroed out weights, we unzip the file,
|
|
and zip it again with compression.
|
|
"""
|
|
import importlib
|
|
import os
|
|
from pathlib import Path
|
|
import random
|
|
import sys
|
|
import typing as tp
|
|
import tempfile
|
|
import zipfile
|
|
|
|
import torch
|
|
from torch import jit
|
|
|
|
from .diffq import DiffQuantizer
|
|
from .uniform import uniform_unquantize
|
|
from .torch_pack import unpack
|
|
|
|
_DiffQPacked = tp.Tuple[
|
|
tp.List[tp.Optional[torch.Tensor]], tp.Tuple[float, float],
|
|
torch.Tensor, tp.List[int]]
|
|
|
|
# This is the template for the generated class.
|
|
TEMPLATE = '''
|
|
import typing as tp
|
|
import torch
|
|
from torch import jit
|
|
|
|
from diffq.ts_export import _unpack_param, _DiffQPacked
|
|
|
|
from {module} import {klass}
|
|
|
|
|
|
class DiffQTSModel(torch.nn.Module):
|
|
def __init__(self, model: {klass}, group_size: int, min_bits: int,
|
|
packed: tp.List[_DiffQPacked]):
|
|
super().__init__()
|
|
self.group_size = group_size
|
|
self.min_bits = min_bits
|
|
self.model = model
|
|
self._unpacked = False
|
|
self._packed = packed
|
|
|
|
@jit.export
|
|
def unpack(self):
|
|
"""
|
|
Unpack the weights, automatically called on the first forward,
|
|
or explicitely."""
|
|
if self._unpacked:
|
|
return
|
|
{unpack_assigns}
|
|
self._unpacked = True
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
self.unpack()
|
|
return self.model.forward(x)
|
|
'''
|
|
|
|
# those are the assignments for each quantized weight.
|
|
UNPACK_ASSIGN = (' ' * 8) + ('self.model{full_name}.data[:] = '
|
|
'_unpack_param(self._packed[{index}], '
|
|
'group_size=self.group_size, min_bits=self.min_bits)')
|
|
UNPACK_ASSIGN_SAME = (' ' * 8) + 'self.model{full_name} = self.model{other_name}'
|
|
|
|
|
|
def export(quantizer: DiffQuantizer, path: tp.Union[str, Path]):
|
|
"""Export the given quantized model to the given path.
|
|
We must save the quantized model ourselves, as we need to recompress
|
|
the zip archive afterwards.
|
|
"""
|
|
packed: tp.List[_DiffQPacked] = []
|
|
uniq_name = ''.join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(12)])
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
sys.path.insert(0, tmpdir)
|
|
try:
|
|
code = _codegen(quantizer)
|
|
with open(Path(tmpdir) / f'{uniq_name}.py', 'w') as f:
|
|
f.write(code)
|
|
module = importlib.import_module(uniq_name)
|
|
ts_klass = module.DiffQTSModel
|
|
state = quantizer.get_quantized_state(packed=True, torch_pack=True)
|
|
quantized = state["quantized"]
|
|
for qparam in quantizer._qparams:
|
|
if qparam.other is None:
|
|
levels, scales, bits = quantized.pop(0)
|
|
size = qparam.param.size()
|
|
packed.append((levels, scales, bits, list(size)))
|
|
qparam.param.data.zero_()
|
|
quantizer.detach()
|
|
ts_premodel = ts_klass(quantizer.model, quantizer.group_size,
|
|
quantizer.min_bits, packed)
|
|
ts_model = jit.script(ts_premodel)
|
|
if path is not None:
|
|
jit.save(ts_model, path)
|
|
recompress(path)
|
|
finally:
|
|
sys.path.pop(0)
|
|
|
|
return ts_model
|
|
|
|
|
|
def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor:
|
|
"""Function called from TorchScript on the first forward to decode the
|
|
packed weights to FP32.
|
|
"""
|
|
packed_all_levels, scales, packed_bits, shape = packed
|
|
numel = 1
|
|
for dim in shape:
|
|
numel *= dim
|
|
bits = unpack(packed_bits, numel // group_size) + min_bits
|
|
levels = torch.empty(bits.numel(), group_size, dtype=torch.short)
|
|
for idx, packed_levels in enumerate(packed_all_levels):
|
|
bit = idx + 1
|
|
if packed_levels is not None:
|
|
sub_levels = levels[bits == bit]
|
|
levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels)
|
|
bits = bits[:, None]
|
|
unquant = uniform_unquantize(levels, scales, bits)
|
|
if len(shape) == 4:
|
|
return unquant.view(shape[0], shape[1], shape[2], shape[3])
|
|
elif len(shape) == 3:
|
|
return unquant.view(shape[0], shape[1], shape[2])
|
|
elif len(shape) == 2:
|
|
return unquant.view(shape[0], shape[1])
|
|
elif len(shape) == 1:
|
|
return unquant.view(shape[0])
|
|
else:
|
|
raise RuntimeError("Invalid numbr of dim")
|
|
|
|
|
|
def recompress(path: tp.Union[str, Path]):
|
|
"""After having saved the torchscript file, this will recompress it
|
|
to make sure all the zeroed out parameters don't actually take any space.
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
with zipfile.ZipFile(path) as zipin:
|
|
zipin.extractall(tmpdir)
|
|
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED,
|
|
compresslevel=1) as zipout:
|
|
for root, folders, files in os.walk(tmpdir):
|
|
for file in files:
|
|
fp = Path(root) / file
|
|
name = fp.relative_to(tmpdir)
|
|
zipout.write(fp, name)
|
|
|
|
|
|
def _get_full_name_access(full_name):
|
|
# When generating code, we need to handle attributes vs. indexing.
|
|
parts = []
|
|
for part in full_name.split("."):
|
|
try:
|
|
index = int(part)
|
|
except ValueError:
|
|
parts.append("." + part)
|
|
else:
|
|
parts.append(f"[{index}]")
|
|
return "".join(parts)
|
|
|
|
|
|
def _codegen(quantizer: DiffQuantizer):
|
|
# Generates the code for the given quantizer
|
|
module = quantizer.model.__class__.__module__
|
|
klass = quantizer.model.__class__.__name__
|
|
model = quantizer.model
|
|
|
|
assert not quantizer.float16
|
|
names = {}
|
|
for mod_name, mod in model.named_modules():
|
|
names[mod] = mod_name
|
|
unpack_assigns = []
|
|
|
|
index = 0
|
|
for qparam in quantizer._qparams:
|
|
mod_name = names[qparam.module]
|
|
if mod_name == '':
|
|
full_name = qparam.name
|
|
else:
|
|
full_name = mod_name + '.' + qparam.name
|
|
full_name = _get_full_name_access(full_name)
|
|
if qparam.other is None:
|
|
unpack_assigns.append(UNPACK_ASSIGN.format(full_name=full_name, index=index))
|
|
index += 1
|
|
else:
|
|
other_name = names[(qparam.other.module, qparam.other.name)]
|
|
other_name = _get_full_name_access(other_name)
|
|
unpack_assigns.append(
|
|
UNPACK_ASSIGN_SAME.format(full_name=full_name, other_name=other_name))
|
|
|
|
return TEMPLATE.format(
|
|
module=module,
|
|
klass=klass,
|
|
unpack_assigns='\n'.join(unpack_assigns))
|