2020-11-09 11:32:56 +01:00
import pprint
import argparse
import os
import cv2
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from lib_v4 import dataset
from lib_v4 import nets
from lib_v4 import spec_utils
import torch
# Command line text parsing and widget manipulation
from collections import defaultdict
import tkinter as tk
import traceback # Error Message Recent Calls
import time # Timer
class VocalRemover ( object ) :
def __init__ ( self , data , text_widget : tk . Text ) :
self . data = data
self . text_widget = text_widget
self . models = defaultdict ( lambda : None )
self . devices = defaultdict ( lambda : None )
self . _load_models ( )
# self.offset = model.offset
def _load_models ( self ) :
self . text_widget . write ( ' Loading models... \n ' ) # nopep8 Write Command Text
# -Instrumental-
if os . path . isfile ( data [ ' instrumentalModel ' ] ) :
device = torch . device ( ' cpu ' )
model = nets . CascadedASPPNet ( self . data [ ' n_fft ' ] )
model . load_state_dict ( torch . load ( self . data [ ' instrumentalModel ' ] ,
map_location = device ) )
if torch . cuda . is_available ( ) and self . data [ ' gpu ' ] > = 0 :
device = torch . device ( ' cuda: {} ' . format ( self . data [ ' gpu ' ] ) )
model . to ( device )
self . models [ ' instrumental ' ] = model
self . devices [ ' instrumental ' ] = device
# -Vocal-
elif os . path . isfile ( data [ ' vocalModel ' ] ) :
device = torch . device ( ' cpu ' )
model = nets . CascadedASPPNet ( self . data [ ' n_fft ' ] )
model . load_state_dict ( torch . load ( self . data [ ' vocalModel ' ] ,
map_location = device ) )
if torch . cuda . is_available ( ) and self . data [ ' gpu ' ] > = 0 :
device = torch . device ( ' cuda: {} ' . format ( self . data [ ' gpu ' ] ) )
model . to ( device )
self . models [ ' vocal ' ] = model
self . devices [ ' vocal ' ] = device
# -Stack-
if os . path . isfile ( self . data [ ' stackModel ' ] ) :
device = torch . device ( ' cpu ' )
model = nets . CascadedASPPNet ( self . data [ ' n_fft ' ] )
model . load_state_dict ( torch . load ( self . data [ ' stackModel ' ] ,
map_location = device ) )
if torch . cuda . is_available ( ) and self . data [ ' gpu ' ] > = 0 :
device = torch . device ( ' cuda: {} ' . format ( self . data [ ' gpu ' ] ) )
model . to ( device )
self . models [ ' stack ' ] = model
self . devices [ ' stack ' ] = device
self . text_widget . write ( ' Done! \n ' )
def _execute ( self , X_mag_pad , roi_size , n_window , device , model ) :
model . eval ( )
with torch . no_grad ( ) :
preds = [ ]
for i in tqdm ( range ( n_window ) ) :
start = i * roi_size
X_mag_window = X_mag_pad [ None , : , : ,
start : start + self . data [ ' window_size ' ] ]
X_mag_window = torch . from_numpy ( X_mag_window ) . to ( device )
pred = model . predict ( X_mag_window )
pred = pred . detach ( ) . cpu ( ) . numpy ( )
preds . append ( pred [ 0 ] )
pred = np . concatenate ( preds , axis = 2 )
return pred
def preprocess ( self , X_spec ) :
X_mag = np . abs ( X_spec )
X_phase = np . angle ( X_spec )
return X_mag , X_phase
def inference ( self , X_spec , device , model ) :
X_mag , X_phase = self . preprocess ( X_spec )
coef = X_mag . max ( )
X_mag_pre = X_mag / coef
n_frame = X_mag_pre . shape [ 2 ]
pad_l , pad_r , roi_size = dataset . make_padding ( n_frame ,
self . data [ ' window_size ' ] , model . offset )
n_window = int ( np . ceil ( n_frame / roi_size ) )
X_mag_pad = np . pad (
X_mag_pre , ( ( 0 , 0 ) , ( 0 , 0 ) , ( pad_l , pad_r ) ) , mode = ' constant ' )
pred = self . _execute ( X_mag_pad , roi_size , n_window ,
device , model )
pred = pred [ : , : , : n_frame ]
return pred * coef , X_mag , np . exp ( 1. j * X_phase )
def inference_tta ( self , X_spec , device , model ) :
X_mag , X_phase = self . preprocess ( X_spec )
coef = X_mag . max ( )
X_mag_pre = X_mag / coef
n_frame = X_mag_pre . shape [ 2 ]
pad_l , pad_r , roi_size = dataset . make_padding ( n_frame ,
self . data [ ' window_size ' ] , model . offset )
n_window = int ( np . ceil ( n_frame / roi_size ) )
X_mag_pad = np . pad (
X_mag_pre , ( ( 0 , 0 ) , ( 0 , 0 ) , ( pad_l , pad_r ) ) , mode = ' constant ' )
pred = self . _execute ( X_mag_pad , roi_size , n_window ,
device , model )
pred = pred [ : , : , : n_frame ]
pad_l + = roi_size / / 2
pad_r + = roi_size / / 2
n_window + = 1
X_mag_pad = np . pad (
X_mag_pre , ( ( 0 , 0 ) , ( 0 , 0 ) , ( pad_l , pad_r ) ) , mode = ' constant ' )
pred_tta = self . _execute ( X_mag_pad , roi_size , n_window ,
device , model )
pred_tta = pred_tta [ : , : , roi_size / / 2 : ]
pred_tta = pred_tta [ : , : , : n_frame ]
return ( pred + pred_tta ) * 0.5 * coef , X_mag , np . exp ( 1. j * X_phase )
data = {
# Paths
' input_paths ' : None ,
' export_path ' : None ,
# Processing Options
' gpu ' : - 1 ,
' postprocess ' : True ,
' tta ' : True ,
' output_image ' : True ,
# Models
' instrumentalModel ' : None ,
' vocalModel ' : None ,
' stackModel ' : None ,
' useModel ' : None ,
# Stack Options
' stackPasses ' : 0 ,
' stackOnly ' : False ,
' saveAllStacked ' : False ,
# Constants
' sr ' : 44_100 ,
' hop_length ' : 1_024 ,
' window_size ' : 512 ,
' n_fft ' : 2_048 ,
}
default_sr = data [ ' sr ' ]
default_hop_length = data [ ' hop_length ' ]
default_window_size = data [ ' window_size ' ]
default_n_fft = data [ ' n_fft ' ]
def update_progress ( progress_var , total_files , total_loops , file_num , loop_num , step : float = 1 ) :
""" Calculate the progress for the progress widget in the GUI """
base = ( 100 / total_files )
progress = base * ( file_num - 1 )
progress + = ( base / total_loops ) * ( loop_num + step )
progress_var . set ( progress )
def get_baseText ( total_files , total_loops , file_num , loop_num ) :
""" Create the base text for the command widget """
text = ' File {file_num} / {total_files} : {loop} ' . format ( file_num = file_num ,
total_files = total_files ,
loop = ' ' if total_loops < = 1 else f ' ( { loop_num + 1 } / { total_loops } ) ' )
return text
def update_constants ( model_name ) :
"""
Decode the conversion settings from the model ' s name
"""
global data
text = model_name . replace ( ' .pth ' , ' ' )
text_parts = text . split ( ' _ ' ) [ 1 : ]
data [ ' sr ' ] = default_sr
data [ ' hop_length ' ] = default_hop_length
data [ ' window_size ' ] = default_window_size
data [ ' n_fft ' ] = default_n_fft
for text_part in text_parts :
if ' sr ' in text_part :
text_part = text_part . replace ( ' sr ' , ' ' )
if text_part . isdecimal ( ) :
try :
data [ ' sr ' ] = int ( text_part )
continue
except ValueError :
# Cannot convert string to int
pass
if ' hl ' in text_part :
text_part = text_part . replace ( ' hl ' , ' ' )
if text_part . isdecimal ( ) :
try :
data [ ' hop_length ' ] = int ( text_part )
continue
except ValueError :
# Cannot convert string to int
pass
if ' w ' in text_part :
text_part = text_part . replace ( ' w ' , ' ' )
if text_part . isdecimal ( ) :
try :
data [ ' window_size ' ] = int ( text_part )
continue
except ValueError :
# Cannot convert string to int
pass
if ' nf ' in text_part :
text_part = text_part . replace ( ' nf ' , ' ' )
if text_part . isdecimal ( ) :
try :
data [ ' n_fft ' ] = int ( text_part )
continue
except ValueError :
# Cannot convert string to int
pass
def determineModelFolderName ( ) :
"""
Determine the name that is used for the folder and appended
to the back of the music files
"""
modelFolderName = ' '
if not data [ ' modelFolder ' ] :
# Model Test Mode not selected
return modelFolderName
# -Instrumental-
if os . path . isfile ( data [ ' instrumentalModel ' ] ) :
modelFolderName + = os . path . splitext ( os . path . basename ( data [ ' instrumentalModel ' ] ) ) [ 0 ] + ' - '
# -Vocal-
elif os . path . isfile ( data [ ' vocalModel ' ] ) :
modelFolderName + = os . path . splitext ( os . path . basename ( data [ ' vocalModel ' ] ) ) [ 0 ] + ' - '
# -Stack-
if os . path . isfile ( data [ ' stackModel ' ] ) :
modelFolderName + = os . path . splitext ( os . path . basename ( data [ ' stackModel ' ] ) ) [ 0 ]
else :
modelFolderName = modelFolderName [ : - 1 ]
if modelFolderName :
modelFolderName = ' / ' + modelFolderName
return modelFolderName
def main ( window : tk . Wm , text_widget : tk . Text , button_widget : tk . Button , progress_var : tk . Variable ,
* * kwargs : dict ) :
def save_files ( wav_instrument , wav_vocals ) :
""" Save output music files """
vocal_name = None
instrumental_name = None
2020-11-10 13:02:48 +01:00
save_path = os . path . dirname ( base_name )
2020-11-09 11:32:56 +01:00
# Get the Suffix Name
if ( not loop_num or
loop_num == ( total_loops - 1 ) ) : # First or Last Loop
if data [ ' stackOnly ' ] :
if loop_num == ( total_loops - 1 ) : # Last Loop
if not ( total_loops - 1 ) : # Only 1 Loop
vocal_name = ' (Vocals) '
instrumental_name = ' (Instrumental) '
else :
vocal_name = ' (Vocal_Final_Stacked_Output) '
instrumental_name = ' (Instrumental_Final_Stacked_Output) '
elif data [ ' useModel ' ] == ' instrumental ' :
if not loop_num : # First Loop
vocal_name = ' (Vocals) '
if loop_num == ( total_loops - 1 ) : # Last Loop
if not ( total_loops - 1 ) : # Only 1 Loop
instrumental_name = ' (Instrumental) '
else :
instrumental_name = ' (Instrumental_Final_Stacked_Output) '
elif data [ ' useModel ' ] == ' vocal ' :
if not loop_num : # First Loop
instrumental_name = ' (Instrumental) '
if loop_num == ( total_loops - 1 ) : # Last Loop
if not ( total_loops - 1 ) : # Only 1 Loop
vocal_name = ' (Vocals) '
else :
vocal_name = ' (Vocals_Final_Stacked_Output) '
if data [ ' useModel ' ] == ' vocal ' :
# Reverse names
vocal_name , instrumental_name = instrumental_name , vocal_name
elif data [ ' saveAllStacked ' ] :
2020-11-10 13:02:48 +01:00
folder_name = os . path . basename ( base_name ) + ' Stacked Outputs ' # nopep8
save_path = os . path . join ( save_path , folder_name )
2020-11-09 11:32:56 +01:00
2020-11-10 13:02:48 +01:00
if not os . path . isdir ( save_path ) :
os . mkdir ( save_path )
2020-11-09 11:32:56 +01:00
if data [ ' stackOnly ' ] :
vocal_name = f ' (Vocal_ { loop_num } _Stacked_Output) '
instrumental_name = f ' (Instrumental_ { loop_num } _Stacked_Output) '
elif ( data [ ' useModel ' ] == ' vocal ' or
data [ ' useModel ' ] == ' instrumental ' ) :
vocal_name = f ' (Vocals_ { loop_num } _Stacked_Output) '
instrumental_name = f ' (Instrumental_ { loop_num } _Stacked_Output) '
if data [ ' useModel ' ] == ' vocal ' :
# Reverse names
vocal_name , instrumental_name = instrumental_name , vocal_name
# Save Temp File
# For instrumental the instrumental is the temp file
# and for vocal the instrumental is the temp file due
# to reversement
sf . write ( f ' temp.wav ' ,
wav_instrument . T , sr )
appendModelFolderName = modelFolderName . replace ( ' / ' , ' _ ' )
# -Save files-
# Instrumental
if instrumental_name is not None :
2020-11-10 13:02:48 +01:00
instrumental_path = ' {save_path} / {file_name} .wav ' . format (
save_path = save_path ,
2020-11-09 11:32:56 +01:00
file_name = f ' { os . path . basename ( base_name ) } _ { instrumental_name } { appendModelFolderName } ' ,
)
sf . write ( instrumental_path ,
wav_instrument . T , sr )
# Vocal
if vocal_name is not None :
2020-11-10 13:02:48 +01:00
vocal_path = ' {save_path} / {file_name} .wav ' . format (
save_path = save_path ,
2020-11-09 11:32:56 +01:00
file_name = f ' { os . path . basename ( base_name ) } _ { vocal_name } { appendModelFolderName } ' ,
)
sf . write ( vocal_path ,
wav_vocals . T , sr )
data . update ( kwargs )
# Update default settings
global default_sr
global default_hop_length
global default_window_size
global default_n_fft
default_sr = data [ ' sr ' ]
default_hop_length = data [ ' hop_length ' ]
default_window_size = data [ ' window_size ' ]
default_n_fft = data [ ' n_fft ' ]
stime = time . perf_counter ( )
progress_var . set ( 0 )
text_widget . clear ( )
button_widget . configure ( state = tk . DISABLED ) # Disable Button
vocal_remover = VocalRemover ( data , text_widget )
modelFolderName = determineModelFolderName ( )
if modelFolderName :
folder_path = f ' { data [ " export_path " ] } { modelFolderName } '
if not os . path . isdir ( folder_path ) :
os . mkdir ( folder_path )
# Determine Loops
total_loops = data [ ' stackPasses ' ]
if not data [ ' stackOnly ' ] :
total_loops + = 1
for file_num , music_file in enumerate ( data [ ' input_paths ' ] , start = 1 ) :
try :
# Determine File Name
base_name = f ' { data [ " export_path " ] } { modelFolderName } / { file_num } _ { os . path . splitext ( os . path . basename ( music_file ) ) [ 0 ] } '
# --Seperate Music Files--
for loop_num in range ( total_loops ) :
# -Determine which model will be used-
if not loop_num :
# First Iteration
if data [ ' stackOnly ' ] :
if os . path . isfile ( data [ ' stackModel ' ] ) :
model_name = os . path . basename ( data [ ' stackModel ' ] )
model = vocal_remover . models [ ' stack ' ]
device = vocal_remover . devices [ ' stack ' ]
else :
raise ValueError ( f ' Selected stack only model, however, stack model path file cannot be found \n Path: " { data [ " stackModel " ] } " ' ) # nopep8
else :
model_name = os . path . basename ( data [ f ' { data [ " useModel " ] } Model ' ] )
model = vocal_remover . models [ data [ ' useModel ' ] ]
device = vocal_remover . devices [ data [ ' useModel ' ] ]
else :
model_name = os . path . basename ( data [ ' stackModel ' ] )
# Every other iteration
model = vocal_remover . models [ ' stack ' ]
device = vocal_remover . devices [ ' stack ' ]
# Reference new music file
music_file = ' temp.wav '
# -Get text and update progress-
base_text = get_baseText ( total_files = len ( data [ ' input_paths ' ] ) ,
total_loops = total_loops ,
file_num = file_num ,
loop_num = loop_num )
progress_kwargs = { ' progress_var ' : progress_var ,
' total_files ' : len ( data [ ' input_paths ' ] ) ,
' total_loops ' : total_loops ,
' file_num ' : file_num ,
' loop_num ' : loop_num }
update_progress ( * * progress_kwargs ,
step = 0 )
update_constants ( model_name )
# -Go through the different steps of seperation-
# Wave source
text_widget . write ( base_text + ' Loading wave source... \n ' )
X , sr = librosa . load ( music_file , data [ ' sr ' ] , False ,
dtype = np . float32 , res_type = ' kaiser_fast ' )
if X . ndim == 1 :
X = np . asarray ( [ X , X ] )
text_widget . write ( base_text + ' Done! \n ' )
update_progress ( * * progress_kwargs ,
step = 0.1 )
# Stft of wave source
text_widget . write ( base_text + ' Stft of wave source... \n ' )
X = spec_utils . wave_to_spectrogram ( X ,
data [ ' hop_length ' ] , data [ ' n_fft ' ] )
if data [ ' tta ' ] :
pred , X_mag , X_phase = vocal_remover . inference_tta ( X ,
device = device ,
model = model )
else :
pred , X_mag , X_phase = vocal_remover . inference ( X ,
device = device ,
model = model )
text_widget . write ( base_text + ' Done! \n ' )
update_progress ( * * progress_kwargs ,
step = 0.6 )
# Postprocess
if data [ ' postprocess ' ] :
text_widget . write ( base_text + ' Post processing... \n ' )
pred_inv = np . clip ( X_mag - pred , 0 , np . inf )
pred = spec_utils . mask_silence ( pred , pred_inv )
text_widget . write ( base_text + ' Done! \n ' )
update_progress ( * * progress_kwargs ,
step = 0.65 )
# Inverse stft
text_widget . write ( base_text + ' Inverse stft of instruments and vocals... \n ' ) # nopep8
y_spec = pred * X_phase
wav_instrument = spec_utils . spectrogram_to_wave ( y_spec ,
hop_length = data [ ' hop_length ' ] )
v_spec = np . clip ( X_mag - pred , 0 , np . inf ) * X_phase
wav_vocals = spec_utils . spectrogram_to_wave ( v_spec ,
hop_length = data [ ' hop_length ' ] )
text_widget . write ( base_text + ' Done! \n ' )
update_progress ( * * progress_kwargs ,
step = 0.7 )
# Save output music files
text_widget . write ( base_text + ' Saving Files... \n ' )
save_files ( wav_instrument , wav_vocals )
text_widget . write ( base_text + ' Done! \n ' )
update_progress ( * * progress_kwargs ,
step = 0.8 )
else :
# Save output image
if data [ ' output_image ' ] :
with open ( ' {} _Instruments.jpg ' . format ( base_name ) , mode = ' wb ' ) as f :
image = spec_utils . spectrogram_to_image ( y_spec )
_ , bin_image = cv2 . imencode ( ' .jpg ' , image )
bin_image . tofile ( f )
with open ( ' {} _Vocals.jpg ' . format ( base_name ) , mode = ' wb ' ) as f :
image = spec_utils . spectrogram_to_image ( v_spec )
_ , bin_image = cv2 . imencode ( ' .jpg ' , image )
bin_image . tofile ( f )
text_widget . write ( base_text + ' Completed Seperation! \n \n ' )
except Exception as e :
traceback_text = ' ' . join ( traceback . format_tb ( e . __traceback__ ) )
message = f ' Traceback Error: " { traceback_text } " \n { type ( e ) . __name__ } : " { e } " \n File: { music_file } \n Loop: { loop_num } \n Please contact the creator and attach a screenshot of this error with the file and settings that caused it! '
tk . messagebox . showerror ( master = window ,
title = ' Untracked Error ' ,
message = message )
print ( traceback_text )
print ( type ( e ) . __name__ , e )
print ( message )
progress_var . set ( 0 )
button_widget . configure ( state = tk . NORMAL ) # Enable Button
return
os . remove ( ' temp.wav ' )
progress_var . set ( 0 )
text_widget . write ( f ' Conversion(s) Completed and Saving all Files! \n ' )
text_widget . write ( f ' Time Elapsed: { time . strftime ( " % H: % M: % S " , time . gmtime ( int ( time . perf_counter ( ) - stime ) ) ) } ' ) # nopep8
2020-11-10 13:02:48 +01:00
button_widget . configure ( state = tk . NORMAL ) # Enable Button