1
0
mirror of synced 2025-01-08 12:31:45 +01:00
Retrieval-based-Voice-Conve.../infer/lib/jit/__init__.py

164 lines
4.7 KiB
Python
Raw Normal View History

chore(sync): merge dev into main (#1379) * Optimize latency (#1259) * add attribute: configs/config.py Optimize latency: tools/rvc_for_realtime.py * new file: assets/Synthesizer_inputs.pth * fix: configs/config.py fix: tools/rvc_for_realtime.py * fix bug: infer/lib/infer_pack/models.py * new file: assets/hubert_inputs.pth new file: assets/rmvpe_inputs.pth modified: configs/config.py new features: infer/lib/rmvpe.py new features: tools/jit_export/__init__.py new features: tools/jit_export/get_hubert.py new features: tools/jit_export/get_rmvpe.py new features: tools/jit_export/get_synthesizer.py optimize: tools/rvc_for_realtime.py * optimize: tools/jit_export/get_synthesizer.py fix bug: tools/jit_export/__init__.py * Fixed a bug caused by using half on the CPU: infer/lib/rmvpe.py Fixed a bug caused by using half on the CPU: tools/jit_export/__init__.py Fixed CIRCULAR IMPORT: tools/jit_export/get_rmvpe.py Fixed CIRCULAR IMPORT: tools/jit_export/get_synthesizer.py Fixed a bug caused by using half on the CPU: tools/rvc_for_realtime.py * Remove useless code: infer/lib/rmvpe.py * Delete gui_v1 copy.py * Delete .vscode/launch.json * Delete jit_export_test.py * Delete tools/rvc_for_realtime copy.py * Delete configs/config.json * Delete .gitignore * Fix exceptions caused by switching inference devices: infer/lib/rmvpe.py Fix exceptions caused by switching inference devices: tools/jit_export/__init__.py Fix exceptions caused by switching inference devices: tools/rvc_for_realtime.py * restore * replace(you can undo this commit) * remove debug_print --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * Fixed some bugs when exporting ONNX model (#1254) * fix import (#1280) * fix import * lint * 🎨 同步 locale (#1242) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Fix jit load and import issue (#1282) * fix jit model loading : infer/lib/rmvpe.py * modified: assets/hubert/.gitignore move file: assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth modified: assets/rmvpe/.gitignore move file: assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth fix import: gui_v1.py * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * Add input wav and delay time monitor for real-time gui (#1293) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * add input wav and delay time monitor --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> * Optimize latency using scripted jit (#1291) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * Optimize-latency-using-scripted: configs/config.py Optimize-latency-using-scripted: infer/lib/infer_pack/attentions.py Optimize-latency-using-scripted: infer/lib/infer_pack/commons.py Optimize-latency-using-scripted: infer/lib/infer_pack/models.py Optimize-latency-using-scripted: infer/lib/infer_pack/modules.py Optimize-latency-using-scripted: infer/lib/jit/__init__.py Optimize-latency-using-scripted: infer/lib/jit/get_hubert.py Optimize-latency-using-scripted: infer/lib/jit/get_rmvpe.py Optimize-latency-using-scripted: infer/lib/jit/get_synthesizer.py Optimize-latency-using-scripted: infer/lib/rmvpe.py Optimize-latency-using-scripted: tools/rvc_for_realtime.py * modified: infer/lib/infer_pack/models.py * fix some bug: configs/config.py fix some bug: infer/lib/infer_pack/models.py fix some bug: infer/lib/rmvpe.py * Fixed abnormal reference of logger in multiprocessing: infer/modules/train/train.py --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Format code (#1298) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * 🎨 同步 locale (#1299) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: optimize actions * feat(workflow): add sync dev * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: add jit options (#1303) Delete useless code: infer/lib/jit/get_synthesizer.py Optimized code: tools/rvc_for_realtime.py * Code refactor + re-design inference ui (#1304) * Code refacor + re-design inference ui * Fix tabname * i18n jp --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * feat: optimize actions * feat: optimize actions * Update README & en_US locale file (#1309) * critical: some bug fixes (#1322) * JIT acceleration switch does not support hot update * fix padding bug of rmvpe in torch-directml * fix padding bug of rmvpe in torch-directml * Fix STFT under torch_directml (#1330) * chore(format): run black on dev (#1318) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(i18n): sync locale on dev (#1317) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: allow for tta to be passed to uvr (#1361) * chore(format): run black on dev (#1373) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Added script for automatically download all needed models at install (#1366) * Delete modules.py * Add files via upload * Add files via upload * Add files via upload * Add files via upload * chore(i18n): sync locale on dev (#1377) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(format): run black on dev (#1376) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Update IPEX library (#1362) * Update IPEX library * Update ipex index * chore(format): run black on dev (#1378) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com> Co-authored-by: Ftps <ftpsflandre@gmail.com> Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com> Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com> Co-authored-by: Rice Cake <gak141808@gmail.com> Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com> Co-authored-by: Dmitry <nda2911@yandex.ru> Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
2023-10-06 11:14:33 +02:00
from io import BytesIO
import pickle
import time
import torch
from tqdm import tqdm
from collections import OrderedDict
def load_inputs(path, device, is_half=False):
parm = torch.load(path, map_location=torch.device("cpu"))
for key in parm.keys():
parm[key] = parm[key].to(device)
if is_half and parm[key].dtype == torch.float32:
parm[key] = parm[key].half()
elif not is_half and parm[key].dtype == torch.float16:
parm[key] = parm[key].float()
return parm
def benchmark(
model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
):
parm = load_inputs(inputs_path, device, is_half)
total_ts = 0.0
bar = tqdm(range(epoch))
for i in bar:
start_time = time.perf_counter()
o = model(**parm)
total_ts += time.perf_counter() - start_time
print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
def to_jit_model(
model_path,
model_type: str,
mode: str = "trace",
inputs_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
model = None
if model_type.lower() == "synthesizer":
from .get_synthesizer import get_synthesizer
model, _ = get_synthesizer(model_path, device)
model.forward = model.infer
elif model_type.lower() == "rmvpe":
from .get_rmvpe import get_rmvpe
model = get_rmvpe(model_path, device)
elif model_type.lower() == "hubert":
from .get_hubert import get_hubert_model
model = get_hubert_model(model_path, device)
model.forward = model.infer
else:
raise ValueError(f"No model type named {model_type}")
model = model.eval()
model = model.half() if is_half else model.float()
if mode == "trace":
assert not inputs_path
inputs = load_inputs(inputs_path, device, is_half)
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
elif mode == "script":
model_jit = torch.jit.script(model)
model_jit.to(device)
model_jit = model_jit.half() if is_half else model_jit.float()
# model = model.half() if is_half else model.float()
return (model, model_jit)
def export(
model: torch.nn.Module,
mode: str = "trace",
inputs: dict = None,
device=torch.device("cpu"),
is_half: bool = False,
) -> dict:
model = model.half() if is_half else model.float()
model.eval()
if mode == "trace":
assert inputs is not None
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
elif mode == "script":
model_jit = torch.jit.script(model)
model_jit.to(device)
model_jit = model_jit.half() if is_half else model_jit.float()
buffer = BytesIO()
# model_jit=model_jit.cpu()
torch.jit.save(model_jit, buffer)
del model_jit
cpt = OrderedDict()
cpt["model"] = buffer.getvalue()
cpt["is_half"] = is_half
return cpt
def load(path: str):
with open(path, "rb") as f:
return pickle.load(f)
def save(ckpt: dict, save_path: str):
with open(save_path, "wb") as f:
pickle.dump(ckpt, f)
def rmvpe_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
from .get_rmvpe import get_rmvpe
model = get_rmvpe(model_path, device)
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export(model, mode, inputs, device, is_half)
ckpt["device"] = str(device)
save(ckpt, save_path)
return ckpt
def synthesizer_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
from .get_synthesizer import get_synthesizer
model, cpt = get_synthesizer(model_path, device)
assert isinstance(cpt, dict)
model.forward = model.infer
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export(model, mode, inputs, device, is_half)
cpt.pop("weight")
cpt["model"] = ckpt["model"]
cpt["device"] = device
save(cpt, save_path)
return cpt