Merge branch 'main' into main

This commit is contained in:
Erfan Tarighi 2025-10-02 23:41:34 +02:00 committed by GitHub
commit f38acdff61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 283 additions and 120 deletions

13
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,13 @@
# Keep GitHub Actions up to date with GitHub's Dependabot...
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
version: 2
updates:
- package-ecosystem: github-actions
directory: /
groups:
github-actions:
patterns:
- "*" # Group all Actions updates into a single larger pull request
schedule:
interval: weekly

View File

@ -8,23 +8,23 @@ jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions-ecosystem/action-regex-match@v2 - uses: actions-ecosystem/action-regex-match@v2
id: regex-match id: regex-match
with: with:
text: ${{ github.event.head_commit.message }} text: ${{ github.event.head_commit.message }}
regex: '^Release ([^ ]+)' regex: '^Release ([^ ]+)'
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: '3.8' python-version: '3.12'
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install setuptools wheel twine pip install setuptools wheel twine build
- name: Release - name: Release
if: ${{ steps.regex-match.outputs.match != '' }} if: ${{ steps.regex-match.outputs.match != '' }}
uses: softprops/action-gh-release@v1 uses: softprops/action-gh-release@v2
with: with:
tag_name: v${{ steps.regex-match.outputs.group1 }} tag_name: v${{ steps.regex-match.outputs.group1 }}
- name: Build and publish - name: Build and publish
@ -33,5 +33,5 @@ jobs:
TWINE_USERNAME: __token__ TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: | run: |
python setup.py sdist python -m build --sdist
twine upload dist/* twine upload dist/*

View File

@ -11,19 +11,19 @@ jobs:
pre-commit: pre-commit:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Fetch base branch - name: Fetch base branch
run: git fetch origin ${{ github.base_ref }} run: git fetch origin ${{ github.base_ref }}
- uses: actions/setup-python@v4 - uses: actions/setup-python@v5
with: with:
python-version: "3.8" python-version: "3.9"
architecture: x64 architecture: x64
- name: Get pip cache dir - name: Get pip cache dir
id: pip-cache id: pip-cache
run: | run: |
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip/pre-commit cache - name: pip/pre-commit cache
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: | path: |
${{ steps.pip-cache.outputs.dir }} ${{ steps.pip-cache.outputs.dir }}
@ -33,24 +33,47 @@ jobs:
${{ runner.os }}-pip-pre-commit ${{ runner.os }}-pip-pre-commit
- name: pre-commit - name: pre-commit
run: | run: |
pip install -U pre-commit pip install --upgrade pre-commit
pre-commit install --install-hooks pre-commit install --install-hooks
pre-commit run --all-files pre-commit run --all-files
whisper-test: whisper-test:
needs: pre-commit needs: pre-commit
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false
matrix: matrix:
python-version: ['3.8', '3.9', '3.10', '3.11'] include:
pytorch-version: [1.13.1, 2.0.0] - python-version: '3.8'
exclude: pytorch-version: 1.10.1
- python-version: '3.11' numpy-requirement: "'numpy<2'"
- python-version: '3.8'
pytorch-version: 1.13.1 pytorch-version: 1.13.1
numpy-requirement: "'numpy<2'"
- python-version: '3.8'
pytorch-version: 2.0.1
numpy-requirement: "'numpy<2'"
- python-version: '3.9'
pytorch-version: 2.1.2
numpy-requirement: "'numpy<2'"
- python-version: '3.10'
pytorch-version: 2.2.2
numpy-requirement: "'numpy<2'"
- python-version: '3.11'
pytorch-version: 2.3.1
numpy-requirement: "'numpy'"
- python-version: '3.12'
pytorch-version: 2.4.1
numpy-requirement: "'numpy'"
- python-version: '3.12'
pytorch-version: 2.5.1
numpy-requirement: "'numpy'"
- python-version: '3.13'
pytorch-version: 2.5.1
numpy-requirement: "'numpy'"
steps: steps:
- uses: conda-incubator/setup-miniconda@v2 - uses: conda-incubator/setup-miniconda@v3
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} - run: conda install -n test ffmpeg python=${{ matrix.python-version }}
- run: pip3 install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu - uses: actions/checkout@v4
- uses: actions/checkout@v3
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install .["dev"] - run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

View File

@ -1,6 +1,6 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1 rev: v5.0.0
hooks: hooks:
- id: check-json - id: check-json
- id: end-of-file-fixer - id: end-of-file-fixer
@ -11,17 +11,17 @@ repos:
- id: check-added-large-files - id: check-added-large-files
args: [--maxkb=4096] args: [--maxkb=4096]
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.7.0 rev: 25.1.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 6.0.0
hooks: hooks:
- id: isort - id: isort
name: isort (python) name: isort (python)
args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"] args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"]
- repo: https://github.com/pycqa/flake8.git - repo: https://github.com/pycqa/flake8.git
rev: 6.0.0 rev: 7.1.1
hooks: hooks:
- id: flake8 - id: flake8
types: [python] types: [python]

View File

@ -1,5 +1,40 @@
# CHANGELOG # CHANGELOG
## [v20250625](https://github.com/openai/whisper/releases/tag/v20250625)
* Fix: Update torch.load to use weights_only=True to prevent security w… ([#2451](https://github.com/openai/whisper/pull/2451))
* Fix: Ensure DTW cost tensor is on the same device as input tensor ([#2561](https://github.com/openai/whisper/pull/2561))
* docs: updated README to specify translation model limitation ([#2547](https://github.com/openai/whisper/pull/2547))
* Fixed triton kernel update to support latest triton versions ([#2588](https://github.com/openai/whisper/pull/2588))
* Fix: GitHub display errors for Jupyter notebooks ([#2589](https://github.com/openai/whisper/pull/2589))
* Bump the github-actions group with 3 updates ([#2592](https://github.com/openai/whisper/pull/2592))
* Keep GitHub Actions up to date with GitHub's Dependabot ([#2486](https://github.com/openai/whisper/pull/2486))
* pre-commit: Upgrade black v25.1.0 and isort v6.0.0 ([#2514](https://github.com/openai/whisper/pull/2514))
* GitHub Actions: Add Python 3.13 to the testing ([#2487](https://github.com/openai/whisper/pull/2487))
* PEP 621: Migrate from setup.py to pyproject.toml ([#2435](https://github.com/openai/whisper/pull/2435))
* pre-commit autoupdate && pre-commit run --all-files ([#2484](https://github.com/openai/whisper/pull/2484))
* Upgrade GitHub Actions ([#2430](https://github.com/openai/whisper/pull/2430))
* Bugfix: Illogical "Avoid computing higher temperatures on no_speech" ([#1903](https://github.com/openai/whisper/pull/1903))
* Updating README and doc strings to reflect that n_mels can now be 128 ([#2049](https://github.com/openai/whisper/pull/2049))
* fix typo data/README.md ([#2433](https://github.com/openai/whisper/pull/2433))
* Update README.md ([#2379](https://github.com/openai/whisper/pull/2379))
* Add option to carry initial_prompt with the sliding window ([#2343](https://github.com/openai/whisper/pull/2343))
* more pytorch versions in tests ([#2408](https://github.com/openai/whisper/pull/2408))
## [v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
* allowing numpy 2 in tests ([#2362](https://github.com/openai/whisper/pull/2362))
* large-v3-turbo model ([#2361](https://github.com/openai/whisper/pull/2361))
* test on python/pytorch versions up to 3.12 and 2.4.1 ([#2360](https://github.com/openai/whisper/pull/2360))
* using sdpa if available ([#2359](https://github.com/openai/whisper/pull/2359))
## [v20240927](https://github.com/openai/whisper/releases/tag/v20240927)
* pinning numpy<2 in tests ([#2332](https://github.com/openai/whisper/pull/2332))
* Relax triton requirements for compatibility with pytorch 2.4 and newer ([#2307](https://github.com/openai/whisper/pull/2307))
* Skip silence around hallucinations ([#1838](https://github.com/openai/whisper/pull/1838))
* Fix triton env marker ([#1887](https://github.com/openai/whisper/pull/1887))
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117) ## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802)) * Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))

View File

@ -57,41 +57,55 @@ pip install setuptools-rust
## Available models and languages ## Available models and languages
There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model; actual speed may vary depending on many factors including the available hardware. There are six model sizes, four with English-only versions, offering speed and accuracy tradeoffs.
Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model.
The relative speeds below are measured by transcribing English speech on a A100, and the real-world speed may vary significantly depending on many factors including the language, the speaking speed, and the available hardware.
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed | | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:| |:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x | | tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~10x |
| base | 74 M | `base.en` | `base` | ~1 GB | ~16x | | base | 74 M | `base.en` | `base` | ~1 GB | ~7x |
| small | 244 M | `small.en` | `small` | ~2 GB | ~6x | | small | 244 M | `small.en` | `small` | ~2 GB | ~4x |
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x | | medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
| large | 1550 M | N/A | `large` | ~10 GB | 1x | | large | 1550 M | N/A | `large` | ~10 GB | 1x |
| turbo | 809 M | N/A | `turbo` | ~6 GB | ~8x |
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models. The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
Additionally, the `turbo` model is an optimized version of `large-v3` that offers faster transcription speed with a minimal degradation in accuracy.
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3. Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
![WER breakdown by language](https://github.com/openai/whisper/assets/266841/f4619d66-1058-4005-8f67-a9d811b77c62) ![WER breakdown by language](https://github.com/openai/whisper/assets/266841/f4619d66-1058-4005-8f67-a9d811b77c62)
## Command-line usage ## Command-line usage
The following command will transcribe speech in audio files, using the `medium` model: The following command will transcribe speech in audio files, using the `turbo` model:
whisper audio.flac audio.mp3 audio.wav --model medium ```bash
whisper audio.flac audio.mp3 audio.wav --model turbo
```
The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option: The default setting (which selects the `turbo` model) works well for transcribing English. However, **the `turbo` model is not trained for translation tasks**. If you need to **translate non-English speech into English**, use one of the **multilingual models** (`tiny`, `base`, `small`, `medium`, `large`) instead of `turbo`.
whisper japanese.wav --language Japanese For example, to transcribe an audio file containing non-English speech, you can specify the language:
Adding `--task translate` will translate the speech into English: ```bash
whisper japanese.wav --language Japanese
```
whisper japanese.wav --language Japanese --task translate To **translate** speech into English, use:
```bash
whisper japanese.wav --model medium --language Japanese --task translate
```
> **Note:** The `turbo` model will return the original language even if `--task translate` is specified. Use `medium` or `large` for the best translation results.
Run the following to view all available options: Run the following to view all available options:
whisper --help ```bash
whisper --help
```
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages. See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
@ -103,7 +117,7 @@ Transcription can also be performed within Python:
```python ```python
import whisper import whisper
model = whisper.load_model("base") model = whisper.load_model("turbo")
result = model.transcribe("audio.mp3") result = model.transcribe("audio.mp3")
print(result["text"]) print(result["text"])
``` ```
@ -115,14 +129,14 @@ Below is an example usage of `whisper.detect_language()` and `whisper.decode()`
```python ```python
import whisper import whisper
model = whisper.load_model("base") model = whisper.load_model("turbo")
# load audio and pad/trim it to fit 30 seconds # load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("audio.mp3") audio = whisper.load_audio("audio.mp3")
audio = whisper.pad_or_trim(audio) audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model # make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device) mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
# detect the spoken language # detect the spoken language
_, probs = model.detect_language(mel) _, probs = model.detect_language(mel)

View File

@ -45,7 +45,7 @@ We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challen
### AMI-IHM, AMI-SDM1 ### AMI-IHM, AMI-SDM1
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b). We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 and 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
## Long-form English-only datasets ## Long-form English-only datasets

View File

@ -16,13 +16,15 @@ The Whisper models are trained for speech recognition and translation tasks, cap
| small | 244 M | ✓ | ✓ | | small | 244 M | ✓ | ✓ |
| medium | 769 M | ✓ | ✓ | | medium | 769 M | ✓ | ✓ |
| large | 1550 M | | ✓ | | large | 1550 M | | ✓ |
| turbo | 798 M | | ✓ |
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023. In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
Additionally, we've added a `turbo` model in September 2024 which is optimized for inference speed.
### Release date ### Release date
September 2022 (original series), December 2022 (`large-v2`), and November 2023 (`large-v3`) September 2022 (original series), December 2022 (`large-v2`), November 2023 (`large-v3`), September 2024 (`large-v3-turbo`)
### Model type ### Model type

View File

@ -949,7 +949,8 @@
"style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588", "style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
"value": " 164/164 [05:08&lt;00:00, 1.86s/it]" "value": " 164/164 [05:08&lt;00:00, 1.86s/it]"
} }
} },
"state": {}
} }
} }
}, },

View File

@ -4219,7 +4219,8 @@
"_view_name": "StyleView", "_view_name": "StyleView",
"description_width": "" "description_width": ""
} }
} },
"state": {}
} }
} }
}, },

View File

@ -1,3 +1,50 @@
[build-system]
build-backend = "setuptools.build_meta"
requires = [ "setuptools>=61.2" ]
[project]
name = "openai-whisper"
description = "Robust Speech Recognition via Large-Scale Weak Supervision"
readme.content-type = "text/markdown"
readme.file = "README.md"
license = { text = "MIT" }
authors = [ { name = "OpenAI" } ]
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dynamic = [ "version" ]
dependencies = [
"more-itertools",
"numba",
"numpy",
"tiktoken",
"torch",
"tqdm",
"triton>=2; (platform_machine=='x86_64' and sys_platform=='linux') or sys_platform=='linux2'",
]
optional-dependencies.dev = [ "black", "flake8", "isort", "pytest", "scipy" ]
urls = { Homepage = "https://github.com/openai/whisper" }
scripts.whisper = "whisper.transcribe:cli"
[tool.setuptools]
py-modules = [ "whisper" ]
include-package-data = true
[tool.setuptools.dynamic]
version = { attr = "whisper.version.__version__" }
[tool.setuptools.packages.find]
exclude = [ "tests*" ]
namespaces = false
[tool.black] [tool.black]
[tool.isort] [tool.isort]
@ -5,4 +52,3 @@ profile = "black"
include_trailing_comma = true include_trailing_comma = true
line_length = 88 line_length = 88
multi_line_output = 3 multi_line_output = 3

View File

@ -4,4 +4,4 @@ torch
tqdm tqdm
more-itertools more-itertools
tiktoken tiktoken
triton>=2.0.0,<3;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2" triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"

View File

@ -1,42 +0,0 @@
import platform
import sys
from pathlib import Path
import pkg_resources
from setuptools import find_packages, setup
def read_version(fname="whisper/version.py"):
exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
return locals()["__version__"]
requirements = []
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
requirements.append("triton>=2.0.0,<3")
setup(
name="openai-whisper",
py_modules=["whisper"],
version=read_version(),
description="Robust Speech Recognition via Large-Scale Weak Supervision",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
readme="README.md",
python_requires=">=3.8",
author="OpenAI",
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
Path(__file__).with_name("requirements.txt").open()
)
],
entry_points={
"console_scripts": ["whisper=whisper.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
)

View File

@ -27,6 +27,8 @@ _MODELS = {
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
} }
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
@ -44,6 +46,8 @@ _ALIGNMENT_HEADS = {
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
} }
@ -143,7 +147,8 @@ def load_model(
with ( with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp: ) as fp:
checkpoint = torch.load(fp, map_location=device) kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
checkpoint = torch.load(fp, map_location=device, **kwargs)
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])

View File

@ -122,7 +122,7 @@ def log_mel_spectrogram(
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int n_mels: int
The number of Mel-frequency filters, only 80 is supported The number of Mel-frequency filters, only 80 and 128 are supported
padding: int padding: int
Number of zero samples to pad to the right Number of zero samples to pad to the right
@ -132,7 +132,7 @@ def log_mel_spectrogram(
Returns Returns
------- -------
torch.Tensor, shape = (80, n_frames) torch.Tensor, shape = (n_mels, n_frames)
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram
""" """
if not torch.is_tensor(audio): if not torch.is_tensor(audio):

View File

@ -1,7 +1,8 @@
import base64 import base64
import gzip import gzip
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Iterable, Optional from typing import Dict, Iterable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -12,6 +13,14 @@ from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function from .transcribe import transcribe as transcribe_function
try:
from torch.nn.functional import scaled_dot_product_attention
SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
scaled_dot_product_attention = None
SDPA_AVAILABLE = False
@dataclass @dataclass
class ModelDimensions: class ModelDimensions:
@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
use_sdpa = True
def __init__(self, n_state: int, n_head: int): def __init__(self, n_state: int, n_head: int):
super().__init__() super().__init__()
self.n_head = n_head self.n_head = n_head
@ -92,20 +113,30 @@ class MultiHeadAttention(nn.Module):
def qkv_attention( def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25 scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
a = scaled_dot_product_attention(
q, k, v, is_causal=mask is not None and n_ctx > 1
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
qk = None
else:
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if mask is not None: if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx] qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float() qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype) w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()
return out, qk
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):

View File

@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
and drop any diacritics (category 'Mn' and some manual mappings) and drop any diacritics (category 'Mn' and some manual mappings)
""" """
return "".join( return "".join(
(
c c
if c in keep if c in keep
else ADDITIONAL_DIACRITICS[c] else (
ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS if c in ADDITIONAL_DIACRITICS
else "" else (
""
if unicodedata.category(c) == "Mn" if unicodedata.category(c) == "Mn"
else " " else " " if unicodedata.category(c)[0] in "MSP" else c
if unicodedata.category(c)[0] in "MSP" )
else c )
)
for c in unicodedata.normalize("NFKD", s) for c in unicodedata.normalize("NFKD", s)
) )

View File

@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
x_skew = x_skew.T.contiguous() x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0 cost[0, 0] = 0
cost = cost.cuda() cost = cost.to(x.device)
trace = torch.zeros_like(cost, dtype=torch.int32) trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)]( dtw_kernel[(1,)](
@ -191,7 +191,9 @@ def find_alignment(
for i, block in enumerate(model.decoder.blocks) for i, block in enumerate(model.decoder.blocks)
] ]
with torch.no_grad(): from .model import disable_sdpa
with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1) token_probs = sampled_logits.softmax(dim=-1)

View File

@ -47,6 +47,7 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None, initial_prompt: Optional[str] = None,
carry_initial_prompt: bool = False,
word_timestamps: bool = False, word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
@ -106,6 +107,11 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly. to make it more likely to predict those word correctly.
carry_initial_prompt: bool
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
`decode()` call. If there is not enough context space at the start of the prompt, it is
left-sliced to make space.
decode_options: dict decode_options: dict
Keyword arguments to construct `DecodingOptions` instances Keyword arguments to construct `DecodingOptions` instances
@ -212,6 +218,8 @@ def transcribe(
if ( if (
no_speech_threshold is not None no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold and decode_result.no_speech_prob > no_speech_threshold
and logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
): ):
needs_fallback = False # silence needs_fallback = False # silence
if not needs_fallback: if not needs_fallback:
@ -231,9 +239,11 @@ def transcribe(
all_segments = [] all_segments = []
prompt_reset_since = 0 prompt_reset_since = 0
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None: if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens) all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else: else:
initial_prompt_tokens = [] initial_prompt_tokens = []
@ -279,7 +289,13 @@ def transcribe(
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:] decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
@ -516,7 +532,7 @@ def cli():
# fmt: off # fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use") parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
@ -534,6 +550,8 @@ def cli():
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")

View File

@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn) kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace( new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE", " LOAD_ALL_ROWS_HERE",
"\n".join( "\n".join(
[ [
@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
] ]
), ),
) )
kernel.src = kernel.src.replace(
new_kernel = new_kernel.replace(
" BUBBLESORT_HERE", " BUBBLESORT_HERE",
"\n\n".join( "\n\n".join(
[ [
@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
] ]
), ),
) )
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
if hasattr(kernel, "_unsafe_update_src") is True:
kernel._unsafe_update_src(new_kernel)
kernel.hash = None
else:
kernel.src = new_kernel
return kernel return kernel

View File

@ -209,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
yield start, end, "".join( yield start, end, "".join(
[ [
(
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word) re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i if j == i
else word else word
)
for j, word in enumerate(all_words) for j, word in enumerate(all_words)
] ]
) )

View File

@ -1 +1 @@
__version__ = "20231117" __version__ = "20250625"