diff --git a/UVR.py b/UVR.py index 1dd5017..5460415 100644 --- a/UVR.py +++ b/UVR.py @@ -44,7 +44,11 @@ from lib_v5.vr_network.model_param_init import ModelParameters from kthread import KThread from lib_v5 import spec_utils from pathlib import Path -from separate import SeperateDemucs, SeperateMDX, SeperateMDXC, SeperateVR, save_format +from separate import ( + SeperateDemucs, SeperateMDX, SeperateMDXC, SeperateVR, # Model-related + save_format, clear_gpu_cache, # Utility functions + cuda_available, directml_available, mps_available +) from playsound import playsound from typing import List import onnx @@ -53,9 +57,15 @@ import sys import yaml from ml_collections import ConfigDict from collections import Counter -import torch_directml -is_gpu_available = torch.backends.mps.is_available() if (OPERATING_SYSTEM == 'Darwin') else (torch_directml.is_available() or torch.cuda.is_available()) +if not is_macos: + import torch_directml + +is_choose_arch = cuda_available and directml_available +is_opencl_only = not cuda_available and directml_available +is_cuda_only = cuda_available and not directml_available + +is_gpu_available = cuda_available or directml_available or mps_available # Change the current working directory to the directory # this file sits in @@ -112,8 +122,6 @@ elif OPERATING_SYSTEM=="Windows": right_click_button = '' application_extension = ".exe" -clear_gpu_cache = torch.mps.empty_cache if is_macos else torch.cuda.empty_cache - def right_click_release_linux(window, top_win=None): if OPERATING_SYSTEM=="Linux": root.bind('', lambda e:window.destroy()) @@ -330,13 +338,15 @@ class ModelData(): is_get_hash_dir_only=False, is_vocal_split_model=False): + device_set = root.device_set_var.get() self.DENOISER_MODEL = DENOISER_MODEL_PATH self.DEVERBER_MODEL = DEVERBER_MODEL_PATH self.is_deverb_vocals = root.is_deverb_vocals_var.get() if os.path.isfile(DEVERBER_MODEL_PATH) else False self.deverb_vocal_opt = DEVERB_MAPPER[root.deverb_vocal_opt_var.get()] self.is_denoise_model = True if root.denoise_option_var.get() == DENOISE_M and os.path.isfile(DENOISER_MODEL_PATH) else False self.is_gpu_conversion = 0 if root.is_gpu_conversion_var.get() else -1 - self.is_normalization = root.is_normalization_var.get() + self.is_normalization = root.is_normalization_var.get()# + self.is_use_opencl = True if is_opencl_only else root.is_use_opencl_var.get() self.is_primary_stem_only = root.is_primary_stem_only_var.get() self.is_secondary_stem_only = root.is_secondary_stem_only_var.get() self.is_denoise = True if not root.denoise_option_var.get() == DENOISE_NONE else False @@ -359,7 +369,8 @@ class ModelData(): self.mdx_stem_count = 1 self.compensate = None self.mdx_n_fft_scale_set = None - self.wav_type_set = root.wav_type_set + self.wav_type_set = root.wav_type_set# + self.device_set = device_set.split(':')[-1].strip() if ':' in device_set else device_set self.mp3_bit_set = root.mp3_bit_set_var.get() self.save_format = root.save_format_var.get() self.is_invert_spec = root.is_invert_spec_var.get()# @@ -1284,7 +1295,6 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): style.configure('TCombobox', selectbackground='#0c0c0c') #style.configure('TCheckbutton', indicatorsize=30) - # Calculate window height height = self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT height += self.CONVERSIONBUTTON_HEIGHT + self.COMMAND_HEIGHT + self.PROGRESS_HEIGHT @@ -1315,7 +1325,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): #Load Images img = ImagePath(BASE_PATH) - self.logo_img = img.open_image(path=img.banner_path, size=(width, 9999)) + self.logo_img = img.open_image(path=img.banner_path, size=(width, height)) self.efile_img = img.efile_img self.stop_img = img.stop_img self.help_img = img.help_img @@ -1490,6 +1500,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.change_state_lambda = None self.file_one_sub_var = tk.StringVar(value=FILE_ONE_MAIN_LABEL) self.file_two_sub_var = tk.StringVar(value=FILE_TWO_MAIN_LABEL) + self.cuda_device_list = GPU_DEVICE_NUM_OPTS + self.opencl_list = GPU_DEVICE_NUM_OPTS #Model Update self.last_found_ensembles = ENSEMBLE_OPTIONS @@ -1516,6 +1528,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.check_dual_paths() if not is_windows: self.update_idletasks() + self.fill_gpu_list() self.online_data_refresh(user_refresh=False, is_start_up=True) # Menu Functions @@ -2110,6 +2123,12 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): if not is_macos: self.bind("", self.adjust_toplevel_positions) + def auto_save(self): + try: + self.save_values(app_close=False, is_auto_save=True) + except Exception as e: + print(e) + #--Input/Export Methods-- def linux_filebox_fix(self, is_on=True): @@ -3149,6 +3168,30 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): return self.DualBatch_inputPaths + def fill_gpu_list(self): + try: + if cuda_available: + self.cuda_device_list = [f"{torch.cuda.get_device_properties(i).name}:{i}" for i in range(torch.cuda.device_count())] + self.cuda_device_list.insert(0, DEFAULT) + print(self.cuda_device_list) + + if directml_available: + self.opencl_list = [f"{torch_directml.device_name(i)}:{i}" for i in range(torch_directml.device_count())] + self.opencl_list.insert(0, DEFAULT) + except Exception as e: + print(e) + + if is_cuda_only: + self.is_use_opencl_var.set(False) + + check_gpu_list = self.opencl_list if is_opencl_only or self.is_use_opencl_var.get() else self.cuda_device_list + if not self.device_set_var.get() in check_gpu_list: + self.device_set_var.set(DEFAULT) + + def loop_gpu_list(self, option_menu:ComboBoxMenu, menu_name, option_list): + option_menu['values'] = option_list + option_menu.update_dropdown_size(option_list, menu_name) + def menu_settings(self, select_tab_2=False, select_tab_3=False):#** """Open Settings and Download Center""" @@ -3202,7 +3245,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): select_Label = self.menu_sub_LABEL_SET(settings_menu_main_Frame, ADDITIONAL_MENUS_INFORMATION_TEXT) select_Label.grid(pady=MENU_PADDING_1) - select_Option = ComboBoxMenu(settings_menu_main_Frame, textvariable=self.main_menu_var, values=OPTION_LIST, width=GEN_SETTINGS_WIDTH+(3 if is_windows else 3)) + select_Option = ComboBoxMenu(settings_menu_main_Frame, textvariable=self.main_menu_var, values=OPTION_LIST, width=GEN_SETTINGS_WIDTH+3) select_Option.update_dropdown_size(OPTION_LIST, 'menuchoose', command=lambda e:(self.check_is_menu_open(self.main_menu_var.get()), close_window())) select_Option.grid(pady=MENU_PADDING_1) @@ -3223,7 +3266,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): delete_your_settings_Label.grid(pady=MENU_PADDING_2) self.help_hints(delete_your_settings_Label, text=DELETE_YOUR_SETTINGS_HELP) - delete_your_settings_Option = ComboBoxMenu(settings_menu_main_Frame, textvariable=option_var, width=GEN_SETTINGS_WIDTH+(3 if is_windows else 3)) + delete_your_settings_Option = ComboBoxMenu(settings_menu_main_Frame, textvariable=option_var, width=GEN_SETTINGS_WIDTH+3) delete_your_settings_Option.grid(padx=20,pady=MENU_PADDING_1) self.deletion_list_fill(delete_your_settings_Option, option_var, SETTINGS_CACHE_DIR, SELECT_SAVED_SETTING, menu_name='deletesetting') @@ -3261,7 +3304,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): mp3_bit_set_Option = ComboBoxMenu(settings_menu_format_Frame, textvariable=self.mp3_bit_set_var, values=MP3_BIT_RATES, width=HELP_HINT_CHECKBOX_WIDTH) mp3_bit_set_Option.grid(padx=20,pady=MENU_PADDING_1) - + audio_format_title_Label = self.menu_title_LABEL_SET(settings_menu_format_Frame, GENERAL_PROCESS_SETTINGS_TEXT) audio_format_title_Label.grid(pady=MENU_PADDING_2) @@ -3291,9 +3334,29 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): change_model_default_Button = ttk.Button(settings_menu_format_Frame, text=CHANGE_MODEL_DEFAULTS_TEXT, command=lambda:self.pop_up_change_model_defaults(settings_menu), width=SETTINGS_BUT_WIDTH-2)# change_model_default_Button.grid(pady=MENU_PADDING_4) - - self.vocal_splitter_Button_opt(settings_menu, settings_menu_format_Frame, width=SETTINGS_BUT_WIDTH-2, pady=MENU_PADDING_4) - + + if not is_choose_arch: + self.vocal_splitter_Button_opt(settings_menu, settings_menu_format_Frame, width=SETTINGS_BUT_WIDTH-2, pady=MENU_PADDING_4) + + if not is_macos and self.is_gpu_available: + gpu_list_options = lambda:self.loop_gpu_list(device_set_Option, 'gpudevice', self.opencl_list if is_opencl_only or self.is_use_opencl_var.get() else self.cuda_device_list) + device_set_Label = self.menu_title_LABEL_SET(settings_menu_format_Frame, CUDA_NUM_TEXT) + device_set_Label.grid(pady=MENU_PADDING_2) + + device_set_Option = ComboBoxMenu(settings_menu_format_Frame, textvariable=self.device_set_var, values=GPU_DEVICE_NUM_OPTS, width=GEN_SETTINGS_WIDTH+1) + device_set_Option.grid(padx=20,pady=MENU_PADDING_1) + gpu_list_options() + self.help_hints(device_set_Label, text=IS_CUDA_SELECT_HELP) + + if is_choose_arch: + is_use_opencl_Option = ttk.Checkbutton(settings_menu_format_Frame, + text=USE_OPENCL_TEXT, + width=9, + variable=self.is_use_opencl_var, + command=lambda:(gpu_list_options(), self.device_set_var.set(DEFAULT))) + is_use_opencl_Option.grid() + self.help_hints(is_use_opencl_Option, text=IS_NORMALIZATION_HELP) + model_sample_mode_Label = self.menu_title_LABEL_SET(settings_menu_format_Frame, MODEL_SAMPLE_MODE_SETTINGS_TEXT) model_sample_mode_Label.grid(pady=MENU_PADDING_2) @@ -5131,7 +5194,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.download_update_link_var.set(UPDATE_LINUX_REPO) if not user_refresh: - if not is_beta_version and not self.lastest_version == PATCH: + if not is_beta_version and not self.lastest_version == current_patch: self.command_Text.write(NEW_UPDATE_FOUND_TEXT(self.lastest_version)) @@ -5481,12 +5544,11 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): def update_loop(self): """Update the model dropdown menus""" - + if self.clear_cache_torch: - #self.set_app_font(is_chosen_font=True) clear_gpu_cache() self.clear_cache_torch = False - + if self.is_process_stopped: if self.thread_check(self.active_processing_thread): self.conversion_Button_Text_var.set(STOP_PROCESSING) @@ -5499,7 +5561,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.progress_bar_main_var.set(0) clear_gpu_cache() self.is_process_stopped = False - + if self.is_confirm_error_var.get(): self.check_is_menu_open(ERROR_OPTION) self.is_confirm_error_var.set(False) @@ -5513,6 +5575,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): close_process(self.msg_queue) self.is_check_splash = False + #self.auto_save() + self.update_available_models() self.after(600, self.update_loop) @@ -5556,7 +5620,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): option_menu['values'] = option_list_option_menu option_menu.set(current_selection) - option_menu.update_dropdown_size(model_list, model_type) + option_menu.update_dropdown_size(option_list, model_type) if self.is_root_defined_var.get() and model_type == MDX_ARCH_TYPE and self.chosen_process_method_var.get() == MDX_ARCH_TYPE: self.selection_action_models_sub(current_selection, model_type, option_var) @@ -6152,6 +6216,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.active_processing_thread.start() def process_button_init(self): + self.auto_save() self.conversion_Button_Text_var.set(WAIT_PROCESSING) self.conversion_Button.configure(state=tk.DISABLED) self.command_Text.clear() @@ -6182,6 +6247,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): def confirm_stop_process(self): """Asks for confirmation before halting active process""" + self.auto_save() + if self.thread_check(self.active_processing_thread): confirm = messagebox.askyesno(parent=root, title=STOP_PROCESS_CONFIRM[0], message=STOP_PROCESS_CONFIRM[1]) @@ -6197,6 +6264,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): def process_end(self, error=None): """End of process actions""" + self.auto_save() self.cached_sources_clear() self.clear_cache_torch = True self.conversion_Button_Text_var.set(START_PROCESSING) @@ -6753,7 +6821,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.semitone_shift_var = tk.StringVar(value=data['semitone_shift']) self.mp3_bit_set_var = tk.StringVar(value=data['mp3_bit_set']) self.save_format_var = tk.StringVar(value=data['save_format']) - self.wav_type_set_var = tk.StringVar(value=data['wav_type_set']) + self.wav_type_set_var = tk.StringVar(value=data['wav_type_set'])# + self.device_set_var = tk.StringVar(value=data['device_set'])# self.user_code_var = tk.StringVar(value=data['user_code']) self.is_gpu_conversion_var = tk.BooleanVar(value=data['is_gpu_conversion']) self.is_primary_stem_only_var = tk.BooleanVar(value=data['is_primary_stem_only']) @@ -6765,6 +6834,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.is_accept_any_input_var = tk.BooleanVar(value=data['is_accept_any_input']) self.is_task_complete_var = tk.BooleanVar(value=data['is_task_complete']) self.is_normalization_var = tk.BooleanVar(value=data['is_normalization'])# + self.is_use_opencl_var = tk.BooleanVar(value=True if is_opencl_only else data['is_use_opencl'])# self.is_wav_ensemble_var = tk.BooleanVar(value=data['is_wav_ensemble'])# self.is_create_model_folder_var = tk.BooleanVar(value=data['is_create_model_folder']) self.help_hints_var = tk.BooleanVar(value=data['help_hints_var']) @@ -6895,7 +6965,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.mp3_bit_set_var.set(loaded_setting['mp3_bit_set']) self.semitone_shift_var.set(loaded_setting['semitone_shift'])# self.save_format_var.set(loaded_setting['save_format']) - self.wav_type_set_var.set(loaded_setting['wav_type_set']) + self.wav_type_set_var.set(loaded_setting['wav_type_set'])# + self.device_set_var.set(loaded_setting['device_set'])# self.user_code_var.set(loaded_setting['user_code']) self.phase_option_var.set(loaded_setting['phase_option'])# self.phase_shifts_var.set(loaded_setting['phase_shifts'])# @@ -6913,6 +6984,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.is_gpu_conversion_var.set(loaded_setting['is_gpu_conversion']) self.is_normalization_var.set(loaded_setting['is_normalization'])# + self.is_use_opencl_var.set(True if is_opencl_only else loaded_setting['is_use_opencl'])# self.is_wav_ensemble_var.set(loaded_setting['is_wav_ensemble'])# self.help_hints_var.set(loaded_setting['help_hints_var']) self.is_wav_ensemble_var.set(loaded_setting['is_wav_ensemble']) @@ -6928,7 +7000,7 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.model_sample_mode_duration_checkbox_var.set(SAMPLE_MODE_CHECKBOX(self.model_sample_mode_duration_var.get())) self.model_sample_mode_duration_label_var.set(f'{self.model_sample_mode_duration_var.get()} Seconds') - def save_values(self, app_close=True, is_restart=False): + def save_values(self, app_close=True, is_restart=False, is_auto_save=False): """Saves application data""" # -Save Data- @@ -7023,12 +7095,14 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): 'is_accept_any_input': self.is_accept_any_input_var.get(), 'is_task_complete': self.is_task_complete_var.get(), 'is_normalization': self.is_normalization_var.get(),# + 'is_use_opencl': self.is_use_opencl_var.get(),# 'is_wav_ensemble': self.is_wav_ensemble_var.get(),# 'is_create_model_folder': self.is_create_model_folder_var.get(), 'mp3_bit_set': self.mp3_bit_set_var.get(), 'semitone_shift': self.semitone_shift_var.get(),# 'save_format': self.save_format_var.get(), - 'wav_type_set': self.wav_type_set_var.get(), + 'wav_type_set': self.wav_type_set_var.get(),# + 'device_set': self.device_set_var.get(),# 'user_code': self.user_code_var.get(), 'help_hints_var': self.help_hints_var.get(), 'set_vocal_splitter': self.set_vocal_splitter_var.get(), @@ -7084,6 +7158,8 @@ class MainWindow(TkinterDnD.Tk if is_dnd_compatible else tk.Tk): self.destroy() + elif is_auto_save: + save_data(data={**main_settings, **other_data}) else: return {**main_settings, **user_saved_extras} diff --git a/separate.py b/separate.py index 7bfe921..08c01e5 100644 --- a/separate.py +++ b/separate.py @@ -31,19 +31,40 @@ import math #import random from onnx import load from onnx2pytorch import ConvertModel -import torch_directml import gc - -cpu = 'cpu' - + if TYPE_CHECKING: from UVR import ModelData +if not is_macos: + import torch_directml + +mps_available = torch.backends.mps.is_available() if is_macos else False +cuda_available = torch.cuda.is_available() + +def get_gpu_info(): + directml_device, directml_available = DIRECTML_DEVICE, False + + if not is_macos: + directml_available = torch_directml.is_available() + + if directml_available: + directml_device = str(torch_directml.device()).partition(":")[0] + + return directml_device, directml_available + +DIRECTML_DEVICE, directml_available = get_gpu_info() + +def clear_gpu_cache(): + gc.collect() + if is_macos: + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + warnings.filterwarnings("ignore") cpu = torch.device('cpu') -clear_gpu_cache = torch.mps.empty_cache if is_macos else torch.cuda.empty_cache - class SeperateAttributes: def __init__(self, model_data: ModelData, process_data: dict, @@ -123,7 +144,7 @@ class SeperateAttributes: self.main_model_primary = main_model_primary self.ensemble_primary_stem = model_data.ensemble_primary_stem self.is_multi_stem_ensemble = model_data.is_multi_stem_ensemble - self.is_mps = False + self.is_other_gpu = False self.is_deverb = True self.DENOISER_MODEL = model_data.DENOISER_MODEL self.DEVERBER_MODEL = model_data.DEVERBER_MODEL @@ -146,6 +167,8 @@ class SeperateAttributes: self.device = cpu self.run_type = ['CPUExecutionProvider'] self.is_opencl = False + self.device_set = model_data.device_set + self.is_use_opencl = model_data.is_use_opencl if self.is_inst_only_voc_splitter or self.is_sec_bv_rebalance: self.is_primary_stem_only = False @@ -155,12 +178,19 @@ class SeperateAttributes: self.primary_stem, self.secondary_stem = main_model_primary, secondary_stem(main_model_primary) if self.is_gpu_conversion >= 0: - if OPERATING_SYSTEM == 'Darwin' and torch.backends.mps.is_available(): - self.device, self.is_mps = 'mps', True - elif torch_directml.is_available(): - self.device, self.is_mps, self.is_opencl = torch_directml.device(), True, True - elif torch.cuda.is_available(): - self.device, self.run_type = 'cuda:0', ['CUDAExecutionProvider'] + if mps_available: + self.device, self.is_other_gpu = 'mps', True + else: + device_prefix = None + if self.device_set != DEFAULT: + device_prefix = DIRECTML_DEVICE if self.is_use_opencl and directml_available else CUDA_DEVICE + + if directml_available and self.is_use_opencl: + self.device = torch_directml.device() if not device_prefix else f'{device_prefix}:{self.device_set}' + self.is_other_gpu = True + elif cuda_available and not self.is_use_opencl: + self.device = CUDA_DEVICE if not device_prefix else f'{device_prefix}:{self.device_set}' + self.run_type = ['CUDAExecutionProvider'] if model_data.process_method == MDX_ARCH_TYPE: self.is_mdx_ckpt = model_data.is_mdx_ckpt @@ -173,7 +203,6 @@ class SeperateAttributes: self.mdx_segment_size = model_data.mdx_segment_size if self.is_mdx_c: - self.device = cpu if self.is_opencl else self.device if not self.is_4_stem_ensemble: self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem @@ -200,7 +229,7 @@ class SeperateAttributes: self.is_demucs_combine_stems = model_data.is_demucs_combine_stems self.demucs_stem_count = model_data.demucs_stem_count self.pre_proc_model = model_data.pre_proc_model - self.device = cpu if self.is_mps and not self.demucs_version == DEMUCS_V4 else self.device + self.device = cpu if self.is_other_gpu and not self.demucs_version in [DEMUCS_V3, DEMUCS_V4] else self.device self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem @@ -457,7 +486,7 @@ class SeperateMDX(SeperateAttributes): separator = MdxnetSet.ConvTDFNet(**model_params) self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval() else: - if self.mdx_segment_size == self.dim_t and not self.is_mps: + if self.mdx_segment_size == self.dim_t and not self.is_other_gpu: ort_ = ort.InferenceSession(self.model_path, providers=self.run_type) self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0] else: @@ -708,8 +737,9 @@ class SeperateMDXC(SeperateAttributes): if self.is_pitch_change: mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift) - model = TFC_TDF_net(self.mdx_c_configs, device=self.device).eval().to(self.device) + model = TFC_TDF_net(self.mdx_c_configs, device=self.device) model.load_state_dict(torch.load(self.model_path, map_location=cpu)) + model.to(self.device).eval() mix = torch.tensor(mix, dtype=torch.float32) try: @@ -720,49 +750,37 @@ class SeperateMDXC(SeperateAttributes): mdx_segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size batch_size = self.mdx_batch_size - C = self.mdx_c_configs.audio.hop_length * (mdx_segment_size - 1) - N = self.overlap_mdx23 + chunk_size = self.mdx_c_configs.audio.hop_length * (mdx_segment_size - 1) + overlap = self.overlap_mdx23 - H = C // N - L = mix.shape[1] - pad_size = H - (L - C) % H - mix = torch.cat([torch.zeros(2, C - H), mix, torch.zeros(2, pad_size + C - H)], 1) + hop_size = chunk_size // overlap + mix_shape = mix.shape[1] + pad_size = hop_size - (mix_shape - chunk_size) % hop_size + mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1) mix = mix.to(self.device) - chunks = [] - i = 0 - while i + C <= mix.shape[1]: - chunks.append(mix[:, i:i + C]) - i += H - chunks = torch.stack(chunks) - - batches = [] - i = 0 - while i < len(chunks): - batches.append(chunks[i:i + batch_size]) - i = i + batch_size - - X = torch.zeros(S, 2, C - H) if S > 1 else torch.zeros(2, C - H) - X = X.to(self.device) - - #with torch.cuda.amp.autocast(): + chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1) + batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)] + + X = torch.zeros(S, *mix.shape).to(self.device) if S > 1 else torch.zeros_like(mix).to(self.device) + with torch.no_grad(): + cnt = 0 for batch in batches: self.running_inference_progress_bar(len(batches)) x = model(batch) + for w in x: - a = X[..., :-(C - H)] - b = X[..., -(C - H):] + w[..., :(C - H)] - c = w[..., (C - H):] - X = torch.cat([a, b, c], -1) - - estimated_sources = X[..., C - H:-(pad_size + C - H)] / N + X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w + cnt += 1 + estimated_sources = X[..., chunk_size - hop_size:-(pad_size + chunk_size - hop_size)] / overlap + del X pitch_fix = lambda s:self.pitch_fix(s, sr_pitched, org_mix) if S > 1: sources = {k: pitch_fix(v) if self.is_pitch_change else v for k, v in zip(self.mdx_c_configs.training.instruments, estimated_sources.cpu().detach().numpy())} - + del estimated_sources if self.is_denoise_model: if VOCAL_STEM in sources.keys() and INST_STEM in sources.keys(): sources[VOCAL_STEM] = vr_denoiser(sources[VOCAL_STEM], self.device, model_path=self.DENOISER_MODEL) @@ -773,6 +791,7 @@ class SeperateMDXC(SeperateAttributes): return sources else: est_s = estimated_sources.cpu().detach().numpy() + del estimated_sources return pitch_fix(est_s) if self.is_pitch_change else est_s class SeperateDemucs(SeperateAttributes): @@ -1339,8 +1358,6 @@ def vr_denoiser(X, device, hop_length=1024, n_fft=2048, cropsize=256, is_deverbe hop_length=1024 nout, nout_lstm = 16, 128 - device = cpu if not device in ['cuda', 'cpu'] else device - model = nets_new.CascadedNet(n_fft, nout=nout, nout_lstm=nout_lstm) model.load_state_dict(torch.load(model_path, map_location=cpu)) model.to(device) @@ -1439,4 +1456,4 @@ def loading_mix(X, mp): del X_wave, X_spec_s - return X_spec \ No newline at end of file + return X_spec