Update dataset.py

This commit is contained in:
aufr33 2021-07-07 09:26:57 +03:00 committed by GitHub
parent 4687ae6708
commit e328910fd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -79,24 +79,26 @@ def train_val_split(dataset_dir, split_mode, val_rate, val_filelist):
return train_filelist, val_filelist
def augment(X, y, reduction_rate, reduction_mask, mixup_rate, mixup_alpha):
def augment(X, y, reduction_rate, reduction_mask, mixup_rate, mixup_alpha, mp):
perm = np.random.permutation(len(X))
for i, idx in enumerate(tqdm(perm)):
if np.random.uniform() < reduction_rate:
y[idx] = spec_utils.reduce_vocal_aggressively(X[idx], y[idx], reduction_mask)
if mixup_rate > 0.0:
if not any([mp.param["mid_side"], mp.param["mid_side_b"], mp.param["mid_side_c"]]):
if np.random.uniform() < 0.5:
# swap channel
X[idx] = X[idx, ::-1]
y[idx] = y[idx, ::-1]
if np.random.uniform() < 0.02:
# mono
X[idx] = X[idx].mean(axis=0, keepdims=True)
y[idx] = y[idx].mean(axis=0, keepdims=True)
if np.random.uniform() < 0.02:
# inst
X[idx] = y[idx]
if np.random.uniform() < 0.02:
# inst
X[idx] = y[idx]
if np.random.uniform() < mixup_rate and i < len(perm) - 1:
lam = np.random.beta(mixup_alpha, mixup_alpha)