From 66a693dcdb92c06c20a463f64f3262260ff5846c Mon Sep 17 00:00:00 2001 From: aufr33 <65520685+aufr33@users.noreply.github.com> Date: Thu, 8 Jul 2021 14:42:46 +0300 Subject: [PATCH] Update dataset.py --- lib/dataset.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/dataset.py b/lib/dataset.py index d58d48a..fa00b2c 100644 --- a/lib/dataset.py +++ b/lib/dataset.py @@ -89,13 +89,19 @@ def augment(X, y, reduction_rate, reduction_mask, mixup_rate, mixup_alpha, mp): if np.random.uniform() < 0.5: # swap channel X[idx] = X[idx, ::-1] - y[idx] = y[idx, ::-1] - + y[idx] = y[idx, ::-1] + if np.random.uniform() < 0.02: # mono X[idx] = X[idx].mean(axis=0, keepdims=True) y[idx] = y[idx].mean(axis=0, keepdims=True) + if np.random.uniform() < 0.02: + # delay / echo + d = np.random.randint(1, 10, size=2) + v = X[idx] - y[idx] + X[idx, 0, :, d[0]:] += v[0, :, :-d[0]] * random.uniform(0.1, 0.3) + X[idx, 1, :, d[1]:] += v[1, :, :-d[1]] * random.uniform(0.1, 0.3) if np.random.uniform() < 0.02: # inst X[idx] = y[idx]