e9dd11bddb
* Optimize latency (#1259) * add attribute: configs/config.py Optimize latency: tools/rvc_for_realtime.py * new file: assets/Synthesizer_inputs.pth * fix: configs/config.py fix: tools/rvc_for_realtime.py * fix bug: infer/lib/infer_pack/models.py * new file: assets/hubert_inputs.pth new file: assets/rmvpe_inputs.pth modified: configs/config.py new features: infer/lib/rmvpe.py new features: tools/jit_export/__init__.py new features: tools/jit_export/get_hubert.py new features: tools/jit_export/get_rmvpe.py new features: tools/jit_export/get_synthesizer.py optimize: tools/rvc_for_realtime.py * optimize: tools/jit_export/get_synthesizer.py fix bug: tools/jit_export/__init__.py * Fixed a bug caused by using half on the CPU: infer/lib/rmvpe.py Fixed a bug caused by using half on the CPU: tools/jit_export/__init__.py Fixed CIRCULAR IMPORT: tools/jit_export/get_rmvpe.py Fixed CIRCULAR IMPORT: tools/jit_export/get_synthesizer.py Fixed a bug caused by using half on the CPU: tools/rvc_for_realtime.py * Remove useless code: infer/lib/rmvpe.py * Delete gui_v1 copy.py * Delete .vscode/launch.json * Delete jit_export_test.py * Delete tools/rvc_for_realtime copy.py * Delete configs/config.json * Delete .gitignore * Fix exceptions caused by switching inference devices: infer/lib/rmvpe.py Fix exceptions caused by switching inference devices: tools/jit_export/__init__.py Fix exceptions caused by switching inference devices: tools/rvc_for_realtime.py * restore * replace(you can undo this commit) * remove debug_print --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * Fixed some bugs when exporting ONNX model (#1254) * fix import (#1280) * fix import * lint * 🎨 同步 locale (#1242) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Fix jit load and import issue (#1282) * fix jit model loading : infer/lib/rmvpe.py * modified: assets/hubert/.gitignore move file: assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth modified: assets/rmvpe/.gitignore move file: assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth fix import: gui_v1.py * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * Add input wav and delay time monitor for real-time gui (#1293) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * add input wav and delay time monitor --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> * Optimize latency using scripted jit (#1291) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * Optimize-latency-using-scripted: configs/config.py Optimize-latency-using-scripted: infer/lib/infer_pack/attentions.py Optimize-latency-using-scripted: infer/lib/infer_pack/commons.py Optimize-latency-using-scripted: infer/lib/infer_pack/models.py Optimize-latency-using-scripted: infer/lib/infer_pack/modules.py Optimize-latency-using-scripted: infer/lib/jit/__init__.py Optimize-latency-using-scripted: infer/lib/jit/get_hubert.py Optimize-latency-using-scripted: infer/lib/jit/get_rmvpe.py Optimize-latency-using-scripted: infer/lib/jit/get_synthesizer.py Optimize-latency-using-scripted: infer/lib/rmvpe.py Optimize-latency-using-scripted: tools/rvc_for_realtime.py * modified: infer/lib/infer_pack/models.py * fix some bug: configs/config.py fix some bug: infer/lib/infer_pack/models.py fix some bug: infer/lib/rmvpe.py * Fixed abnormal reference of logger in multiprocessing: infer/modules/train/train.py --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Format code (#1298) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * 🎨 同步 locale (#1299) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: optimize actions * feat(workflow): add sync dev * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: add jit options (#1303) Delete useless code: infer/lib/jit/get_synthesizer.py Optimized code: tools/rvc_for_realtime.py * Code refactor + re-design inference ui (#1304) * Code refacor + re-design inference ui * Fix tabname * i18n jp --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * feat: optimize actions * feat: optimize actions * Update README & en_US locale file (#1309) * critical: some bug fixes (#1322) * JIT acceleration switch does not support hot update * fix padding bug of rmvpe in torch-directml * fix padding bug of rmvpe in torch-directml * Fix STFT under torch_directml (#1330) * chore(format): run black on dev (#1318) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(i18n): sync locale on dev (#1317) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: allow for tta to be passed to uvr (#1361) * chore(format): run black on dev (#1373) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Added script for automatically download all needed models at install (#1366) * Delete modules.py * Add files via upload * Add files via upload * Add files via upload * Add files via upload * chore(i18n): sync locale on dev (#1377) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(format): run black on dev (#1376) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Update IPEX library (#1362) * Update IPEX library * Update ipex index * chore(format): run black on dev (#1378) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com> Co-authored-by: Ftps <ftpsflandre@gmail.com> Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com> Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com> Co-authored-by: Rice Cake <gak141808@gmail.com> Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com> Co-authored-by: Dmitry <nda2911@yandex.ru> Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
213 lines
8.2 KiB
Python
213 lines
8.2 KiB
Python
import torch
|
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
|
|
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
|
|
|
original_torch_bmm = torch.bmm
|
|
|
|
|
|
def torch_bmm(input, mat2, *, out=None):
|
|
if input.dtype != mat2.dtype:
|
|
mat2 = mat2.to(input.dtype)
|
|
|
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
|
batch_size_attention, input_tokens, mat2_shape = (
|
|
input.shape[0],
|
|
input.shape[1],
|
|
mat2.shape[2],
|
|
)
|
|
block_multiply = input.element_size()
|
|
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
|
|
block_size = batch_size_attention * slice_block_size
|
|
|
|
split_slice_size = batch_size_attention
|
|
if block_size > 4:
|
|
do_split = True
|
|
# Find something divisible with the input_tokens
|
|
while (split_slice_size * slice_block_size) > 4:
|
|
split_slice_size = split_slice_size // 2
|
|
if split_slice_size <= 1:
|
|
split_slice_size = 1
|
|
break
|
|
else:
|
|
do_split = False
|
|
|
|
split_2_slice_size = input_tokens
|
|
if split_slice_size * slice_block_size > 4:
|
|
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
|
|
do_split_2 = True
|
|
# Find something divisible with the input_tokens
|
|
while (split_2_slice_size * slice_block_size2) > 4:
|
|
split_2_slice_size = split_2_slice_size // 2
|
|
if split_2_slice_size <= 1:
|
|
split_2_slice_size = 1
|
|
break
|
|
else:
|
|
do_split_2 = False
|
|
|
|
if do_split:
|
|
hidden_states = torch.zeros(
|
|
input.shape[0],
|
|
input.shape[1],
|
|
mat2.shape[2],
|
|
device=input.device,
|
|
dtype=input.dtype,
|
|
)
|
|
for i in range(batch_size_attention // split_slice_size):
|
|
start_idx = i * split_slice_size
|
|
end_idx = (i + 1) * split_slice_size
|
|
if do_split_2:
|
|
for i2 in range(
|
|
input_tokens // split_2_slice_size
|
|
): # pylint: disable=invalid-name
|
|
start_idx_2 = i2 * split_2_slice_size
|
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
|
hidden_states[
|
|
start_idx:end_idx, start_idx_2:end_idx_2
|
|
] = original_torch_bmm(
|
|
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
|
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
|
out=out,
|
|
)
|
|
else:
|
|
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
|
input[start_idx:end_idx], mat2[start_idx:end_idx], out=out
|
|
)
|
|
else:
|
|
return original_torch_bmm(input, mat2, out=out)
|
|
return hidden_states
|
|
|
|
|
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
|
|
|
|
|
def scaled_dot_product_attention(
|
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
|
):
|
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
|
if len(query.shape) == 3:
|
|
batch_size_attention, query_tokens, shape_four = query.shape
|
|
shape_one = 1
|
|
no_shape_one = True
|
|
else:
|
|
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
|
no_shape_one = False
|
|
|
|
block_multiply = query.element_size()
|
|
slice_block_size = (
|
|
shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
|
|
)
|
|
block_size = batch_size_attention * slice_block_size
|
|
|
|
split_slice_size = batch_size_attention
|
|
if block_size > 4:
|
|
do_split = True
|
|
# Find something divisible with the shape_one
|
|
while (split_slice_size * slice_block_size) > 4:
|
|
split_slice_size = split_slice_size // 2
|
|
if split_slice_size <= 1:
|
|
split_slice_size = 1
|
|
break
|
|
else:
|
|
do_split = False
|
|
|
|
split_2_slice_size = query_tokens
|
|
if split_slice_size * slice_block_size > 4:
|
|
slice_block_size2 = (
|
|
shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
|
|
)
|
|
do_split_2 = True
|
|
# Find something divisible with the batch_size_attention
|
|
while (split_2_slice_size * slice_block_size2) > 4:
|
|
split_2_slice_size = split_2_slice_size // 2
|
|
if split_2_slice_size <= 1:
|
|
split_2_slice_size = 1
|
|
break
|
|
else:
|
|
do_split_2 = False
|
|
|
|
if do_split:
|
|
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
|
for i in range(batch_size_attention // split_slice_size):
|
|
start_idx = i * split_slice_size
|
|
end_idx = (i + 1) * split_slice_size
|
|
if do_split_2:
|
|
for i2 in range(
|
|
query_tokens // split_2_slice_size
|
|
): # pylint: disable=invalid-name
|
|
start_idx_2 = i2 * split_2_slice_size
|
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
|
if no_shape_one:
|
|
hidden_states[
|
|
start_idx:end_idx, start_idx_2:end_idx_2
|
|
] = original_scaled_dot_product_attention(
|
|
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
|
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
|
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
|
attn_mask=attn_mask[
|
|
start_idx:end_idx, start_idx_2:end_idx_2
|
|
]
|
|
if attn_mask is not None
|
|
else attn_mask,
|
|
dropout_p=dropout_p,
|
|
is_causal=is_causal,
|
|
)
|
|
else:
|
|
hidden_states[
|
|
:, start_idx:end_idx, start_idx_2:end_idx_2
|
|
] = original_scaled_dot_product_attention(
|
|
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
|
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
|
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
|
attn_mask=attn_mask[
|
|
:, start_idx:end_idx, start_idx_2:end_idx_2
|
|
]
|
|
if attn_mask is not None
|
|
else attn_mask,
|
|
dropout_p=dropout_p,
|
|
is_causal=is_causal,
|
|
)
|
|
else:
|
|
if no_shape_one:
|
|
hidden_states[
|
|
start_idx:end_idx
|
|
] = original_scaled_dot_product_attention(
|
|
query[start_idx:end_idx],
|
|
key[start_idx:end_idx],
|
|
value[start_idx:end_idx],
|
|
attn_mask=attn_mask[start_idx:end_idx]
|
|
if attn_mask is not None
|
|
else attn_mask,
|
|
dropout_p=dropout_p,
|
|
is_causal=is_causal,
|
|
)
|
|
else:
|
|
hidden_states[
|
|
:, start_idx:end_idx
|
|
] = original_scaled_dot_product_attention(
|
|
query[:, start_idx:end_idx],
|
|
key[:, start_idx:end_idx],
|
|
value[:, start_idx:end_idx],
|
|
attn_mask=attn_mask[:, start_idx:end_idx]
|
|
if attn_mask is not None
|
|
else attn_mask,
|
|
dropout_p=dropout_p,
|
|
is_causal=is_causal,
|
|
)
|
|
else:
|
|
return original_scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask=attn_mask,
|
|
dropout_p=dropout_p,
|
|
is_causal=is_causal,
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
def attention_init():
|
|
# ARC GPUs can't allocate more than 4GB to a single block:
|
|
torch.bmm = torch_bmm
|
|
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|