Add files via upload

This commit is contained in:
Anjok07 2020-07-20 16:52:35 -05:00 committed by GitHub
parent 6c886b73ef
commit 8f30a09d68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 1041 additions and 0 deletions

542
VocalRemover.py Normal file
View File

@ -0,0 +1,542 @@
# GUI modules
import tkinter as tk
import tkinter.ttk as ttk
import tkinter.messagebox
import tkinter.filedialog
import tkinter.font
from datetime import datetime
# Images
from PIL import Image
from PIL import ImageTk
import pickle # Save Data
# Other Modules
import subprocess # Run python file
# Pathfinding
import pathlib
import os
from collections import defaultdict
# Used for live text displaying
import queue
import threading # Run the algorithm inside a thread
import torch
import inference
# --Global Variables--
base_path = os.path.dirname(__file__)
os.chdir(base_path) # Change the current working directory to the base path
models_dir = os.path.join(base_path, 'models')
logo_path = os.path.join(base_path, 'Images/UVR-logo.png')
DEFAULT_DATA = {
'exportPath': '',
'gpuConversion': False,
'postprocessing': False,
'mask': False,
'stackLoops': False,
'srValue': 44100,
'hopValue': 1024,
'stackLoopsNum': 1,
'winSize': 512,
}
# Supported Music Files
AVAILABLE_FORMATS = ['.mp3', '.mp4', '.m4a', '.flac', '.wav']
def open_image(path: str, size: tuple = None, keep_aspect: bool = True, rotate: int = 0) -> tuple:
"""
Open the image on the path and apply given settings\n
Paramaters:
path(str):
Absolute path of the image
size(tuple):
first value - width
second value - height
keep_aspect(bool):
keep aspect ratio of image and resize
to maximum possible width and height
(maxima are given by size)
rotate(int):
clockwise rotation of image
Returns(tuple):
(ImageTk.PhotoImage, Image)
"""
img = Image.open(path)
ratio = img.height/img.width
img = img.rotate(angle=-rotate)
if size is not None:
size = (int(size[0]), int(size[1]))
if keep_aspect:
img = img.resize((size[0], int(size[0] * ratio)), Image.ANTIALIAS)
else:
img = img.resize(size, Image.ANTIALIAS)
img = img.convert(mode='RGBA')
return ImageTk.PhotoImage(img), img
def save_data(data):
"""
Saves given data as a .pkl (pickle) file
Paramters:
data(dict):
Dictionary containing all the necessary data to save
"""
# Open data file, create it if it does not exist
with open('data.pkl', 'wb') as data_file:
pickle.dump(data, data_file)
def load_data() -> dict:
"""
Loads saved pkl file and returns the stored data
Returns(dict):
Dictionary containing all the saved data
"""
try:
with open('data.pkl', 'rb') as data_file: # Open data file
data = pickle.load(data_file)
return data
except (ValueError, FileNotFoundError):
# Data File is corrupted or not found so recreate it
save_data(data=DEFAULT_DATA)
return load_data()
class ThreadSafeConsole(tk.Text):
"""
Text Widget which is thread safe for tkinter
"""
def __init__(self, master, **options):
tk.Text.__init__(self, master, **options)
self.queue = queue.Queue()
self.update_me()
def write(self, line):
self.queue.put(line)
def clear(self):
self.queue.put(None)
def update_me(self):
self.configure(state=tk.NORMAL)
try:
while 1:
line = self.queue.get_nowait()
if line is None:
self.delete(1.0, tk.END)
else:
self.insert(tk.END, str(line))
self.see(tk.END)
self.update_idletasks()
except queue.Empty:
pass
self.configure(state=tk.DISABLED)
self.after(100, self.update_me)
class MainWindow(tk.Tk):
# --Constants--
# None
def __init__(self):
# Run the __init__ method on the tk.Tk class
super().__init__()
# --Window Settings--
self.title('Desktop Application')
# Set Geometry and Center Window
self.geometry('{width}x{height}+{xpad}+{ypad}'.format(
width=530,
height=690,
xpad=int(self.winfo_screenwidth()/2 - 530/2),
ypad=int(self.winfo_screenheight()/2 - 690/2)))
self.configure(bg='#FFFFFF') # Set background color to white
self.resizable(False, False)
self.update()
# --Variables--
self.logo_img = open_image(path=logo_path,
size=(self.winfo_width(), 9999),
keep_aspect=True)[0]
self.label_to_path = defaultdict(lambda: '')
# -Tkinter Value Holders-
data = load_data()
self.exportPath_var = tk.StringVar(value=data['exportPath'])
self.filePaths = ''
self.gpuConversion_var = tk.BooleanVar(value=data['gpuConversion'])
self.postprocessing_var = tk.BooleanVar(value=data['postprocessing'])
self.mask_var = tk.BooleanVar(value=data['mask'])
self.stackLoops_var = tk.IntVar(value=data['stackLoops'])
self.srValue_var = tk.IntVar(value=data['srValue'])
self.hopValue_var = tk.IntVar(value=data['hopValue'])
self.winSize_var = tk.IntVar(value=data['winSize'])
self.stackLoopsNum_var = tk.IntVar(value=data['stackLoopsNum'])
self.model_var = tk.StringVar(value='')
self.progress_var = tk.IntVar(value=0)
# --Widgets--
self.create_widgets()
self.configure_widgets()
self.place_widgets()
self.update_available_models()
self.update_stack_state()
# -Widget Methods-
def create_widgets(self):
"""Create window widgets"""
self.title_Label = tk.Label(master=self, bg='white',
image=self.logo_img, compound=tk.TOP)
self.filePaths_Frame = tk.Frame(master=self, bg='white')
self.fill_filePaths_Frame()
self.options_Frame = tk.Frame(master=self, bg='white')
self.fill_options_Frame()
self.conversion_Button = ttk.Button(master=self,
text='Start Conversion',
command=self.start_conversion)
self.progressbar = ttk.Progressbar(master=self,
variable=self.progress_var)
self.command_Text = ThreadSafeConsole(master=self,
background='#EFEFEF',
borderwidth=0,)
self.command_Text.write(f'COMMAND LINE [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]') # nopep8
def configure_widgets(self):
"""Change widget styling and appearance"""
ttk.Style().configure('TCheckbutton', background='white')
def place_widgets(self):
"""Place main widgets"""
self.title_Label.place(x=-2, y=-2)
self.filePaths_Frame.place(x=10, y=0, width=-20, height=0,
relx=0, rely=0.19, relwidth=1, relheight=0.14)
self.options_Frame.place(x=25, y=15, width=-50, height=-30,
relx=0, rely=0.33, relwidth=1, relheight=0.23)
self.conversion_Button.place(x=10, y=5, width=-20, height=-10,
relx=0, rely=0.56, relwidth=1, relheight=0.07)
self.command_Text.place(x=15, y=10, width=-30, height=-10,
relx=0, rely=0.63, relwidth=1, relheight=0.28)
self.progressbar.place(x=25, y=15, width=-50, height=-30,
relx=0, rely=0.91, relwidth=1, relheight=0.09)
def fill_filePaths_Frame(self):
"""Fill Frame with neccessary widgets"""
# -Create Widgets-
# Save To Option
self.filePaths_saveTo_Button = ttk.Button(master=self.filePaths_Frame,
text='Save to',
command=self.open_export_filedialog)
self.filePaths_saveTo_Entry = ttk.Entry(master=self.filePaths_Frame,
textvariable=self.exportPath_var,
state=tk.DISABLED
)
# Select Music Files Option
self.filePaths_musicFile_Button = ttk.Button(master=self.filePaths_Frame,
text='Select Your Audio File(s)',
command=self.open_file_filedialog)
self.filePaths_musicFile_Entry = ttk.Entry(master=self.filePaths_Frame,
text=self.filePaths,
state=tk.DISABLED
)
# -Place Widgets-
# Save To Option
self.filePaths_saveTo_Button.place(x=0, y=5, width=0, height=-10,
relx=0, rely=0, relwidth=0.3, relheight=0.5)
self.filePaths_saveTo_Entry.place(x=10, y=7, width=-20, height=-14,
relx=0.3, rely=0, relwidth=0.7, relheight=0.5)
# Select Music Files Option
self.filePaths_musicFile_Button.place(x=0, y=5, width=0, height=-10,
relx=0, rely=0.5, relwidth=0.4, relheight=0.5)
self.filePaths_musicFile_Entry.place(x=10, y=7, width=-20, height=-14,
relx=0.4, rely=0.5, relwidth=0.6, relheight=0.5)
def fill_options_Frame(self):
"""Fill Frame with neccessary widgets"""
# -Create Widgets-
# GPU Selection
self.options_gpu_Checkbutton = ttk.Checkbutton(master=self.options_Frame,
text='GPU Conversion',
variable=self.gpuConversion_var,
)
# Postprocessing
self.options_post_Checkbutton = ttk.Checkbutton(master=self.options_Frame,
text='Post-Process (Dev Opt)',
variable=self.postprocessing_var,
)
# Mask
self.options_mask_Checkbutton = ttk.Checkbutton(master=self.options_Frame,
text='Save Mask PNG',
variable=self.mask_var,
)
# SR
self.options_sr_Entry = ttk.Entry(master=self.options_Frame,
textvariable=self.srValue_var,)
self.options_sr_Label = tk.Label(master=self.options_Frame,
text='SR', anchor=tk.W,
background='white')
# HOP LENGTH
self.options_hop_Entry = ttk.Entry(master=self.options_Frame,
textvariable=self.hopValue_var,)
self.options_hop_Label = tk.Label(master=self.options_Frame,
text='HOP LENGTH', anchor=tk.W,
background='white')
# WINDOW SIZE
self.options_winSize_Entry = ttk.Entry(master=self.options_Frame,
textvariable=self.winSize_var,)
self.options_winSize_Label = tk.Label(master=self.options_Frame,
text='WINDOW SIZE', anchor=tk.W,
background='white')
# Stack Loops
self.options_stack_Checkbutton = ttk.Checkbutton(master=self.options_Frame,
text='Stack Passes',
variable=self.stackLoops_var,
)
self.options_stack_Entry = ttk.Entry(master=self.options_Frame,
textvariable=self.stackLoopsNum_var,)
self.options_stack_Checkbutton.configure(command=self.update_stack_state) # nopep8
# Choose Model
self.options_model_Label = tk.Label(master=self.options_Frame,
text='Choose Your Model',
background='white')
self.options_model_Optionmenu = ttk.OptionMenu(self.options_Frame,
self.model_var,
1,
*[1, 2])
self.options_model_Button = ttk.Button(master=self.options_Frame,
text='Add Your Own Model',
command=self.open_newModel_filedialog)
# -Place Widgets-
# GPU Selection
self.options_gpu_Checkbutton.place(x=0, y=0, width=0, height=0,
relx=0, rely=0, relwidth=1/3, relheight=1/4)
self.options_post_Checkbutton.place(x=0, y=0, width=0, height=0,
relx=0, rely=1/4, relwidth=1/3, relheight=1/4)
self.options_mask_Checkbutton.place(x=0, y=0, width=0, height=0,
relx=0, rely=2/4, relwidth=1/3, relheight=1/4)
# Stack Loops
self.options_stack_Checkbutton.place(x=0, y=0, width=0, height=0,
relx=0, rely=3/4, relwidth=1/3/4*3, relheight=1/4)
self.options_stack_Entry.place(x=0, y=4, width=0, height=-8,
relx=1/3/4*2.4, rely=3/4, relwidth=1/3/4*0.9, relheight=1/4)
# SR
self.options_sr_Entry.place(x=-5, y=4, width=5, height=-8,
relx=1/3, rely=0, relwidth=1/3/4, relheight=1/4)
self.options_sr_Label.place(x=10, y=4, width=-10, height=-8,
relx=1/3/4 + 1/3, rely=0, relwidth=1/3/4*3, relheight=1/4)
# HOP LENGTH
self.options_hop_Entry.place(x=-5, y=4, width=5, height=-8,
relx=1/3, rely=1/4, relwidth=1/3/4, relheight=1/4)
self.options_hop_Label.place(x=10, y=4, width=-10, height=-8,
relx=1/3/4 + 1/3, rely=1/4, relwidth=1/3/4*3, relheight=1/4)
# WINDOW SIZE
self.options_winSize_Entry.place(x=-5, y=4, width=5, height=-8,
relx=1/3, rely=2/4, relwidth=1/3/4, relheight=1/4)
self.options_winSize_Label.place(x=10, y=4, width=-10, height=-8,
relx=1/3/4 + 1/3, rely=2/4, relwidth=1/3/4*3, relheight=1/4)
# Choose Model
self.options_model_Label.place(x=0, y=0, width=0, height=-10,
relx=2/3, rely=0, relwidth=1/3, relheight=1/3)
self.options_model_Optionmenu.place(x=15, y=-2.5, width=-30, height=-10,
relx=2/3, rely=1/3, relwidth=1/3, relheight=1/3)
self.options_model_Button.place(x=15, y=0, width=-30, height=-5,
relx=2/3, rely=2/3, relwidth=1/3, relheight=1/3)
# Opening filedialogs
def open_file_filedialog(self):
"""Make user select music files"""
paths = tk.filedialog.askopenfilenames(
parent=self,
title=f'Select Music Files',
initialdir='/',
initialfile='',
filetypes=[
('; '.join(AVAILABLE_FORMATS).replace('.', ''),
'*' + ' *'.join(AVAILABLE_FORMATS)),
])
if paths: # Path selected
for path in paths:
if not path.lower().endswith(tuple(AVAILABLE_FORMATS)):
tk.messagebox.showerror(master=self,
title='Invalid File',
message='Please select a \"{}\" audio file!'.format('" or "'.join(AVAILABLE_FORMATS)), # nopep8
detail=f'File: {path}')
return
self.filePaths = paths
# Change the entry text
self.filePaths_musicFile_Entry.configure(state=tk.NORMAL)
self.filePaths_musicFile_Entry.delete(0, tk.END)
self.filePaths_musicFile_Entry.insert(0, self.filePaths)
self.filePaths_musicFile_Entry.configure(state=tk.DISABLED)
def open_export_filedialog(self):
"""Make user select a folder to export the converted files in"""
path = tk.filedialog.askdirectory(
parent=self,
title=f'Select Folder',
initialdir='/',)
if path: # Path selected
self.exportPath_var.set(path)
def open_newModel_filedialog(self):
"""Make user select a ".pth" model to use for the vocal removing"""
path = tk.filedialog.askopenfilename(
parent=self,
title=f'Select Model File',
initialdir='/',
initialfile='',
filetypes=[
('pth', '*.pth'),
])
if path: # Path selected
if path.lower().endswith(('.pth')):
self.add_available_model(abs_path=path)
else:
tk.messagebox.showerror(master=self,
title='Invalid File',
message=f'Please select a PyTorch model file ".pth"!',
detail=f'File: {path}')
return
def start_conversion(self):
"""
Start the conversion for all the given mp3 and wav files
"""
# -Get all variables-
input_paths = self.filePaths
export_path = self.exportPath_var.get()
model_path = self.label_to_path[self.model_var.get()]
try:
sr = self.srValue_var.get()
hop_length = self.hopValue_var.get()
window_size = self.winSize_var.get()
loops_num = self.stackLoopsNum_var.get()
except tk.TclError: # Non integer was put in entry box
tk.messagebox.showwarning(master=self,
title='Invalid Input',
message='Please make sure you only input integer numbers!')
return
except SyntaxError: # Non integer was put in entry box
tk.messagebox.showwarning(master=self,
title='Invalid Music File',
message='You have selected an invalid music file!\nPlease make sure that your files still exist and end with either ".mp3", ".mp4", ".m4a", ".flac", ".wav"')
return
# -Check for invalid inputs-
if not any([(os.path.isfile(path) and path.endswith(('.mp3', '.mp4', '.m4a', '.flac', '.wav')))
for path in input_paths]):
tk.messagebox.showwarning(master=self,
title='Invalid Music File',
message='You have selected an invalid music file!\nPlease make sure that your files still exist and end with either ".mp3", ".mp4", ".m4a", ".flac", ".wav"')
return
if not os.path.isdir(export_path):
tk.messagebox.showwarning(master=self,
title='Invalid Export Directory',
message='You have selected an invalid export directory!\nPlease make sure that your directory still exists!')
return
if not os.path.isfile(model_path):
tk.messagebox.showwarning(master=self,
title='Invalid Model File',
message='You have selected an invalid model file!\nPlease make sure that your model file still exists!')
return
# -Save Data-
save_data(data={
'exportPath': export_path,
'gpuConversion': self.gpuConversion_var.get(),
'postprocessing': self.postprocessing_var.get(),
'mask': self.mask_var.get(),
'stackLoops': self.stackLoops_var.get(),
'gpuConversion': self.gpuConversion_var.get(),
'srValue': sr,
'hopValue': hop_length,
'winSize': window_size,
'stackLoopsNum': loops_num,
})
# -Run the algorithm-
threading.Thread(target=inference.main,
kwargs={
'input_paths': input_paths,
'gpu': 0 if self.gpuConversion_var.get() else -1,
'postprocess': self.postprocessing_var.get(),
'out_mask': self.mask_var.get(),
'model': model_path,
'sr': sr,
'hop_length': hop_length,
'window_size': window_size,
'export_path': export_path,
'loops': loops_num,
# Other Variables (Tkinter)
'window': self,
'command_widget': self.command_Text,
'button_widget': self.conversion_Button,
'progress_var': self.progress_var,
},
daemon=True
).start()
# Models
def update_available_models(self):
"""
Loop through every model (.pth) in the models directory
and add to the select your model list
"""
# Delete all previous options
self.model_var.set('')
self.options_model_Optionmenu['menu'].delete(0, 'end')
for file_name in os.listdir(models_dir):
if file_name.endswith('.pth'):
# Add Radiobutton to the Options Menu
self.options_model_Optionmenu['menu'].add_radiobutton(label=file_name,
command=tk._setit(self.model_var, file_name))
# Link the files name to its absolute path
self.label_to_path[file_name] = os.path.join(models_dir, file_name) # nopep8
def add_available_model(self, abs_path: str):
"""
Add the given absolute path of the file (.pth) to the available options
and set the currently selected model to this one
"""
if abs_path.endswith('.pth'):
file_name = f'[CUSTOM] {os.path.basename(abs_path)}'
# Add Radiobutton to the Options Menu
self.options_model_Optionmenu['menu'].add_radiobutton(label=file_name,
command=tk._setit(self.model_var, file_name))
# Set selected model to the newly added one
self.model_var.set(file_name)
# Link the files name to its absolute path
self.label_to_path[file_name] = abs_path # nopep8
else:
tk.messagebox.showerror(master=self,
title='Invalid File',
message='Please select a model file with the ".pth" ending!',
detail=f'File: {abs_path}')
def update_stack_state(self):
"""
Vary the stack Entry fro disabled/enabled based on the
stackLoops variable, which is connected to the checkbutton
"""
if self.stackLoops_var.get():
self.options_stack_Entry.configure(state=tk.NORMAL)
else:
self.options_stack_Entry.configure(state=tk.DISABLED)
self.stackLoopsNum_var.set(1)
if __name__ == "__main__":
root = MainWindow()
root.mainloop()

