Compare commits

..

99 Commits

Author SHA1 Message Date
Jong Wook Kim
c0d2f624c0 Release 20250625 2025-06-25 18:05:47 -07:00
Jong Wook Kim
db7fbc75fe Release 20250625 2025-06-25 18:03:25 -07:00
Jong Wook Kim
31243bad24 Release 20250625 2025-06-25 18:00:48 -07:00
Dridi Yassin
1f8fc975d3
Fix: Update torch.load to use weights_only=True to prevent security w… (#2451)
* Fix: Update torch.load to use weights_only=True to prevent security warning

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2025-06-25 17:54:30 -07:00
Nathan Harmon
679ae1d141
Fix: Ensure DTW cost tensor is on the same device as input tensor (#2561)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2025-06-25 17:42:09 -07:00
Nicholas Nadeau, Ph.D., P.Eng.
f50c4f264e
docs: updated README to specify translation model limitation (#2547)
Updated README given info from https://github.com/openai/whisper/discussions/2483
2025-06-25 17:03:47 -07:00
ExtReMLapin
86899243e9
Fixed triton kernel update to support latest triton versions (#2588)
* Update triton kernel using _unsafe_update_src

* support old triton versions

* refactored changes to update triton kernel only once

* Update triton_ops.py

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <ilikekjw@gmail.com>
2025-06-25 17:02:54 -07:00
Learpcs
5dff4db81a
Fix: GitHub display errors for Jupyter notebooks (#2589)
* Update LibriSpeech.ipynb

Update LibriSpeech.ipynb

* Update Multilingual_ASR.ipynb
2025-06-25 16:55:15 -07:00
dependabot[bot]
dd985ac4b9
Bump the github-actions group with 3 updates (#2592)
Bumps the github-actions group with 3 updates: [actions/checkout](https://github.com/actions/checkout), [actions/setup-python](https://github.com/actions/setup-python) and [softprops/action-gh-release](https://github.com/softprops/action-gh-release).


Updates `actions/checkout` from 3 to 4
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

Updates `actions/setup-python` from 4 to 5
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v4...v5)

Updates `softprops/action-gh-release` from 1 to 2
- [Release notes](https://github.com/softprops/action-gh-release/releases)
- [Changelog](https://github.com/softprops/action-gh-release/blob/master/CHANGELOG.md)
- [Commits](https://github.com/softprops/action-gh-release/compare/v1...v2)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: softprops/action-gh-release
  dependency-version: '2'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-05-13 11:22:31 -07:00
Christian Clauss
e1e6aa60ff
Keep GitHub Actions up to date with GitHub's Dependabot (#2486)
Automates the creation of pull requests like
* #2430 

* [Keeping your actions up to date with Dependabot](https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot)
* [Configuration options for the dependabot.yml file - package-ecosystem](https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem)
2025-05-13 11:10:43 -07:00
Christian Clauss
e6a5fc0ff0
pre-commit: Upgrade black v25.1.0 and isort v6.0.0 (#2514) 2025-05-13 09:43:34 -07:00
Christian Clauss
13907bed90
GitHub Actions: Add Python 3.13 to the testing (#2487)
* GitHub Actions: Add Python 3.13 to the testing

* GitHub Actions: Add Python 3.13 to the testing

* numba==0.61.0rc2; python_version=='3.13'

* triton>=2; python_version<'3.13'

* fail-fast: false

* Numba v0.61.0 is released

https://github.com/numba/numba/releases

* Update pyproject.toml
2025-05-12 21:10:40 -07:00
Jong Wook Kim
517a43ecd1
Update python-publish.yml
using `-m build --sdist` instead of `setup.py sdist`
2025-01-04 12:56:16 -08:00
Christian Clauss
dd4d010d2c
PEP 621: Migrate from setup.py to pyproject.toml (#2435) 2025-01-04 01:38:35 -08:00
Christian Clauss
26a7cacc83
pre-commit autoupdate && pre-commit run --all-files (#2484)
* pre-commit autoupdate && pre-commit run --all-files

* Black formatter needs a current version of Python
2025-01-04 01:02:18 -08:00
Christian Clauss
6c1d8f1ea1
Upgrade GitHub Actions (#2430) 2025-01-04 00:47:12 -08:00
Purfview
90db0de189
Bugfix: Illogical "Avoid computing higher temperatures on no_speech" (#1903)
* Bugfix: Illogical "Avoid computing higher temperatures on no_speech"

Bugfix for https://github.com/openai/whisper/pull/1279

It's "silence" when decoding has failed due to `compression_ratio_threshold` too, when further down the code it's not "silence" anymore.

"Silence" should be only when decoding has failed due to `logprob_threshold`.

Like described there:
8bc8860694/whisper/transcribe.py (L421)

And in code there:
8bc8860694/whisper/transcribe.py (L243-L251)

* Fix if "logprob_threshold=None"

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2024-11-30 21:47:01 -08:00
Lowell Vaughn
fc5ded7d90
Updating README and doc strings to reflect that n_mels can now be 128 (#2049) 2024-11-26 09:37:01 -08:00
f1sh
173ff7dd1d
fix typo data/README.md (#2433) 2024-11-12 16:35:54 -08:00
BotMaster3000
271445b2f2
Update README.md (#2379)
Default now uses Turbo instead of Small
2024-11-03 23:00:30 -08:00
kittsil
5979f03701
Add option to carry initial_prompt with the sliding window (#2343)
* Add option to carry initial_prompt with the sliding window

Add an option `carry_initial_prompt = False` to `whisper.transcribe()`.
When set to `True`, `initial_prompt` is prepended to each internal `decode()` call's `prompt`.
If there is not enough context space at the start of the prompt, the prompt is left-sliced to make space.

* Prevent redundant initial_prompt_tokens

* Revert unnecessary .gitignore change

---------

Co-authored-by: Kittsil <kittsil@gmail.com>
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2024-10-26 07:17:31 -07:00
Jong Wook Kim
cdb8147962
more pytorch versions in tests (#2408) 2024-10-25 17:30:02 -07:00
Jong Wook Kim
25639fc17d Release 20240930 2024-09-30 11:20:53 -07:00
Jong Wook Kim
260bbcfcb3
allowing numpy 2 in tests (#2362)
* allowing numpy 2 in tests

* allowing numpy 2 in tests
2024-09-30 11:18:17 -07:00
Jong Wook Kim
25e5c364e0
large-v3-turbo model (#2361) 2024-09-30 10:59:51 -07:00
Jong Wook Kim
b66b46f32d
test on python/pytorch versions up to 3.12 and 2.4.1 (#2360) 2024-09-30 10:33:56 -07:00
Jong Wook Kim
27f971320a
using sdpa if available (#2359)
* using sdpa if available

* Update model.py
2024-09-30 10:27:14 -07:00
Jong Wook Kim
423492dda7 Release 20240927 2024-09-27 16:43:58 -07:00
Jong Wook Kim
279133e310
pinning numpy<2 in tests (#2332)
* pinning numpy<2 in tests

* pip install together

* pip install together
2024-09-10 10:43:21 -07:00
Jianan Xing
32d55d5d76
Relax triton requirements for compatibility with pytorch 2.4 and newer (#2307)
* Relax triton requirements for compatibility with pytorch 2.4 and newer

Similar to https://github.com/openai/whisper/pull/1802, but now when pytorch upgrades to 2.4, it requires triton==3.0.0. I am not sure if it makes sense to remove the upper bound version constraints

* Update requirements.txt
2024-09-10 09:53:08 -07:00
ryanheise
ba3f3cd54b
Skip silence around hallucinations (#1838)
* Add clip_timestamps option

* Add hallucination_silence_threshold option

* Fix typing for python < 3.9

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-12-18 12:11:16 -08:00
Bob Lin
8bc8860694
Fix triton env marker (#1887) 2023-12-11 10:39:08 -05:00
Jong Wook Kim
e58f288045 Release 20231117 2023-11-17 11:59:28 -08:00
Eugene Indenbom
1cea435768
Relax triton requirements for compatibility with pytorch 2.1 and newer (#1802) 2023-11-13 09:43:42 -08:00
Jong Wook Kim
fcfeaf1b61 Release 20231106 2023-11-06 10:14:04 -08:00
Jong Wook Kim
c5d4256076
large-v3 (#1761)
* mel_filters() loads 128 mel bins

* can load 100-language models

* large-v3 checkpoint and evals

* add mandarin alias

* remove unused path

* flake8 fix

* formatting fix
2023-11-06 10:10:30 -08:00
Jong Wook Kim
f6f01c561c Release 20231105 2023-11-06 03:08:56 -08:00
Jong Wook Kim
746aaaeafa
remove tiktoken pin (#1759) 2023-11-06 03:05:21 -08:00
Philippe Hebert
b9f17e1f2d
docs: Disambiguation of the term "relative speed" in the README (#1751)
* docs: defines relative speed in README

* combined paragraphs

---------

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-11-06 02:43:07 -08:00
Mohamad Zamini
7dfcd56304
allow_pickle=False while loading of mel matrix IN audio.py (#1511)
* Update audio.py

 The `mel_filters` function is using a `np.load` function to load a pre-computed mel filterbank matrix. This function is not thread-safe, which means that if it is called from multiple threads at the same time, it may corrupt the data.

To fix this, you can use the `torch.load` function instead. This function is thread-safe, so it will not corrupt the data if it is called from multiple threads at the same time.

* Update audio.py

updated the docstring

* allow_pickle=False

* newline

---------

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-11-06 02:28:51 -08:00
Marco Zucconelli
b7d277acd5
handling transcribe exceptions. (#1682)
* handling transcribe() exceptions.

* printing stacktrace

---------

Co-authored-by: invalid <invalid@email.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-11-06 02:06:19 -08:00
amosal
6ed314fe41
Add new option to generate subtitles by a specific number of words (#1729)
* ADD parser for new argument --max_words_count

* ADD max_words_count in words_options
ADD warning for max_line_width compatibility

* ADD logic for max_words_count

* rename to max_words_per_line

* make them kwargs

* allow specifying file path by --model

* black formatting

---------

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-11-06 01:49:33 -08:00
Jordi Mas
b38a1f20f4
Fix exception when an audio file with no speech is provided (#1396)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-10-10 10:01:01 -07:00
Jong Wook Kim
0a60fcaa9b Release 20230918 2023-09-18 17:13:19 -07:00
Jong Wook Kim
5f957da5ca
Update test.yml 2023-09-18 16:38:17 -07:00
Arthur Kim
8b330df096
Add .pre-commit-config.yaml (#1528)
* Add .pre-commit-config.yaml

Co-authored-by: arthur <arthur@rtzr.ai>

* flake8 E741

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-09-18 16:15:33 -07:00
sqhao
21010ef454
fix doc of TextDecoder (#1526)
Signed-off-by: haoshengqiang <haoshengqiang@xiaohongshu.com>
Co-authored-by: haoshengqiang <haoshengqiang@xiaohongshu.com>
2023-09-18 16:09:59 -07:00
Nino Risteski
29b7df6231
Update model-card.md (#1643)
fixed a few typos
2023-09-18 15:59:49 -07:00
taylorchu
e8622f9afc
word timing tweaks (#1559)
* word timing tweaks

* comment on eot

* clearer comments
2023-08-08 06:48:56 +09:00
WangChou Lu
b91c907694
Avoid rearranging all caches (#1483)
* avoid rearranging all kv_caches

* avoid calculating the same kv_cache from cross attn

* Update decoding.py

* linter fix

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-07-06 12:48:08 -07:00
ryanheise
f572f2161b
Improve timestamp heuristics. (#1461)
* Improve timestamp heuristics.

* Track pauses with last_speech_timestamp
2023-06-29 16:51:24 -07:00
Valentin Berkes
248b6cb124
fix condition_on_previous_text (#1224)
prompt_reset_since is set before all_tokens is extended hence does not have the expected effect.
2023-05-05 00:31:35 -07:00
Paul Willot
7ca9fbea86
Fix numba depreceation notice (#1233)
From numba 0.57 raise a warning if `nopython` is not supplied:
https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit
2023-05-04 23:48:06 -07:00
Brett Balquist
b1c0815c79
Updated README.md to provide more insight on BLEU and specific appendices (#1236)
* Updated README.md to provide more insight on BLEU and specific appendices in the research paper

* Update README.md

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-05-04 23:47:45 -07:00
Théo BOYER
e334ff141d
Avoid computing higher temperatures on no_speech segments (#1279)
* Avoid computing higher temperatures on no_speech

In decode_with_fallback, we compute higher temperatures in the case where compression_ratio is too high or avg_logprob is too low.
But as the computation of no_speech_prob doens't depend on sampling, we can avoid computing higher temperatures if we detect in the first one that the no_speech condition is fulfilled

* Update transcribe.py

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-05-04 17:02:36 -07:00
petterreinholdtsen
5523722842
Dropped unused execute bit from mel_filters.npz. (#1254) 2023-05-04 10:58:56 -07:00
petterreinholdtsen
8035e9ef48
Drop ffmpeg-python dependency and call ffmpeg directly. (#1242)
* Drop ffmpeg-python dependency and call ffmpeg directly.

The last ffmpeg-python module release was in 2019[1], upstream seem to be
unavailable[2] and the project development seem to have stagnated[3].  As
the features it provide is trivial to replace using the Python native
subprocess module, drop the dependency.

 [1] <URL: https://github.com/kkroening/ffmpeg-python/tags >
 [2] <URL: https://github.com/kkroening/ffmpeg-python/issues/760 >
 [3] <URL: https://openhub.net/p/ffmpeg-python >

* Rewrote to use subprocess.run() instead of subprocess.Popen().

* formatting changes

* formatting update

* isort fix

* Error checking

* isort 🤦🏻

* flake8 fix

* minor spelling changes

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-05-04 10:53:59 -07:00
Johnny
e69930cb9c
Python 3.11 (#1171)
* python 3.11

* python 3.11

* fix

* fix

* fix

* revert changes

* Update requirements.txt

* Trying pip3 install instead

* Excluding cp39 - torch 1.10.2

* Removing 1.10.2 from test

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-05-04 10:42:09 -07:00
Jong Wook Kim
c09a7ae299
Update decoding.py (#1219) 2023-04-11 15:13:13 -07:00
Fernando O. Gallego
b0022b3283
Update decoding.py (#1155)
* Update decoding.py

Following the suggestions of @Jeronymous in https://github.com/openai/whisper/pull/914 and https://github.com/openai/whisper/discussions/924, it solves the problem of endless loop.

* Removed blank line and whitespaces in empty lines.

* Suggested changes according to the linter

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-04-11 15:06:03 -07:00
Arseniy Bushyn
76c901ab8d
Update README.md to reference tiktoken (#1105)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-04-10 17:39:17 -07:00
ryanheise
43940fc978
Implement max line width and max line count, and make word highlighting optional (#1184)
* Add highlight_words, max_line_width, max_line_count

* Refactor subtitle generator

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-04-10 17:28:35 -07:00
ryanheise
255887f219
Squash long words at window and sentence boundaries. (#1114)
* Squash long words at window and sentence boundaries.

* Formatting requirements.

* Fix squashing logic to point to correct words.

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-04-10 17:23:53 -07:00
K.B.Dharun Krishna
a151816b6b
python-publish.yml: bump actions version to fix node warning (#1211) 2023-04-10 13:54:09 -07:00
Jong Wook Kim
b5851c6c40
Update tokenizer.py (#1163) 2023-03-29 13:12:36 -07:00
Jong Wook Kim
6dea21fd7f Release 20230314 2023-03-15 00:39:19 -07:00
Jong Wook Kim
79c43e4859
abort find_alignment on empty input (#1090) 2023-03-14 12:47:58 -07:00
Guillaume Klein
5f9ac653b7
Fix truncated words list when the replacement character is decoded (#1089) 2023-03-14 09:32:41 -07:00
Akash Mahajan
ba88b8e1b3
fix github language stats getting dominated by jupyter notebook (#1076)
Co-authored-by: Akash Mahajan <akash.mahajan@microsoft.com>
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-03-14 00:07:09 -07:00
Guillaume Klein
671ac5a4ce
Fix alignment between the segments and the list of words (#1087)
* Fix alignment between the segments and the list of words

* Ensure the word index does not overflow
2023-03-13 16:34:09 -07:00
Jong Wook Kim
839639a223
Use tiktoken (#1044)
* use tiktoken==0.3.0

* formatting

* tuple should be safer

* Update whisper/tokenizer.py

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>

* use tiktoken 0.3.1

* reflecting suggestions

* cleanup

* bypassing load_tiktoken_bpe to avoid blobfile dep

---------

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>
2023-03-13 02:34:16 -07:00
Jong Wook Kim
ad3250a846 Release 20230308 2023-03-08 15:48:57 -08:00
Jong Wook Kim
c4b50c0824
kwargs in decode() for convenience (#1061)
* kwargs in decode() for convenience

* formatting fix
2023-03-08 15:46:38 -08:00
Jong Wook Kim
38f2f4d99d
fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060) 2023-03-08 15:34:07 -08:00
Jong Wook Kim
aac47c9834 fix typo 2023-03-07 20:43:49 -08:00
Jong Wook Kim
26807ec6d3 Release 20230307 2023-03-07 20:36:29 -08:00
Jong Wook Kim
919a713499
attempt to fix the repetition/hallucination issue identified in #1046 (#1052)
* attempt to fix the repetition/hallucination issue identified in #1046

* zero-pad the audio instead of spectrogram

* formatting fix

* delete debug print
2023-03-07 20:08:45 -08:00
Jong Wook Kim
38e990d853
Use triton==2.0.0 (#1053) 2023-03-07 16:56:31 -08:00
Jong Wook Kim
924e1f8e06
Try installing triton only if linux & x86_64 (#1051) 2023-03-07 11:31:40 -08:00
Jong Wook Kim
4b0d5e58d0
Update setup.py 2023-03-07 04:47:46 -08:00
Jong Wook Kim
8180fde939 Release 20230306 2023-03-06 18:53:04 -08:00
Local State
c6e4e5efb3
remove auxiliary audio extension (#1021)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-03-06 17:48:14 -08:00
Jong Wook Kim
b80bcf610d
apply formatting with black (#1038)
* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
2023-03-06 15:50:37 -08:00
Jong Wook Kim
500d0fe966
word-level timestamps in transcribe() (#869)
* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <ryan@ryanheise.com>
2023-03-06 14:00:49 -08:00
Jong Wook Kim
eab8d920ed
Decoding improvements (#1033)
* suppress task tokens (transcribe/translate)

* not ignoring the last segment ending with one timestamp
2023-03-06 11:32:32 -08:00
Roman Vasilenko
3e1780fd37
Update README.md (#894)
Fixed a few typos and made general improvements for clarity.

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-03-03 16:41:59 -08:00
Andrey Chernykh
7858aa9c08
Fix infinite loop caused by incorrect timestamp tokens prediction (#914)
* Fix infinite loop caused by incorrect timestamp tokens prediction

https://github.com/openai/whisper/discussions/810

* Update decoding.py

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-02-01 15:46:51 -08:00
Jong Wook Kim
5c1a8c10e7
clarify that 3.11 is not supported 2023-01-27 00:01:49 -08:00
Jong Wook Kim
4e635c6644
Update README.md about Python 3.8+ requirement 2023-01-24 14:45:56 -08:00
Jong Wook Kim
a6b36ede1f
drop python 3.7 support (#889) 2023-01-24 14:05:57 -08:00
Jong Wook Kim
55f690af79 Release 20230124 2023-01-24 11:11:08 -08:00
Jong Wook Kim
7f1ef223ab
handle printing even if sys.stdout.buffer is not available (#887) 2023-01-24 10:12:04 -08:00
Niels Mayer
f5bfe004ec
Add TSV formatted output in transcript, using integer start/end times in milliseconds. (#228)
* Add CSV format output in transcript, containing lines of characters formatted like: <startTime-in-integer-milliseconds>, <endTime-in-integer-milliseconds>, <transcript-including-commas>

* for easier reading by spreadsheets importing CSV, the third

column of the CSV file is delimited by quotes, and any quote
characters that might be in the transcript (which would interfere with
parsing the third column as a string) are converted to "''".

* fix syntax error

* docstring edit

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-22 00:27:17 -08:00
Aaryan YVS
da600abd2b
Added --output_format option (#333)
* Added --output option

--output option will help select the output files that will be generated.

Corrected the logic, which wrongly shows progress bar when verbose is set to False

* Changed output_files variable

* Changed back the tqdm verbose

* refactor output format handling

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-21 23:58:38 -08:00
zer0-x
9f7aba6099
Handle XDG_CACHE_HOME properly for download_root (#864)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-01-21 01:09:39 -08:00
Jong Wook Kim
12e1089462
use stdout for printing transcription progress (#867) 2023-01-20 00:54:05 -08:00
Markus Hennerbichler
ea1c266709
Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm (#659)
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-18 10:41:11 -08:00
Jong Wook Kim
8135a7c31c verbose outputs from pytest 2023-01-18 10:30:18 -08:00
Jong Wook Kim
9d646db9d8
print '?' if a letter can't be encoded using the system default encoding (#859) 2023-01-17 23:28:36 -08:00
50 changed files with 109750 additions and 102570 deletions

4
.flake8 Normal file
View File

@ -0,0 +1,4 @@
[flake8]
per-file-ignores =
*/__init__.py: F401

3
.gitattributes vendored Normal file
View File

@ -0,0 +1,3 @@
# Override jupyter in Github language stats for more accurate estimate of repo code languages
# reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
*.ipynb linguist-generated

13
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,13 @@
# Keep GitHub Actions up to date with GitHub's Dependabot...
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
version: 2
updates:
- package-ecosystem: github-actions
directory: /
groups:
github-actions:
patterns:
- "*" # Group all Actions updates into a single larger pull request
schedule:
interval: weekly

View File

@ -8,23 +8,23 @@ jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-ecosystem/action-regex-match@v2
id: regex-match
with:
text: ${{ github.event.head_commit.message }}
regex: '^Release ([^ ]+)'
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.12'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
pip install setuptools wheel twine build
- name: Release
if: ${{ steps.regex-match.outputs.match != '' }}
uses: softprops/action-gh-release@v1
uses: softprops/action-gh-release@v2
with:
tag_name: v${{ steps.regex-match.outputs.group1 }}
- name: Build and publish
@ -33,5 +33,5 @@ jobs:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python setup.py sdist
python -m build --sdist
twine upload dist/*

View File

@ -6,21 +6,74 @@ on:
pull_request:
branches:
- main
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Fetch base branch
run: git fetch origin ${{ github.base_ref }}
- uses: actions/setup-python@v5
with:
python-version: "3.9"
architecture: x64
- name: Get pip cache dir
id: pip-cache
run: |
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip/pre-commit cache
uses: actions/cache@v4
with:
path: |
${{ steps.pip-cache.outputs.dir }}
~/.cache/pre-commit
key: ${{ runner.os }}-pip-pre-commit-${{ hashFiles('**/.pre-commit-config.yaml') }}
restore-keys: |
${{ runner.os }}-pip-pre-commit
- name: pre-commit
run: |
pip install --upgrade pre-commit
pre-commit install --install-hooks
pre-commit run --all-files
whisper-test:
needs: pre-commit
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10']
pytorch-version: [1.10.2, 1.13.1]
exclude:
include:
- python-version: '3.8'
pytorch-version: 1.10.1
numpy-requirement: "'numpy<2'"
- python-version: '3.8'
pytorch-version: 1.13.1
numpy-requirement: "'numpy<2'"
- python-version: '3.8'
pytorch-version: 2.0.1
numpy-requirement: "'numpy<2'"
- python-version: '3.9'
pytorch-version: 2.1.2
numpy-requirement: "'numpy<2'"
- python-version: '3.10'
pytorch-version: 1.10.2
pytorch-version: 2.2.2
numpy-requirement: "'numpy<2'"
- python-version: '3.11'
pytorch-version: 2.3.1
numpy-requirement: "'numpy'"
- python-version: '3.12'
pytorch-version: 2.4.1
numpy-requirement: "'numpy'"
- python-version: '3.12'
pytorch-version: 2.5.1
numpy-requirement: "'numpy'"
- python-version: '3.13'
pytorch-version: 2.5.1
numpy-requirement: "'numpy'"
steps:
- uses: conda-incubator/setup-miniconda@v2
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
- uses: actions/checkout@v2
- uses: conda-incubator/setup-miniconda@v3
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
- uses: actions/checkout@v4
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest
- run: pip install .
- run: pytest -k 'not test_transcribe or test_transcribe[tiny]'
- run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

28
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,28 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-json
- id: end-of-file-fixer
types: [file, python]
- id: trailing-whitespace
types: [file, python]
- id: mixed-line-ending
- id: check-added-large-files
args: [--maxkb=4096]
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.0
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"]
- repo: https://github.com/pycqa/flake8.git
rev: 7.1.1
hooks:
- id: flake8
types: [python]
args: ["--max-line-length", "88", "--ignore", "E203,E501,W503,W504"]

121
CHANGELOG.md Normal file
View File

@ -0,0 +1,121 @@
# CHANGELOG
## [v20250625](https://github.com/openai/whisper/releases/tag/v20250625)
* Fix: Update torch.load to use weights_only=True to prevent security w… ([#2451](https://github.com/openai/whisper/pull/2451))
* Fix: Ensure DTW cost tensor is on the same device as input tensor ([#2561](https://github.com/openai/whisper/pull/2561))
* docs: updated README to specify translation model limitation ([#2547](https://github.com/openai/whisper/pull/2547))
* Fixed triton kernel update to support latest triton versions ([#2588](https://github.com/openai/whisper/pull/2588))
* Fix: GitHub display errors for Jupyter notebooks ([#2589](https://github.com/openai/whisper/pull/2589))
* Bump the github-actions group with 3 updates ([#2592](https://github.com/openai/whisper/pull/2592))
* Keep GitHub Actions up to date with GitHub's Dependabot ([#2486](https://github.com/openai/whisper/pull/2486))
* pre-commit: Upgrade black v25.1.0 and isort v6.0.0 ([#2514](https://github.com/openai/whisper/pull/2514))
* GitHub Actions: Add Python 3.13 to the testing ([#2487](https://github.com/openai/whisper/pull/2487))
* PEP 621: Migrate from setup.py to pyproject.toml ([#2435](https://github.com/openai/whisper/pull/2435))
* pre-commit autoupdate && pre-commit run --all-files ([#2484](https://github.com/openai/whisper/pull/2484))
* Upgrade GitHub Actions ([#2430](https://github.com/openai/whisper/pull/2430))
* Bugfix: Illogical "Avoid computing higher temperatures on no_speech" ([#1903](https://github.com/openai/whisper/pull/1903))
* Updating README and doc strings to reflect that n_mels can now be 128 ([#2049](https://github.com/openai/whisper/pull/2049))
* fix typo data/README.md ([#2433](https://github.com/openai/whisper/pull/2433))
* Update README.md ([#2379](https://github.com/openai/whisper/pull/2379))
* Add option to carry initial_prompt with the sliding window ([#2343](https://github.com/openai/whisper/pull/2343))
* more pytorch versions in tests ([#2408](https://github.com/openai/whisper/pull/2408))
## [v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
* allowing numpy 2 in tests ([#2362](https://github.com/openai/whisper/pull/2362))
* large-v3-turbo model ([#2361](https://github.com/openai/whisper/pull/2361))
* test on python/pytorch versions up to 3.12 and 2.4.1 ([#2360](https://github.com/openai/whisper/pull/2360))
* using sdpa if available ([#2359](https://github.com/openai/whisper/pull/2359))
## [v20240927](https://github.com/openai/whisper/releases/tag/v20240927)
* pinning numpy<2 in tests ([#2332](https://github.com/openai/whisper/pull/2332))
* Relax triton requirements for compatibility with pytorch 2.4 and newer ([#2307](https://github.com/openai/whisper/pull/2307))
* Skip silence around hallucinations ([#1838](https://github.com/openai/whisper/pull/1838))
* Fix triton env marker ([#1887](https://github.com/openai/whisper/pull/1887))
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))
## [v20231106](https://github.com/openai/whisper/releases/tag/v20231106)
* large-v3 ([#1761](https://github.com/openai/whisper/pull/1761))
## [v20231105](https://github.com/openai/whisper/releases/tag/v20231105)
* remove tiktoken pin ([#1759](https://github.com/openai/whisper/pull/1759))
* docs: Disambiguation of the term "relative speed" in the README ([#1751](https://github.com/openai/whisper/pull/1751))
* allow_pickle=False while loading of mel matrix IN audio.py ([#1511](https://github.com/openai/whisper/pull/1511))
* handling transcribe exceptions. ([#1682](https://github.com/openai/whisper/pull/1682))
* Add new option to generate subtitles by a specific number of words ([#1729](https://github.com/openai/whisper/pull/1729))
* Fix exception when an audio file with no speech is provided ([#1396](https://github.com/openai/whisper/pull/1396))
## [v20230918](https://github.com/openai/whisper/releases/tag/v20230918)
* Add .pre-commit-config.yaml ([#1528](https://github.com/openai/whisper/pull/1528))
* fix doc of TextDecoder ([#1526](https://github.com/openai/whisper/pull/1526))
* Update model-card.md ([#1643](https://github.com/openai/whisper/pull/1643))
* word timing tweaks ([#1559](https://github.com/openai/whisper/pull/1559))
* Avoid rearranging all caches ([#1483](https://github.com/openai/whisper/pull/1483))
* Improve timestamp heuristics. ([#1461](https://github.com/openai/whisper/pull/1461))
* fix condition_on_previous_text ([#1224](https://github.com/openai/whisper/pull/1224))
* Fix numba depreceation notice ([#1233](https://github.com/openai/whisper/pull/1233))
* Updated README.md to provide more insight on BLEU and specific appendices ([#1236](https://github.com/openai/whisper/pull/1236))
* Avoid computing higher temperatures on no_speech segments ([#1279](https://github.com/openai/whisper/pull/1279))
* Dropped unused execute bit from mel_filters.npz. ([#1254](https://github.com/openai/whisper/pull/1254))
* Drop ffmpeg-python dependency and call ffmpeg directly. ([#1242](https://github.com/openai/whisper/pull/1242))
* Python 3.11 ([#1171](https://github.com/openai/whisper/pull/1171))
* Update decoding.py ([#1219](https://github.com/openai/whisper/pull/1219))
* Update decoding.py ([#1155](https://github.com/openai/whisper/pull/1155))
* Update README.md to reference tiktoken ([#1105](https://github.com/openai/whisper/pull/1105))
* Implement max line width and max line count, and make word highlighting optional ([#1184](https://github.com/openai/whisper/pull/1184))
* Squash long words at window and sentence boundaries. ([#1114](https://github.com/openai/whisper/pull/1114))
* python-publish.yml: bump actions version to fix node warning ([#1211](https://github.com/openai/whisper/pull/1211))
* Update tokenizer.py ([#1163](https://github.com/openai/whisper/pull/1163))
## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
* fix typo in CHANGELOG.md
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
* Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051))
* update setup.py to specify python >= 3.8 requirement
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
* remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021))
* apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038))
* word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869))
* Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033))
* Update README.md ([#894](https://github.com/openai/whisper/pull/894))
* Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914))
* drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889))
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
* handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887))
* Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228))
* Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333))
* Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864))
* use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867))
* Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659))
* print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859))
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
The first versioned release available on [PyPI](https://pypi.org/project/openai-whisper/)

View File

@ -2,6 +2,4 @@ include requirements.txt
include README.md
include LICENSE
include whisper/assets/*
include whisper/assets/gpt2/*
include whisper/assets/multilingual/*
include whisper/normalizers/english.json

View File

@ -5,19 +5,19 @@
[[Model card]](https://github.com/openai/whisper/blob/main/model-card.md)
[[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb)
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
## Approach
![Approach](https://raw.githubusercontent.com/openai/whisper/main/approach.png)
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. All of these tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing for a single model to replace many different stages of a traditional speech processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
## Setup
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.11 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [OpenAI's tiktoken](https://github.com/openai/tiktoken) for their fast tokenizer implementation. You can download and install (or update to) the latest release of Whisper with the following command:
pip install -U openai-whisper
@ -48,7 +48,7 @@ choco install ffmpeg
scoop install ffmpeg
```
You may need [`rust`](http://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
You may need [`rust`](http://rust-lang.org) installed as well, in case [tiktoken](https://github.com/openai/tiktoken) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
```bash
pip install setuptools-rust
@ -57,42 +57,55 @@ pip install setuptools-rust
## Available models and languages
There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and relative speed.
There are six model sizes, four with English-only versions, offering speed and accuracy tradeoffs.
Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model.
The relative speeds below are measured by transcribing English speech on a A100, and the real-world speed may vary significantly depending on many factors including the language, the speaking speed, and the available hardware.
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x |
| base | 74 M | `base.en` | `base` | ~1 GB | ~16x |
| small | 244 M | `small.en` | `small` | ~2 GB | ~6x |
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~10x |
| base | 74 M | `base.en` | `base` | ~1 GB | ~7x |
| small | 244 M | `small.en` | `small` | ~2 GB | ~4x |
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
| turbo | 809 M | N/A | `turbo` | ~6 GB | ~8x |
For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of Fleurs dataset, using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller is better.
![WER breakdown by language](https://raw.githubusercontent.com/openai/whisper/main/language-breakdown.svg)
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
Additionally, the `turbo` model is an optimized version of `large-v3` that offers faster transcription speed with a minimal degradation in accuracy.
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
![WER breakdown by language](https://github.com/openai/whisper/assets/266841/f4619d66-1058-4005-8f67-a9d811b77c62)
## Command-line usage
The following command will transcribe speech in audio files, using the `medium` model:
The following command will transcribe speech in audio files, using the `turbo` model:
whisper audio.flac audio.mp3 audio.wav --model medium
```bash
whisper audio.flac audio.mp3 audio.wav --model turbo
```
The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
The default setting (which selects the `turbo` model) works well for transcribing English. However, **the `turbo` model is not trained for translation tasks**. If you need to **translate non-English speech into English**, use one of the **multilingual models** (`tiny`, `base`, `small`, `medium`, `large`) instead of `turbo`.
whisper japanese.wav --language Japanese
For example, to transcribe an audio file containing non-English speech, you can specify the language:
Adding `--task translate` will translate the speech into English:
```bash
whisper japanese.wav --language Japanese
```
whisper japanese.wav --language Japanese --task translate
To **translate** speech into English, use:
```bash
whisper japanese.wav --model medium --language Japanese --task translate
```
> **Note:** The `turbo` model will return the original language even if `--task translate` is specified. Use `medium` or `large` for the best translation results.
Run the following to view all available options:
whisper --help
```bash
whisper --help
```
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
@ -104,7 +117,7 @@ Transcription can also be performed within Python:
```python
import whisper
model = whisper.load_model("base")
model = whisper.load_model("turbo")
result = model.transcribe("audio.mp3")
print(result["text"])
```
@ -116,14 +129,14 @@ Below is an example usage of `whisper.detect_language()` and `whisper.decode()`
```python
import whisper
model = whisper.load_model("base")
model = whisper.load_model("turbo")
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("audio.mp3")
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
@ -144,4 +157,4 @@ Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussion
## License
The code and the model weights of Whisper are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.

View File

@ -45,7 +45,7 @@ We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challen
### AMI-IHM, AMI-SDM1
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 and 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
## Long-form English-only datasets

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 100 KiB

After

Width:  |  Height:  |  Size: 272 KiB

View File

@ -16,13 +16,15 @@ The Whisper models are trained for speech recognition and translation tasks, cap
| small | 244 M | ✓ | ✓ |
| medium | 769 M | ✓ | ✓ |
| large | 1550 M | | ✓ |
| turbo | 798 M | | ✓ |
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661).
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
Additionally, we've added a `turbo` model in September 2024 which is optimized for inference speed.
### Release date
September 2022 (original series) and December 2022 (`large-v2`)
September 2022 (original series), December 2022 (`large-v2`), November 2023 (`large-v3`), September 2024 (`large-v3-turbo`)
### Model type
@ -37,7 +39,7 @@ Sequence-to-sequence ASR (automatic speech recognition) and speech translation m
### Evaluated Use
The primary intended users of these models are AI researchers studying robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research.
The primary intended users of these models are AI researchers studying the robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research.
The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them.
@ -53,17 +55,17 @@ As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we s
## Performance and Limitations
Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, technical language, as well as zero shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level.
Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, and technical language, as well as zero-shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level.
However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself.
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356).
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include a higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356).
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis of these limitations is provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse in lower-resource and/or lower-discoverability languages.
## Broader Implications
We anticipate that Whisper models transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications.
There are also potential dual use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects.
There are also potential dual-use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects.

View File

@ -949,7 +949,8 @@
"style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
"value": " 164/164 [05:08&lt;00:00, 1.86s/it]"
}
}
},
"state": {}
}
}
},

View File

@ -4219,7 +4219,8 @@
"_view_name": "StyleView",
"description_width": ""
}
}
},
"state": {}
}
}
},

54
pyproject.toml Normal file
View File

@ -0,0 +1,54 @@
[build-system]
build-backend = "setuptools.build_meta"
requires = [ "setuptools>=61.2" ]
[project]
name = "openai-whisper"
description = "Robust Speech Recognition via Large-Scale Weak Supervision"
readme.content-type = "text/markdown"
readme.file = "README.md"
license = { text = "MIT" }
authors = [ { name = "OpenAI" } ]
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dynamic = [ "version" ]
dependencies = [
"more-itertools",
"numba",
"numpy",
"tiktoken",
"torch",
"tqdm",
"triton>=2; (platform_machine=='x86_64' and sys_platform=='linux') or sys_platform=='linux2'",
]
optional-dependencies.dev = [ "black", "flake8", "isort", "pytest", "scipy" ]
urls = { Homepage = "https://github.com/openai/whisper" }
scripts.whisper = "whisper.transcribe:cli"
[tool.setuptools]
py-modules = [ "whisper" ]
include-package-data = true
[tool.setuptools.dynamic]
version = { attr = "whisper.version.__version__" }
[tool.setuptools.packages.find]
exclude = [ "tests*" ]
namespaces = false
[tool.black]
[tool.isort]
profile = "black"
include_trailing_comma = true
line_length = 88
multi_line_output = 3

View File

@ -1,6 +1,7 @@
numba
numpy
torch
tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
tiktoken
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"

View File

@ -1,36 +0,0 @@
import os
import pkg_resources
from setuptools import setup, find_packages
def read_version(fname="whisper/version.py"):
exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
return locals()["__version__"]
setup(
name="openai-whisper",
py_modules=["whisper"],
version=read_version(),
description="Robust Speech Recognition via Large-Scale Weak Supervision",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
readme="README.md",
python_requires=">=3.7",
author="OpenAI",
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
entry_points={
"console_scripts": ["whisper=whisper.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest"]},
)

14
tests/conftest.py Normal file
View File

@ -0,0 +1,14 @@
import random as rand
import numpy
import pytest
def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda")
@pytest.fixture
def random():
rand.seed(42)
numpy.random.seed(42)

View File

@ -2,7 +2,7 @@ import os.path
import numpy as np
from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
def test_audio():

View File

@ -1,7 +1,10 @@
import pytest
from whisper.normalizers import EnglishTextNormalizer
from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
from whisper.normalizers.english import (
EnglishNumberNormalizer,
EnglishSpellingNormalizer,
)
@pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])
@ -84,6 +87,7 @@ def test_text_normalizer():
assert std("he's like") == "he is like"
assert std("she's been like") == "she has been like"
assert std("10km") == "10 km"
assert std("10mm") == "10 mm"
assert std("RC232") == "rc 232"
assert (

96
tests/test_timing.py Normal file
View File

@ -0,0 +1,96 @@
import numpy as np
import pytest
import scipy.ndimage
import torch
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
sizes = [
(10, 20),
(32, 16),
(123, 1500),
(234, 189),
]
shapes = [
(10,),
(1, 15),
(4, 5, 345),
(6, 12, 240, 512),
]
@pytest.mark.parametrize("N, M", sizes)
def test_dtw(N: int, M: int):
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
np.random.shuffle(steps)
x = np.random.random((N, M)).astype(np.float32)
i, j, k = 0, 0, 0
trace = []
while True:
x[i, j] -= 1
trace.append((i, j))
if k == len(steps):
break
if k + 1 < len(steps) and steps[k] != steps[k + 1]:
i += 1
j += 1
k += 2
continue
if steps[k] == 0:
i += 1
if steps[k] == 1:
j += 1
k += 1
trace = np.array(trace).T
dtw_trace = dtw_cpu(x)
assert np.allclose(trace, dtw_trace)
@pytest.mark.requires_cuda
@pytest.mark.parametrize("N, M", sizes)
def test_dtw_cuda_equivalence(N: int, M: int):
x_numpy = np.random.randn(N, M).astype(np.float32)
x_cuda = torch.from_numpy(x_numpy).cuda()
trace_cpu = dtw_cpu(x_numpy)
trace_cuda = dtw_cuda(x_cuda)
assert np.allclose(trace_cpu, trace_cuda)
@pytest.mark.parametrize("shape", shapes)
def test_median_filter(shape):
x = torch.randn(*shape)
for filter_width in [3, 5, 7, 13]:
filtered = median_filter(x, filter_width)
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(
x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
)
scipy_filtered = scipy.ndimage.median_filter(
padded_x, [1] * (x.ndim - 1) + [filter_width]
)
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
assert np.allclose(filtered, scipy_filtered)
@pytest.mark.requires_cuda
@pytest.mark.parametrize("shape", shapes)
def test_median_filter_equivalence(shape):
x = torch.randn(*shape)
for filter_width in [3, 5, 7, 13]:
filtered_cpu = median_filter(x, filter_width)
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
assert np.allclose(filtered_cpu, filtered_gpu)

View File

@ -1,7 +1,17 @@
import pytest
from whisper.tokenizer import get_tokenizer
def test_tokenizer():
@pytest.mark.parametrize("multilingual", [True, False])
def test_tokenizer(multilingual):
tokenizer = get_tokenizer(multilingual=False)
assert tokenizer.sot in tokenizer.sot_sequence
assert len(tokenizer.all_language_codes) == len(tokenizer.all_language_tokens)
assert all(c < tokenizer.timestamp_begin for c in tokenizer.all_language_tokens)
def test_multilingual_tokenizer():
gpt2_tokenizer = get_tokenizer(multilingual=False)
multilingual_tokenizer = get_tokenizer(multilingual=True)
@ -12,3 +22,13 @@ def test_tokenizer():
assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens)
def test_split_on_unicode():
multilingual_tokenizer = get_tokenizer(multilingual=True)
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
assert words == [" elle", " est", " l", "'", "\ufffd", "é", "rit", "oire"]
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]

View File

@ -4,6 +4,7 @@ import pytest
import torch
import whisper
from whisper.tokenizer import get_tokenizer
@pytest.mark.parametrize("model_name", whisper.available_models())
@ -13,10 +14,29 @@ def test_transcribe(model_name: str):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None
result = model.transcribe(audio_path, language=language, temperature=0.0)
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
transcription = result["text"].lower()
assert "my fellow americans" in transcription
assert "your country" in transcription
assert "do for you" in transcription
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"]
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
timing_checked = False
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
timing_checked = True
assert timing_checked

View File

@ -10,11 +10,10 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@ -26,7 +25,29 @@ _MODELS = {
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
@ -45,10 +66,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
@ -59,7 +88,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
@ -69,7 +100,12 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
@ -94,24 +130,32 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
checkpoint = torch.load(fp, map_location=device)
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
checkpoint = torch.load(fp, map_location=device, **kwargs)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)

View File

@ -1,4 +1,3 @@
from .transcribe import cli
cli()

50256
whisper/assets/gpt2.tiktoken Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@ -1 +0,0 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

BIN
whisper/assets/mel_filters.npz Executable file → Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
{"<|endoftext|>": 50257}

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@ -1 +0,0 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

View File

@ -1,8 +1,8 @@
import os
from functools import lru_cache
from typing import Union
from subprocess import CalledProcessError, run
from typing import Optional, Union
import ffmpeg
import numpy as np
import torch
import torch.nn.functional as F
@ -12,11 +12,14 @@ from .utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
@ -35,15 +38,25 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
@ -55,7 +68,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
@ -74,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
@ -82,14 +97,22 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
@ -99,11 +122,17 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
The number of Mel-frequency filters, only 80 and 128 are supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
torch.Tensor, shape = (n_mels, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
@ -111,6 +140,10 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@ -16,7 +16,9 @@ if TYPE_CHECKING:
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
@ -30,9 +32,16 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
tokenizer = get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
):
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
)
single = mel.ndim == 2
if single:
@ -70,31 +79,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None
# whether to perform X->X "transcribe" or X->English "translate"
task: str = "transcribe"
# language that the audio is in; uses detected language if None
language: Optional[str] = None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
# "alpha" in Google NMT, or None for length norm, when ranking generations
# to select which to return among the beams or best-of-N samples
length_penalty: Optional[float] = None
# prompt, prefix, and token suppression
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
suppress_blank: bool = True # this will suppress blank outputs
# text or tokens to feed as the prompt or the prefix; for more info:
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt: Optional[Union[str, List[int]]] = None # for the previous context
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True # this will suppress blank outputs
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@ -134,6 +148,10 @@ class PyTorchInference(Inference):
self.kv_cache = {}
self.hooks = []
key_modules = [block.attn.key for block in self.model.decoder.blocks]
value_modules = [block.attn.value for block in self.model.decoder.blocks]
self.kv_modules = key_modules + value_modules
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
@ -152,13 +170,16 @@ class PyTorchInference(Inference):
self.hooks = []
def rearrange_kv_cache(self, source_indices):
for module, tensor in self.kv_cache.items():
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = tensor[source_indices].detach()
if source_indices != list(range(len(source_indices))):
for module in self.kv_modules:
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
def rank(
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
@ -196,7 +217,9 @@ class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
@ -251,12 +274,13 @@ class GreedyDecoder(TokenDecoder):
self.temperature = temperature
self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
next_tokens = Categorical(logits=logits / self.temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
@ -275,7 +299,13 @@ class GreedyDecoder(TokenDecoder):
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
def __init__(
self,
beam_size: int,
eot: int,
inference: Inference,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
@ -283,12 +313,16 @@ class BeamSearchDecoder(TokenDecoder):
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
@ -332,7 +366,9 @@ class BeamSearchDecoder(TokenDecoder):
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
@ -340,7 +376,8 @@ class BeamSearchDecoder(TokenDecoder):
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
@ -348,7 +385,9 @@ class BeamSearchDecoder(TokenDecoder):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished
if (
len(sequences) < self.beam_size
): # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
@ -356,7 +395,8 @@ class BeamSearchDecoder(TokenDecoder):
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
[torch.tensor(seq) for seq in sequences.keys()]
for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
@ -400,7 +440,10 @@ class SuppressTokens(LogitFilter):
class ApplyTimestampRules(LogitFilter):
def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
self,
tokenizer: Tokenizer,
sample_begin: int,
max_initial_timestamp_index: Optional[int],
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
@ -413,9 +456,14 @@ class ApplyTimestampRules(LogitFilter):
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
penultimate_was_timestamp = (
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
)
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
@ -423,19 +471,35 @@ class ApplyTimestampRules(LogitFilter):
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[
sampled_tokens.ge(self.tokenizer.timestamp_begin)
]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
# also force each segment to have a nonzero length, to prevent infinite looping
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
timestamp_last = timestamps[-1] + 1
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
last_allowed = (
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
)
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
dim=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
@ -451,7 +515,12 @@ class DecodingTask:
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
@ -491,9 +560,13 @@ class DecodingTask:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
max_initial_timestamp_index = round(
self.options.max_initial_timestamp / precision
)
self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
ApplyTimestampRules(
tokenizer, self.sample_begin, max_initial_timestamp_index
)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
@ -504,30 +577,38 @@ class DecodingTask:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
if options.length_penalty is not None and not (
0 <= options.length_penalty <= 1
):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt
if prefix:
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
self.tokenizer.encode(" " + prefix.strip())
if isinstance(prefix, str)
else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt:
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
self.tokenizer.encode(" " + prompt.strip())
if isinstance(prompt, str)
else prompt
)
tokens = (
[self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens)
@ -546,7 +627,13 @@ class DecodingTask:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
[
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
@ -558,14 +645,21 @@ class DecodingTask:
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
if mel.shape[-2:] == (
self.model.dims.n_audio_ctx,
self.model.dims.n_audio_state,
):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32
):
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)
return audio_features
@ -574,7 +668,9 @@ class DecodingTask:
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
lang_tokens, lang_probs = self.model.detect_language(
audio_features, self.tokenizer
)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
@ -582,7 +678,6 @@ class DecodingTask:
return languages, lang_probs
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
@ -591,7 +686,9 @@ class DecodingTask:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
@ -625,12 +722,15 @@ class DecodingTask:
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(audio_features=features, language=language, language_probs=probs)
for features, language, probs in zip(audio_features, languages, language_probs)
DecodingResult(
audio_features=features, language=language, language_probs=probs
)
for features, language, probs in zip(
audio_features, languages, language_probs
)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
# repeat text tensors by the group size, for beam search or best-of-n sampling
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop
@ -647,7 +747,8 @@ class DecodingTask:
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
for s in tokens
]
# select the top-ranked sample in each group
@ -656,9 +757,18 @@ class DecodingTask:
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
fields = (
texts,
languages,
tokens,
audio_features,
avg_logprobs,
no_speech_probs,
)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
@ -673,12 +783,19 @@ class DecodingTask:
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
*fields
)
]
@torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
def decode(
model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@ -698,13 +815,12 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel)
if single:
result = result[0]
return result
return result[0] if single else result

View File

@ -1,15 +1,25 @@
import base64
import gzip
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
from typing import Dict, Iterable, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch import Tensor, nn
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function
try:
from torch.nn.functional import scaled_dot_product_attention
SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
scaled_dot_product_attention = None
SDPA_AVAILABLE = False
@dataclass
@ -34,12 +44,16 @@ class LayerNorm(nn.LayerNorm):
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
@ -54,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module):
use_sdpa = True
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
@ -85,20 +111,32 @@ class MultiHeadAttention(nn.Module):
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
a = scaled_dot_product_attention(
q, k, v, is_causal=mask is not None and n_ctx > 1
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
qk = None
else:
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
w = F.softmax(qk, dim=-1).to(q.dtype)
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()
return out, qk
class ResidualAttentionBlock(nn.Module):
@ -108,11 +146,15 @@ class ResidualAttentionBlock(nn.Module):
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
@ -130,7 +172,9 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
@ -161,14 +205,19 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
@ -179,18 +228,23 @@ class TextDecoder(nn.Module):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
@ -213,6 +267,22 @@ class Whisper(nn.Module):
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
@ -220,7 +290,9 @@ class Whisper(nn.Module):
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
@ -229,7 +301,11 @@ class Whisper(nn.Module):
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
@ -249,8 +325,9 @@ class Whisper(nn.Module):
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]

View File

@ -1,2 +1,2 @@
from .basic import BasicTextNormalizer
from .english import EnglishTextNormalizer
from .basic import BasicTextNormalizer as BasicTextNormalizer
from .english import EnglishTextNormalizer as EnglishTextNormalizer

View File

@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
c
if c in keep
else ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else ""
if unicodedata.category(c) == "Mn"
else " "
if unicodedata.category(c)[0] in "MSP"
else c
(
c
if c in keep
else (
ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else (
""
if unicodedata.category(c) == "Mn"
else " " if unicodedata.category(c)[0] in "MSP" else c
)
)
)
for c in unicodedata.normalize("NFKD", s)
)
@ -48,13 +52,16 @@ def remove_symbols(s: str):
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
" " if unicodedata.category(c)[0] in "MSP" else c
for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
self.clean = (
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
)
self.split_letters = split_letters
def __call__(self, s: str):
@ -66,6 +73,8 @@ class BasicTextNormalizer:
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
s = re.sub(
r"\s+", " ", s
) # replace any successive whitespace characters with a space
return s

View File

@ -1737,6 +1737,5 @@
"yoghurt": "yogurt",
"yoghurts": "yogurts",
"mhm": "hmm",
"mm": "hmm",
"mmm": "hmm"
}

View File

@ -84,7 +84,8 @@ class EnglishNumberNormalizer:
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
}
self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
name.replace("y", "ieth"): (value, "th")
for name, value in self.tens.items()
}
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
@ -108,7 +109,10 @@ class EnglishNumberNormalizer:
self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items()
}
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
self.multipliers_suffixed = {
**self.multipliers_plural,
**self.multipliers_ordinal,
}
self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = {
@ -128,7 +132,8 @@ class EnglishNumberNormalizer:
"cents": "¢",
}
self.prefixes = set(
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
list(self.preceding_prefixers.values())
+ list(self.following_prefixers.values())
)
self.suffixers = {
"per": {"cent": "%"},
@ -218,7 +223,9 @@ class EnglishNumberNormalizer:
if value is None:
value = ones
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10: # replace the last zero with the digit
if (
prev in self.tens and ones < 10
): # replace the last zero with the digit
assert value[-1] == "0"
value = value[:-1] + str(ones)
else:
@ -522,14 +529,14 @@ class EnglishTextNormalizer:
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
s = self.standardize_numbers(s)
s = self.standardize_spellings(s)
@ -538,6 +545,6 @@ class EnglishTextNormalizer:
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
return s

388
whisper/timing.py Normal file
View File

@ -0,0 +1,388 @@
import itertools
import subprocess
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List
import numba
import numpy as np
import torch
import torch.nn.functional as F
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer
if TYPE_CHECKING:
from .model import Whisper
def median_filter(x: torch.Tensor, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
# F.pad requires the padding width to be smaller than the input dimension
return x
if (ndim := x.ndim) <= 2:
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :]
assert (
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)
if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2:
result = result[0, 0]
return result
@numba.jit(nopython=True)
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
result = []
while i > 0 or j > 0:
result.append((i - 1, j - 1))
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected trace[i, j]")
result = np.array(result)
return result[::-1, :].T
@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t
return backtrace(trace)
def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.to(x.device)
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)
return dtw_cpu(x.double().cpu().numpy())
@dataclass
class WordTiming:
word: str
tokens: List[int]
start: float
end: float
probability: float
def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
num_frames: int,
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
if len(text_tokens) == 0:
return []
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
from .model import disable_sdpa
with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
# heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]
text_indices, time_indices = dtw(-matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
if len(word_tokens) <= 1:
# return on eot only
# >>> np.pad([], (1, 0))
# array([0.])
# This results in crashes when we lookup jump_times with float, like
# IndexError: arrays used as indices must be of integer (or boolean) type
return []
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous.word.startswith(" ") and previous.word.strip() in prepended:
# prepend it to the following word
following.word = previous.word + following.word
following.tokens = previous.tokens + following.tokens
previous.word = ""
previous.tokens = []
else:
j = i
i -= 1
# merge appended punctuations
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous.word.endswith(" ") and following.word in appended:
# append it to the previous word
previous.word = previous.word + following.word
previous.tokens = previous.tokens + following.tokens
following.word = ""
following.tokens = []
else:
i = j
j += 1
def add_word_timestamps(
*,
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
last_speech_timestamp: float,
**kwargs,
):
if len(segments) == 0:
return
text_tokens_per_segment = [
[token for token in segment["tokens"] if token < tokenizer.eot]
for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(word_durations) > 0:
sentence_end_marks = ".。!?"
# ensure words at sentence boundaries are not longer than twice the median word duration.
for i in range(1, len(alignment)):
if alignment[i].end - alignment[i].start > max_duration:
if alignment[i].word in sentence_end_marks:
alignment[i].end = alignment[i].start + max_duration
elif alignment[i - 1].word in sentence_end_marks:
alignment[i].start = alignment[i].end - max_duration
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
word_index = 0
for segment, text_tokens in zip(segments, text_tokens_per_segment):
saved_tokens = 0
words = []
while word_index < len(alignment) and saved_tokens < len(text_tokens):
timing = alignment[word_index]
if timing.word:
words.append(
dict(
word=timing.word,
start=round(time_offset + timing.start, 2),
end=round(time_offset + timing.end, 2),
probability=timing.probability,
)
)
saved_tokens += len(timing.tokens)
word_index += 1
# hack: truncate long words at segment boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(words) > 0:
# ensure the first and second word after a pause is not longer than
# twice the median word duration.
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
words[0]["end"] - words[0]["start"] > max_duration
or (
len(words) > 1
and words[1]["end"] - words[0]["start"] > max_duration * 2
)
):
if (
len(words) > 1
and words[1]["end"] - words[1]["start"] > max_duration
):
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
words[0]["end"] = words[1]["start"] = boundary
words[0]["start"] = max(0, words[0]["end"] - max_duration)
# prefer the segment-level start timestamp if the first word is too long.
if (
segment["start"] < words[0]["end"]
and segment["start"] - 0.5 > words[0]["start"]
):
words[0]["start"] = max(
0, min(words[0]["end"] - median_duration, segment["start"])
)
else:
segment["start"] = words[0]["start"]
# prefer the segment-level end timestamp if the last word is too long.
if (
segment["end"] > words[-1]["start"]
and segment["end"] + 0.5 < words[-1]["end"]
):
words[-1]["end"] = max(
words[-1]["start"] + median_duration, segment["end"]
)
else:
segment["end"] = words[-1]["end"]
last_speech_timestamp = segment["end"]
segment["words"] = words

View File

@ -1,11 +1,11 @@
import base64
import os
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import string
from dataclasses import dataclass, field
from functools import cached_property, lru_cache
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from transformers import GPT2TokenizerFast
import tiktoken
LANGUAGES = {
"en": "english",
@ -107,6 +107,7 @@ LANGUAGES = {
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
@ -123,117 +124,121 @@ TO_LANGUAGE_CODE = {
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
@dataclass(frozen=True)
@dataclass
class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast"
language: Optional[str]
sot_sequence: Tuple[int]
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
return self.encoding.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
return self.tokenizer.decode(token_ids, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)
return self.encoding.decode(token_ids, **kwargs)
@property
@lru_cache()
@cached_property
def eot(self) -> int:
return self.tokenizer.eos_token_id
return self.encoding.eot_token
@property
@lru_cache()
@cached_property
def transcribe(self) -> int:
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self) -> int:
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
return self.special_tokens["<|startoftranscript|>"]
@property
@lru_cache()
@cached_property
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
return self.special_tokens["<|startoflm|>"]
@property
@lru_cache()
@cached_property
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
return self.special_tokens["<|startofprev|>"]
@property
@lru_cache()
@cached_property
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
return self.special_tokens["<|nospeech|>"]
@property
@lru_cache()
@cached_property
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
return self.special_tokens["<|notimestamps|>"]
@property
@lru_cache()
@cached_property
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
return self.special_tokens["<|0.00|>"]
@property
@lru_cache()
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured")
raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
return self.to_language_token(self.language)
raise KeyError(f"Language {self.language} not found in tokenizer.")
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
@property
@lru_cache()
raise KeyError(f"Language {language} not found in tokenizer.")
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
return tuple(result)[: self.num_languages]
@property
@lru_cache()
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@property
@lru_cache()
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
@ -245,8 +250,10 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@ -256,47 +263,113 @@ class Tokenizer:
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
for tokens in [
self.encoding.encode(symbol),
self.encoding.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name)
tokenizer = GPT2TokenizerFast.from_pretrained(path)
def get_encoding(name: str = "gpt2", num_languages: int = 99):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer:
if language is not None:
language = language.lower()
@ -307,25 +380,16 @@ def get_tokenizer(
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
tokenizer_name = "gpt2"
task = None
encoding_name = "gpt2"
language = None
task = None
tokenizer = build_tokenizer(name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
return Tokenizer(
encoding=encoding, num_languages=num_languages, language=language, task=task
)

View File

@ -1,16 +1,35 @@
import argparse
import os
import traceback
import warnings
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
from .utils import (
exact_div,
format_timestamp,
get_end,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING:
from .model import Whisper
@ -26,6 +45,13 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
carry_initial_prompt: bool = False,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options,
):
"""
@ -62,9 +88,37 @@ def transcribe(
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
carry_initial_prompt: bool
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
`decode()` call. If there is not enough context space at the start of the prompt, it is
left-sliced to make space.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -81,26 +135,56 @@ def transcribe(
if dtype == torch.float32:
decode_options["fp16"] = False
mel = log_mel_spectrogram(audio)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
@ -117,17 +201,30 @@ def transcribe(
decode_result = model.decode(segment, options)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
and logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
return decode_result
seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
@ -138,123 +235,303 @@ def transcribe(
all_segments = []
prompt_reset_since = 0
initial_prompt = decode_options.pop("initial_prompt", None) or []
if initial_prompt:
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt)
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
all_segments.append(
{
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment)
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment.shape[-1] # fast-forward to the next segment boundary
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in consecutive:
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_position = (
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result,
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
seek += segment.shape[-1]
all_tokens.extend(tokens.tolist())
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek
pbar.update(min(content_frames, seek) - previous_seek)
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)
def cli():
from . import available_models
def valid_model_name(name):
if name in available_models() or os.path.exists(name):
return name
raise ValueError(
f"model should be one of {available_models()} or path to a model checkpoint"
)
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@ -268,6 +545,8 @@ def cli():
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
@ -275,51 +554,70 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
threads = args.pop("threads")
if threads > 0:
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
word_options = [
"highlight_words",
"max_line_count",
"max_line_width",
"max_words_per_line",
]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
if args["max_words_per_line"] and args["max_line_width"]:
warnings.warn("--max_words_per_line has no effect with --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
audio_basename = os.path.basename(audio_path)
# save TXT
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result["segments"], file=txt)
# save VTT
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt)
# save SRT
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
try:
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
if __name__ == '__main__':
if __name__ == "__main__":
cli()

117
whisper/triton_ops.py Normal file
View File

@ -0,0 +1,117 @@
from functools import lru_cache
import numpy as np
import torch
try:
import triton
import triton.language as tl
except ImportError:
raise RuntimeError("triton import failed; try `pip install --pre triton`")
@triton.jit
def dtw_kernel(
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < M
for k in range(1, N + M + 1): # k = i + j
tl.debug_barrier()
p0 = cost + (k - 1) * cost_stride
p1 = cost + k * cost_stride
p2 = cost + k * cost_stride + 1
c0 = tl.load(p0 + offsets, mask=mask)
c1 = tl.load(p1 + offsets, mask=mask)
c2 = tl.load(p2 + offsets, mask=mask)
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
cost_ptr = cost + (k + 1) * cost_stride + 1
tl.store(cost_ptr + offsets, cost_row, mask=mask)
trace_ptr = trace + (k + 1) * trace_stride + 1
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
@lru_cache(maxsize=None)
def median_kernel(filter_width: int):
@triton.jit
def kernel(
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
): # x.shape[-1] == filter_width
row_idx = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < y_stride
x_ptr = x + row_idx * x_stride # noqa: F841
y_ptr = y + row_idx * y_stride
LOAD_ALL_ROWS_HERE # noqa: F821
BUBBLESORT_HERE # noqa: F821
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn)
new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join(
[
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
for i in range(filter_width)
]
),
)
new_kernel = new_kernel.replace(
" BUBBLESORT_HERE",
"\n\n".join(
[
"\n\n".join(
[
"\n".join(
[
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f" row{j} = smaller",
f" row{j + 1} = larger",
]
)
for j in range(filter_width - i - 1)
]
)
for i in range(filter_width // 2 + 1)
]
),
)
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
if hasattr(kernel, "_unsafe_update_src") is True:
kernel._unsafe_update_src(new_kernel)
kernel.hash = None
else:
kernel.src = new_kernel
return kernel
def median_filter_cuda(x: torch.Tensor, filter_width: int):
"""Apply a median filter of given width along the last dimension of x"""
slices = x.contiguous().unfold(-1, filter_width, 1)
grid = np.prod(slices.shape[:-2])
kernel = median_kernel(filter_width)
y = torch.empty_like(slices[..., 0])
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
return y

View File

@ -1,5 +1,24 @@
import json
import os
import re
import sys
import zlib
from typing import Iterator, TextIO
from typing import Callable, List, Optional, TextIO
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y):
@ -28,7 +47,9 @@ def compression_ratio(text) -> float:
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
@ -42,47 +63,256 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)
def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(
self,
result: dict,
options: Optional[dict] = None,
*,
max_line_width: Optional[int] = None,
max_line_count: Optional[int] = None,
highlight_words: bool = False,
max_words_per_line: Optional[int] = None,
):
options = options or {}
max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count")
highlight_words = highlight_words or options.get("highlight_words", False)
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
preserve_segments = max_line_count is None or max_line_width is None
max_line_width = max_line_width or 1000
max_words_per_line = max_words_per_line or 1000
def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: List[dict] = []
last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]:
chunk_index = 0
words_count = max_words_per_line
while chunk_index < len(segment["words"]):
remaining_words = len(segment["words"]) - chunk_index
if max_words_per_line > len(segment["words"]) - chunk_index:
words_count = remaining_words
for i, original_timing in enumerate(
segment["words"][chunk_index : chunk_index + words_count]
):
timing = original_timing.copy()
long_pause = (
not preserve_segments and timing["start"] - last > 3.0
)
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if (
line_len > 0
and has_room
and not long_pause
and not seg_break
):
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
chunk_index += max_words_per_line
if len(subtitle) > 0:
yield subtitle
if len(result["segments"]) > 0 and "words" in result["segments"][0]:
for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text
yield start, end, "".join(
[
(
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
)
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
def write_srt(transcript: Iterator[dict], file: TextIO):
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options, **kwargs), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in SRT format.
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Example usage:
from pathlib import Path
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
extension: str = "tsv"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
json.dump(result, file)
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for writer in all_writers:
writer(result, file, options, **kwargs)
return write_all
return writers[output_format](output_dir)

View File

@ -1 +1 @@
__version__ = "20230117"
__version__ = "20250625"