diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 0000000..be006de
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,13 @@
+# Keep GitHub Actions up to date with GitHub's Dependabot...
+# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
+# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
+version: 2
+updates:
+ - package-ecosystem: github-actions
+ directory: /
+ groups:
+ github-actions:
+ patterns:
+ - "*" # Group all Actions updates into a single larger pull request
+ schedule:
+ interval: weekly
diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml
index 4b91a2a..ff8f122 100644
--- a/.github/workflows/python-publish.yml
+++ b/.github/workflows/python-publish.yml
@@ -8,23 +8,23 @@ jobs:
deploy:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions-ecosystem/action-regex-match@v2
id: regex-match
with:
text: ${{ github.event.head_commit.message }}
regex: '^Release ([^ ]+)'
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.12'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install setuptools wheel twine
+ pip install setuptools wheel twine build
- name: Release
if: ${{ steps.regex-match.outputs.match != '' }}
- uses: softprops/action-gh-release@v1
+ uses: softprops/action-gh-release@v2
with:
tag_name: v${{ steps.regex-match.outputs.group1 }}
- name: Build and publish
@@ -33,5 +33,5 @@ jobs:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
- python setup.py sdist
+ python -m build --sdist
twine upload dist/*
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index dffc17c..3b53de8 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -11,19 +11,19 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Fetch base branch
run: git fetch origin ${{ github.base_ref }}
- - uses: actions/setup-python@v4
+ - uses: actions/setup-python@v5
with:
- python-version: "3.8"
+ python-version: "3.9"
architecture: x64
- name: Get pip cache dir
id: pip-cache
run: |
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip/pre-commit cache
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: |
${{ steps.pip-cache.outputs.dir }}
@@ -33,24 +33,47 @@ jobs:
${{ runner.os }}-pip-pre-commit
- name: pre-commit
run: |
- pip install -U pre-commit
+ pip install --upgrade pre-commit
pre-commit install --install-hooks
pre-commit run --all-files
whisper-test:
needs: pre-commit
runs-on: ubuntu-latest
strategy:
+ fail-fast: false
matrix:
- python-version: ['3.8', '3.9', '3.10', '3.11']
- pytorch-version: [1.13.1, 2.0.0]
- exclude:
- - python-version: '3.11'
+ include:
+ - python-version: '3.8'
+ pytorch-version: 1.10.1
+ numpy-requirement: "'numpy<2'"
+ - python-version: '3.8'
pytorch-version: 1.13.1
+ numpy-requirement: "'numpy<2'"
+ - python-version: '3.8'
+ pytorch-version: 2.0.1
+ numpy-requirement: "'numpy<2'"
+ - python-version: '3.9'
+ pytorch-version: 2.1.2
+ numpy-requirement: "'numpy<2'"
+ - python-version: '3.10'
+ pytorch-version: 2.2.2
+ numpy-requirement: "'numpy<2'"
+ - python-version: '3.11'
+ pytorch-version: 2.3.1
+ numpy-requirement: "'numpy'"
+ - python-version: '3.12'
+ pytorch-version: 2.4.1
+ numpy-requirement: "'numpy'"
+ - python-version: '3.12'
+ pytorch-version: 2.5.1
+ numpy-requirement: "'numpy'"
+ - python-version: '3.13'
+ pytorch-version: 2.5.1
+ numpy-requirement: "'numpy'"
steps:
- - uses: conda-incubator/setup-miniconda@v2
+ - uses: conda-incubator/setup-miniconda@v3
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
- - run: pip3 install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- - run: pip install .["dev"]
+ - run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3f5a74b..514f940 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.0.1
+ rev: v5.0.0
hooks:
- id: check-json
- id: end-of-file-fixer
@@ -11,17 +11,17 @@ repos:
- id: check-added-large-files
args: [--maxkb=4096]
- repo: https://github.com/psf/black
- rev: 23.7.0
+ rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
- rev: 5.12.0
+ rev: 6.0.0
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"]
- repo: https://github.com/pycqa/flake8.git
- rev: 6.0.0
+ rev: 7.1.1
hooks:
- id: flake8
types: [python]
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5895541..0876010 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,40 @@
# CHANGELOG
+## [v20250625](https://github.com/openai/whisper/releases/tag/v20250625)
+
+* Fix: Update torch.load to use weights_only=True to prevent security w… ([#2451](https://github.com/openai/whisper/pull/2451))
+* Fix: Ensure DTW cost tensor is on the same device as input tensor ([#2561](https://github.com/openai/whisper/pull/2561))
+* docs: updated README to specify translation model limitation ([#2547](https://github.com/openai/whisper/pull/2547))
+* Fixed triton kernel update to support latest triton versions ([#2588](https://github.com/openai/whisper/pull/2588))
+* Fix: GitHub display errors for Jupyter notebooks ([#2589](https://github.com/openai/whisper/pull/2589))
+* Bump the github-actions group with 3 updates ([#2592](https://github.com/openai/whisper/pull/2592))
+* Keep GitHub Actions up to date with GitHub's Dependabot ([#2486](https://github.com/openai/whisper/pull/2486))
+* pre-commit: Upgrade black v25.1.0 and isort v6.0.0 ([#2514](https://github.com/openai/whisper/pull/2514))
+* GitHub Actions: Add Python 3.13 to the testing ([#2487](https://github.com/openai/whisper/pull/2487))
+* PEP 621: Migrate from setup.py to pyproject.toml ([#2435](https://github.com/openai/whisper/pull/2435))
+* pre-commit autoupdate && pre-commit run --all-files ([#2484](https://github.com/openai/whisper/pull/2484))
+* Upgrade GitHub Actions ([#2430](https://github.com/openai/whisper/pull/2430))
+* Bugfix: Illogical "Avoid computing higher temperatures on no_speech" ([#1903](https://github.com/openai/whisper/pull/1903))
+* Updating README and doc strings to reflect that n_mels can now be 128 ([#2049](https://github.com/openai/whisper/pull/2049))
+* fix typo data/README.md ([#2433](https://github.com/openai/whisper/pull/2433))
+* Update README.md ([#2379](https://github.com/openai/whisper/pull/2379))
+* Add option to carry initial_prompt with the sliding window ([#2343](https://github.com/openai/whisper/pull/2343))
+* more pytorch versions in tests ([#2408](https://github.com/openai/whisper/pull/2408))
+
+## [v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
+
+* allowing numpy 2 in tests ([#2362](https://github.com/openai/whisper/pull/2362))
+* large-v3-turbo model ([#2361](https://github.com/openai/whisper/pull/2361))
+* test on python/pytorch versions up to 3.12 and 2.4.1 ([#2360](https://github.com/openai/whisper/pull/2360))
+* using sdpa if available ([#2359](https://github.com/openai/whisper/pull/2359))
+
+## [v20240927](https://github.com/openai/whisper/releases/tag/v20240927)
+
+* pinning numpy<2 in tests ([#2332](https://github.com/openai/whisper/pull/2332))
+* Relax triton requirements for compatibility with pytorch 2.4 and newer ([#2307](https://github.com/openai/whisper/pull/2307))
+* Skip silence around hallucinations ([#1838](https://github.com/openai/whisper/pull/1838))
+* Fix triton env marker ([#1887](https://github.com/openai/whisper/pull/1887))
+
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))
diff --git a/README.md b/README.md
index afca9c9..196b48f 100644
--- a/README.md
+++ b/README.md
@@ -57,41 +57,55 @@ pip install setuptools-rust
## Available models and languages
-There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model; actual speed may vary depending on many factors including the available hardware.
+There are six model sizes, four with English-only versions, offering speed and accuracy tradeoffs.
+Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model.
+The relative speeds below are measured by transcribing English speech on a A100, and the real-world speed may vary significantly depending on many factors including the language, the speaking speed, and the available hardware.
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
-| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x |
-| base | 74 M | `base.en` | `base` | ~1 GB | ~16x |
-| small | 244 M | `small.en` | `small` | ~2 GB | ~6x |
+| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~10x |
+| base | 74 M | `base.en` | `base` | ~1 GB | ~7x |
+| small | 244 M | `small.en` | `small` | ~2 GB | ~4x |
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
+| turbo | 809 M | N/A | `turbo` | ~6 GB | ~8x |
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
+Additionally, the `turbo` model is an optimized version of `large-v3` that offers faster transcription speed with a minimal degradation in accuracy.
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.

-
-
## Command-line usage
-The following command will transcribe speech in audio files, using the `medium` model:
+The following command will transcribe speech in audio files, using the `turbo` model:
- whisper audio.flac audio.mp3 audio.wav --model medium
+```bash
+whisper audio.flac audio.mp3 audio.wav --model turbo
+```
-The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
+The default setting (which selects the `turbo` model) works well for transcribing English. However, **the `turbo` model is not trained for translation tasks**. If you need to **translate non-English speech into English**, use one of the **multilingual models** (`tiny`, `base`, `small`, `medium`, `large`) instead of `turbo`.
- whisper japanese.wav --language Japanese
+For example, to transcribe an audio file containing non-English speech, you can specify the language:
-Adding `--task translate` will translate the speech into English:
+```bash
+whisper japanese.wav --language Japanese
+```
- whisper japanese.wav --language Japanese --task translate
+To **translate** speech into English, use:
+
+```bash
+whisper japanese.wav --model medium --language Japanese --task translate
+```
+
+> **Note:** The `turbo` model will return the original language even if `--task translate` is specified. Use `medium` or `large` for the best translation results.
Run the following to view all available options:
- whisper --help
+```bash
+whisper --help
+```
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
@@ -103,7 +117,7 @@ Transcription can also be performed within Python:
```python
import whisper
-model = whisper.load_model("base")
+model = whisper.load_model("turbo")
result = model.transcribe("audio.mp3")
print(result["text"])
```
@@ -115,14 +129,14 @@ Below is an example usage of `whisper.detect_language()` and `whisper.decode()`
```python
import whisper
-model = whisper.load_model("base")
+model = whisper.load_model("turbo")
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("audio.mp3")
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
-mel = whisper.log_mel_spectrogram(audio).to(model.device)
+mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
diff --git a/data/README.md b/data/README.md
index 3b4aea1..fcb3200 100644
--- a/data/README.md
+++ b/data/README.md
@@ -45,7 +45,7 @@ We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challen
### AMI-IHM, AMI-SDM1
-We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
+We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 and 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
## Long-form English-only datasets
diff --git a/model-card.md b/model-card.md
index 3c041a1..291bc4b 100644
--- a/model-card.md
+++ b/model-card.md
@@ -16,13 +16,15 @@ The Whisper models are trained for speech recognition and translation tasks, cap
| small | 244 M | ✓ | ✓ |
| medium | 769 M | ✓ | ✓ |
| large | 1550 M | | ✓ |
+| turbo | 798 M | | ✓ |
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
+Additionally, we've added a `turbo` model in September 2024 which is optimized for inference speed.
### Release date
-September 2022 (original series), December 2022 (`large-v2`), and November 2023 (`large-v3`)
+September 2022 (original series), December 2022 (`large-v2`), November 2023 (`large-v3`), September 2024 (`large-v3-turbo`)
### Model type
diff --git a/notebooks/LibriSpeech.ipynb b/notebooks/LibriSpeech.ipynb
index 3d90e65..602bbe4 100644
--- a/notebooks/LibriSpeech.ipynb
+++ b/notebooks/LibriSpeech.ipynb
@@ -949,7 +949,8 @@
"style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
"value": " 164/164 [05:08<00:00, 1.86s/it]"
}
- }
+ },
+ "state": {}
}
}
},
diff --git a/notebooks/Multilingual_ASR.ipynb b/notebooks/Multilingual_ASR.ipynb
index 2d32e0e..f19e3e0 100644
--- a/notebooks/Multilingual_ASR.ipynb
+++ b/notebooks/Multilingual_ASR.ipynb
@@ -4219,7 +4219,8 @@
"_view_name": "StyleView",
"description_width": ""
}
- }
+ },
+ "state": {}
}
}
},
diff --git a/pyproject.toml b/pyproject.toml
index 84637eb..21b90e7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,50 @@
+[build-system]
+build-backend = "setuptools.build_meta"
+
+requires = [ "setuptools>=61.2" ]
+
+[project]
+name = "openai-whisper"
+description = "Robust Speech Recognition via Large-Scale Weak Supervision"
+readme.content-type = "text/markdown"
+readme.file = "README.md"
+license = { text = "MIT" }
+authors = [ { name = "OpenAI" } ]
+requires-python = ">=3.8"
+classifiers = [
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+]
+dynamic = [ "version" ]
+dependencies = [
+ "more-itertools",
+ "numba",
+ "numpy",
+ "tiktoken",
+ "torch",
+ "tqdm",
+ "triton>=2; (platform_machine=='x86_64' and sys_platform=='linux') or sys_platform=='linux2'",
+]
+optional-dependencies.dev = [ "black", "flake8", "isort", "pytest", "scipy" ]
+urls = { Homepage = "https://github.com/openai/whisper" }
+scripts.whisper = "whisper.transcribe:cli"
+
+[tool.setuptools]
+py-modules = [ "whisper" ]
+include-package-data = true
+
+[tool.setuptools.dynamic]
+version = { attr = "whisper.version.__version__" }
+
+[tool.setuptools.packages.find]
+exclude = [ "tests*" ]
+namespaces = false
+
[tool.black]
[tool.isort]
@@ -5,4 +52,3 @@ profile = "black"
include_trailing_comma = true
line_length = 88
multi_line_output = 3
-
diff --git a/requirements.txt b/requirements.txt
index 62f5f9d..8ee5920 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,4 +4,4 @@ torch
tqdm
more-itertools
tiktoken
-triton>=2.0.0,<3;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
+triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 183b527..0000000
--- a/setup.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import platform
-import sys
-from pathlib import Path
-
-import pkg_resources
-from setuptools import find_packages, setup
-
-
-def read_version(fname="whisper/version.py"):
- exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
- return locals()["__version__"]
-
-
-requirements = []
-if sys.platform.startswith("linux") and platform.machine() == "x86_64":
- requirements.append("triton>=2.0.0,<3")
-
-setup(
- name="openai-whisper",
- py_modules=["whisper"],
- version=read_version(),
- description="Robust Speech Recognition via Large-Scale Weak Supervision",
- long_description=open("README.md", encoding="utf-8").read(),
- long_description_content_type="text/markdown",
- readme="README.md",
- python_requires=">=3.8",
- author="OpenAI",
- url="https://github.com/openai/whisper",
- license="MIT",
- packages=find_packages(exclude=["tests*"]),
- install_requires=[
- str(r)
- for r in pkg_resources.parse_requirements(
- Path(__file__).with_name("requirements.txt").open()
- )
- ],
- entry_points={
- "console_scripts": ["whisper=whisper.transcribe:cli"],
- },
- include_package_data=True,
- extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
-)
diff --git a/whisper/__init__.py b/whisper/__init__.py
index d7fbba3..f284ec0 100644
--- a/whisper/__init__.py
+++ b/whisper/__init__.py
@@ -27,6 +27,8 @@ _MODELS = {
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
+ "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
+ "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
@@ -44,6 +46,8 @@ _ALIGNMENT_HEADS = {
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
+ "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
+ "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
@@ -143,7 +147,8 @@ def load_model(
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
- checkpoint = torch.load(fp, map_location=device)
+ kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
+ checkpoint = torch.load(fp, map_location=device, **kwargs)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
diff --git a/whisper/audio.py b/whisper/audio.py
index cf6c66a..826250f 100644
--- a/whisper/audio.py
+++ b/whisper/audio.py
@@ -122,7 +122,7 @@ def log_mel_spectrogram(
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
- The number of Mel-frequency filters, only 80 is supported
+ The number of Mel-frequency filters, only 80 and 128 are supported
padding: int
Number of zero samples to pad to the right
@@ -132,7 +132,7 @@ def log_mel_spectrogram(
Returns
-------
- torch.Tensor, shape = (80, n_frames)
+ torch.Tensor, shape = (n_mels, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
diff --git a/whisper/model.py b/whisper/model.py
index a678283..e537447 100644
--- a/whisper/model.py
+++ b/whisper/model.py
@@ -1,7 +1,8 @@
import base64
import gzip
+from contextlib import contextmanager
from dataclasses import dataclass
-from typing import Dict, Iterable, Optional
+from typing import Dict, Iterable, Optional, Tuple
import numpy as np
import torch
@@ -12,6 +13,14 @@ from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
+try:
+ from torch.nn.functional import scaled_dot_product_attention
+
+ SDPA_AVAILABLE = True
+except (ImportError, RuntimeError, OSError):
+ scaled_dot_product_attention = None
+ SDPA_AVAILABLE = False
+
@dataclass
class ModelDimensions:
@@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+@contextmanager
+def disable_sdpa():
+ prev_state = MultiHeadAttention.use_sdpa
+ try:
+ MultiHeadAttention.use_sdpa = False
+ yield
+ finally:
+ MultiHeadAttention.use_sdpa = prev_state
+
+
class MultiHeadAttention(nn.Module):
+ use_sdpa = True
+
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
@@ -92,20 +113,30 @@ class MultiHeadAttention(nn.Module):
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
- ):
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
- q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
- k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
- qk = q @ k
- if mask is not None:
- qk = qk + mask[:n_ctx, :n_ctx]
- qk = qk.float()
+ if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
+ a = scaled_dot_product_attention(
+ q, k, v, is_causal=mask is not None and n_ctx > 1
+ )
+ out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
+ qk = None
+ else:
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
+ if mask is not None:
+ qk = qk + mask[:n_ctx, :n_ctx]
+ qk = qk.float()
- w = F.softmax(qk, dim=-1).to(q.dtype)
- return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+ w = F.softmax(qk, dim=-1).to(q.dtype)
+ out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
+ qk = qk.detach()
+
+ return out, qk
class ResidualAttentionBlock(nn.Module):
diff --git a/whisper/normalizers/basic.py b/whisper/normalizers/basic.py
index a824032..8690ae7 100644
--- a/whisper/normalizers/basic.py
+++ b/whisper/normalizers/basic.py
@@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
- c
- if c in keep
- else ADDITIONAL_DIACRITICS[c]
- if c in ADDITIONAL_DIACRITICS
- else ""
- if unicodedata.category(c) == "Mn"
- else " "
- if unicodedata.category(c)[0] in "MSP"
- else c
+ (
+ c
+ if c in keep
+ else (
+ ADDITIONAL_DIACRITICS[c]
+ if c in ADDITIONAL_DIACRITICS
+ else (
+ ""
+ if unicodedata.category(c) == "Mn"
+ else " " if unicodedata.category(c)[0] in "MSP" else c
+ )
+ )
+ )
for c in unicodedata.normalize("NFKD", s)
)
diff --git a/whisper/timing.py b/whisper/timing.py
index e7604fa..233b3f6 100644
--- a/whisper/timing.py
+++ b/whisper/timing.py
@@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
- cost = cost.cuda()
+ cost = cost.to(x.device)
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
@@ -191,7 +191,9 @@ def find_alignment(
for i, block in enumerate(model.decoder.blocks)
]
- with torch.no_grad():
+ from .model import disable_sdpa
+
+ with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
diff --git a/whisper/transcribe.py b/whisper/transcribe.py
index df063cb..5928d20 100644
--- a/whisper/transcribe.py
+++ b/whisper/transcribe.py
@@ -47,6 +47,7 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
+ carry_initial_prompt: bool = False,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
@@ -106,6 +107,11 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
+ carry_initial_prompt: bool
+ If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
+ `decode()` call. If there is not enough context space at the start of the prompt, it is
+ left-sliced to make space.
+
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
@@ -212,6 +218,8 @@ def transcribe(
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
+ and logprob_threshold is not None
+ and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = False # silence
if not needs_fallback:
@@ -231,9 +239,11 @@ def transcribe(
all_segments = []
prompt_reset_since = 0
+ remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
+ remaining_prompt_length -= len(initial_prompt_tokens)
else:
initial_prompt_tokens = []
@@ -279,7 +289,13 @@ def transcribe(
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
- decode_options["prompt"] = all_tokens[prompt_reset_since:]
+ if carry_initial_prompt:
+ nignored = max(len(initial_prompt_tokens), prompt_reset_since)
+ remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
+ decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
+ else:
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
+
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
@@ -516,7 +532,7 @@ def cli():
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
- parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
+ parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
@@ -534,6 +550,8 @@ def cli():
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
+ parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py
index edd4564..13d417b 100644
--- a/whisper/triton_ops.py
+++ b/whisper/triton_ops.py
@@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn)
- kernel.src = kernel.src.replace(
+ new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join(
[
@@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
]
),
)
- kernel.src = kernel.src.replace(
+
+ new_kernel = new_kernel.replace(
" BUBBLESORT_HERE",
"\n\n".join(
[
@@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
]
),
)
- kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
+
+ new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
+
+ if hasattr(kernel, "_unsafe_update_src") is True:
+ kernel._unsafe_update_src(new_kernel)
+ kernel.hash = None
+ else:
+ kernel.src = new_kernel
return kernel
diff --git a/whisper/utils.py b/whisper/utils.py
index 9b9b138..13792f7 100644
--- a/whisper/utils.py
+++ b/whisper/utils.py
@@ -209,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
yield start, end, "".join(
[
- re.sub(r"^(\s*)(.*)$", r"\1\2", word)
- if j == i
- else word
+ (
+ re.sub(r"^(\s*)(.*)$", r"\1\2", word)
+ if j == i
+ else word
+ )
for j, word in enumerate(all_words)
]
)
diff --git a/whisper/version.py b/whisper/version.py
index c96dd9c..67426aa 100644
--- a/whisper/version.py
+++ b/whisper/version.py
@@ -1 +1 @@
-__version__ = "20231117"
+__version__ = "20250625"