From 8f30a09d68f42c9496e83852ddb043e622f799de Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Mon, 20 Jul 2020 16:52:35 -0500 Subject: [PATCH] Add files via upload --- VocalRemover.py | 542 ++++++++++++++++++++++++++++++++++++++++++++++++ augment.py | 76 +++++++ inference.py | 200 ++++++++++++++++++ train.py | 223 ++++++++++++++++++++ 4 files changed, 1041 insertions(+) create mode 100644 VocalRemover.py create mode 100644 augment.py create mode 100644 inference.py create mode 100644 train.py diff --git a/VocalRemover.py b/VocalRemover.py new file mode 100644 index 0000000..16ee1a5 --- /dev/null +++ b/VocalRemover.py @@ -0,0 +1,542 @@ +# GUI modules +import tkinter as tk +import tkinter.ttk as ttk +import tkinter.messagebox +import tkinter.filedialog +import tkinter.font +from datetime import datetime +# Images +from PIL import Image +from PIL import ImageTk +import pickle # Save Data +# Other Modules +import subprocess # Run python file +# Pathfinding +import pathlib +import os +from collections import defaultdict +# Used for live text displaying +import queue +import threading # Run the algorithm inside a thread + +import torch +import inference + +# --Global Variables-- +base_path = os.path.dirname(__file__) +os.chdir(base_path) # Change the current working directory to the base path +models_dir = os.path.join(base_path, 'models') +logo_path = os.path.join(base_path, 'Images/UVR-logo.png') +DEFAULT_DATA = { + 'exportPath': '', + 'gpuConversion': False, + 'postprocessing': False, + 'mask': False, + 'stackLoops': False, + 'srValue': 44100, + 'hopValue': 1024, + 'stackLoopsNum': 1, + 'winSize': 512, +} +# Supported Music Files +AVAILABLE_FORMATS = ['.mp3', '.mp4', '.m4a', '.flac', '.wav'] + + +def open_image(path: str, size: tuple = None, keep_aspect: bool = True, rotate: int = 0) -> tuple: + """ + Open the image on the path and apply given settings\n + Paramaters: + path(str): + Absolute path of the image + size(tuple): + first value - width + second value - height + keep_aspect(bool): + keep aspect ratio of image and resize + to maximum possible width and height + (maxima are given by size) + rotate(int): + clockwise rotation of image + Returns(tuple): + (ImageTk.PhotoImage, Image) + """ + img = Image.open(path) + ratio = img.height/img.width + img = img.rotate(angle=-rotate) + if size is not None: + size = (int(size[0]), int(size[1])) + if keep_aspect: + img = img.resize((size[0], int(size[0] * ratio)), Image.ANTIALIAS) + else: + img = img.resize(size, Image.ANTIALIAS) + img = img.convert(mode='RGBA') + return ImageTk.PhotoImage(img), img + + +def save_data(data): + """ + Saves given data as a .pkl (pickle) file + + Paramters: + data(dict): + Dictionary containing all the necessary data to save + """ + # Open data file, create it if it does not exist + with open('data.pkl', 'wb') as data_file: + pickle.dump(data, data_file) + + +def load_data() -> dict: + """ + Loads saved pkl file and returns the stored data + + Returns(dict): + Dictionary containing all the saved data + """ + try: + with open('data.pkl', 'rb') as data_file: # Open data file + data = pickle.load(data_file) + + return data + except (ValueError, FileNotFoundError): + # Data File is corrupted or not found so recreate it + save_data(data=DEFAULT_DATA) + + return load_data() + + +class ThreadSafeConsole(tk.Text): + """ + Text Widget which is thread safe for tkinter + """ + + def __init__(self, master, **options): + tk.Text.__init__(self, master, **options) + self.queue = queue.Queue() + self.update_me() + + def write(self, line): + self.queue.put(line) + + def clear(self): + self.queue.put(None) + + def update_me(self): + self.configure(state=tk.NORMAL) + try: + while 1: + line = self.queue.get_nowait() + if line is None: + self.delete(1.0, tk.END) + else: + self.insert(tk.END, str(line)) + self.see(tk.END) + self.update_idletasks() + except queue.Empty: + pass + self.configure(state=tk.DISABLED) + self.after(100, self.update_me) + + +class MainWindow(tk.Tk): + # --Constants-- + # None + + def __init__(self): + # Run the __init__ method on the tk.Tk class + super().__init__() + + # --Window Settings-- + self.title('Desktop Application') + # Set Geometry and Center Window + self.geometry('{width}x{height}+{xpad}+{ypad}'.format( + width=530, + height=690, + xpad=int(self.winfo_screenwidth()/2 - 530/2), + ypad=int(self.winfo_screenheight()/2 - 690/2))) + self.configure(bg='#FFFFFF') # Set background color to white + self.resizable(False, False) + self.update() + + # --Variables-- + self.logo_img = open_image(path=logo_path, + size=(self.winfo_width(), 9999), + keep_aspect=True)[0] + self.label_to_path = defaultdict(lambda: '') + # -Tkinter Value Holders- + data = load_data() + self.exportPath_var = tk.StringVar(value=data['exportPath']) + self.filePaths = '' + self.gpuConversion_var = tk.BooleanVar(value=data['gpuConversion']) + self.postprocessing_var = tk.BooleanVar(value=data['postprocessing']) + self.mask_var = tk.BooleanVar(value=data['mask']) + self.stackLoops_var = tk.IntVar(value=data['stackLoops']) + self.srValue_var = tk.IntVar(value=data['srValue']) + self.hopValue_var = tk.IntVar(value=data['hopValue']) + self.winSize_var = tk.IntVar(value=data['winSize']) + self.stackLoopsNum_var = tk.IntVar(value=data['stackLoopsNum']) + self.model_var = tk.StringVar(value='') + + self.progress_var = tk.IntVar(value=0) + + # --Widgets-- + self.create_widgets() + self.configure_widgets() + self.place_widgets() + + self.update_available_models() + self.update_stack_state() + + # -Widget Methods- + def create_widgets(self): + """Create window widgets""" + self.title_Label = tk.Label(master=self, bg='white', + image=self.logo_img, compound=tk.TOP) + self.filePaths_Frame = tk.Frame(master=self, bg='white') + self.fill_filePaths_Frame() + + self.options_Frame = tk.Frame(master=self, bg='white') + self.fill_options_Frame() + + self.conversion_Button = ttk.Button(master=self, + text='Start Conversion', + command=self.start_conversion) + + self.progressbar = ttk.Progressbar(master=self, + variable=self.progress_var) + + self.command_Text = ThreadSafeConsole(master=self, + background='#EFEFEF', + borderwidth=0,) + self.command_Text.write(f'COMMAND LINE [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]') # nopep8 + + def configure_widgets(self): + """Change widget styling and appearance""" + ttk.Style().configure('TCheckbutton', background='white') + + def place_widgets(self): + """Place main widgets""" + self.title_Label.place(x=-2, y=-2) + + self.filePaths_Frame.place(x=10, y=0, width=-20, height=0, + relx=0, rely=0.19, relwidth=1, relheight=0.14) + self.options_Frame.place(x=25, y=15, width=-50, height=-30, + relx=0, rely=0.33, relwidth=1, relheight=0.23) + self.conversion_Button.place(x=10, y=5, width=-20, height=-10, + relx=0, rely=0.56, relwidth=1, relheight=0.07) + self.command_Text.place(x=15, y=10, width=-30, height=-10, + relx=0, rely=0.63, relwidth=1, relheight=0.28) + self.progressbar.place(x=25, y=15, width=-50, height=-30, + relx=0, rely=0.91, relwidth=1, relheight=0.09) + + def fill_filePaths_Frame(self): + """Fill Frame with neccessary widgets""" + # -Create Widgets- + # Save To Option + self.filePaths_saveTo_Button = ttk.Button(master=self.filePaths_Frame, + text='Save to', + command=self.open_export_filedialog) + self.filePaths_saveTo_Entry = ttk.Entry(master=self.filePaths_Frame, + textvariable=self.exportPath_var, + state=tk.DISABLED + ) + # Select Music Files Option + self.filePaths_musicFile_Button = ttk.Button(master=self.filePaths_Frame, + text='Select Your Audio File(s)', + command=self.open_file_filedialog) + self.filePaths_musicFile_Entry = ttk.Entry(master=self.filePaths_Frame, + text=self.filePaths, + state=tk.DISABLED + ) + # -Place Widgets- + # Save To Option + self.filePaths_saveTo_Button.place(x=0, y=5, width=0, height=-10, + relx=0, rely=0, relwidth=0.3, relheight=0.5) + self.filePaths_saveTo_Entry.place(x=10, y=7, width=-20, height=-14, + relx=0.3, rely=0, relwidth=0.7, relheight=0.5) + # Select Music Files Option + self.filePaths_musicFile_Button.place(x=0, y=5, width=0, height=-10, + relx=0, rely=0.5, relwidth=0.4, relheight=0.5) + self.filePaths_musicFile_Entry.place(x=10, y=7, width=-20, height=-14, + relx=0.4, rely=0.5, relwidth=0.6, relheight=0.5) + + def fill_options_Frame(self): + """Fill Frame with neccessary widgets""" + # -Create Widgets- + # GPU Selection + self.options_gpu_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='GPU Conversion', + variable=self.gpuConversion_var, + ) + # Postprocessing + self.options_post_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Post-Process (Dev Opt)', + variable=self.postprocessing_var, + ) + # Mask + self.options_mask_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Save Mask PNG', + variable=self.mask_var, + ) + # SR + self.options_sr_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.srValue_var,) + self.options_sr_Label = tk.Label(master=self.options_Frame, + text='SR', anchor=tk.W, + background='white') + # HOP LENGTH + self.options_hop_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.hopValue_var,) + self.options_hop_Label = tk.Label(master=self.options_Frame, + text='HOP LENGTH', anchor=tk.W, + background='white') + # WINDOW SIZE + self.options_winSize_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.winSize_var,) + self.options_winSize_Label = tk.Label(master=self.options_Frame, + text='WINDOW SIZE', anchor=tk.W, + background='white') + # Stack Loops + self.options_stack_Checkbutton = ttk.Checkbutton(master=self.options_Frame, + text='Stack Passes', + variable=self.stackLoops_var, + ) + self.options_stack_Entry = ttk.Entry(master=self.options_Frame, + textvariable=self.stackLoopsNum_var,) + self.options_stack_Checkbutton.configure(command=self.update_stack_state) # nopep8 + # Choose Model + self.options_model_Label = tk.Label(master=self.options_Frame, + text='Choose Your Model', + background='white') + self.options_model_Optionmenu = ttk.OptionMenu(self.options_Frame, + self.model_var, + 1, + *[1, 2]) + self.options_model_Button = ttk.Button(master=self.options_Frame, + text='Add Your Own Model', + command=self.open_newModel_filedialog) + # -Place Widgets- + # GPU Selection + self.options_gpu_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=0, relwidth=1/3, relheight=1/4) + self.options_post_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=1/4, relwidth=1/3, relheight=1/4) + self.options_mask_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=2/4, relwidth=1/3, relheight=1/4) + # Stack Loops + self.options_stack_Checkbutton.place(x=0, y=0, width=0, height=0, + relx=0, rely=3/4, relwidth=1/3/4*3, relheight=1/4) + self.options_stack_Entry.place(x=0, y=4, width=0, height=-8, + relx=1/3/4*2.4, rely=3/4, relwidth=1/3/4*0.9, relheight=1/4) + # SR + self.options_sr_Entry.place(x=-5, y=4, width=5, height=-8, + relx=1/3, rely=0, relwidth=1/3/4, relheight=1/4) + self.options_sr_Label.place(x=10, y=4, width=-10, height=-8, + relx=1/3/4 + 1/3, rely=0, relwidth=1/3/4*3, relheight=1/4) + # HOP LENGTH + self.options_hop_Entry.place(x=-5, y=4, width=5, height=-8, + relx=1/3, rely=1/4, relwidth=1/3/4, relheight=1/4) + self.options_hop_Label.place(x=10, y=4, width=-10, height=-8, + relx=1/3/4 + 1/3, rely=1/4, relwidth=1/3/4*3, relheight=1/4) + # WINDOW SIZE + self.options_winSize_Entry.place(x=-5, y=4, width=5, height=-8, + relx=1/3, rely=2/4, relwidth=1/3/4, relheight=1/4) + self.options_winSize_Label.place(x=10, y=4, width=-10, height=-8, + relx=1/3/4 + 1/3, rely=2/4, relwidth=1/3/4*3, relheight=1/4) + # Choose Model + self.options_model_Label.place(x=0, y=0, width=0, height=-10, + relx=2/3, rely=0, relwidth=1/3, relheight=1/3) + self.options_model_Optionmenu.place(x=15, y=-2.5, width=-30, height=-10, + relx=2/3, rely=1/3, relwidth=1/3, relheight=1/3) + self.options_model_Button.place(x=15, y=0, width=-30, height=-5, + relx=2/3, rely=2/3, relwidth=1/3, relheight=1/3) + + # Opening filedialogs + def open_file_filedialog(self): + """Make user select music files""" + paths = tk.filedialog.askopenfilenames( + parent=self, + title=f'Select Music Files', + initialdir='/', + initialfile='', + filetypes=[ + ('; '.join(AVAILABLE_FORMATS).replace('.', ''), + '*' + ' *'.join(AVAILABLE_FORMATS)), + ]) + if paths: # Path selected + for path in paths: + if not path.lower().endswith(tuple(AVAILABLE_FORMATS)): + tk.messagebox.showerror(master=self, + title='Invalid File', + message='Please select a \"{}\" audio file!'.format('" or "'.join(AVAILABLE_FORMATS)), # nopep8 + detail=f'File: {path}') + return + self.filePaths = paths + # Change the entry text + self.filePaths_musicFile_Entry.configure(state=tk.NORMAL) + self.filePaths_musicFile_Entry.delete(0, tk.END) + self.filePaths_musicFile_Entry.insert(0, self.filePaths) + self.filePaths_musicFile_Entry.configure(state=tk.DISABLED) + + def open_export_filedialog(self): + """Make user select a folder to export the converted files in""" + path = tk.filedialog.askdirectory( + parent=self, + title=f'Select Folder', + initialdir='/',) + if path: # Path selected + self.exportPath_var.set(path) + + def open_newModel_filedialog(self): + """Make user select a ".pth" model to use for the vocal removing""" + path = tk.filedialog.askopenfilename( + parent=self, + title=f'Select Model File', + initialdir='/', + initialfile='', + filetypes=[ + ('pth', '*.pth'), + ]) + + if path: # Path selected + if path.lower().endswith(('.pth')): + self.add_available_model(abs_path=path) + else: + tk.messagebox.showerror(master=self, + title='Invalid File', + message=f'Please select a PyTorch model file ".pth"!', + detail=f'File: {path}') + return + + def start_conversion(self): + """ + Start the conversion for all the given mp3 and wav files + """ + # -Get all variables- + input_paths = self.filePaths + export_path = self.exportPath_var.get() + model_path = self.label_to_path[self.model_var.get()] + try: + sr = self.srValue_var.get() + hop_length = self.hopValue_var.get() + window_size = self.winSize_var.get() + loops_num = self.stackLoopsNum_var.get() + except tk.TclError: # Non integer was put in entry box + tk.messagebox.showwarning(master=self, + title='Invalid Input', + message='Please make sure you only input integer numbers!') + return + except SyntaxError: # Non integer was put in entry box + tk.messagebox.showwarning(master=self, + title='Invalid Music File', + message='You have selected an invalid music file!\nPlease make sure that your files still exist and end with either ".mp3", ".mp4", ".m4a", ".flac", ".wav"') + return + + # -Check for invalid inputs- + if not any([(os.path.isfile(path) and path.endswith(('.mp3', '.mp4', '.m4a', '.flac', '.wav'))) + for path in input_paths]): + tk.messagebox.showwarning(master=self, + title='Invalid Music File', + message='You have selected an invalid music file!\nPlease make sure that your files still exist and end with either ".mp3", ".mp4", ".m4a", ".flac", ".wav"') + return + if not os.path.isdir(export_path): + tk.messagebox.showwarning(master=self, + title='Invalid Export Directory', + message='You have selected an invalid export directory!\nPlease make sure that your directory still exists!') + return + if not os.path.isfile(model_path): + tk.messagebox.showwarning(master=self, + title='Invalid Model File', + message='You have selected an invalid model file!\nPlease make sure that your model file still exists!') + return + + # -Save Data- + save_data(data={ + 'exportPath': export_path, + 'gpuConversion': self.gpuConversion_var.get(), + 'postprocessing': self.postprocessing_var.get(), + 'mask': self.mask_var.get(), + 'stackLoops': self.stackLoops_var.get(), + 'gpuConversion': self.gpuConversion_var.get(), + 'srValue': sr, + 'hopValue': hop_length, + 'winSize': window_size, + 'stackLoopsNum': loops_num, + }) + + # -Run the algorithm- + threading.Thread(target=inference.main, + kwargs={ + 'input_paths': input_paths, + 'gpu': 0 if self.gpuConversion_var.get() else -1, + 'postprocess': self.postprocessing_var.get(), + 'out_mask': self.mask_var.get(), + 'model': model_path, + 'sr': sr, + 'hop_length': hop_length, + 'window_size': window_size, + 'export_path': export_path, + 'loops': loops_num, + # Other Variables (Tkinter) + 'window': self, + 'command_widget': self.command_Text, + 'button_widget': self.conversion_Button, + 'progress_var': self.progress_var, + }, + daemon=True + ).start() + + # Models + def update_available_models(self): + """ + Loop through every model (.pth) in the models directory + and add to the select your model list + """ + # Delete all previous options + self.model_var.set('') + self.options_model_Optionmenu['menu'].delete(0, 'end') + + for file_name in os.listdir(models_dir): + if file_name.endswith('.pth'): + # Add Radiobutton to the Options Menu + self.options_model_Optionmenu['menu'].add_radiobutton(label=file_name, + command=tk._setit(self.model_var, file_name)) + # Link the files name to its absolute path + self.label_to_path[file_name] = os.path.join(models_dir, file_name) # nopep8 + + def add_available_model(self, abs_path: str): + """ + Add the given absolute path of the file (.pth) to the available options + and set the currently selected model to this one + """ + if abs_path.endswith('.pth'): + file_name = f'[CUSTOM] {os.path.basename(abs_path)}' + # Add Radiobutton to the Options Menu + self.options_model_Optionmenu['menu'].add_radiobutton(label=file_name, + command=tk._setit(self.model_var, file_name)) + # Set selected model to the newly added one + self.model_var.set(file_name) + # Link the files name to its absolute path + self.label_to_path[file_name] = abs_path # nopep8 + else: + tk.messagebox.showerror(master=self, + title='Invalid File', + message='Please select a model file with the ".pth" ending!', + detail=f'File: {abs_path}') + + def update_stack_state(self): + """ + Vary the stack Entry fro disabled/enabled based on the + stackLoops variable, which is connected to the checkbutton + """ + if self.stackLoops_var.get(): + self.options_stack_Entry.configure(state=tk.NORMAL) + else: + self.options_stack_Entry.configure(state=tk.DISABLED) + self.stackLoopsNum_var.set(1) + + +if __name__ == "__main__": + root = MainWindow() + + root.mainloop() diff --git a/augment.py b/augment.py new file mode 100644 index 0000000..467230a --- /dev/null +++ b/augment.py @@ -0,0 +1,76 @@ +import argparse +import os +import subprocess + +import librosa +import numpy as np +import soundfile as sf +from tqdm import tqdm + +from lib import spec_utils + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument('--sr', '-r', type=int, default=44100) + p.add_argument('--hop_length', '-l', type=int, default=1024) + p.add_argument('--pitch', '-p', type=int, default=-2) + p.add_argument('--mixture_dataset', '-m', required=True) + p.add_argument('--instrumental_dataset', '-i', required=True) + args = p.parse_args() + + input_exts = ['.wav', '.m4a', '.3gp', '.oma', '.mp3', '.mp4'] + X_list = sorted([ + os.path.join(args.mixture_dataset, fname) + for fname in os.listdir(args.mixture_dataset) + if os.path.splitext(fname)[1] in input_exts]) + y_list = sorted([ + os.path.join(args.instrumental_dataset, fname) + for fname in os.listdir(args.instrumental_dataset) + if os.path.splitext(fname)[1] in input_exts]) + + input_i = 'input_i_{}.wav'.format(args.pitch) + input_v = 'input_v_{}.wav'.format(args.pitch) + output_i = 'output_i_{}.wav'.format(args.pitch) + output_v = 'output_v_{}.wav'.format(args.pitch) + cmd_i = 'soundstretch {} {} -pitch={}'.format(input_i, output_i, args.pitch) + cmd_v = 'soundstretch {} {} -pitch={}'.format(input_v, output_v, args.pitch) + suffix = '_pitch{}.npy'.format(args.pitch) + + filelist = list(zip(X_list, y_list)) + for mix_path, inst_path in tqdm(filelist): + X, _ = librosa.load( + mix_path, args.sr, False, dtype=np.float32, res_type='kaiser_fast') + y, _ = librosa.load( + inst_path, args.sr, False, dtype=np.float32, res_type='kaiser_fast') + + X, _ = librosa.effects.trim(X) + y, _ = librosa.effects.trim(y) + X, y = spec_utils.align_wave_head_and_tail(X, y, args.sr) + + v = X - y + sf.write(input_i, y.T, args.sr) + sf.write(input_v, v.T, args.sr) + subprocess.call(cmd_i, stderr=subprocess.DEVNULL) + subprocess.call(cmd_v, stderr=subprocess.DEVNULL) + + y, _ = librosa.load( + output_i, args.sr, False, dtype=np.float32, res_type='kaiser_fast') + v, _ = librosa.load( + output_v, args.sr, False, dtype=np.float32, res_type='kaiser_fast') + X = y + v + + spec = spec_utils.calc_spec(X, args.hop_length) + basename, _ = os.path.splitext(os.path.basename(mix_path)) + outpath = os.path.join(args.mixture_dataset, basename + suffix) + np.save(outpath, np.abs(spec)) + + spec = spec_utils.calc_spec(y, args.hop_length) + basename, _ = os.path.splitext(os.path.basename(inst_path)) + outpath = os.path.join(args.instrumental_dataset, basename + suffix) + np.save(outpath, np.abs(spec)) + + os.remove(input_i) + os.remove(input_v) + os.remove(output_i) + os.remove(output_v) diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..2b05552 --- /dev/null +++ b/inference.py @@ -0,0 +1,200 @@ +import argparse +import os + +import cv2 +import librosa +import numpy as np +import soundfile as sf +from tqdm import tqdm + +from lib import dataset +from lib import nets +from lib import spec_utils + +# Variable manipulation and command line text parsing +import torch +import tkinter as tk +import traceback # Error Message Recent Calls + + +class Namespace: + """ + Replaces ArgumentParser + """ + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +def main(window: tk.Wm, input_paths: list, gpu: bool = -1, + model: str = 'models/baseline.pth', sr: int = 44100, hop_length: int = 1024, + window_size: int = 512, out_mask: bool = False, postprocess: bool = False, + export_path: str = '', loops: int = 1, + # Other Variables (Tkinter) + progress_var: tk.Variable = None, button_widget: tk.Button = None, command_widget: tk.Text = None, + ): + def load_model(): + args.command_widget.write('Loading model...\n') # nopep8 Write Command Text + device = torch.device('cpu') + model = nets.CascadedASPPNet() + model.load_state_dict(torch.load(args.model, map_location=device)) + if torch.cuda.is_available() and args.gpu >= 0: + device = torch.device('cuda:{}'.format(args.gpu)) + model.to(device) + args.command_widget.write('Done!\n') # nopep8 Write Command Text + + return model, device + + def load_wave_source(): + args.command_widget.write(base_text + 'Loading wave source...\n') # nopep8 Write Command Text + X, sr = librosa.load(music_file, + args.sr, + False, + dtype=np.float32, + res_type='kaiser_fast') + args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + return X, sr + + def stft_wave_source(X): + args.command_widget.write(base_text + 'Stft of wave source...\n') # nopep8 Write Command Text + X = spec_utils.calc_spec(X, args.hop_length) + X, phase = np.abs(X), np.exp(1.j * np.angle(X)) + coeff = X.max() + X /= coeff + + offset = model.offset + l, r, roi_size = dataset.make_padding( + X.shape[2], args.window_size, offset) + X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant') + X_roll = np.roll(X_pad, roi_size // 2, axis=2) + + model.eval() + with torch.no_grad(): + masks = [] + masks_roll = [] + length = int(np.ceil(X.shape[2] / roi_size)) + for i in tqdm(range(length)): + progress_var.set(base_progress + max_progress * (0.1 + (0.6/length * i))) # nopep8 Update Progress + start = i * roi_size + X_window = torch.from_numpy(np.asarray([ + X_pad[:, :, start:start + args.window_size], + X_roll[:, :, start:start + args.window_size] + ])).to(device) + pred = model.predict(X_window) + pred = pred.detach().cpu().numpy() + masks.append(pred[0]) + masks_roll.append(pred[1]) + + mask = np.concatenate(masks, axis=2)[:, :, :X.shape[2]] + mask_roll = np.concatenate(masks_roll, axis=2)[ + :, :, :X.shape[2]] + mask = (mask + np.roll(mask_roll, -roi_size // 2, axis=2)) / 2 + + if args.postprocess: + vocal = X * (1 - mask) * coeff + mask = spec_utils.mask_uninformative(mask, vocal) + args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + inst = X * mask * coeff + vocal = X * (1 - mask) * coeff + + return inst, vocal, phase, mask + + def invert_instrum_vocal(inst, vocal, phase): + args.command_widget.write(base_text + 'Inverse stft of instruments and vocals...\n') # nopep8 Write Command Text + + wav_instrument = spec_utils.spec_to_wav(inst, phase, args.hop_length) # nopep8 + wav_vocals = spec_utils.spec_to_wav(vocal, phase, args.hop_length) # nopep8 + + args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + return wav_instrument, wav_vocals + + def save_files(wav_instrument, wav_vocals): + args.command_widget.write(base_text + 'Saving Files...\n') # nopep8 Write Command Text + sf.write(f'{export_path}/{base_name}_(Instrumental).wav', + wav_instrument.T, sr) + if cur_loop == 0: + sf.write(f'{export_path}/{base_name}_(Vocals).wav', + wav_vocals.T, sr) + if (cur_loop == (args.loops - 1) and + args.loops > 1): + sf.write(f'{export_path}/{base_name}_(Last_Vocals).wav', + wav_vocals.T, sr) + + args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + def create_mask(): + args.command_widget.write(base_text + 'Creating Mask...\n') # nopep8 Write Command Text + norm_mask = np.uint8((1 - mask) * 255).transpose(1, 2, 0) + norm_mask = np.concatenate([ + np.max(norm_mask, axis=2, keepdims=True), + norm_mask], axis=2)[::-1] + _, bin_mask = cv2.imencode('.png', norm_mask) + args.command_widget.write(base_text + 'Saving Mask...\n') # nopep8 Write Command Text + with open(f'{export_path}/{base_name}_(Mask).png', mode='wb') as f: + bin_mask.tofile(f) + args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text + + args = Namespace(input=input_paths, gpu=gpu, model=model, + sr=sr, hop_length=hop_length, window_size=window_size, + out_mask=out_mask, postprocess=postprocess, export=export_path, + loops=loops, + # Other Variables (Tkinter) + window=window, progress_var=progress_var, + button_widget=button_widget, command_widget=command_widget, + ) + args.command_widget.clear() # Clear Command Text + args.button_widget.configure(state=tk.DISABLED) # Disable Button + total_files = len(args.input) # Used to calculate progress + + model, device = load_model() + + for file_num, music_file in enumerate(args.input, start=1): + try: + base_name = f'{file_num}_{os.path.splitext(os.path.basename(music_file))[0]}' + for cur_loop in range(args.loops): + if cur_loop > 0: + args.command_widget.write(f'File {file_num}/{total_files}: ' + 'Next Pass!\n') # nopep8 Write Command Text + music_file = f'{export_path}/{base_name}_(Instrumental).wav' + base_progress = 100 / \ + (total_files*args.loops) * \ + ((file_num*args.loops)-((args.loops-1) - cur_loop)-1) + base_text = 'File {file_num}/{total_files}:{loop} '.format( + file_num=file_num, + total_files=total_files, + loop='' if args.loops <= 1 else f' ({cur_loop+1}/{args.loops})') + max_progress = 100 / (total_files*args.loops) + progress_var.set(base_progress + max_progress * 0.05) # nopep8 Update Progress + + X, sr = load_wave_source() + progress_var.set(base_progress + max_progress * 0.1) # nopep8 Update Progress + + inst, vocal, phase, mask = stft_wave_source(X) + progress_var.set(base_progress + max_progress * 0.7) # nopep8 Update Progress + + wav_instrument, wav_vocals = invert_instrum_vocal(inst, vocal, phase) # nopep8 + progress_var.set(base_progress + max_progress * 0.8) # nopep8 Update Progress + + save_files(wav_instrument, wav_vocals) + progress_var.set(base_progress + max_progress * 0.9) # nopep8 Update Progress + + if args.out_mask: + create_mask() + progress_var.set(base_progress + max_progress * 1) # nopep8 Update Progress + + args.command_widget.write(base_text + 'Completed Seperation!\n\n') # nopep8 Write Command Text + except Exception as e: + traceback_text = ''.join(traceback.format_tb(e.__traceback__)) + print(traceback_text) + print(type(e).__name__, e) + tk.messagebox.showerror(master=args.window, + title='Untracked Error', + message=f'Traceback Error: "{traceback_text}"\n{type(e).__name__}: "{e}"\nFile: {music_file}\n\nPlease contact the creator and attach a screenshot of this error with the file which caused it!') + args.button_widget.configure(state=tk.NORMAL) # Enable Button + return + + progress_var.set(100) # Update Progress + args.command_widget.write(f'Conversion(s) Completed and Saving all Files!') # nopep8 Write Command Text + args.button_widget.configure(state=tk.NORMAL) # Enable Button diff --git a/train.py b/train.py new file mode 100644 index 0000000..420f487 --- /dev/null +++ b/train.py @@ -0,0 +1,223 @@ +import argparse +from datetime import datetime as dt +import gc +import json +import os +import random + +import numpy as np +import torch +import torch.nn as nn + +from lib import dataset +from lib import nets +from lib import spec_utils + + +def train_val_split(mix_dir, inst_dir, val_rate, val_filelist_json): + input_exts = ['.wav', '.m4a', '.3gp', '.oma', '.mp3', '.mp4'] + X_list = sorted([ + os.path.join(mix_dir, fname) + for fname in os.listdir(mix_dir) + if os.path.splitext(fname)[1] in input_exts]) + y_list = sorted([ + os.path.join(inst_dir, fname) + for fname in os.listdir(inst_dir) + if os.path.splitext(fname)[1] in input_exts]) + + filelist = list(zip(X_list, y_list)) + random.shuffle(filelist) + + val_filelist = [] + if val_filelist_json is not None: + with open(val_filelist_json, 'r', encoding='utf8') as f: + val_filelist = json.load(f) + + if len(val_filelist) == 0: + val_size = int(len(filelist) * val_rate) + train_filelist = filelist[:-val_size] + val_filelist = filelist[-val_size:] + else: + train_filelist = [ + pair for pair in filelist + if list(pair) not in val_filelist] + + return train_filelist, val_filelist + + +def train_inner_epoch(X_train, y_train, model, optimizer, batchsize, instance_loss): + sum_loss = 0 + model.train() + aux_crit = nn.L1Loss() + criterion = nn.L1Loss(reduction='none') + perm = np.random.permutation(len(X_train)) + for i in range(0, len(X_train), batchsize): + local_perm = perm[i: i + batchsize] + X_batch = torch.from_numpy(X_train[local_perm]).cpu() + y_batch = torch.from_numpy(y_train[local_perm]).cpu() + + model.zero_grad() + mask, aux = model(X_batch) + + aux_loss = aux_crit(X_batch * aux, y_batch) + X_batch = spec_utils.crop_center(mask, X_batch, False) + y_batch = spec_utils.crop_center(mask, y_batch, False) + abs_diff = criterion(X_batch * mask, y_batch) + + loss = abs_diff.mean() * 0.9 + aux_loss * 0.1 + loss.backward() + optimizer.step() + + abs_diff_np = abs_diff.detach().cpu().numpy() + instance_loss[local_perm] += abs_diff_np.mean(axis=(1, 2, 3)) + sum_loss += float(loss.detach().cpu().numpy()) * len(X_batch) + + return sum_loss / len(X_train) + + +def val_inner_epoch(dataloader, model): + sum_loss = 0 + model.eval() + criterion = nn.L1Loss() + with torch.no_grad(): + for X_batch, y_batch in dataloader: + X_batch = X_batch.cpu() + y_batch = y_batch.cpu() + mask = model.predict(X_batch) + X_batch = spec_utils.crop_center(mask, X_batch, False) + y_batch = spec_utils.crop_center(mask, y_batch, False) + + loss = criterion(X_batch * mask, y_batch) + sum_loss += float(loss.detach().cpu().numpy()) * len(X_batch) + + return sum_loss / len(dataloader.dataset) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--gpu', '-g', type=int, default=-1) + p.add_argument('--seed', '-s', type=int, default=2019) + p.add_argument('--sr', '-r', type=int, default=44100) + p.add_argument('--hop_length', '-l', type=int, default=1024) + p.add_argument('--mixture_dataset', '-m', required=True) + p.add_argument('--instrumental_dataset', '-i', required=True) + p.add_argument('--learning_rate', type=float, default=0.001) + p.add_argument('--lr_min', type=float, default=0.0001) + p.add_argument('--lr_decay_factor', type=float, default=0.9) + p.add_argument('--lr_decay_patience', type=int, default=6) + p.add_argument('--batchsize', '-B', type=int, default=4) + p.add_argument('--cropsize', '-c', type=int, default=256) + p.add_argument('--val_rate', '-v', type=float, default=0.1) + p.add_argument('--val_filelist', '-V', type=str, default=None) + p.add_argument('--val_batchsize', '-b', type=int, default=4) + p.add_argument('--val_cropsize', '-C', type=int, default=512) + p.add_argument('--patches', '-p', type=int, default=16) + p.add_argument('--epoch', '-E', type=int, default=100) + p.add_argument('--inner_epoch', '-e', type=int, default=4) + p.add_argument('--oracle_rate', '-O', type=float, default=0) + p.add_argument('--oracle_drop_rate', '-o', type=float, default=0.5) + p.add_argument('--mixup_rate', '-M', type=float, default=0.0) + p.add_argument('--mixup_alpha', '-a', type=float, default=1.0) + p.add_argument('--pretrained_model', '-P', type=str, default=None) + p.add_argument('--debug', '-d', action='store_true') + args = p.parse_args() + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + timestamp = dt.now().strftime('%Y%m%d%H%M%S') + + model = nets.CascadedASPPNet() + if args.pretrained_model is not None: + model.load_state_dict(torch.load(args.pretrained_model)) + if args.gpu >= 0: + model.cuda() + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=args.lr_decay_factor, + patience=args.lr_decay_patience, + min_lr=args.lr_min, + verbose=True) + + train_filelist, val_filelist = train_val_split( + mix_dir=args.mixture_dataset, + inst_dir=args.instrumental_dataset, + val_rate=args.val_rate, + val_filelist_json=args.val_filelist) + + if args.debug: + print('### DEBUG MODE') + train_filelist = train_filelist[:1] + val_filelist = val_filelist[:1] + + with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f: + json.dump(val_filelist, f, ensure_ascii=False) + + for i, (X_fname, y_fname) in enumerate(val_filelist): + print(i + 1, os.path.basename(X_fname), os.path.basename(y_fname)) + + val_dataset = dataset.make_validation_set( + filelist=val_filelist, + cropsize=args.val_cropsize, + sr=args.sr, + hop_length=args.hop_length, + offset=model.offset) + val_dataloader = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=args.val_batchsize, + shuffle=False, + num_workers=4) + + log = [] + oracle_X = None + oracle_y = None + best_loss = np.inf + for epoch in range(args.epoch): + X_train, y_train = dataset.make_training_set( + train_filelist, args.cropsize, args.patches, args.sr, args.hop_length, model.offset) + + X_train, y_train = dataset.mixup_generator( + X_train, y_train, args.mixup_rate, args.mixup_alpha) + + if oracle_X is not None and oracle_y is not None: + perm = np.random.permutation(len(oracle_X)) + X_train[perm] = oracle_X + y_train[perm] = oracle_y + + print('# epoch', epoch) + instance_loss = np.zeros(len(X_train), dtype=np.float32) + for inner_epoch in range(args.inner_epoch): + print(' * inner epoch {}'.format(inner_epoch)) + train_loss = train_inner_epoch( + X_train, y_train, model, optimizer, args.batchsize, instance_loss) + val_loss = val_inner_epoch(val_dataloader, model) + + print(' * training loss = {:.6f}, validation loss = {:.6f}' + .format(train_loss * 1000, val_loss * 1000)) + + scheduler.step(val_loss) + + if val_loss < best_loss: + best_loss = val_loss + print(' * best validation loss') + model_path = 'models/model_iter{}.pth'.format(epoch) + torch.save(model.state_dict(), model_path) + + log.append([train_loss, val_loss]) + with open('log_{}.json'.format(timestamp), 'w', encoding='utf8') as f: + json.dump(log, f, ensure_ascii=False) + + if args.oracle_rate > 0: + instance_loss /= args.inner_epoch + oracle_X, oracle_y, idx = dataset.get_oracle_data( + X_train, y_train, instance_loss, args.oracle_rate, args.oracle_drop_rate) + print(' * oracle loss = {:.6f}'.format(instance_loss[idx].mean())) + + del X_train, y_train + gc.collect() + + +if __name__ == '__main__': + main()