76
augment.py Normal file
View File

@ -0,0 +1,76 @@
import argparse
import os
import subprocess
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from lib import spec_utils
if __name__ == '__main__':
p = argparse.ArgumentParser()
p.add_argument('--sr', '-r', type=int, default=44100)
p.add_argument('--hop_length', '-l', type=int, default=1024)
p.add_argument('--pitch', '-p', type=int, default=-2)
p.add_argument('--mixture_dataset', '-m', required=True)
p.add_argument('--instrumental_dataset', '-i', required=True)
args = p.parse_args()
input_exts = ['.wav', '.m4a', '.3gp', '.oma', '.mp3', '.mp4']
X_list = sorted([
os.path.join(args.mixture_dataset, fname)
for fname in os.listdir(args.mixture_dataset)
if os.path.splitext(fname)[1] in input_exts])
y_list = sorted([
os.path.join(args.instrumental_dataset, fname)
for fname in os.listdir(args.instrumental_dataset)
if os.path.splitext(fname)[1] in input_exts])
input_i = 'input_i_{}.wav'.format(args.pitch)
input_v = 'input_v_{}.wav'.format(args.pitch)
output_i = 'output_i_{}.wav'.format(args.pitch)
output_v = 'output_v_{}.wav'.format(args.pitch)
cmd_i = 'soundstretch {} {} -pitch={}'.format(input_i, output_i, args.pitch)
cmd_v = 'soundstretch {} {} -pitch={}'.format(input_v, output_v, args.pitch)
suffix = '_pitch{}.npy'.format(args.pitch)
filelist = list(zip(X_list, y_list))
for mix_path, inst_path in tqdm(filelist):
X, _ = librosa.load(
mix_path, args.sr, False, dtype=np.float32, res_type='kaiser_fast')
y, _ = librosa.load(
inst_path, args.sr, False, dtype=np.float32, res_type='kaiser_fast')
X, _ = librosa.effects.trim(X)
y, _ = librosa.effects.trim(y)
X, y = spec_utils.align_wave_head_and_tail(X, y, args.sr)
v = X - y
sf.write(input_i, y.T, args.sr)
sf.write(input_v, v.T, args.sr)
subprocess.call(cmd_i, stderr=subprocess.DEVNULL)
subprocess.call(cmd_v, stderr=subprocess.DEVNULL)
y, _ = librosa.load(
output_i, args.sr, False, dtype=np.float32, res_type='kaiser_fast')
v, _ = librosa.load(
output_v, args.sr, False, dtype=np.float32, res_type='kaiser_fast')
X = y + v
spec = spec_utils.calc_spec(X, args.hop_length)
basename, _ = os.path.splitext(os.path.basename(mix_path))
outpath = os.path.join(args.mixture_dataset, basename + suffix)
np.save(outpath, np.abs(spec))
spec = spec_utils.calc_spec(y, args.hop_length)
basename, _ = os.path.splitext(os.path.basename(inst_path))
outpath = os.path.join(args.instrumental_dataset, basename + suffix)
np.save(outpath, np.abs(spec))
os.remove(input_i)
os.remove(input_v)
os.remove(output_i)
os.remove(output_v)

