mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
Use ndimage.median_filter instead of signal.medfilter (#812)
For a 30s long audio file which didn't have any silence, ndimage.median_filter took 7s where signa.medfilter took 30s. Co-authored-by: Umar Farooqi <umar@paystash.com> Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
This commit is contained in:
parent
a84191faae
commit
f0083e7eb2
@ -874,7 +874,7 @@
|
|||||||
"from IPython.display import display, HTML\n",
|
"from IPython.display import display, HTML\n",
|
||||||
"from whisper.tokenizer import get_tokenizer\n",
|
"from whisper.tokenizer import get_tokenizer\n",
|
||||||
"from dtw import dtw\n",
|
"from dtw import dtw\n",
|
||||||
"from scipy.signal import medfilt\n",
|
"from scipy.ndimage import median_filter\n",
|
||||||
"\n",
|
"\n",
|
||||||
"%matplotlib inline\n",
|
"%matplotlib inline\n",
|
||||||
"%config InlineBackend.figure_format = \"retina\""
|
"%config InlineBackend.figure_format = \"retina\""
|
||||||
@ -3610,7 +3610,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" weights = torch.cat(QKs) # layers * heads * tokens * frames \n",
|
" weights = torch.cat(QKs) # layers * heads * tokens * frames \n",
|
||||||
" weights = weights[:, :, :, : duration // AUDIO_SAMPLES_PER_TOKEN].cpu()\n",
|
" weights = weights[:, :, :, : duration // AUDIO_SAMPLES_PER_TOKEN].cpu()\n",
|
||||||
" weights = medfilt(weights, (1, 1, 1, medfilt_width))\n",
|
" weights = median_filter(weights, (1, 1, 1, medfilt_width))\n",
|
||||||
" weights = torch.tensor(weights * qk_scale).softmax(dim=-1)\n",
|
" weights = torch.tensor(weights * qk_scale).softmax(dim=-1)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" w = weights / weights.norm(dim=-2, keepdim=True)\n",
|
" w = weights / weights.norm(dim=-2, keepdim=True)\n",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user