fix: extract freture cannot run on pure cpu
This commit is contained in:
parent
9e59375311
commit
f27a991794
@ -51,7 +51,8 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|||||||
)
|
)
|
||||||
model = models[0]
|
model = models[0]
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model = model.half()
|
if torch.cuda.is_available():
|
||||||
|
model = model.half()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
todo=sorted(list(os.listdir(wavPath)))[i_part::n_part]
|
todo=sorted(list(os.listdir(wavPath)))[i_part::n_part]
|
||||||
@ -70,7 +71,7 @@ else:
|
|||||||
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
|
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
|
||||||
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
||||||
inputs = {
|
inputs = {
|
||||||
"source": feats.half().to(device),
|
"source": feats.half().to(device) if torch.cuda.is_available() else feats.to(device),
|
||||||
"padding_mask": padding_mask.to(device),
|
"padding_mask": padding_mask.to(device),
|
||||||
"output_layer": 9, # layer 9
|
"output_layer": 9, # layer 9
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user