Format code (#1193)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
72a18e66b6
commit
a6456f6d46
@ -5,10 +5,13 @@ import json
|
|||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from infer.modules.ipex import ipex_init
|
from infer.modules.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
150
gui_v1.py
150
gui_v1.py
@ -478,15 +478,28 @@ if __name__ == "__main__":
|
|||||||
inp_q,
|
inp_q,
|
||||||
opt_q,
|
opt_q,
|
||||||
device,
|
device,
|
||||||
self.rvc if hasattr(self, "rvc") else None
|
self.rvc if hasattr(self, "rvc") else None,
|
||||||
)
|
)
|
||||||
self.config.samplerate = self.rvc.tgt_sr
|
self.config.samplerate = self.rvc.tgt_sr
|
||||||
self.zc = self.rvc.tgt_sr // 100
|
self.zc = self.rvc.tgt_sr // 100
|
||||||
self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc
|
self.block_frame = (
|
||||||
|
int(np.round(self.config.block_time * self.config.samplerate / self.zc))
|
||||||
|
* self.zc
|
||||||
|
)
|
||||||
self.block_frame_16k = 160 * self.block_frame // self.zc
|
self.block_frame_16k = 160 * self.block_frame // self.zc
|
||||||
self.crossfade_frame = int(np.round(self.config.crossfade_time * self.config.samplerate / self.zc)) * self.zc
|
self.crossfade_frame = (
|
||||||
|
int(
|
||||||
|
np.round(
|
||||||
|
self.config.crossfade_time * self.config.samplerate / self.zc
|
||||||
|
)
|
||||||
|
)
|
||||||
|
* self.zc
|
||||||
|
)
|
||||||
self.sola_search_frame = self.zc
|
self.sola_search_frame = self.zc
|
||||||
self.extra_frame = int(np.round(self.config.extra_time * self.config.samplerate / self.zc)) * self.zc
|
self.extra_frame = (
|
||||||
|
int(np.round(self.config.extra_time * self.config.samplerate / self.zc))
|
||||||
|
* self.zc
|
||||||
|
)
|
||||||
self.input_wav: torch.Tensor = torch.zeros(
|
self.input_wav: torch.Tensor = torch.zeros(
|
||||||
self.extra_frame
|
self.extra_frame
|
||||||
+ self.crossfade_frame
|
+ self.crossfade_frame
|
||||||
@ -495,7 +508,11 @@ if __name__ == "__main__":
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.input_wav_res: torch.Tensor= torch.zeros(160 * self.input_wav.shape[0] // self.zc, device=device,dtype=torch.float32)
|
self.input_wav_res: torch.Tensor = torch.zeros(
|
||||||
|
160 * self.input_wav.shape[0] // self.zc,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
self.pitch: np.ndarray = np.zeros(
|
self.pitch: np.ndarray = np.zeros(
|
||||||
self.input_wav.shape[0] // self.zc,
|
self.input_wav.shape[0] // self.zc,
|
||||||
dtype="int32",
|
dtype="int32",
|
||||||
@ -509,7 +526,9 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
self.nr_buffer: torch.Tensor = self.sola_buffer.clone()
|
self.nr_buffer: torch.Tensor = self.sola_buffer.clone()
|
||||||
self.output_buffer: torch.Tensor = self.input_wav.clone()
|
self.output_buffer: torch.Tensor = self.input_wav.clone()
|
||||||
self.res_buffer: torch.Tensor = torch.zeros(2 * self.zc, device=device,dtype=torch.float32)
|
self.res_buffer: torch.Tensor = torch.zeros(
|
||||||
|
2 * self.zc, device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
self.valid_rate = 1 - (self.extra_frame - 1) / self.input_wav.shape[0]
|
self.valid_rate = 1 - (self.extra_frame - 1) / self.input_wav.shape[0]
|
||||||
self.fade_in_window: torch.Tensor = (
|
self.fade_in_window: torch.Tensor = (
|
||||||
torch.sin(
|
torch.sin(
|
||||||
@ -529,7 +548,9 @@ if __name__ == "__main__":
|
|||||||
self.resampler = tat.Resample(
|
self.resampler = tat.Resample(
|
||||||
orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32
|
orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32
|
||||||
).to(device)
|
).to(device)
|
||||||
self.tg = TorchGate(sr=self.config.samplerate, n_fft=4*self.zc, prop_decrease=0.9).to(device)
|
self.tg = TorchGate(
|
||||||
|
sr=self.config.samplerate, n_fft=4 * self.zc, prop_decrease=0.9
|
||||||
|
).to(device)
|
||||||
thread_vc = threading.Thread(target=self.soundinput)
|
thread_vc = threading.Thread(target=self.soundinput)
|
||||||
thread_vc.start()
|
thread_vc.start()
|
||||||
|
|
||||||
@ -560,7 +581,7 @@ if __name__ == "__main__":
|
|||||||
indata = librosa.to_mono(indata.T)
|
indata = librosa.to_mono(indata.T)
|
||||||
if self.config.threhold > -60:
|
if self.config.threhold > -60:
|
||||||
rms = librosa.feature.rms(
|
rms = librosa.feature.rms(
|
||||||
y=indata, frame_length=4*self.zc, hop_length=self.zc
|
y=indata, frame_length=4 * self.zc, hop_length=self.zc
|
||||||
)
|
)
|
||||||
db_threhold = (
|
db_threhold = (
|
||||||
librosa.amplitude_to_db(rms, ref=1.0)[0] < self.config.threhold
|
librosa.amplitude_to_db(rms, ref=1.0)[0] < self.config.threhold
|
||||||
@ -568,28 +589,44 @@ if __name__ == "__main__":
|
|||||||
for i in range(db_threhold.shape[0]):
|
for i in range(db_threhold.shape[0]):
|
||||||
if db_threhold[i]:
|
if db_threhold[i]:
|
||||||
indata[i * self.zc : (i + 1) * self.zc] = 0
|
indata[i * self.zc : (i + 1) * self.zc] = 0
|
||||||
self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :].clone()
|
self.input_wav[: -self.block_frame] = self.input_wav[
|
||||||
self.input_wav[-self.block_frame: ] = torch.from_numpy(indata).to(device)
|
self.block_frame :
|
||||||
self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone()
|
].clone()
|
||||||
|
self.input_wav[-self.block_frame :] = torch.from_numpy(indata).to(device)
|
||||||
|
self.input_wav_res[: -self.block_frame_16k] = self.input_wav_res[
|
||||||
|
self.block_frame_16k :
|
||||||
|
].clone()
|
||||||
# input noise reduction and resampling
|
# input noise reduction and resampling
|
||||||
if self.config.I_noise_reduce:
|
if self.config.I_noise_reduce:
|
||||||
input_wav = self.input_wav[-self.crossfade_frame -self.block_frame-2*self.zc: ]
|
input_wav = self.input_wav[
|
||||||
input_wav = self.tg(input_wav.unsqueeze(0), self.input_wav.unsqueeze(0))[0, 2*self.zc:]
|
-self.crossfade_frame - self.block_frame - 2 * self.zc :
|
||||||
|
]
|
||||||
|
input_wav = self.tg(
|
||||||
|
input_wav.unsqueeze(0), self.input_wav.unsqueeze(0)
|
||||||
|
)[0, 2 * self.zc :]
|
||||||
input_wav[: self.crossfade_frame] *= self.fade_in_window
|
input_wav[: self.crossfade_frame] *= self.fade_in_window
|
||||||
input_wav[: self.crossfade_frame] += self.nr_buffer * self.fade_out_window
|
input_wav[: self.crossfade_frame] += (
|
||||||
self.nr_buffer[:] = input_wav[-self.crossfade_frame: ]
|
self.nr_buffer * self.fade_out_window
|
||||||
input_wav = torch.cat((self.res_buffer[:], input_wav[: self.block_frame]))
|
)
|
||||||
self.res_buffer[:] = input_wav[-2*self.zc: ]
|
self.nr_buffer[:] = input_wav[-self.crossfade_frame :]
|
||||||
self.input_wav_res[-self.block_frame_16k-160: ] = self.resampler(input_wav)[160: ]
|
input_wav = torch.cat(
|
||||||
|
(self.res_buffer[:], input_wav[: self.block_frame])
|
||||||
|
)
|
||||||
|
self.res_buffer[:] = input_wav[-2 * self.zc :]
|
||||||
|
self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler(
|
||||||
|
input_wav
|
||||||
|
)[160:]
|
||||||
else:
|
else:
|
||||||
self.input_wav_res[-self.block_frame_16k-160: ] = self.resampler(self.input_wav[-self.block_frame-2*self.zc: ])[160: ]
|
self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler(
|
||||||
|
self.input_wav[-self.block_frame - 2 * self.zc :]
|
||||||
|
)[160:]
|
||||||
# infer
|
# infer
|
||||||
f0_extractor_frame = self.block_frame_16k + 800
|
f0_extractor_frame = self.block_frame_16k + 800
|
||||||
if self.config.f0method == 'rmvpe':
|
if self.config.f0method == "rmvpe":
|
||||||
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1)
|
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1)
|
||||||
infer_wav = self.rvc.infer(
|
infer_wav = self.rvc.infer(
|
||||||
self.input_wav_res,
|
self.input_wav_res,
|
||||||
self.input_wav_res[-f0_extractor_frame :].cpu().numpy(),
|
self.input_wav_res[-f0_extractor_frame:].cpu().numpy(),
|
||||||
self.block_frame_16k,
|
self.block_frame_16k,
|
||||||
self.valid_rate,
|
self.valid_rate,
|
||||||
self.pitch,
|
self.pitch,
|
||||||
@ -601,48 +638,77 @@ if __name__ == "__main__":
|
|||||||
]
|
]
|
||||||
# output noise reduction
|
# output noise reduction
|
||||||
if self.config.O_noise_reduce:
|
if self.config.O_noise_reduce:
|
||||||
self.output_buffer[: -self.block_frame] = self.output_buffer[self.block_frame :].clone()
|
self.output_buffer[: -self.block_frame] = self.output_buffer[
|
||||||
self.output_buffer[-self.block_frame: ] = infer_wav[-self.block_frame:]
|
self.block_frame :
|
||||||
infer_wav = self.tg(infer_wav.unsqueeze(0), self.output_buffer.unsqueeze(0)).squeeze(0)
|
].clone()
|
||||||
|
self.output_buffer[-self.block_frame :] = infer_wav[-self.block_frame :]
|
||||||
|
infer_wav = self.tg(
|
||||||
|
infer_wav.unsqueeze(0), self.output_buffer.unsqueeze(0)
|
||||||
|
).squeeze(0)
|
||||||
# volume envelop mixing
|
# volume envelop mixing
|
||||||
if self.config.rms_mix_rate < 1:
|
if self.config.rms_mix_rate < 1:
|
||||||
rms1 = librosa.feature.rms(
|
rms1 = librosa.feature.rms(
|
||||||
y=self.input_wav_res[-160*infer_wav.shape[0]//self.zc :].cpu().numpy(),
|
y=self.input_wav_res[-160 * infer_wav.shape[0] // self.zc :]
|
||||||
frame_length=640,
|
.cpu()
|
||||||
hop_length=160,
|
.numpy(),
|
||||||
|
frame_length=640,
|
||||||
|
hop_length=160,
|
||||||
)
|
)
|
||||||
rms1 = torch.from_numpy(rms1).to(device)
|
rms1 = torch.from_numpy(rms1).to(device)
|
||||||
rms1 = F.interpolate(
|
rms1 = F.interpolate(
|
||||||
rms1.unsqueeze(0), size=infer_wav.shape[0] + 1, mode="linear",align_corners=True,
|
rms1.unsqueeze(0),
|
||||||
)[0,0,:-1]
|
size=infer_wav.shape[0] + 1,
|
||||||
|
mode="linear",
|
||||||
|
align_corners=True,
|
||||||
|
)[0, 0, :-1]
|
||||||
rms2 = librosa.feature.rms(
|
rms2 = librosa.feature.rms(
|
||||||
y=infer_wav[:].cpu().numpy(), frame_length=4*self.zc, hop_length=self.zc
|
y=infer_wav[:].cpu().numpy(),
|
||||||
|
frame_length=4 * self.zc,
|
||||||
|
hop_length=self.zc,
|
||||||
)
|
)
|
||||||
rms2 = torch.from_numpy(rms2).to(device)
|
rms2 = torch.from_numpy(rms2).to(device)
|
||||||
rms2 = F.interpolate(
|
rms2 = F.interpolate(
|
||||||
rms2.unsqueeze(0), size=infer_wav.shape[0] + 1, mode="linear",align_corners=True,
|
rms2.unsqueeze(0),
|
||||||
)[0,0,:-1]
|
size=infer_wav.shape[0] + 1,
|
||||||
|
mode="linear",
|
||||||
|
align_corners=True,
|
||||||
|
)[0, 0, :-1]
|
||||||
rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3)
|
rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3)
|
||||||
infer_wav *= torch.pow(rms1 / rms2, torch.tensor(1 - self.config.rms_mix_rate))
|
infer_wav *= torch.pow(
|
||||||
|
rms1 / rms2, torch.tensor(1 - self.config.rms_mix_rate)
|
||||||
|
)
|
||||||
# SOLA algorithm from https://github.com/yxlllc/DDSP-SVC
|
# SOLA algorithm from https://github.com/yxlllc/DDSP-SVC
|
||||||
conv_input = infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame]
|
conv_input = infer_wav[
|
||||||
|
None, None, : self.crossfade_frame + self.sola_search_frame
|
||||||
|
]
|
||||||
cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
|
cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
|
||||||
cor_den = torch.sqrt(
|
cor_den = torch.sqrt(
|
||||||
F.conv1d(conv_input ** 2, torch.ones(1, 1, self.crossfade_frame, device=device)) + 1e-8)
|
F.conv1d(
|
||||||
|
conv_input**2,
|
||||||
|
torch.ones(1, 1, self.crossfade_frame, device=device),
|
||||||
|
)
|
||||||
|
+ 1e-8
|
||||||
|
)
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
_, sola_offset = torch.max(cor_nom[0, 0] / cor_den[0, 0])
|
_, sola_offset = torch.max(cor_nom[0, 0] / cor_den[0, 0])
|
||||||
sola_offset = sola_offset.item()
|
sola_offset = sola_offset.item()
|
||||||
else:
|
else:
|
||||||
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
||||||
logger.debug("sola_offset = %d", int(sola_offset))
|
logger.debug("sola_offset = %d", int(sola_offset))
|
||||||
infer_wav = infer_wav[sola_offset: sola_offset + self.block_frame + self.crossfade_frame]
|
infer_wav = infer_wav[
|
||||||
|
sola_offset : sola_offset + self.block_frame + self.crossfade_frame
|
||||||
|
]
|
||||||
infer_wav[: self.crossfade_frame] *= self.fade_in_window
|
infer_wav[: self.crossfade_frame] *= self.fade_in_window
|
||||||
infer_wav[: self.crossfade_frame] += self.sola_buffer *self.fade_out_window
|
infer_wav[: self.crossfade_frame] += self.sola_buffer * self.fade_out_window
|
||||||
self.sola_buffer[:] = infer_wav[-self.crossfade_frame:]
|
self.sola_buffer[:] = infer_wav[-self.crossfade_frame :]
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
outdata[:] = infer_wav[:-self.crossfade_frame].cpu().numpy()[:, np.newaxis]
|
outdata[:] = (
|
||||||
|
infer_wav[: -self.crossfade_frame].cpu().numpy()[:, np.newaxis]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
outdata[:] = infer_wav[:-self.crossfade_frame].repeat(2, 1).t().cpu().numpy()
|
outdata[:] = (
|
||||||
|
infer_wav[: -self.crossfade_frame].repeat(2, 1).t().cpu().numpy()
|
||||||
|
)
|
||||||
total_time = time.perf_counter() - start_time
|
total_time = time.perf_counter() - start_time
|
||||||
self.window["infer_time"].update(int(total_time * 1000))
|
self.window["infer_time"].update(int(total_time * 1000))
|
||||||
logger.info("Infer time: %.2f", total_time)
|
logger.info("Infer time: %.2f", total_time)
|
||||||
@ -698,9 +764,7 @@ if __name__ == "__main__":
|
|||||||
sd.default.device[1] = output_device_indices[
|
sd.default.device[1] = output_device_indices[
|
||||||
output_devices.index(output_device)
|
output_devices.index(output_device)
|
||||||
]
|
]
|
||||||
logger.info(
|
logger.info("Input device: %s:%s", str(sd.default.device[0]), input_device)
|
||||||
"Input device: %s:%s", str(sd.default.device[0]), input_device
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Output device: %s:%s", str(sd.default.device[1]), output_device
|
"Output device: %s:%s", str(sd.default.device[1]), output_device
|
||||||
)
|
)
|
||||||
|
@ -1028,7 +1028,7 @@ with gr.Blocks(title="RVC WebUI") as app:
|
|||||||
fn=vc.get_vc,
|
fn=vc.get_vc,
|
||||||
inputs=[sid0, protect0, protect1],
|
inputs=[sid0, protect0, protect1],
|
||||||
outputs=[spk_item, protect0, protect1, file_index2, file_index4],
|
outputs=[spk_item, protect0, protect1, file_index2, file_index4],
|
||||||
api_name="infer_change_voice"
|
api_name="infer_change_voice",
|
||||||
)
|
)
|
||||||
with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
|
@ -3,38 +3,49 @@ import numpy as np
|
|||||||
import av
|
import av
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
def wav2(i, o, format):
|
def wav2(i, o, format):
|
||||||
inp = av.open(i, 'rb')
|
inp = av.open(i, "rb")
|
||||||
if format == "m4a": format = "mp4"
|
if format == "m4a":
|
||||||
out = av.open(o, 'wb', format=format)
|
format = "mp4"
|
||||||
if format == "ogg": format = "libvorbis"
|
out = av.open(o, "wb", format=format)
|
||||||
if format == "mp4": format = "aac"
|
if format == "ogg":
|
||||||
|
format = "libvorbis"
|
||||||
|
if format == "mp4":
|
||||||
|
format = "aac"
|
||||||
|
|
||||||
ostream = out.add_stream(format)
|
ostream = out.add_stream(format)
|
||||||
|
|
||||||
for frame in inp.decode(audio=0):
|
for frame in inp.decode(audio=0):
|
||||||
for p in ostream.encode(frame): out.mux(p)
|
for p in ostream.encode(frame):
|
||||||
|
out.mux(p)
|
||||||
|
|
||||||
for p in ostream.encode(None): out.mux(p)
|
for p in ostream.encode(None):
|
||||||
|
out.mux(p)
|
||||||
|
|
||||||
out.close()
|
out.close()
|
||||||
inp.close()
|
inp.close()
|
||||||
|
|
||||||
|
|
||||||
def audio2(i, o, format, sr):
|
def audio2(i, o, format, sr):
|
||||||
inp = av.open(i, 'rb')
|
inp = av.open(i, "rb")
|
||||||
out = av.open(o, 'wb', format=format)
|
out = av.open(o, "wb", format=format)
|
||||||
if format == "ogg": format = "libvorbis"
|
if format == "ogg":
|
||||||
if format == "f32le": format = "pcm_f32le"
|
format = "libvorbis"
|
||||||
|
if format == "f32le":
|
||||||
|
format = "pcm_f32le"
|
||||||
|
|
||||||
ostream = out.add_stream(format, channels=1)
|
ostream = out.add_stream(format, channels=1)
|
||||||
ostream.sample_rate = sr
|
ostream.sample_rate = sr
|
||||||
|
|
||||||
for frame in inp.decode(audio=0):
|
for frame in inp.decode(audio=0):
|
||||||
for p in ostream.encode(frame): out.mux(p)
|
for p in ostream.encode(frame):
|
||||||
|
out.mux(p)
|
||||||
|
|
||||||
out.close()
|
out.close()
|
||||||
inp.close()
|
inp.close()
|
||||||
|
|
||||||
|
|
||||||
def load_audio(file, sr):
|
def load_audio(file, sr):
|
||||||
try:
|
try:
|
||||||
file = (
|
file = (
|
||||||
|
@ -15,6 +15,7 @@ from infer.lib.infer_pack.commons import get_padding, init_weights
|
|||||||
|
|
||||||
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
|
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||||
|
|
||||||
|
|
||||||
class TextEncoder256(nn.Module):
|
class TextEncoder256(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1158,7 +1159,9 @@ class DiscriminatorP(torch.nn.Module):
|
|||||||
if t % self.period != 0: # pad first
|
if t % self.period != 0: # pad first
|
||||||
n_pad = self.period - (t % self.period)
|
n_pad = self.period - (t % self.period)
|
||||||
if has_xpu and x.dtype == torch.bfloat16:
|
if has_xpu and x.dtype == torch.bfloat16:
|
||||||
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(dtype=torch.bfloat16)
|
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
|
||||||
|
dtype=torch.bfloat16
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
x = F.pad(x, (0, n_pad), "reflect")
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
t = t + n_pad
|
t = t + n_pad
|
||||||
|
@ -2,11 +2,14 @@ import pdb, os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
#Fix "Torch not compiled with CUDA enabled"
|
# Fix "Torch not compiled with CUDA enabled"
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from infer.modules.ipex import ipex_init
|
from infer.modules.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
@ -2,15 +2,16 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import contextlib
|
import contextlib
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
from .hijacks import ipex_hijacks
|
from .hijacks import ipex_hijacks
|
||||||
from .attention import attention_init
|
from .attention import attention_init
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
def ipex_init(): # pylint: disable=too-many-statements
|
|
||||||
|
def ipex_init(): # pylint: disable=too-many-statements
|
||||||
try:
|
try:
|
||||||
#Replace cuda with xpu:
|
# Replace cuda with xpu:
|
||||||
torch.cuda.current_device = torch.xpu.current_device
|
torch.cuda.current_device = torch.xpu.current_device
|
||||||
torch.cuda.current_stream = torch.xpu.current_stream
|
torch.cuda.current_stream = torch.xpu.current_stream
|
||||||
torch.cuda.device = torch.xpu.device
|
torch.cuda.device = torch.xpu.device
|
||||||
@ -91,11 +92,11 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||||
torch.cuda.__file__ = torch.xpu.__file__
|
torch.cuda.__file__ = torch.xpu.__file__
|
||||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||||
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||||
|
|
||||||
#Memory:
|
# Memory:
|
||||||
torch.cuda.memory = torch.xpu.memory
|
torch.cuda.memory = torch.xpu.memory
|
||||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
if "linux" in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||||
torch.xpu.empty_cache = lambda: None
|
torch.xpu.empty_cache = lambda: None
|
||||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||||
@ -111,9 +112,11 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
||||||
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
||||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
torch.cuda.reset_accumulated_memory_stats = (
|
||||||
|
torch.xpu.reset_accumulated_memory_stats
|
||||||
|
)
|
||||||
|
|
||||||
#RNG:
|
# RNG:
|
||||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||||
@ -124,35 +127,44 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.seed_all = torch.xpu.seed_all
|
torch.cuda.seed_all = torch.xpu.seed_all
|
||||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||||
|
|
||||||
#AMP:
|
# AMP:
|
||||||
torch.cuda.amp = torch.xpu.amp
|
torch.cuda.amp = torch.xpu.amp
|
||||||
if not hasattr(torch.cuda.amp, "common"):
|
if not hasattr(torch.cuda.amp, "common"):
|
||||||
torch.cuda.amp.common = contextlib.nullcontext()
|
torch.cuda.amp.common = contextlib.nullcontext()
|
||||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||||
try:
|
try:
|
||||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
try:
|
try:
|
||||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
from .gradscaler import (
|
||||||
|
gradscaler_init,
|
||||||
|
) # pylint: disable=import-outside-toplevel, import-error
|
||||||
|
|
||||||
gradscaler_init()
|
gradscaler_init()
|
||||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||||
|
|
||||||
#C
|
# C
|
||||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||||
ipex._C._DeviceProperties.major = 2023
|
ipex._C._DeviceProperties.major = 2023
|
||||||
ipex._C._DeviceProperties.minor = 2
|
ipex._C._DeviceProperties.minor = 2
|
||||||
|
|
||||||
#Fix functions with ipex:
|
# Fix functions with ipex:
|
||||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory]
|
torch.cuda.mem_get_info = lambda device=None: [
|
||||||
|
(
|
||||||
|
torch.xpu.get_device_properties(device).total_memory
|
||||||
|
- torch.xpu.memory_allocated(device)
|
||||||
|
),
|
||||||
|
torch.xpu.get_device_properties(device).total_memory,
|
||||||
|
]
|
||||||
torch._utils._get_available_device_type = lambda: "xpu"
|
torch._utils._get_available_device_type = lambda: "xpu"
|
||||||
torch.has_cuda = True
|
torch.has_cuda = True
|
||||||
torch.cuda.has_half = True
|
torch.cuda.has_half = True
|
||||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||||
torch.version.cuda = "11.7"
|
torch.version.cuda = "11.7"
|
||||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
|
torch.cuda.get_device_capability = lambda *args, **kwargs: [11, 7]
|
||||||
torch.cuda.get_device_properties.major = 11
|
torch.cuda.get_device_properties.major = 11
|
||||||
torch.cuda.get_device_properties.minor = 7
|
torch.cuda.get_device_properties.minor = 7
|
||||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||||
|
@ -1,22 +1,32 @@
|
|||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
original_torch_bmm = torch.bmm
|
original_torch_bmm = torch.bmm
|
||||||
|
|
||||||
|
|
||||||
def torch_bmm(input, mat2, *, out=None):
|
def torch_bmm(input, mat2, *, out=None):
|
||||||
if input.dtype != mat2.dtype:
|
if input.dtype != mat2.dtype:
|
||||||
mat2 = mat2.to(input.dtype)
|
mat2 = mat2.to(input.dtype)
|
||||||
|
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
batch_size_attention, input_tokens, mat2_shape = (
|
||||||
|
input.shape[0],
|
||||||
|
input.shape[1],
|
||||||
|
mat2.shape[2],
|
||||||
|
)
|
||||||
block_multiply = 2.4 if input.dtype == torch.float32 else 1.2
|
block_multiply = 2.4 if input.dtype == torch.float32 else 1.2
|
||||||
block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB
|
block_size = (
|
||||||
|
(batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply
|
||||||
|
) # MB
|
||||||
split_slice_size = batch_size_attention
|
split_slice_size = batch_size_attention
|
||||||
if block_size >= 4000:
|
if block_size >= 4000:
|
||||||
do_split = True
|
do_split = True
|
||||||
#Find something divisible with the input_tokens
|
# Find something divisible with the input_tokens
|
||||||
while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000:
|
while (
|
||||||
|
(split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply
|
||||||
|
) > 4000:
|
||||||
split_slice_size = split_slice_size // 2
|
split_slice_size = split_slice_size // 2
|
||||||
if split_slice_size <= 1:
|
if split_slice_size <= 1:
|
||||||
split_slice_size = 1
|
split_slice_size = 1
|
||||||
@ -24,12 +34,16 @@ def torch_bmm(input, mat2, *, out=None):
|
|||||||
else:
|
else:
|
||||||
do_split = False
|
do_split = False
|
||||||
|
|
||||||
split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB
|
split_block_size = (
|
||||||
|
(split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply
|
||||||
|
) # MB
|
||||||
split_2_slice_size = input_tokens
|
split_2_slice_size = input_tokens
|
||||||
if split_block_size >= 4000:
|
if split_block_size >= 4000:
|
||||||
do_split_2 = True
|
do_split_2 = True
|
||||||
#Find something divisible with the input_tokens
|
# Find something divisible with the input_tokens
|
||||||
while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000:
|
while (
|
||||||
|
(split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply
|
||||||
|
) > 4000:
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
split_2_slice_size = split_2_slice_size // 2
|
||||||
if split_2_slice_size <= 1:
|
if split_2_slice_size <= 1:
|
||||||
split_2_slice_size = 1
|
split_2_slice_size = 1
|
||||||
@ -38,40 +52,61 @@ def torch_bmm(input, mat2, *, out=None):
|
|||||||
do_split_2 = False
|
do_split_2 = False
|
||||||
|
|
||||||
if do_split:
|
if do_split:
|
||||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
hidden_states = torch.zeros(
|
||||||
|
input.shape[0],
|
||||||
|
input.shape[1],
|
||||||
|
mat2.shape[2],
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
start_idx = i * split_slice_size
|
start_idx = i * split_slice_size
|
||||||
end_idx = (i + 1) * split_slice_size
|
end_idx = (i + 1) * split_slice_size
|
||||||
if do_split_2:
|
if do_split_2:
|
||||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
for i2 in range(
|
||||||
|
input_tokens // split_2_slice_size
|
||||||
|
): # pylint: disable=invalid-name
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
hidden_states[
|
||||||
|
start_idx:end_idx, start_idx_2:end_idx_2
|
||||||
|
] = original_torch_bmm(
|
||||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
out=out
|
out=out,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||||
input[start_idx:end_idx],
|
input[start_idx:end_idx], mat2[start_idx:end_idx], out=out
|
||||||
mat2[start_idx:end_idx],
|
|
||||||
out=out
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return original_torch_bmm(input, mat2, out=out)
|
return original_torch_bmm(input, mat2, out=out)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
):
|
||||||
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
||||||
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
|
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
|
||||||
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
|
block_size = (
|
||||||
|
(shape_one * batch_size_attention * query_tokens * shape_four)
|
||||||
|
/ 1024
|
||||||
|
* block_multiply
|
||||||
|
) # MB
|
||||||
split_slice_size = batch_size_attention
|
split_slice_size = batch_size_attention
|
||||||
if block_size >= 4000:
|
if block_size >= 4000:
|
||||||
do_split = True
|
do_split = True
|
||||||
#Find something divisible with the shape_one
|
# Find something divisible with the shape_one
|
||||||
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
|
while (
|
||||||
|
(shape_one * split_slice_size * query_tokens * shape_four)
|
||||||
|
/ 1024
|
||||||
|
* block_multiply
|
||||||
|
) > 4000:
|
||||||
split_slice_size = split_slice_size // 2
|
split_slice_size = split_slice_size // 2
|
||||||
if split_slice_size <= 1:
|
if split_slice_size <= 1:
|
||||||
split_slice_size = 1
|
split_slice_size = 1
|
||||||
@ -79,12 +114,20 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|||||||
else:
|
else:
|
||||||
do_split = False
|
do_split = False
|
||||||
|
|
||||||
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
|
split_block_size = (
|
||||||
|
(shape_one * split_slice_size * query_tokens * shape_four)
|
||||||
|
/ 1024
|
||||||
|
* block_multiply
|
||||||
|
) # MB
|
||||||
split_2_slice_size = query_tokens
|
split_2_slice_size = query_tokens
|
||||||
if split_block_size >= 4000:
|
if split_block_size >= 4000:
|
||||||
do_split_2 = True
|
do_split_2 = True
|
||||||
#Find something divisible with the batch_size_attention
|
# Find something divisible with the batch_size_attention
|
||||||
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
|
while (
|
||||||
|
(shape_one * split_slice_size * split_2_slice_size * shape_four)
|
||||||
|
/ 1024
|
||||||
|
* block_multiply
|
||||||
|
) > 4000:
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
split_2_slice_size = split_2_slice_size // 2
|
||||||
if split_2_slice_size <= 1:
|
if split_2_slice_size <= 1:
|
||||||
split_2_slice_size = 1
|
split_2_slice_size = 1
|
||||||
@ -98,31 +141,49 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|||||||
start_idx = i * split_slice_size
|
start_idx = i * split_slice_size
|
||||||
end_idx = (i + 1) * split_slice_size
|
end_idx = (i + 1) * split_slice_size
|
||||||
if do_split_2:
|
if do_split_2:
|
||||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
for i2 in range(
|
||||||
|
query_tokens // split_2_slice_size
|
||||||
|
): # pylint: disable=invalid-name
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
hidden_states[
|
||||||
|
:, start_idx:end_idx, start_idx_2:end_idx_2
|
||||||
|
] = original_scaled_dot_product_attention(
|
||||||
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
if attn_mask is not None
|
||||||
|
else attn_mask,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
|
hidden_states[
|
||||||
|
:, start_idx:end_idx
|
||||||
|
] = original_scaled_dot_product_attention(
|
||||||
query[:, start_idx:end_idx],
|
query[:, start_idx:end_idx],
|
||||||
key[:, start_idx:end_idx],
|
key[:, start_idx:end_idx],
|
||||||
value[:, start_idx:end_idx],
|
value[:, start_idx:end_idx],
|
||||||
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
|
attn_mask=attn_mask[:, start_idx:end_idx]
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
if attn_mask is not None
|
||||||
|
else attn_mask,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return original_scaled_dot_product_attention(
|
return original_scaled_dot_product_attention(
|
||||||
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def attention_init():
|
def attention_init():
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
# ARC GPUs can't allocate more than 4GB to a single block:
|
||||||
torch.bmm = torch_bmm
|
torch.bmm = torch_bmm
|
||||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
_refresh_per_optimizer_state = (
|
||||||
|
ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||||
|
)
|
||||||
|
|
||||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
|
||||||
|
def _unscale_grads_(
|
||||||
|
self, optimizer, inv_scale, found_inf, allow_fp16
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||||
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||||
|
|
||||||
@ -43,9 +48,9 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint
|
|||||||
|
|
||||||
# -: is there a way to split by device and dtype without appending in the inner loop?
|
# -: is there a way to split by device and dtype without appending in the inner loop?
|
||||||
to_unscale = to_unscale.to("cpu")
|
to_unscale = to_unscale.to("cpu")
|
||||||
per_device_and_dtype_grads[to_unscale.device][
|
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
|
||||||
to_unscale.dtype
|
to_unscale
|
||||||
].append(to_unscale)
|
)
|
||||||
|
|
||||||
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||||
for grads in per_dtype_grads.values():
|
for grads in per_dtype_grads.values():
|
||||||
@ -57,6 +62,7 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint
|
|||||||
|
|
||||||
return per_device_found_inf._per_device_tensors
|
return per_device_found_inf._per_device_tensors
|
||||||
|
|
||||||
|
|
||||||
def unscale_(self, optimizer):
|
def unscale_(self, optimizer):
|
||||||
"""
|
"""
|
||||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||||
@ -87,7 +93,7 @@ def unscale_(self, optimizer):
|
|||||||
|
|
||||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"unscale_() has already been called on this optimizer since the last update()."
|
"unscale_() has already been called on this optimizer since the last update()."
|
||||||
)
|
)
|
||||||
@ -96,16 +102,17 @@ def unscale_(self, optimizer):
|
|||||||
|
|
||||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||||
assert self._scale is not None
|
assert self._scale is not None
|
||||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
inv_scale = (
|
||||||
found_inf = torch.full(
|
self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
|
||||||
)
|
)
|
||||||
|
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||||
|
|
||||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||||
optimizer, inv_scale, found_inf, False
|
optimizer, inv_scale, found_inf, False
|
||||||
)
|
)
|
||||||
optimizer_state["stage"] = OptState.UNSCALED
|
optimizer_state["stage"] = OptState.UNSCALED
|
||||||
|
|
||||||
|
|
||||||
def update(self, new_scale=None):
|
def update(self, new_scale=None):
|
||||||
"""
|
"""
|
||||||
Updates the scale factor.
|
Updates the scale factor.
|
||||||
@ -171,6 +178,7 @@ def update(self, new_scale=None):
|
|||||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
|
||||||
|
|
||||||
def gradscaler_init():
|
def gradscaler_init():
|
||||||
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||||
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
||||||
|
@ -1,45 +1,59 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import importlib
|
import importlib
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||||
|
|
||||||
class CondFunc: # pylint: disable=missing-class-docstring
|
|
||||||
|
class CondFunc: # pylint: disable=missing-class-docstring
|
||||||
def __new__(cls, orig_func, sub_func, cond_func):
|
def __new__(cls, orig_func, sub_func, cond_func):
|
||||||
self = super(CondFunc, cls).__new__(cls)
|
self = super(CondFunc, cls).__new__(cls)
|
||||||
if isinstance(orig_func, str):
|
if isinstance(orig_func, str):
|
||||||
func_path = orig_func.split('.')
|
func_path = orig_func.split(".")
|
||||||
for i in range(len(func_path)-1, -1, -1):
|
for i in range(len(func_path) - 1, -1, -1):
|
||||||
try:
|
try:
|
||||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
resolved_obj = importlib.import_module(".".join(func_path[:i]))
|
||||||
break
|
break
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
for attr_name in func_path[i:-1]:
|
for attr_name in func_path[i:-1]:
|
||||||
resolved_obj = getattr(resolved_obj, attr_name)
|
resolved_obj = getattr(resolved_obj, attr_name)
|
||||||
orig_func = getattr(resolved_obj, func_path[-1])
|
orig_func = getattr(resolved_obj, func_path[-1])
|
||||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
setattr(
|
||||||
|
resolved_obj,
|
||||||
|
func_path[-1],
|
||||||
|
lambda *args, **kwargs: self(*args, **kwargs),
|
||||||
|
)
|
||||||
self.__init__(orig_func, sub_func, cond_func)
|
self.__init__(orig_func, sub_func, cond_func)
|
||||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||||
|
|
||||||
def __init__(self, orig_func, sub_func, cond_func):
|
def __init__(self, orig_func, sub_func, cond_func):
|
||||||
self.__orig_func = orig_func
|
self.__orig_func = orig_func
|
||||||
self.__sub_func = sub_func
|
self.__sub_func = sub_func
|
||||||
self.__cond_func = cond_func
|
self.__cond_func = cond_func
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
||||||
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.__orig_func(*args, **kwargs)
|
return self.__orig_func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
_utils = torch.utils.data._utils
|
_utils = torch.utils.data._utils
|
||||||
|
|
||||||
|
|
||||||
def _shutdown_workers(self):
|
def _shutdown_workers(self):
|
||||||
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
|
if (
|
||||||
|
torch.utils.data._utils is None
|
||||||
|
or torch.utils.data._utils.python_exit_status is True
|
||||||
|
or torch.utils.data._utils.python_exit_status is None
|
||||||
|
):
|
||||||
return
|
return
|
||||||
if hasattr(self, "_shutdown") and not self._shutdown:
|
if hasattr(self, "_shutdown") and not self._shutdown:
|
||||||
self._shutdown = True
|
self._shutdown = True
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_pin_memory_thread'):
|
if hasattr(self, "_pin_memory_thread"):
|
||||||
self._pin_memory_thread_done_event.set()
|
self._pin_memory_thread_done_event.set()
|
||||||
self._worker_result_queue.put((None, None))
|
self._worker_result_queue.put((None, None))
|
||||||
self._pin_memory_thread.join()
|
self._pin_memory_thread.join()
|
||||||
@ -49,145 +63,292 @@ def _shutdown_workers(self):
|
|||||||
for worker_id in range(len(self._workers)):
|
for worker_id in range(len(self._workers)):
|
||||||
if self._persistent_workers or self._workers_status[worker_id]:
|
if self._persistent_workers or self._workers_status[worker_id]:
|
||||||
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
||||||
for w in self._workers: # pylint: disable=invalid-name
|
for w in self._workers: # pylint: disable=invalid-name
|
||||||
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
|
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
|
||||||
for q in self._index_queues: # pylint: disable=invalid-name
|
for q in self._index_queues: # pylint: disable=invalid-name
|
||||||
q.cancel_join_thread()
|
q.cancel_join_thread()
|
||||||
q.close()
|
q.close()
|
||||||
finally:
|
finally:
|
||||||
if self._worker_pids_set:
|
if self._worker_pids_set:
|
||||||
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
|
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
|
||||||
self._worker_pids_set = False
|
self._worker_pids_set = False
|
||||||
for w in self._workers: # pylint: disable=invalid-name
|
for w in self._workers: # pylint: disable=invalid-name
|
||||||
if w.is_alive():
|
if w.is_alive():
|
||||||
w.terminate()
|
w.terminate()
|
||||||
|
|
||||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
|
||||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
class DummyDataParallel(
|
||||||
|
torch.nn.Module
|
||||||
|
): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||||
|
def __new__(
|
||||||
|
cls, module, device_ids=None, output_device=None, dim=0
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||||
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||||
return module.to("xpu")
|
return module.to("xpu")
|
||||||
|
|
||||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
|
||||||
|
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
|
||||||
def check_device(device):
|
def check_device(device):
|
||||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
return bool(
|
||||||
|
(isinstance(device, torch.device) and device.type == "cuda")
|
||||||
|
or (isinstance(device, str) and "cuda" in device)
|
||||||
|
or isinstance(device, int)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def return_xpu(device):
|
def return_xpu(device):
|
||||||
return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
return (
|
||||||
|
f"xpu:{device[-1]}"
|
||||||
|
if isinstance(device, str) and ":" in device
|
||||||
|
else f"xpu:{device}"
|
||||||
|
if isinstance(device, int)
|
||||||
|
else torch.device("xpu")
|
||||||
|
if isinstance(device, torch.device)
|
||||||
|
else "xpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ipex_no_cuda(orig_func, *args, **kwargs):
|
def ipex_no_cuda(orig_func, *args, **kwargs):
|
||||||
torch.cuda.is_available = lambda: False
|
torch.cuda.is_available = lambda: False
|
||||||
orig_func(*args, **kwargs)
|
orig_func(*args, **kwargs)
|
||||||
torch.cuda.is_available = torch.xpu.is_available
|
torch.cuda.is_available = torch.xpu.is_available
|
||||||
|
|
||||||
|
|
||||||
original_autocast = torch.autocast
|
original_autocast = torch.autocast
|
||||||
|
|
||||||
|
|
||||||
def ipex_autocast(*args, **kwargs):
|
def ipex_autocast(*args, **kwargs):
|
||||||
if len(args) > 0 and args[0] == "cuda":
|
if len(args) > 0 and args[0] == "cuda":
|
||||||
return original_autocast("xpu", *args[1:], **kwargs)
|
return original_autocast("xpu", *args[1:], **kwargs)
|
||||||
else:
|
else:
|
||||||
return original_autocast(*args, **kwargs)
|
return original_autocast(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
original_torch_cat = torch.cat
|
original_torch_cat = torch.cat
|
||||||
|
|
||||||
|
|
||||||
def torch_cat(tensor, *args, **kwargs):
|
def torch_cat(tensor, *args, **kwargs):
|
||||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
if len(tensor) == 3 and (
|
||||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype
|
||||||
|
):
|
||||||
|
return original_torch_cat(
|
||||||
|
[tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return original_torch_cat(tensor, *args, **kwargs)
|
return original_torch_cat(tensor, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
original_interpolate = torch.nn.functional.interpolate
|
original_interpolate = torch.nn.functional.interpolate
|
||||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
|
||||||
|
|
||||||
|
def interpolate(
|
||||||
|
tensor,
|
||||||
|
size=None,
|
||||||
|
scale_factor=None,
|
||||||
|
mode="nearest",
|
||||||
|
align_corners=None,
|
||||||
|
recompute_scale_factor=None,
|
||||||
|
antialias=False,
|
||||||
|
): # pylint: disable=too-many-arguments
|
||||||
if antialias or align_corners is not None:
|
if antialias or align_corners is not None:
|
||||||
return_device = tensor.device
|
return_device = tensor.device
|
||||||
return_dtype = tensor.dtype
|
return_dtype = tensor.dtype
|
||||||
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
return original_interpolate(
|
||||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
|
tensor.to("cpu", dtype=torch.float32),
|
||||||
|
size=size,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
mode=mode,
|
||||||
|
align_corners=align_corners,
|
||||||
|
recompute_scale_factor=recompute_scale_factor,
|
||||||
|
antialias=antialias,
|
||||||
|
).to(return_device, dtype=return_dtype)
|
||||||
else:
|
else:
|
||||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
return original_interpolate(
|
||||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
tensor,
|
||||||
|
size=size,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
mode=mode,
|
||||||
|
align_corners=align_corners,
|
||||||
|
recompute_scale_factor=recompute_scale_factor,
|
||||||
|
antialias=antialias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
original_linalg_solve = torch.linalg.solve
|
original_linalg_solve = torch.linalg.solve
|
||||||
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
||||||
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
||||||
return_device = A.device
|
return_device = A.device
|
||||||
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(
|
||||||
|
return_device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return original_linalg_solve(A, B, *args, **kwargs)
|
return original_linalg_solve(A, B, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def ipex_hijacks():
|
def ipex_hijacks():
|
||||||
CondFunc('torch.Tensor.to',
|
CondFunc(
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
"torch.Tensor.to",
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
lambda orig_func, self, device=None, *args, **kwargs: orig_func(
|
||||||
CondFunc('torch.Tensor.cuda',
|
self, return_xpu(device), *args, **kwargs
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
),
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
lambda orig_func, self, device=None, *args, **kwargs: check_device(device),
|
||||||
CondFunc('torch.empty',
|
)
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
CondFunc(
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
"torch.Tensor.cuda",
|
||||||
CondFunc('torch.load',
|
lambda orig_func, self, device=None, *args, **kwargs: orig_func(
|
||||||
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
|
self, return_xpu(device), *args, **kwargs
|
||||||
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
|
),
|
||||||
CondFunc('torch.randn',
|
lambda orig_func, self, device=None, *args, **kwargs: check_device(device),
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
)
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
CondFunc(
|
||||||
CondFunc('torch.ones',
|
"torch.empty",
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
lambda orig_func, *args, device=None, **kwargs: orig_func(
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
*args, device=return_xpu(device), **kwargs
|
||||||
CondFunc('torch.zeros',
|
),
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
lambda orig_func, *args, device=None, **kwargs: check_device(device),
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
)
|
||||||
CondFunc('torch.tensor',
|
CondFunc(
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
"torch.load",
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
lambda orig_func, *args, map_location=None, **kwargs: orig_func(
|
||||||
CondFunc('torch.linspace',
|
*args, return_xpu(map_location), **kwargs
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
),
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
lambda orig_func, *args, map_location=None, **kwargs: map_location is None
|
||||||
|
or check_device(map_location),
|
||||||
|
)
|
||||||
|
CondFunc(
|
||||||
|
"torch.randn",
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: orig_func(
|
||||||
|
*args, device=return_xpu(device), **kwargs
|
||||||
|
),
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: check_device(device),
|
||||||
|
)
|
||||||
|
CondFunc(
|
||||||
|
"torch.ones",
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: orig_func(
|
||||||
|
*args, device=return_xpu(device), **kwargs
|
||||||
|
),
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: check_device(device),
|
||||||
|
)
|
||||||
|
CondFunc(
|
||||||
|
"torch.zeros",
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: orig_func(
|
||||||
|
*args, device=return_xpu(device), **kwargs
|
||||||
|
),
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: check_device(device),
|
||||||
|
)
|
||||||
|
CondFunc(
|
||||||
|
"torch.tensor",
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: orig_func(
|
||||||
|
*args, device=return_xpu(device), **kwargs
|
||||||
|
),
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: check_device(device),
|
||||||
|
)
|
||||||
|
CondFunc(
|
||||||
|
"torch.linspace",
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: orig_func(
|
||||||
|
*args, device=return_xpu(device), **kwargs
|
||||||
|
),
|
||||||
|
lambda orig_func, *args, device=None, **kwargs: check_device(device),
|
||||||
|
)
|
||||||
|
|
||||||
CondFunc('torch.Generator',
|
CondFunc(
|
||||||
|
"torch.Generator",
|
||||||
lambda orig_func, device=None: torch.xpu.Generator(device),
|
lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||||
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
lambda orig_func, device=None: device is not None
|
||||||
|
and device != torch.device("cpu")
|
||||||
|
and device != "cpu",
|
||||||
|
)
|
||||||
|
|
||||||
CondFunc('torch.batch_norm',
|
CondFunc(
|
||||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
"torch.batch_norm",
|
||||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
|
||||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
input,
|
||||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
weight
|
||||||
CondFunc('torch.instance_norm',
|
if weight is not None
|
||||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
else torch.ones(input.size()[1], device=input.device),
|
||||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
bias
|
||||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
if bias is not None
|
||||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
else torch.zeros(input.size()[1], device=input.device),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"),
|
||||||
|
)
|
||||||
|
CondFunc(
|
||||||
|
"torch.instance_norm",
|
||||||
|
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
|
||||||
|
input,
|
||||||
|
weight
|
||||||
|
if weight is not None
|
||||||
|
else torch.ones(input.size()[1], device=input.device),
|
||||||
|
bias
|
||||||
|
if bias is not None
|
||||||
|
else torch.zeros(input.size()[1], device=input.device),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"),
|
||||||
|
)
|
||||||
|
|
||||||
#Functions with dtype errors:
|
# Functions with dtype errors:
|
||||||
CondFunc('torch.nn.modules.GroupNorm.forward',
|
CondFunc(
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
"torch.nn.modules.GroupNorm.forward",
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
lambda orig_func, self, input: orig_func(
|
||||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
self, input.to(self.weight.data.dtype)
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
),
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
|
||||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
)
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
CondFunc(
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
"torch.nn.modules.linear.Linear.forward",
|
||||||
CondFunc('torch.nn.functional.layer_norm',
|
lambda orig_func, self, input: orig_func(
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
self, input.to(self.weight.data.dtype)
|
||||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
),
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
|
||||||
weight is not None and input.dtype != weight.data.dtype)
|
)
|
||||||
|
CondFunc(
|
||||||
|
"torch.nn.modules.conv.Conv2d.forward",
|
||||||
|
lambda orig_func, self, input: orig_func(
|
||||||
|
self, input.to(self.weight.data.dtype)
|
||||||
|
),
|
||||||
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
|
||||||
|
)
|
||||||
|
CondFunc(
|
||||||
|
"torch.nn.functional.layer_norm",
|
||||||
|
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(
|
||||||
|
input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs
|
||||||
|
),
|
||||||
|
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight
|
||||||
|
is not None
|
||||||
|
and input.dtype != weight.data.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
|
# Diffusers Float64 (ARC GPUs doesn't support double or Float64):
|
||||||
if not torch.xpu.has_fp64_dtype():
|
if not torch.xpu.has_fp64_dtype():
|
||||||
CondFunc('torch.from_numpy',
|
CondFunc(
|
||||||
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
"torch.from_numpy",
|
||||||
lambda orig_func, ndarray: ndarray.dtype == float)
|
lambda orig_func, ndarray: orig_func(ndarray.astype("float32")),
|
||||||
|
lambda orig_func, ndarray: ndarray.dtype == float,
|
||||||
|
)
|
||||||
|
|
||||||
#Broken functions when torch.cuda.is_available is True:
|
# Broken functions when torch.cuda.is_available is True:
|
||||||
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
|
CondFunc(
|
||||||
|
"torch.utils.data.dataloader._BaseDataLoaderIter.__init__",
|
||||||
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
|
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
|
||||||
lambda orig_func, *args, **kwargs: True)
|
lambda orig_func, *args, **kwargs: True,
|
||||||
|
)
|
||||||
|
|
||||||
#Functions that make compile mad with CondFunc:
|
# Functions that make compile mad with CondFunc:
|
||||||
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = (
|
||||||
|
_shutdown_workers
|
||||||
|
)
|
||||||
torch.nn.DataParallel = DummyDataParallel
|
torch.nn.DataParallel = DummyDataParallel
|
||||||
torch.autocast = ipex_autocast
|
torch.autocast = ipex_autocast
|
||||||
torch.cat = torch_cat
|
torch.cat = torch_cat
|
||||||
|
@ -17,12 +17,15 @@ n_gpus = len(hps.gpus.split("-"))
|
|||||||
from random import randint, shuffle
|
from random import randint, shuffle
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from infer.modules.ipex import ipex_init
|
from infer.modules.ipex import ipex_init
|
||||||
from infer.modules.ipex.gradscaler import gradscaler_init
|
from infer.modules.ipex.gradscaler import gradscaler_init
|
||||||
from torch.xpu.amp import autocast
|
from torch.xpu.amp import autocast
|
||||||
|
|
||||||
GradScaler = gradscaler_init()
|
GradScaler = gradscaler_init()
|
||||||
ipex_init()
|
ipex_init()
|
||||||
else:
|
else:
|
||||||
|
@ -288,14 +288,13 @@ class VC:
|
|||||||
tgt_sr,
|
tgt_sr,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = "%s/%s.%s" % (opt_root, os.path.basename(path), format1)
|
path = "%s/%s.%s" % (
|
||||||
|
opt_root,
|
||||||
|
os.path.basename(path),
|
||||||
|
format1,
|
||||||
|
)
|
||||||
with BytesIO() as wavf:
|
with BytesIO() as wavf:
|
||||||
sf.write(
|
sf.write(wavf, audio_opt, tgt_sr, format="wav")
|
||||||
wavf,
|
|
||||||
audio_opt,
|
|
||||||
tgt_sr,
|
|
||||||
format="wav"
|
|
||||||
)
|
|
||||||
wavf.seek(0, 0)
|
wavf.seek(0, 0)
|
||||||
with open(path, "wb") as outf:
|
with open(path, "wb") as outf:
|
||||||
wav2(wavf, outf, format1)
|
wav2(wavf, outf, format1)
|
||||||
|
13
modules.py
13
modules.py
@ -288,14 +288,13 @@ class VC:
|
|||||||
tgt_sr,
|
tgt_sr,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = "%s/%s.%s" % (opt_root, os.path.basename(path), format1)
|
path = "%s/%s.%s" % (
|
||||||
|
opt_root,
|
||||||
|
os.path.basename(path),
|
||||||
|
format1,
|
||||||
|
)
|
||||||
with BytesIO() as wavf:
|
with BytesIO() as wavf:
|
||||||
sf.write(
|
sf.write(wavf, audio_opt, tgt_sr, format="wav")
|
||||||
wavf,
|
|
||||||
audio_opt,
|
|
||||||
tgt_sr,
|
|
||||||
format="wav"
|
|
||||||
)
|
|
||||||
wavf.seek(0, 0)
|
wavf.seek(0, 0)
|
||||||
with open(path, "wb") as outf:
|
with open(path, "wb") as outf:
|
||||||
wav2(wavf, outf, format1)
|
wav2(wavf, outf, format1)
|
||||||
|
@ -357,19 +357,13 @@ class RVC:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
# print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
|
# print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
|
||||||
infered_audio = (
|
infered_audio = self.net_g.infer(
|
||||||
self.net_g.infer(
|
feats, p_len, cache_pitch, cache_pitchf, sid, rate
|
||||||
feats, p_len, cache_pitch, cache_pitchf, sid, rate
|
)[0][0, 0].data.float()
|
||||||
)[0][0, 0]
|
|
||||||
.data
|
|
||||||
.float()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
infered_audio = (
|
infered_audio = self.net_g.infer(feats, p_len, sid, rate)[0][
|
||||||
self.net_g.infer(feats, p_len, sid, rate)[0][0, 0]
|
0, 0
|
||||||
.data
|
].data.float()
|
||||||
.float()
|
|
||||||
)
|
|
||||||
t5 = ttime()
|
t5 = ttime()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Spent time: fea = %.2fs, index = %.2fs, f0 = %.2fs, model = %.2fs",
|
"Spent time: fea = %.2fs, index = %.2fs, f0 = %.2fs, model = %.2fs",
|
||||||
|
@ -3,7 +3,9 @@ from torch.types import Number
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def amp_to_db(x: torch.Tensor, eps=torch.finfo(torch.float64).eps, top_db=40) -> torch.Tensor:
|
def amp_to_db(
|
||||||
|
x: torch.Tensor, eps=torch.finfo(torch.float64).eps, top_db=40
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert the input tensor from amplitude to decibel scale.
|
Convert the input tensor from amplitude to decibel scale.
|
||||||
|
|
||||||
@ -40,7 +42,9 @@ def temperature_sigmoid(x: torch.Tensor, x0: float, temp_coeff: float) -> torch.
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def linspace(start: Number, stop: Number, num: int = 50, endpoint: bool = True, **kwargs) -> torch.Tensor:
|
def linspace(
|
||||||
|
start: Number, stop: Number, num: int = 50, endpoint: bool = True, **kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Generate a linearly spaced 1-D tensor.
|
Generate a linearly spaced 1-D tensor.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user