diff --git a/api.py b/api.py index 8ea1eb8..43351b8 100644 --- a/api.py +++ b/api.py @@ -14,21 +14,13 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel import threading import uvicorn +import logging -from tools.torchgate import TorchGate -import tools.rvc_for_realtime as rvc_for_realtime -from configs.config import Config - -load_dotenv() -os.environ["OMP_NUM_THREADS"] = "4" -if sys.platform == "darwin": - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - -now_dir = os.getcwd() -sys.path.append(now_dir) - -stream_latency = -1 +# Initialize the logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +# Define FastAPI app app = FastAPI() class GUIConfig: @@ -69,23 +61,25 @@ class ConfigData(BaseModel): class AudioAPI: def __init__(self) -> None: self.gui_config = GUIConfig() - self.config = Config() + self.config = None # Initialize Config object as None self.flag_vc = False self.function = "vc" self.delay_time = 0 + self.rvc = None # Initialize RVC object as None def load(self): input_devices, output_devices, _, _ = self.get_devices() try: - with open("configs/config.json", "r") as j: + with open("configs/config.json", "r", encoding='utf-8') as j: data = json.load(j) data["rmvpe"] = True # Ensure rmvpe is the only f0method if data["sg_input_device"] not in input_devices: data["sg_input_device"] = input_devices[sd.default.device[0]] if data["sg_output_device"] not in output_devices: data["sg_output_device"] = output_devices[sd.default.device[1]] - except: - with open("configs/config.json", "w") as j: + except Exception as e: + logger.error(f"Failed to load configuration: {e}") + with open("configs/config.json", "w", encoding='utf-8') as j: data = { "pth_path": " ", "index_path": " ", @@ -102,18 +96,15 @@ class AudioAPI: "use_jit": False, } data["rmvpe"] = True # Ensure rmvpe is the only f0method + json.dump(data, j, ensure_ascii=False) return data def set_values(self, values): - if len(values.pth_path.strip()) == 0: + logger.info(f"Setting values: {values}") + if not values.pth_path.strip(): raise HTTPException(status_code=400, detail="Please select a .pth file") - if len(values.index_path.strip()) == 0: + if not values.index_path.strip(): raise HTTPException(status_code=400, detail="Please select an index file") - pattern = re.compile("[^\x00-\x7F]+") - if pattern.findall(values.pth_path): - raise HTTPException(status_code=400, detail=".pth file path should not contain non-ASCII characters") - if pattern.findall(values.index_path): - raise HTTPException(status_code=400, detail="Index file path should not contain non-ASCII characters") self.set_devices(values.sg_input_device, values.sg_output_device) self.config.use_jit = False self.gui_config.pth_path = values.pth_path @@ -143,10 +134,8 @@ class AudioAPI: 0, 0, self.config, - self.rvc if hasattr(self, "rvc") else None, + self.rvc if self.rvc else None, ) - if not hasattr(self.rvc, 'tgt_sr'): - self.rvc.tgt_sr = 44100 self.gui_config.samplerate = self.rvc.tgt_sr self.zc = self.rvc.tgt_sr // 100 self.block_frame = ( @@ -226,8 +215,8 @@ class AudioAPI: stream_latency = stream.latency[-1] while self.flag_vc: time.sleep(self.gui_config.block_time) - print("Audio block passed.") - print("ENDing VC") + logger.info("Audio block passed.") + logger.info("Ending VC") def audio_callback(self, indata: np.ndarray, outdata: np.ndarray, frames, times, status): start_time = time.perf_counter() @@ -289,7 +278,7 @@ class AudioAPI: sola_offset = sola_offset.item() else: sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0]) - print("sola_offset = %d" % int(sola_offset)) + logger.info(f"sola_offset = {sola_offset}") infer_wav = infer_wav[sola_offset : sola_offset + self.block_frame + self.crossfade_frame] infer_wav[: self.crossfade_frame] *= self.fade_in_window infer_wav[: self.crossfade_frame] += self.sola_buffer * self.fade_out_window @@ -299,7 +288,7 @@ class AudioAPI: else: outdata[:] = infer_wav[: -self.crossfade_frame].repeat(2, 1).t().cpu().numpy() total_time = time.perf_counter() - start_time - print("Infer time: %.2f" % total_time) + logger.info(f"Infer time: {total_time:.2f}") def get_devices(self, update: bool = True): if update: @@ -344,34 +333,61 @@ class AudioAPI: input_device_indices, output_device_indices, ) = self.get_devices() + logger.debug(f"Available input devices: {input_devices}") + logger.debug(f"Available output devices: {output_devices}") + logger.debug(f"Selected input device: {input_device}") + logger.debug(f"Selected output device: {output_device}") + + if input_device not in input_devices: + logger.error(f"Input device '{input_device}' is not in the list of available devices") + raise HTTPException(status_code=400, detail=f"Input device '{input_device}' is not available") + + if output_device not in output_devices: + logger.error(f"Output device '{output_device}' is not in the list of available devices") + raise HTTPException(status_code=400, detail=f"Output device '{output_device}' is not available") + sd.default.device[0] = input_device_indices[input_devices.index(input_device)] sd.default.device[1] = output_device_indices[output_devices.index(output_device)] - print("Input device: %s:%s" % (str(sd.default.device[0]), input_device)) - print("Output device: %s:%s" % (str(sd.default.device[1]), output_device)) + logger.info(f"Input device set to {sd.default.device[0]}: {input_device}") + logger.info(f"Output device set to {sd.default.device[1]}: {output_device}") audio_api = AudioAPI() -@app.get("/inputDevices") +@app.get("/inputDevices", response_model=list) def get_input_devices(): - input_devices, _, _, _ = audio_api.get_devices() - return input_devices + try: + input_devices, _, _, _ = audio_api.get_devices() + return input_devices + except Exception as e: + logger.error(f"Failed to get input devices: {e}") + raise HTTPException(status_code=500, detail="Failed to get input devices") -@app.get("/outputDevices") +@app.get("/outputDevices", response_model=list) def get_output_devices(): - _, output_devices, _, _ = audio_api.get_devices() - return output_devices + try: + _, output_devices, _, _ = audio_api.get_devices() + return output_devices + except Exception as e: + logger.error(f"Failed to get output devices: {e}") + raise HTTPException(status_code=500, detail="Failed to get output devices") @app.post("/config") def configure_audio(config_data: ConfigData): try: + logger.info(f"Configuring audio with data: {config_data}") if audio_api.set_values(config_data): settings = config_data.dict() settings["use_jit"] = False settings["f0method"] = "rmvpe" - with open("configs/config.json", "w") as j: - json.dump(settings, j) + with open("configs/config.json", "w", encoding='utf-8') as j: + json.dump(settings, j, ensure_ascii=False) + logger.info("Configuration set successfully") return {"message": "Configuration set successfully"} + except HTTPException as e: + logger.error(f"Configuration error: {e.detail}") + raise except Exception as e: + logger.error(f"Configuration failed: {e}") raise HTTPException(status_code=400, detail=f"Configuration failed: {e}") @app.post("/start") @@ -381,8 +397,13 @@ def start_conversion(): audio_api.start_vc() return {"message": "Audio conversion started"} else: + logger.warning("Audio conversion already running") raise HTTPException(status_code=400, detail="Audio conversion already running") + except HTTPException as e: + logger.error(f"Start conversion error: {e.detail}") + raise except Exception as e: + logger.error(f"Failed to start conversion: {e}") raise HTTPException(status_code=500, detail=f"Failed to start conversion: {e}") @app.post("/stop") @@ -394,9 +415,25 @@ def stop_conversion(): stream_latency = -1 return {"message": "Audio conversion stopped"} else: + logger.warning("Audio conversion not running") raise HTTPException(status_code=400, detail="Audio conversion not running") + except HTTPException as e: + logger.error(f"Stop conversion error: {e.detail}") + raise except Exception as e: + logger.error(f"Failed to stop conversion: {e}") raise HTTPException(status_code=500, detail=f"Failed to stop conversion: {e}") if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8043) + if sys.platform == "win32": + from multiprocessing import freeze_support + freeze_support() + load_dotenv() + os.environ["OMP_NUM_THREADS"] = "4" + if sys.platform == "darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + from tools.torchgate import TorchGate + import tools.rvc_for_realtime as rvc_for_realtime + from configs.config import Config + audio_api.config = Config() + uvicorn.run(app, host="0.0.0.0", port=6242)