fix mps in gui-v1.py (#769)
* Fix mps on realtime * Added back repeat chs
This commit is contained in:
parent
2e0dfeec50
commit
2b3fe8cf1b
32
gui_v1.py
32
gui_v1.py
@ -1,5 +1,8 @@
|
||||
import os, sys
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
import multiprocessing
|
||||
@ -45,7 +48,7 @@ if __name__ == "__main__":
|
||||
from i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
|
||||
current_dir = os.getcwd()
|
||||
inp_q = Queue()
|
||||
opt_q = Queue()
|
||||
@ -441,8 +444,9 @@ if __name__ == "__main__":
|
||||
"""
|
||||
接受音频输入
|
||||
"""
|
||||
channels = 1 if sys.platform == "darwin" else 2
|
||||
with sd.Stream(
|
||||
channels=2,
|
||||
channels=channels,
|
||||
callback=self.audio_callback,
|
||||
blocksize=self.block_frame,
|
||||
samplerate=self.config.samplerate,
|
||||
@ -524,6 +528,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
+ 1e-8
|
||||
)
|
||||
if sys.platform == "darwin":
|
||||
cor_nom = cor_nom.cpu()
|
||||
cor_den = cor_den.cpu()
|
||||
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
||||
print("sola offset: " + str(int(sola_offset)))
|
||||
self.output_wav[:] = infer_wav[sola_offset : sola_offset + self.block_frame]
|
||||
@ -545,14 +552,23 @@ if __name__ == "__main__":
|
||||
infer_wav[-self.crossfade_frame :] * self.fade_out_window
|
||||
)
|
||||
if self.config.O_noise_reduce:
|
||||
outdata[:] = np.tile(
|
||||
nr.reduce_noise(
|
||||
if sys.platform == "darwin":
|
||||
noise_reduced_signal = nr.reduce_noise(
|
||||
y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate
|
||||
),
|
||||
(2, 1),
|
||||
).T
|
||||
)
|
||||
outdata[:] = noise_reduced_signal[:, np.newaxis]
|
||||
else:
|
||||
outdata[:] = np.tile(
|
||||
nr.reduce_noise(
|
||||
y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate
|
||||
),
|
||||
(2, 1),
|
||||
).T
|
||||
else:
|
||||
outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy()
|
||||
if sys.platform == "darwin":
|
||||
outdata[:] = self.output_wav[:].cpu().numpy()[:, np.newaxis]
|
||||
else:
|
||||
outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy()
|
||||
total_time = time.perf_counter() - start_time
|
||||
self.window["infer_time"].update(int(total_time * 1000))
|
||||
print("infer time:" + str(total_time))
|
||||
|
Loading…
Reference in New Issue
Block a user