1
0
mirror of synced 2024-11-23 23:21:03 +01:00

Feature: Intel ARC GPU support with IPEX (#1204)

* Initial Intel ARC support with IPEX

* Fix infer

* Fix train model

* Cleanup

* Cleanup

* Update README

* Make pylint happy

* Move dataloader fix to hijacks

* Fix torch.linalg.solve

* Fix SDP

* Add has_xpu to config.py

* Revert return_xpu fix
This commit is contained in:
Disty0 2023-09-09 07:00:29 +03:00 committed by GitHub
parent c761bda09a
commit 0c94f60093
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 817 additions and 20 deletions

View File

@ -5,7 +5,13 @@ import json
from multiprocessing import cpu_count
import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
ipex_init()
except Exception:
pass
import logging
logger = logging.getLogger(__name__)
@ -103,12 +109,22 @@ class Config:
except Exception:
return False
@staticmethod
def has_xpu() -> bool:
if hasattr(torch, "xpu") and torch.xpu.is_available():
return True
else:
return False
def use_fp32_config(self):
for config_file in version_config_list:
self.json_config[config_file]["train"]["fp16_run"] = False
def device_config(self) -> tuple:
if torch.cuda.is_available():
if self.has_xpu():
self.device = self.instead = "xpu:0"
self.is_half = True
i_device = int(self.device.split(":")[-1])
self.gpu_name = torch.cuda.get_device_name(i_device)
if (

View File

@ -43,6 +43,7 @@ This repository has the following features:
+ Use the UVR5 model to quickly separate vocals and instruments.
+ Use the most powerful High-pitch Voice Extraction Algorithm [InterSpeech2023-RMVPE](#Credits) to prevent the muted sound problem. Provides the best results (significantly) and is faster, with even lower resource consumption than Crepe_full.
+ AMD/Intel graphics cards acceleration supported.
+ Intel ARC graphics cards acceleration with IPEX supported.
## Preparing the environment
The following commands need to be executed in the environment of Python version 3.8 or higher.
@ -77,6 +78,9 @@ for Nvidia graphics cards
for AMD/Intel graphics cards
pip install -r requirements-dml.txt
for Intel ARC graphics cards on Linux / WSL using Python 3.10:
pip install -r requirements-ipex.txt
```
------
@ -124,6 +128,9 @@ https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.pt
https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.onnx
```
Intel ARC graphics cards users needs to run `source /opt/intel/oneapi/setvars.sh` command before starting Webui.
Then use this command to start Webui:
```bash
python infer-web.py

View File

@ -42,6 +42,7 @@ Bu depo aşağıdaki özelliklere sahiptir:
+ UVR5 modelini kullanarak hızla vokalleri ve enstrümanları ayırma.
+ En güçlü Yüksek tiz Ses Çıkarma Algoritması [InterSpeech2023-RMVPE](#Krediler) sessiz ses sorununu önlemek için kullanılır. En iyi sonuçları (önemli ölçüde) sağlar ve Crepe_full'den daha hızlı çalışır, hatta daha düşük kaynak tüketimi sağlar.
+ AMD/Intel grafik kartları hızlandırması desteklenir.
+ Intel ARC grafik kartları hızlandırması IPEX ile desteklenir.
## Ortamın Hazırlanması
Aşağıdaki komutlar, Python sürümü 3.8 veya daha yüksek olan bir ortamda çalıştırılmalıdır.
@ -73,11 +74,12 @@ Ayrıca bunları pip kullanarak da kurabilirsiniz:
Nvidia grafik kartları için
pip install -r requirements.txt
AMD/Intel grafik kartları için
pip install -r requirements-dml.txt
Intel ARC grafik kartları için Linux / WSL ile Python 3.10 kullanarak:
pip install -r requirements-ipex.txt
```
------
@ -125,6 +127,9 @@ https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.pt
https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.onnx
```
Intel ARC grafik kartları kullanıcıları Webui'yi başlatmadan önce `source /opt/intel/oneapi/setvars.sh` komutunu çalıştırmalı.
Daha sonra bu komutu kullanarak Webui'yi başlatabilirsiniz:
```bash
python infer-web.py

View File

@ -13,6 +13,7 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from infer.lib.infer_pack import attentions, commons, modules
from infer.lib.infer_pack.commons import get_padding, init_weights
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class TextEncoder256(nn.Module):
def __init__(
@ -1156,6 +1157,9 @@ class DiscriminatorP(torch.nn.Module):
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
if has_xpu and x.dtype == torch.bfloat16:
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(dtype=torch.bfloat16)
else:
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)

View File

@ -2,6 +2,14 @@ import pdb, os
import numpy as np
import torch
try:
#Fix "Torch not compiled with CUDA enabled"
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
ipex_init()
except Exception:
pass
import torch.nn as nn
import torch.nn.functional as F
from librosa.util import normalize, pad_center, tiny

View File

@ -0,0 +1,165 @@
import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
try:
#Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
#Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
#RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
#AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7"
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
torch.cuda.get_device_properties.major = 11
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks()
attention_init()
except Exception as e:
return False, e
return True, None

View File

@ -0,0 +1,128 @@
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = 2.4 if input.dtype == torch.float32 else 1.2
block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB
split_slice_size = batch_size_attention
if block_size >= 4000:
do_split = True
#Find something divisible with the input_tokens
while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB
split_2_slice_size = input_tokens
if split_block_size >= 4000:
do_split_2 = True
#Find something divisible with the input_tokens
while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
split_slice_size = batch_size_attention
if block_size >= 4000:
do_split = True
#Find something divisible with the shape_one
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
split_2_slice_size = query_tokens
if split_block_size >= 4000:
do_split_2 = True
#Find something divisible with the batch_size_attention
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return hidden_states
def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@ -0,0 +1,179 @@
from collections import defaultdict
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
# sync grad to master weight
if hasattr(optimizer, "sync_grad"):
optimizer.sync_grad()
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# -: is there a way to split by device and dtype without appending in the inner loop?
to_unscale = to_unscale.to("cpu")
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
core._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get("cpu"),
per_device_inv_scale.get("cpu"),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device="cpu", non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
to_device = _scale.device
_scale = _scale.to("cpu")
_growth_tracker = _growth_tracker.to("cpu")
core._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
_scale = _scale.to(to_device)
_growth_tracker = _growth_tracker.to(to_device)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def gradscaler_init():
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
torch.xpu.amp.GradScaler.unscale_ = unscale_
torch.xpu.amp.GradScaler.update = update
return torch.xpu.amp.GradScaler

View File

@ -0,0 +1,196 @@
import contextlib
import importlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class CondFunc: # pylint: disable=missing-class-docstring
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
_utils = torch.utils.data._utils
def _shutdown_workers(self):
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
return
if hasattr(self, "_shutdown") and not self._shutdown:
self._shutdown = True
try:
if hasattr(self, '_pin_memory_thread'):
self._pin_memory_thread_done_event.set()
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.cancel_join_thread()
self._worker_result_queue.close()
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers: # pylint: disable=invalid-name
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
for q in self._index_queues: # pylint: disable=invalid-name
q.cancel_join_thread()
q.close()
finally:
if self._worker_pids_set:
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers: # pylint: disable=invalid-name
if w.is_alive():
w.terminate()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device):
return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
def ipex_no_cuda(orig_func, *args, **kwargs):
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
return original_autocast("xpu", *args[1:], **kwargs)
else:
return original_autocast(*args, **kwargs)
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
original_interpolate = torch.nn.functional.interpolate
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
else:
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
original_linalg_solve = torch.linalg.solve
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
return_device = A.device
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
else:
return original_linalg_solve(A, B, *args, **kwargs)
def ipex_hijacks():
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.ones',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
#Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
#Broken functions when torch.cuda.is_available is True:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
#Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
torch.nn.DataParallel = DummyDataParallel
torch.autocast = ipex_autocast
torch.cat = torch_cat
torch.linalg.solve = linalg_solve
torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context

View File

@ -17,6 +17,18 @@ n_gpus = len(hps.gpus.split("-"))
from random import randint, shuffle
import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
from infer.modules.ipex.gradscaler import gradscaler_init
from torch.xpu.amp import autocast
GradScaler = gradscaler_init()
ipex_init()
else:
from torch.cuda.amp import GradScaler, autocast
except Exception:
from torch.cuda.amp import GradScaler, autocast
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
@ -25,7 +37,6 @@ from time import time as ttime
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
@ -185,7 +196,9 @@ def run(rank, n_gpus, hps):
)
# net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
if torch.cuda.is_available():
if hasattr(torch, "xpu") and torch.xpu.is_available():
pass
elif torch.cuda.is_available():
net_g = DDP(net_g, device_ids=[rank])
net_d = DDP(net_d, device_ids=[rank])
else:
@ -212,19 +225,33 @@ def run(rank, n_gpus, hps):
if hps.pretrainG != "":
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainG))
if hasattr(net_g, "module"):
logger.info(
net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
else:
logger.info(
net_g.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
if hps.pretrainD != "":
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainD))
if hasattr(net_d, "module"):
logger.info(
net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
)
else:
logger.info(
net_d.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2

View File

@ -75,6 +75,14 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
)
infos.append("%s->Success" % (os.path.basename(inp_path)))
yield "\n".join(infos)
except:
try:
if done == 0:
pre_fun._path_audio_(
inp_path, save_root_ins, save_root_vocal, format0
)
infos.append("%s->Success" % (os.path.basename(inp_path)))
yield "\n".join(infos)
except:
infos.append(
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())