200
inference.py Normal file
View File

@ -0,0 +1,200 @@
import argparse
import os
import cv2
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from lib import dataset
from lib import nets
from lib import spec_utils
# Variable manipulation and command line text parsing
import torch
import tkinter as tk
import traceback # Error Message Recent Calls
class Namespace:
"""
Replaces ArgumentParser
"""
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def main(window: tk.Wm, input_paths: list, gpu: bool = -1,
model: str = 'models/baseline.pth', sr: int = 44100, hop_length: int = 1024,
window_size: int = 512, out_mask: bool = False, postprocess: bool = False,
export_path: str = '', loops: int = 1,
# Other Variables (Tkinter)
progress_var: tk.Variable = None, button_widget: tk.Button = None, command_widget: tk.Text = None,
):
def load_model():
args.command_widget.write('Loading model...\n') # nopep8 Write Command Text
device = torch.device('cpu')
model = nets.CascadedASPPNet()
model.load_state_dict(torch.load(args.model, map_location=device))
if torch.cuda.is_available() and args.gpu >= 0:
device = torch.device('cuda:{}'.format(args.gpu))
model.to(device)
args.command_widget.write('Done!\n') # nopep8 Write Command Text
return model, device
def load_wave_source():
args.command_widget.write(base_text + 'Loading wave source...\n') # nopep8 Write Command Text
X, sr = librosa.load(music_file,
args.sr,
False,
dtype=np.float32,
res_type='kaiser_fast')
args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text
return X, sr
def stft_wave_source(X):
args.command_widget.write(base_text + 'Stft of wave source...\n') # nopep8 Write Command Text
X = spec_utils.calc_spec(X, args.hop_length)
X, phase = np.abs(X), np.exp(1.j * np.angle(X))
coeff = X.max()
X /= coeff
offset = model.offset
l, r, roi_size = dataset.make_padding(
X.shape[2], args.window_size, offset)
X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant')
X_roll = np.roll(X_pad, roi_size // 2, axis=2)
model.eval()
with torch.no_grad():
masks = []
masks_roll = []
length = int(np.ceil(X.shape[2] / roi_size))
for i in tqdm(range(length)):
progress_var.set(base_progress + max_progress * (0.1 + (0.6/length * i))) # nopep8 Update Progress
start = i * roi_size
X_window = torch.from_numpy(np.asarray([
X_pad[:, :, start:start + args.window_size],
X_roll[:, :, start:start + args.window_size]
])).to(device)
pred = model.predict(X_window)
pred = pred.detach().cpu().numpy()
masks.append(pred[0])
masks_roll.append(pred[1])
mask = np.concatenate(masks, axis=2)[:, :, :X.shape[2]]
mask_roll = np.concatenate(masks_roll, axis=2)[
:, :, :X.shape[2]]
mask = (mask + np.roll(mask_roll, -roi_size // 2, axis=2)) / 2
if args.postprocess:
vocal = X * (1 - mask) * coeff
mask = spec_utils.mask_uninformative(mask, vocal)
args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text
inst = X * mask * coeff
vocal = X * (1 - mask) * coeff
return inst, vocal, phase, mask
def invert_instrum_vocal(inst, vocal, phase):
args.command_widget.write(base_text + 'Inverse stft of instruments and vocals...\n') # nopep8 Write Command Text
wav_instrument = spec_utils.spec_to_wav(inst, phase, args.hop_length) # nopep8
wav_vocals = spec_utils.spec_to_wav(vocal, phase, args.hop_length) # nopep8
args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text
return wav_instrument, wav_vocals
def save_files(wav_instrument, wav_vocals):
args.command_widget.write(base_text + 'Saving Files...\n') # nopep8 Write Command Text
sf.write(f'{export_path}/{base_name}_(Instrumental).wav',
wav_instrument.T, sr)
if cur_loop == 0:
sf.write(f'{export_path}/{base_name}_(Vocals).wav',
wav_vocals.T, sr)
if (cur_loop == (args.loops - 1) and
args.loops > 1):
sf.write(f'{export_path}/{base_name}_(Last_Vocals).wav',
wav_vocals.T, sr)
args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text
def create_mask():
args.command_widget.write(base_text + 'Creating Mask...\n') # nopep8 Write Command Text
norm_mask = np.uint8((1 - mask) * 255).transpose(1, 2, 0)
norm_mask = np.concatenate([
np.max(norm_mask, axis=2, keepdims=True),
norm_mask], axis=2)[::-1]
_, bin_mask = cv2.imencode('.png', norm_mask)
args.command_widget.write(base_text + 'Saving Mask...\n') # nopep8 Write Command Text
with open(f'{export_path}/{base_name}_(Mask).png', mode='wb') as f:
bin_mask.tofile(f)
args.command_widget.write(base_text + 'Done!\n') # nopep8 Write Command Text
args = Namespace(input=input_paths, gpu=gpu, model=model,
sr=sr, hop_length=hop_length, window_size=window_size,
out_mask=out_mask, postprocess=postprocess, export=export_path,
loops=loops,
# Other Variables (Tkinter)
window=window, progress_var=progress_var,
button_widget=button_widget, command_widget=command_widget,
)
args.command_widget.clear() # Clear Command Text
args.button_widget.configure(state=tk.DISABLED) # Disable Button
total_files = len(args.input) # Used to calculate progress
model, device = load_model()
for file_num, music_file in enumerate(args.input, start=1):
try:
base_name = f'{file_num}_{os.path.splitext(os.path.basename(music_file))[0]}'
for cur_loop in range(args.loops):
if cur_loop > 0:
args.command_widget.write(f'File {file_num}/{total_files}: ' + 'Next Pass!\n') # nopep8 Write Command Text
music_file = f'{export_path}/{base_name}_(Instrumental).wav'
base_progress = 100 / \
(total_files*args.loops) * \
((file_num*args.loops)-((args.loops-1) - cur_loop)-1)
base_text = 'File {file_num}/{total_files}:{loop} '.format(
file_num=file_num,
total_files=total_files,
loop='' if args.loops <= 1 else f' ({cur_loop+1}/{args.loops})')
max_progress = 100 / (total_files*args.loops)
progress_var.set(base_progress + max_progress * 0.05) # nopep8 Update Progress
X, sr = load_wave_source()
progress_var.set(base_progress + max_progress * 0.1) # nopep8 Update Progress
inst, vocal, phase, mask = stft_wave_source(X)
progress_var.set(base_progress + max_progress * 0.7) # nopep8 Update Progress
wav_instrument, wav_vocals = invert_instrum_vocal(inst, vocal, phase) # nopep8
progress_var.set(base_progress + max_progress * 0.8) # nopep8 Update Progress
save_files(wav_instrument, wav_vocals)
progress_var.set(base_progress + max_progress * 0.9) # nopep8 Update Progress
if args.out_mask:
create_mask()
progress_var.set(base_progress + max_progress * 1) # nopep8 Update Progress
args.command_widget.write(base_text + 'Completed Seperation!\n\n') # nopep8 Write Command Text
except Exception as e:
traceback_text = ''.join(traceback.format_tb(e.__traceback__))
print(traceback_text)
print(type(e).__name__, e)
tk.messagebox.showerror(master=args.window,
title='Untracked Error',
message=f'Traceback Error: "{traceback_text}"\n{type(e).__name__}: "{e}"\nFile: {music_file}\n\nPlease contact the creator and attach a screenshot of this error with the file which caused it!')
args.button_widget.configure(state=tk.NORMAL) # Enable Button
return
progress_var.set(100) # Update Progress
args.command_widget.write(f'Conversion(s) Completed and Saving all Files!') # nopep8 Write Command Text
args.button_widget.configure(state=tk.NORMAL) # Enable Button

223
train.py Normal file
View File

@ -0,0 +1,223 @@
import argparse
from datetime import datetime as dt
import gc
import json
import os
import random
import numpy as np
import torch
import torch.nn as nn
from lib import dataset
from lib import nets
from lib import spec_utils
def train_val_split(mix_dir, inst_dir, val_rate, val_filelist_json):
input_exts = ['.wav', '.m4a', '.3gp', '.oma', '.mp3', '.mp4']
X_list = sorted([
os.path.join(mix_dir, fname)
for fname in os.listdir(mix_dir)
if os.path.splitext(fname)[1] in input_exts])
y_list = sorted([
os.path.join(inst_dir, fname)
for fname in os.listdir(inst_dir)
if os.path.splitext(fname)[1] in input_exts])
filelist = list(zip(X_list, y_list))
random.shuffle(filelist)
val_filelist = []
if val_filelist_json is not None:
with open(val_filelist_json, 'r', encoding='utf8') as f:
val_filelist = json.load(f)
if len(val_filelist) == 0:
val_size = int(len(filelist) * val_rate)
train_filelist = filelist[:-val_size]
val_filelist = filelist[-val_size:]
else:
train_filelist = [
pair for pair in filelist
if list(pair) not in val_filelist]
return train_filelist, val_filelist
def train_inner_epoch(X_train, y_train, model, optimizer, batchsize, instance_loss):
sum_loss = 0
model.train()
aux_crit = nn.L1Loss()
criterion = nn.L1Loss(reduction='none')
perm = np.random.permutation(len(X_train))
for i in range(0, len(X_train), batchsize):
local_perm = perm[i: i + batchsize]
X_batch = torch.from_numpy(X_train[local_perm]).cpu()
y_batch = torch.from_numpy(y_train[local_perm]).cpu()
model.zero_grad()
mask, aux = model(X_batch)
aux_loss = aux_crit(X_batch * aux, y_batch)
X_batch = spec_utils.crop_center(mask, X_batch, False)
y_batch = spec_utils.crop_center(mask, y_batch, False)
abs_diff = criterion(X_batch * mask, y_batch)
loss = abs_diff.mean() * 0.9 + aux_loss * 0.1
loss.backward()
optimizer.step()
abs_diff_np = abs_diff.detach().cpu().numpy()
instance_loss[local_perm] += abs_diff_np.mean(axis=(1, 2, 3))
sum_loss += float(loss.detach().cpu().numpy()) * len(X_batch)
return sum_loss / len(X_train)
def val_inner_epoch(dataloader, model):
sum_loss = 0
model.eval()
criterion = nn.L1Loss()
with torch.no_grad():
for X_batch, y_batch in dataloader:
X_batch = X_batch.cpu()
y_batch = y_batch.cpu()
mask = model.predict(X_batch)
X_batch = spec_utils.crop_center(mask, X_batch, False)
y_batch = spec_utils.crop_center(mask, y_batch, False)
loss = criterion(X_batch * mask, y_batch)
sum_loss += float(loss.detach().cpu().numpy()) * len(X_batch)
return sum_loss / len(dataloader.dataset)
def main():
p = argparse.ArgumentParser()
p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--seed', '-s', type=int, default=2019)
p.add_argument('--sr', '-r', type=int, default=44100)
p.add_argument('--hop_length', '-l', type=int, default=1024)
p.add_argument('--mixture_dataset', '-m', required=True)
p.add_argument('--instrumental_dataset', '-i', required=True)
p.add_argument('--learning_rate', type=float, default=0.001)
p.add_argument('--lr_min', type=float, default=0.0001)
p.add_argument('--lr_decay_factor', type=float, default=0.9)
p.add_argument('--lr_decay_patience', type=int, default=6)
p.add_argument('--batchsize', '-B', type=int, default=4)
p.add_argument('--cropsize', '-c', type=int, default=256)
p.add_argument('--val_rate', '-v', type=float, default=0.1)
p.add_argument('--val_filelist', '-V', type=str, default=None)
p.add_argument('--val_batchsize', '-b', type=int, default=4)
p.add_argument('--val_cropsize', '-C', type=int, default=512)
p.add_argument('--patches', '-p', type=int, default=16)
p.add_argument('--epoch', '-E', type=int, default=100)
p.add_argument('--inner_epoch', '-e', type=int, default=4)
p.add_argument('--oracle_rate', '-O', type=float, default=0)
p.add_argument('--oracle_drop_rate', '-o', type=float, default=0.5)
p.add_argument('--mixup_rate', '-M', type=float, default=0.0)
p.add_argument('--mixup_alpha', '-a', type=float, default=1.0)
p.add_argument('--pretrained_model', '-P', type=str, default=None)
p.add_argument('--debug', '-d', action='store_true')
args = p.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
timestamp = dt.now().strftime('%Y%m%d%H%M%S')
model = nets.CascadedASPPNet()
if args.pretrained_model is not None:
model.load_state_dict(torch.load(args.pretrained_model))
if args.gpu >= 0:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=args.lr_decay_factor,
patience=args.lr_decay_patience,
min_lr=args.lr_min,
verbose=True)
train_filelist, val_filelist = train_val_split(
mix_dir=args.mixture_dataset,
inst_dir=args.instrumental_dataset,
val_rate=args.val_rate,
val_filelist_json=args.val_filelist)
if args.debug:
print('### DEBUG MODE')
train_filelist = train_filelist[:1]
val_filelist = val_filelist[:1]
with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
json.dump(val_filelist, f, ensure_ascii=False)
for i, (X_fname, y_fname) in enumerate(val_filelist):
print(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))
val_dataset = dataset.make_validation_set(
filelist=val_filelist,
cropsize=args.val_cropsize,
sr=args.sr,
hop_length=args.hop_length,
offset=model.offset)
val_dataloader = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=args.val_batchsize,
shuffle=False,
num_workers=4)
log = []
oracle_X = None
oracle_y = None
best_loss = np.inf
for epoch in range(args.epoch):
X_train, y_train = dataset.make_training_set(
train_filelist, args.cropsize, args.patches, args.sr, args.hop_length, model.offset)
X_train, y_train = dataset.mixup_generator(
X_train, y_train, args.mixup_rate, args.mixup_alpha)
if oracle_X is not None and oracle_y is not None:
perm = np.random.permutation(len(oracle_X))
X_train[perm] = oracle_X
y_train[perm] = oracle_y
print('# epoch', epoch)
instance_loss = np.zeros(len(X_train), dtype=np.float32)
for inner_epoch in range(args.inner_epoch):
print(' * inner epoch {}'.format(inner_epoch))
train_loss = train_inner_epoch(
X_train, y_train, model, optimizer, args.batchsize, instance_loss)
val_loss = val_inner_epoch(val_dataloader, model)
print(' * training loss = {:.6f}, validation loss = {:.6f}'
.format(train_loss * 1000, val_loss * 1000))
scheduler.step(val_loss)
if val_loss < best_loss:
best_loss = val_loss
print(' * best validation loss')
model_path = 'models/model_iter{}.pth'.format(epoch)
torch.save(model.state_dict(), model_path)
log.append([train_loss, val_loss])
with open('log_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
json.dump(log, f, ensure_ascii=False)
if args.oracle_rate > 0:
instance_loss /= args.inner_epoch
oracle_X, oracle_y, idx = dataset.get_oracle_data(
X_train, y_train, instance_loss, args.oracle_rate, args.oracle_drop_rate)
print(' * oracle loss = {:.6f}'.format(instance_loss[idx].mean()))
del X_train, y_train
gc.collect()
if __name__ == '__main__':
main()