diff --git a/config.py b/config.py index d978577..0261b02 100644 --- a/config.py +++ b/config.py @@ -148,39 +148,28 @@ class Config: x_max = 32 if self.dml: print("use DirectML instead") - try: - os.rename( - "runtime\Lib\site-packages\onnxruntime", - "runtime\Lib\site-packages\onnxruntime-cuda", - ) - except: - pass - try: - os.rename( - "runtime\Lib\site-packages\onnxruntime-dml", - "runtime\Lib\site-packages\onnxruntime", - ) - except: - pass + if(os.path.exists("runtime\Lib\site-packages\onnxruntime\capi\DirectML.dll")==False): + try: + os.rename("runtime\Lib\site-packages\onnxruntime", "runtime\Lib\site-packages\onnxruntime-cuda") + except: + pass + try: + os.rename("runtime\Lib\site-packages\onnxruntime-dml", "runtime\Lib\site-packages\onnxruntime") + except: + pass import torch_directml - self.device = torch_directml.device(torch_directml.default_device()) self.is_half = False else: if self.instead: print(f"use {self.instead} instead") - try: - os.rename( - "runtime\Lib\site-packages\onnxruntime", - "runtime\Lib\site-packages\onnxruntime-dml", - ) - except: - pass - try: - os.rename( - "runtime\Lib\site-packages\onnxruntime-cuda", - "runtime\Lib\site-packages\onnxruntime", - ) - except: - pass + if(os.path.exists("runtime\Lib\site-packages\onnxruntime\capi\onnxruntime_providers_cuda.dll")==False): + try: + os.rename("runtime\Lib\site-packages\onnxruntime", "runtime\Lib\site-packages\onnxruntime-dml") + except: + pass + try: + os.rename("runtime\Lib\site-packages\onnxruntime-cuda", "runtime\Lib\site-packages\onnxruntime") + except: + pass return x_pad, x_query, x_center, x_max