View File

@ -362,7 +362,7 @@ class Pipeline(object):
)
pitch = pitch[:p_len]
pitchf = pitchf[:p_len]
if self.device == "mps":
if self.device == "mps" or "xpu" in self.device:
pitchf = pitchf.astype(np.float32)
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()

54
requirements-ipex.txt Normal file
View File

@ -0,0 +1,54 @@
torch==2.0.1a0
intel_extension_for_pytorch==2.0.110+xpu
torchvision==0.15.2a0
https://github.com/Disty0/Retrieval-based-Voice-Conversion-WebUI/releases/download/torchaudio_wheels_for_ipex/torchaudio-2.0.2+31de77d-cp310-cp310-linux_x86_64.whl
-f https://developer.intel.com/ipex-whl-stable-xpu
joblib>=1.1.0
numba==0.56.4
numpy==1.23.5
scipy
librosa==0.9.1
llvmlite==0.39.0
fairseq==0.12.2
faiss-cpu==1.7.3
gradio==3.34.0
Cython
pydub>=0.25.1
soundfile>=0.12.1
ffmpeg-python>=0.2.0
tensorboardX
Jinja2>=3.1.2
json5
Markdown
matplotlib>=3.7.0
matplotlib-inline>=0.1.3
praat-parselmouth>=0.4.2
Pillow>=9.1.1
resampy>=0.4.2
scikit-learn
tensorboard
tqdm>=4.63.1
tornado>=6.1
Werkzeug>=2.2.3
uc-micro-py>=1.0.1
sympy>=1.11.1
tabulate>=0.8.10
PyYAML>=6.0
pyasn1>=0.4.8
pyasn1-modules>=0.2.8
fsspec>=2022.11.0
absl-py>=1.2.0
audioread
uvicorn>=0.21.1
colorama>=0.4.5
pyworld==0.3.2
httpx
onnxruntime; sys_platform == 'darwin'
onnxruntime-gpu; sys_platform != 'darwin'
torchcrepe==0.0.20
fastapi==0.88
ffmpy==0.3.1
python-dotenv>=1.0.0
av
PySimpleGUI
sounddevice