mirror of
https://github.com/openai/whisper.git
synced 2025-09-13 19:20:10 +00:00
Compare commits
No commits in common. "main" and "v20231117" have entirely different histories.
13
.github/dependabot.yml
vendored
13
.github/dependabot.yml
vendored
@ -1,13 +0,0 @@
|
|||||||
# Keep GitHub Actions up to date with GitHub's Dependabot...
|
|
||||||
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
|
|
||||||
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
|
|
||||||
version: 2
|
|
||||||
updates:
|
|
||||||
- package-ecosystem: github-actions
|
|
||||||
directory: /
|
|
||||||
groups:
|
|
||||||
github-actions:
|
|
||||||
patterns:
|
|
||||||
- "*" # Group all Actions updates into a single larger pull request
|
|
||||||
schedule:
|
|
||||||
interval: weekly
|
|
12
.github/workflows/python-publish.yml
vendored
12
.github/workflows/python-publish.yml
vendored
@ -8,23 +8,23 @@ jobs:
|
|||||||
deploy:
|
deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v3
|
||||||
- uses: actions-ecosystem/action-regex-match@v2
|
- uses: actions-ecosystem/action-regex-match@v2
|
||||||
id: regex-match
|
id: regex-match
|
||||||
with:
|
with:
|
||||||
text: ${{ github.event.head_commit.message }}
|
text: ${{ github.event.head_commit.message }}
|
||||||
regex: '^Release ([^ ]+)'
|
regex: '^Release ([^ ]+)'
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.8'
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install setuptools wheel twine build
|
pip install setuptools wheel twine
|
||||||
- name: Release
|
- name: Release
|
||||||
if: ${{ steps.regex-match.outputs.match != '' }}
|
if: ${{ steps.regex-match.outputs.match != '' }}
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v1
|
||||||
with:
|
with:
|
||||||
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
||||||
- name: Build and publish
|
- name: Build and publish
|
||||||
@ -33,5 +33,5 @@ jobs:
|
|||||||
TWINE_USERNAME: __token__
|
TWINE_USERNAME: __token__
|
||||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
python -m build --sdist
|
python setup.py sdist
|
||||||
twine upload dist/*
|
twine upload dist/*
|
||||||
|
49
.github/workflows/test.yml
vendored
49
.github/workflows/test.yml
vendored
@ -11,19 +11,19 @@ jobs:
|
|||||||
pre-commit:
|
pre-commit:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v3
|
||||||
- name: Fetch base branch
|
- name: Fetch base branch
|
||||||
run: git fetch origin ${{ github.base_ref }}
|
run: git fetch origin ${{ github.base_ref }}
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: "3.9"
|
python-version: "3.8"
|
||||||
architecture: x64
|
architecture: x64
|
||||||
- name: Get pip cache dir
|
- name: Get pip cache dir
|
||||||
id: pip-cache
|
id: pip-cache
|
||||||
run: |
|
run: |
|
||||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||||
- name: pip/pre-commit cache
|
- name: pip/pre-commit cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ steps.pip-cache.outputs.dir }}
|
${{ steps.pip-cache.outputs.dir }}
|
||||||
@ -33,47 +33,24 @@ jobs:
|
|||||||
${{ runner.os }}-pip-pre-commit
|
${{ runner.os }}-pip-pre-commit
|
||||||
- name: pre-commit
|
- name: pre-commit
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade pre-commit
|
pip install -U pre-commit
|
||||||
pre-commit install --install-hooks
|
pre-commit install --install-hooks
|
||||||
pre-commit run --all-files
|
pre-commit run --all-files
|
||||||
whisper-test:
|
whisper-test:
|
||||||
needs: pre-commit
|
needs: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||||
- python-version: '3.8'
|
pytorch-version: [1.13.1, 2.0.0]
|
||||||
pytorch-version: 1.10.1
|
exclude:
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.8'
|
|
||||||
pytorch-version: 1.13.1
|
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.8'
|
|
||||||
pytorch-version: 2.0.1
|
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.9'
|
|
||||||
pytorch-version: 2.1.2
|
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.10'
|
|
||||||
pytorch-version: 2.2.2
|
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.11'
|
- python-version: '3.11'
|
||||||
pytorch-version: 2.3.1
|
pytorch-version: 1.13.1
|
||||||
numpy-requirement: "'numpy'"
|
|
||||||
- python-version: '3.12'
|
|
||||||
pytorch-version: 2.4.1
|
|
||||||
numpy-requirement: "'numpy'"
|
|
||||||
- python-version: '3.12'
|
|
||||||
pytorch-version: 2.5.1
|
|
||||||
numpy-requirement: "'numpy'"
|
|
||||||
- python-version: '3.13'
|
|
||||||
pytorch-version: 2.5.1
|
|
||||||
numpy-requirement: "'numpy'"
|
|
||||||
steps:
|
steps:
|
||||||
- uses: conda-incubator/setup-miniconda@v3
|
- uses: conda-incubator/setup-miniconda@v2
|
||||||
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
|
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
|
||||||
- uses: actions/checkout@v4
|
- run: pip3 install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
- uses: actions/checkout@v3
|
||||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||||
- run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
|
- run: pip install .["dev"]
|
||||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v4.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-json
|
- id: check-json
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@ -11,17 +11,17 @@ repos:
|
|||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
args: [--maxkb=4096]
|
args: [--maxkb=4096]
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 25.1.0
|
rev: 23.7.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 6.0.0
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
name: isort (python)
|
name: isort (python)
|
||||||
args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"]
|
args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"]
|
||||||
- repo: https://github.com/pycqa/flake8.git
|
- repo: https://github.com/pycqa/flake8.git
|
||||||
rev: 7.1.1
|
rev: 6.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
types: [python]
|
types: [python]
|
||||||
|
35
CHANGELOG.md
35
CHANGELOG.md
@ -1,40 +1,5 @@
|
|||||||
# CHANGELOG
|
# CHANGELOG
|
||||||
|
|
||||||
## [v20250625](https://github.com/openai/whisper/releases/tag/v20250625)
|
|
||||||
|
|
||||||
* Fix: Update torch.load to use weights_only=True to prevent security w… ([#2451](https://github.com/openai/whisper/pull/2451))
|
|
||||||
* Fix: Ensure DTW cost tensor is on the same device as input tensor ([#2561](https://github.com/openai/whisper/pull/2561))
|
|
||||||
* docs: updated README to specify translation model limitation ([#2547](https://github.com/openai/whisper/pull/2547))
|
|
||||||
* Fixed triton kernel update to support latest triton versions ([#2588](https://github.com/openai/whisper/pull/2588))
|
|
||||||
* Fix: GitHub display errors for Jupyter notebooks ([#2589](https://github.com/openai/whisper/pull/2589))
|
|
||||||
* Bump the github-actions group with 3 updates ([#2592](https://github.com/openai/whisper/pull/2592))
|
|
||||||
* Keep GitHub Actions up to date with GitHub's Dependabot ([#2486](https://github.com/openai/whisper/pull/2486))
|
|
||||||
* pre-commit: Upgrade black v25.1.0 and isort v6.0.0 ([#2514](https://github.com/openai/whisper/pull/2514))
|
|
||||||
* GitHub Actions: Add Python 3.13 to the testing ([#2487](https://github.com/openai/whisper/pull/2487))
|
|
||||||
* PEP 621: Migrate from setup.py to pyproject.toml ([#2435](https://github.com/openai/whisper/pull/2435))
|
|
||||||
* pre-commit autoupdate && pre-commit run --all-files ([#2484](https://github.com/openai/whisper/pull/2484))
|
|
||||||
* Upgrade GitHub Actions ([#2430](https://github.com/openai/whisper/pull/2430))
|
|
||||||
* Bugfix: Illogical "Avoid computing higher temperatures on no_speech" ([#1903](https://github.com/openai/whisper/pull/1903))
|
|
||||||
* Updating README and doc strings to reflect that n_mels can now be 128 ([#2049](https://github.com/openai/whisper/pull/2049))
|
|
||||||
* fix typo data/README.md ([#2433](https://github.com/openai/whisper/pull/2433))
|
|
||||||
* Update README.md ([#2379](https://github.com/openai/whisper/pull/2379))
|
|
||||||
* Add option to carry initial_prompt with the sliding window ([#2343](https://github.com/openai/whisper/pull/2343))
|
|
||||||
* more pytorch versions in tests ([#2408](https://github.com/openai/whisper/pull/2408))
|
|
||||||
|
|
||||||
## [v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
|
|
||||||
|
|
||||||
* allowing numpy 2 in tests ([#2362](https://github.com/openai/whisper/pull/2362))
|
|
||||||
* large-v3-turbo model ([#2361](https://github.com/openai/whisper/pull/2361))
|
|
||||||
* test on python/pytorch versions up to 3.12 and 2.4.1 ([#2360](https://github.com/openai/whisper/pull/2360))
|
|
||||||
* using sdpa if available ([#2359](https://github.com/openai/whisper/pull/2359))
|
|
||||||
|
|
||||||
## [v20240927](https://github.com/openai/whisper/releases/tag/v20240927)
|
|
||||||
|
|
||||||
* pinning numpy<2 in tests ([#2332](https://github.com/openai/whisper/pull/2332))
|
|
||||||
* Relax triton requirements for compatibility with pytorch 2.4 and newer ([#2307](https://github.com/openai/whisper/pull/2307))
|
|
||||||
* Skip silence around hallucinations ([#1838](https://github.com/openai/whisper/pull/1838))
|
|
||||||
* Fix triton env marker ([#1887](https://github.com/openai/whisper/pull/1887))
|
|
||||||
|
|
||||||
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
|
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
|
||||||
|
|
||||||
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))
|
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))
|
||||||
|
46
README.md
46
README.md
@ -57,55 +57,41 @@ pip install setuptools-rust
|
|||||||
|
|
||||||
## Available models and languages
|
## Available models and languages
|
||||||
|
|
||||||
There are six model sizes, four with English-only versions, offering speed and accuracy tradeoffs.
|
There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model; actual speed may vary depending on many factors including the available hardware.
|
||||||
Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model.
|
|
||||||
The relative speeds below are measured by transcribing English speech on a A100, and the real-world speed may vary significantly depending on many factors including the language, the speaking speed, and the available hardware.
|
|
||||||
|
|
||||||
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
||||||
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
|
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
|
||||||
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~10x |
|
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x |
|
||||||
| base | 74 M | `base.en` | `base` | ~1 GB | ~7x |
|
| base | 74 M | `base.en` | `base` | ~1 GB | ~16x |
|
||||||
| small | 244 M | `small.en` | `small` | ~2 GB | ~4x |
|
| small | 244 M | `small.en` | `small` | ~2 GB | ~6x |
|
||||||
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
||||||
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
||||||
| turbo | 809 M | N/A | `turbo` | ~6 GB | ~8x |
|
|
||||||
|
|
||||||
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||||
Additionally, the `turbo` model is an optimized version of `large-v3` that offers faster transcription speed with a minimal degradation in accuracy.
|
|
||||||
|
|
||||||
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
|
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Command-line usage
|
## Command-line usage
|
||||||
|
|
||||||
The following command will transcribe speech in audio files, using the `turbo` model:
|
The following command will transcribe speech in audio files, using the `medium` model:
|
||||||
|
|
||||||
```bash
|
whisper audio.flac audio.mp3 audio.wav --model medium
|
||||||
whisper audio.flac audio.mp3 audio.wav --model turbo
|
|
||||||
```
|
|
||||||
|
|
||||||
The default setting (which selects the `turbo` model) works well for transcribing English. However, **the `turbo` model is not trained for translation tasks**. If you need to **translate non-English speech into English**, use one of the **multilingual models** (`tiny`, `base`, `small`, `medium`, `large`) instead of `turbo`.
|
The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
|
||||||
|
|
||||||
For example, to transcribe an audio file containing non-English speech, you can specify the language:
|
whisper japanese.wav --language Japanese
|
||||||
|
|
||||||
```bash
|
Adding `--task translate` will translate the speech into English:
|
||||||
whisper japanese.wav --language Japanese
|
|
||||||
```
|
|
||||||
|
|
||||||
To **translate** speech into English, use:
|
whisper japanese.wav --language Japanese --task translate
|
||||||
|
|
||||||
```bash
|
|
||||||
whisper japanese.wav --model medium --language Japanese --task translate
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Note:** The `turbo` model will return the original language even if `--task translate` is specified. Use `medium` or `large` for the best translation results.
|
|
||||||
|
|
||||||
Run the following to view all available options:
|
Run the following to view all available options:
|
||||||
|
|
||||||
```bash
|
whisper --help
|
||||||
whisper --help
|
|
||||||
```
|
|
||||||
|
|
||||||
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
|
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
|
||||||
|
|
||||||
@ -117,7 +103,7 @@ Transcription can also be performed within Python:
|
|||||||
```python
|
```python
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
model = whisper.load_model("turbo")
|
model = whisper.load_model("base")
|
||||||
result = model.transcribe("audio.mp3")
|
result = model.transcribe("audio.mp3")
|
||||||
print(result["text"])
|
print(result["text"])
|
||||||
```
|
```
|
||||||
@ -129,14 +115,14 @@ Below is an example usage of `whisper.detect_language()` and `whisper.decode()`
|
|||||||
```python
|
```python
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
model = whisper.load_model("turbo")
|
model = whisper.load_model("base")
|
||||||
|
|
||||||
# load audio and pad/trim it to fit 30 seconds
|
# load audio and pad/trim it to fit 30 seconds
|
||||||
audio = whisper.load_audio("audio.mp3")
|
audio = whisper.load_audio("audio.mp3")
|
||||||
audio = whisper.pad_or_trim(audio)
|
audio = whisper.pad_or_trim(audio)
|
||||||
|
|
||||||
# make log-Mel spectrogram and move to the same device as the model
|
# make log-Mel spectrogram and move to the same device as the model
|
||||||
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
|
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
||||||
|
|
||||||
# detect the spoken language
|
# detect the spoken language
|
||||||
_, probs = model.detect_language(mel)
|
_, probs = model.detect_language(mel)
|
||||||
|
@ -45,7 +45,7 @@ We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challen
|
|||||||
|
|
||||||
### AMI-IHM, AMI-SDM1
|
### AMI-IHM, AMI-SDM1
|
||||||
|
|
||||||
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 and 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
|
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
|
||||||
|
|
||||||
|
|
||||||
## Long-form English-only datasets
|
## Long-form English-only datasets
|
||||||
|
@ -16,15 +16,13 @@ The Whisper models are trained for speech recognition and translation tasks, cap
|
|||||||
| small | 244 M | ✓ | ✓ |
|
| small | 244 M | ✓ | ✓ |
|
||||||
| medium | 769 M | ✓ | ✓ |
|
| medium | 769 M | ✓ | ✓ |
|
||||||
| large | 1550 M | | ✓ |
|
| large | 1550 M | | ✓ |
|
||||||
| turbo | 798 M | | ✓ |
|
|
||||||
|
|
||||||
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
|
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
|
||||||
Additionally, we've added a `turbo` model in September 2024 which is optimized for inference speed.
|
|
||||||
|
|
||||||
|
|
||||||
### Release date
|
### Release date
|
||||||
|
|
||||||
September 2022 (original series), December 2022 (`large-v2`), November 2023 (`large-v3`), September 2024 (`large-v3-turbo`)
|
September 2022 (original series), December 2022 (`large-v2`), and November 2023 (`large-v3`)
|
||||||
|
|
||||||
### Model type
|
### Model type
|
||||||
|
|
||||||
|
3
notebooks/LibriSpeech.ipynb
generated
3
notebooks/LibriSpeech.ipynb
generated
@ -949,8 +949,7 @@
|
|||||||
"style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
|
"style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
|
||||||
"value": " 164/164 [05:08<00:00, 1.86s/it]"
|
"value": " 164/164 [05:08<00:00, 1.86s/it]"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"state": {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
3
notebooks/Multilingual_ASR.ipynb
generated
3
notebooks/Multilingual_ASR.ipynb
generated
@ -4219,8 +4219,7 @@
|
|||||||
"_view_name": "StyleView",
|
"_view_name": "StyleView",
|
||||||
"description_width": ""
|
"description_width": ""
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"state": {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1,50 +1,3 @@
|
|||||||
[build-system]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
requires = [ "setuptools>=61.2" ]
|
|
||||||
|
|
||||||
[project]
|
|
||||||
name = "openai-whisper"
|
|
||||||
description = "Robust Speech Recognition via Large-Scale Weak Supervision"
|
|
||||||
readme.content-type = "text/markdown"
|
|
||||||
readme.file = "README.md"
|
|
||||||
license = { text = "MIT" }
|
|
||||||
authors = [ { name = "OpenAI" } ]
|
|
||||||
requires-python = ">=3.8"
|
|
||||||
classifiers = [
|
|
||||||
"Programming Language :: Python :: 3 :: Only",
|
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
|
||||||
"Programming Language :: Python :: 3.12",
|
|
||||||
"Programming Language :: Python :: 3.13",
|
|
||||||
]
|
|
||||||
dynamic = [ "version" ]
|
|
||||||
dependencies = [
|
|
||||||
"more-itertools",
|
|
||||||
"numba",
|
|
||||||
"numpy",
|
|
||||||
"tiktoken",
|
|
||||||
"torch",
|
|
||||||
"tqdm",
|
|
||||||
"triton>=2; (platform_machine=='x86_64' and sys_platform=='linux') or sys_platform=='linux2'",
|
|
||||||
]
|
|
||||||
optional-dependencies.dev = [ "black", "flake8", "isort", "pytest", "scipy" ]
|
|
||||||
urls = { Homepage = "https://github.com/openai/whisper" }
|
|
||||||
scripts.whisper = "whisper.transcribe:cli"
|
|
||||||
|
|
||||||
[tool.setuptools]
|
|
||||||
py-modules = [ "whisper" ]
|
|
||||||
include-package-data = true
|
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
|
||||||
version = { attr = "whisper.version.__version__" }
|
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
|
||||||
exclude = [ "tests*" ]
|
|
||||||
namespaces = false
|
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
@ -52,3 +5,4 @@ profile = "black"
|
|||||||
include_trailing_comma = true
|
include_trailing_comma = true
|
||||||
line_length = 88
|
line_length = 88
|
||||||
multi_line_output = 3
|
multi_line_output = 3
|
||||||
|
|
||||||
|
@ -4,4 +4,3 @@ torch
|
|||||||
tqdm
|
tqdm
|
||||||
more-itertools
|
more-itertools
|
||||||
tiktoken
|
tiktoken
|
||||||
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
|
||||||
|
43
setup.py
Normal file
43
setup.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
def read_version(fname="whisper/version.py"):
|
||||||
|
exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
|
||||||
|
return locals()["__version__"]
|
||||||
|
|
||||||
|
|
||||||
|
requirements = []
|
||||||
|
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
||||||
|
requirements.append("triton>=2.0.0,<3")
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="openai-whisper",
|
||||||
|
py_modules=["whisper"],
|
||||||
|
version=read_version(),
|
||||||
|
description="Robust Speech Recognition via Large-Scale Weak Supervision",
|
||||||
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
readme="README.md",
|
||||||
|
python_requires=">=3.8",
|
||||||
|
author="OpenAI",
|
||||||
|
url="https://github.com/openai/whisper",
|
||||||
|
license="MIT",
|
||||||
|
packages=find_packages(exclude=["tests*"]),
|
||||||
|
install_requires=requirements
|
||||||
|
+ [
|
||||||
|
str(r)
|
||||||
|
for r in pkg_resources.parse_requirements(
|
||||||
|
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||||
|
)
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": ["whisper=whisper.transcribe:cli"],
|
||||||
|
},
|
||||||
|
include_package_data=True,
|
||||||
|
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
|
||||||
|
)
|
@ -27,8 +27,6 @@ _MODELS = {
|
|||||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
|
||||||
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||||
@ -46,8 +44,6 @@ _ALIGNMENT_HEADS = {
|
|||||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
|
||||||
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -147,8 +143,7 @@ def load_model(
|
|||||||
with (
|
with (
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
) as fp:
|
) as fp:
|
||||||
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
|
checkpoint = torch.load(fp, map_location=device)
|
||||||
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
|
||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
|
@ -122,7 +122,7 @@ def log_mel_spectrogram(
|
|||||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||||
|
|
||||||
n_mels: int
|
n_mels: int
|
||||||
The number of Mel-frequency filters, only 80 and 128 are supported
|
The number of Mel-frequency filters, only 80 is supported
|
||||||
|
|
||||||
padding: int
|
padding: int
|
||||||
Number of zero samples to pad to the right
|
Number of zero samples to pad to the right
|
||||||
@ -132,7 +132,7 @@ def log_mel_spectrogram(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor, shape = (n_mels, n_frames)
|
torch.Tensor, shape = (80, n_frames)
|
||||||
A Tensor that contains the Mel spectrogram
|
A Tensor that contains the Mel spectrogram
|
||||||
"""
|
"""
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import gzip
|
import gzip
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, Optional, Tuple
|
from typing import Dict, Iterable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -13,14 +12,6 @@ from .decoding import decode as decode_function
|
|||||||
from .decoding import detect_language as detect_language_function
|
from .decoding import detect_language as detect_language_function
|
||||||
from .transcribe import transcribe as transcribe_function
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
|
||||||
|
|
||||||
SDPA_AVAILABLE = True
|
|
||||||
except (ImportError, RuntimeError, OSError):
|
|
||||||
scaled_dot_product_attention = None
|
|
||||||
SDPA_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelDimensions:
|
class ModelDimensions:
|
||||||
@ -68,19 +59,7 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def disable_sdpa():
|
|
||||||
prev_state = MultiHeadAttention.use_sdpa
|
|
||||||
try:
|
|
||||||
MultiHeadAttention.use_sdpa = False
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
MultiHeadAttention.use_sdpa = prev_state
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
use_sdpa = True
|
|
||||||
|
|
||||||
def __init__(self, n_state: int, n_head: int):
|
def __init__(self, n_state: int, n_head: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
@ -113,30 +92,20 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
):
|
||||||
n_batch, n_ctx, n_state = q.shape
|
n_batch, n_ctx, n_state = q.shape
|
||||||
scale = (n_state // self.n_head) ** -0.25
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
qk = q @ k
|
||||||
a = scaled_dot_product_attention(
|
if mask is not None:
|
||||||
q, k, v, is_causal=mask is not None and n_ctx > 1
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
)
|
qk = qk.float()
|
||||||
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
|
||||||
qk = None
|
|
||||||
else:
|
|
||||||
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
|
||||||
if mask is not None:
|
|
||||||
qk = qk + mask[:n_ctx, :n_ctx]
|
|
||||||
qk = qk.float()
|
|
||||||
|
|
||||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||||
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||||
qk = qk.detach()
|
|
||||||
|
|
||||||
return out, qk
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
@ -30,19 +30,15 @@ def remove_symbols_and_diacritics(s: str, keep=""):
|
|||||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||||
"""
|
"""
|
||||||
return "".join(
|
return "".join(
|
||||||
(
|
c
|
||||||
c
|
if c in keep
|
||||||
if c in keep
|
else ADDITIONAL_DIACRITICS[c]
|
||||||
else (
|
if c in ADDITIONAL_DIACRITICS
|
||||||
ADDITIONAL_DIACRITICS[c]
|
else ""
|
||||||
if c in ADDITIONAL_DIACRITICS
|
if unicodedata.category(c) == "Mn"
|
||||||
else (
|
else " "
|
||||||
""
|
if unicodedata.category(c)[0] in "MSP"
|
||||||
if unicodedata.category(c) == "Mn"
|
else c
|
||||||
else " " if unicodedata.category(c)[0] in "MSP" else c
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for c in unicodedata.normalize("NFKD", s)
|
for c in unicodedata.normalize("NFKD", s)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
|||||||
x_skew = x_skew.T.contiguous()
|
x_skew = x_skew.T.contiguous()
|
||||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||||
cost[0, 0] = 0
|
cost[0, 0] = 0
|
||||||
cost = cost.to(x.device)
|
cost = cost.cuda()
|
||||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||||
|
|
||||||
dtw_kernel[(1,)](
|
dtw_kernel[(1,)](
|
||||||
@ -191,9 +191,7 @@ def find_alignment(
|
|||||||
for i, block in enumerate(model.decoder.blocks)
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
]
|
]
|
||||||
|
|
||||||
from .model import disable_sdpa
|
with torch.no_grad():
|
||||||
|
|
||||||
with torch.no_grad(), disable_sdpa():
|
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
@ -301,7 +299,6 @@ def add_word_timestamps(
|
|||||||
word_durations = np.array([t.end - t.start for t in alignment])
|
word_durations = np.array([t.end - t.start for t in alignment])
|
||||||
word_durations = word_durations[word_durations.nonzero()]
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||||
median_duration = min(0.7, float(median_duration))
|
|
||||||
max_duration = median_duration * 2
|
max_duration = median_duration * 2
|
||||||
|
|
||||||
# hack: truncate long words at sentence boundaries.
|
# hack: truncate long words at sentence boundaries.
|
||||||
|
@ -2,7 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -23,7 +23,6 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
exact_div,
|
exact_div,
|
||||||
format_timestamp,
|
format_timestamp,
|
||||||
get_end,
|
|
||||||
get_writer,
|
get_writer,
|
||||||
make_safe,
|
make_safe,
|
||||||
optional_float,
|
optional_float,
|
||||||
@ -46,12 +45,9 @@ def transcribe(
|
|||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
carry_initial_prompt: bool = False,
|
|
||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
clip_timestamps: Union[str, List[float]] = "0",
|
|
||||||
hallucination_silence_threshold: Optional[float] = None,
|
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -103,22 +99,9 @@ def transcribe(
|
|||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||||
to make it more likely to predict those word correctly.
|
to make it more likely to predict those word correctly.
|
||||||
|
|
||||||
carry_initial_prompt: bool
|
|
||||||
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
|
||||||
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
|
||||||
left-sliced to make space.
|
|
||||||
|
|
||||||
decode_options: dict
|
decode_options: dict
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
clip_timestamps: Union[str, List[float]]
|
|
||||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
|
||||||
The last end timestamp defaults to the end of the file.
|
|
||||||
|
|
||||||
hallucination_silence_threshold: Optional[float]
|
|
||||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
|
||||||
when a possible hallucination is detected
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
@ -138,7 +121,6 @@ def transcribe(
|
|||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
if not model.is_multilingual:
|
if not model.is_multilingual:
|
||||||
@ -165,19 +147,6 @@ def transcribe(
|
|||||||
task=task,
|
task=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(clip_timestamps, str):
|
|
||||||
clip_timestamps = [
|
|
||||||
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
|
||||||
]
|
|
||||||
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
|
||||||
if len(seek_points) == 0:
|
|
||||||
seek_points.append(0)
|
|
||||||
if len(seek_points) % 2 == 1:
|
|
||||||
seek_points.append(content_frames)
|
|
||||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
|
||||||
|
|
||||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
|
||||||
|
|
||||||
if word_timestamps and task == "translate":
|
if word_timestamps and task == "translate":
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
|
|
||||||
@ -214,8 +183,6 @@ def transcribe(
|
|||||||
if (
|
if (
|
||||||
no_speech_threshold is not None
|
no_speech_threshold is not None
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
and decode_result.no_speech_prob > no_speech_threshold
|
||||||
and logprob_threshold is not None
|
|
||||||
and decode_result.avg_logprob < logprob_threshold
|
|
||||||
):
|
):
|
||||||
needs_fallback = False # silence
|
needs_fallback = False # silence
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
@ -223,8 +190,7 @@ def transcribe(
|
|||||||
|
|
||||||
return decode_result
|
return decode_result
|
||||||
|
|
||||||
clip_idx = 0
|
seek = 0
|
||||||
seek = seek_clips[clip_idx][0]
|
|
||||||
input_stride = exact_div(
|
input_stride = exact_div(
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
N_FRAMES, model.dims.n_audio_ctx
|
||||||
) # mel frames per output token: 2
|
) # mel frames per output token: 2
|
||||||
@ -235,11 +201,9 @@ def transcribe(
|
|||||||
all_segments = []
|
all_segments = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
|
||||||
if initial_prompt is not None:
|
if initial_prompt is not None:
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
remaining_prompt_length -= len(initial_prompt_tokens)
|
|
||||||
else:
|
else:
|
||||||
initial_prompt_tokens = []
|
initial_prompt_tokens = []
|
||||||
|
|
||||||
@ -265,33 +229,14 @@ def transcribe(
|
|||||||
total=content_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
last_speech_timestamp = 0.0
|
last_speech_timestamp = 0.0
|
||||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
while seek < content_frames:
|
||||||
# A later commit should turn this into a simpler nested loop.
|
|
||||||
# for seek_clip_start, seek_clip_end in seek_clips:
|
|
||||||
# while seek < seek_clip_end
|
|
||||||
while clip_idx < len(seek_clips):
|
|
||||||
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
|
||||||
if seek < seek_clip_start:
|
|
||||||
seek = seek_clip_start
|
|
||||||
if seek >= seek_clip_end:
|
|
||||||
clip_idx += 1
|
|
||||||
if clip_idx < len(seek_clips):
|
|
||||||
seek = seek_clips[clip_idx][0]
|
|
||||||
continue
|
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||||
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
segment_size = min(N_FRAMES, content_frames - seek)
|
||||||
mel_segment = mel[:, seek : seek + segment_size]
|
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
if carry_initial_prompt:
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
|
||||||
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
|
||||||
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
|
||||||
else:
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
||||||
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
@ -312,30 +257,6 @@ def transcribe(
|
|||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
|
|
||||||
# anomalous words are very long/short/improbable
|
|
||||||
def word_anomaly_score(word: dict) -> float:
|
|
||||||
probability = word.get("probability", 0.0)
|
|
||||||
duration = word["end"] - word["start"]
|
|
||||||
score = 0.0
|
|
||||||
if probability < 0.15:
|
|
||||||
score += 1.0
|
|
||||||
if duration < 0.133:
|
|
||||||
score += (0.133 - duration) * 15
|
|
||||||
if duration > 2.0:
|
|
||||||
score += duration - 2.0
|
|
||||||
return score
|
|
||||||
|
|
||||||
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
|
||||||
if segment is None or not segment["words"]:
|
|
||||||
return False
|
|
||||||
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
|
||||||
words = words[:8]
|
|
||||||
score = sum(word_anomaly_score(w) for w in words)
|
|
||||||
return score >= 3 or score + 0.01 >= len(words)
|
|
||||||
|
|
||||||
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
|
||||||
return next((s for s in segments if s["words"]), None)
|
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
|
|
||||||
@ -409,71 +330,17 @@ def transcribe(
|
|||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
last_speech_timestamp=last_speech_timestamp,
|
||||||
)
|
)
|
||||||
|
word_end_timestamps = [
|
||||||
if not single_timestamp_ending:
|
w["end"] for s in current_segments for w in s["words"]
|
||||||
last_word_end = get_end(current_segments)
|
]
|
||||||
if last_word_end is not None and last_word_end > time_offset:
|
if len(word_end_timestamps) > 0:
|
||||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
last_speech_timestamp = word_end_timestamps[-1]
|
||||||
|
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||||
# skip silence before possible hallucinations
|
seek_shift = round(
|
||||||
if hallucination_silence_threshold is not None:
|
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||||
threshold = hallucination_silence_threshold
|
)
|
||||||
if not single_timestamp_ending:
|
if seek_shift > 0:
|
||||||
last_word_end = get_end(current_segments)
|
seek = previous_seek + seek_shift
|
||||||
if last_word_end is not None and last_word_end > time_offset:
|
|
||||||
remaining_duration = window_end_time - last_word_end
|
|
||||||
if remaining_duration > threshold:
|
|
||||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
|
||||||
else:
|
|
||||||
seek = previous_seek + segment_size
|
|
||||||
|
|
||||||
# if first segment might be a hallucination, skip leading silence
|
|
||||||
first_segment = next_words_segment(current_segments)
|
|
||||||
if first_segment is not None and is_segment_anomaly(first_segment):
|
|
||||||
gap = first_segment["start"] - time_offset
|
|
||||||
if gap > threshold:
|
|
||||||
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# skip silence before any possible hallucination that is surrounded
|
|
||||||
# by silence or more hallucinations
|
|
||||||
hal_last_end = last_speech_timestamp
|
|
||||||
for si in range(len(current_segments)):
|
|
||||||
segment = current_segments[si]
|
|
||||||
if not segment["words"]:
|
|
||||||
continue
|
|
||||||
if is_segment_anomaly(segment):
|
|
||||||
next_segment = next_words_segment(
|
|
||||||
current_segments[si + 1 :]
|
|
||||||
)
|
|
||||||
if next_segment is not None:
|
|
||||||
hal_next_start = next_segment["words"][0]["start"]
|
|
||||||
else:
|
|
||||||
hal_next_start = time_offset + segment_duration
|
|
||||||
silence_before = (
|
|
||||||
segment["start"] - hal_last_end > threshold
|
|
||||||
or segment["start"] < threshold
|
|
||||||
or segment["start"] - time_offset < 2.0
|
|
||||||
)
|
|
||||||
silence_after = (
|
|
||||||
hal_next_start - segment["end"] > threshold
|
|
||||||
or is_segment_anomaly(next_segment)
|
|
||||||
or window_end_time - segment["end"] < 2.0
|
|
||||||
)
|
|
||||||
if silence_before and silence_after:
|
|
||||||
seek = round(
|
|
||||||
max(time_offset + 1, segment["start"])
|
|
||||||
* FRAMES_PER_SECOND
|
|
||||||
)
|
|
||||||
if content_duration - segment["end"] < threshold:
|
|
||||||
seek = content_frames
|
|
||||||
current_segments[si:] = []
|
|
||||||
break
|
|
||||||
hal_last_end = segment["end"]
|
|
||||||
|
|
||||||
last_word_end = get_end(current_segments)
|
|
||||||
if last_word_end is not None:
|
|
||||||
last_speech_timestamp = last_word_end
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
@ -527,7 +394,7 @@ def cli():
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
@ -545,8 +412,6 @@ def cli():
|
|||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
|
||||||
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
@ -562,8 +427,6 @@ def cli():
|
|||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||||
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
|
||||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
|
@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
|
|||||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||||
|
|
||||||
kernel = triton.JITFunction(kernel.fn)
|
kernel = triton.JITFunction(kernel.fn)
|
||||||
new_kernel = kernel.src.replace(
|
kernel.src = kernel.src.replace(
|
||||||
" LOAD_ALL_ROWS_HERE",
|
" LOAD_ALL_ROWS_HERE",
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
@ -69,8 +69,7 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
kernel.src = kernel.src.replace(
|
||||||
new_kernel = new_kernel.replace(
|
|
||||||
" BUBBLESORT_HERE",
|
" BUBBLESORT_HERE",
|
||||||
"\n\n".join(
|
"\n\n".join(
|
||||||
[
|
[
|
||||||
@ -91,14 +90,7 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||||
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
|
||||||
|
|
||||||
if hasattr(kernel, "_unsafe_update_src") is True:
|
|
||||||
kernel._unsafe_update_src(new_kernel)
|
|
||||||
kernel.hash = None
|
|
||||||
else:
|
|
||||||
kernel.src = new_kernel
|
|
||||||
|
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Callable, List, Optional, TextIO
|
from typing import Callable, Optional, TextIO
|
||||||
|
|
||||||
system_encoding = sys.getdefaultencoding()
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
@ -68,20 +68,6 @@ def format_timestamp(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_start(segments: List[dict]) -> Optional[float]:
|
|
||||||
return next(
|
|
||||||
(w["start"] for s in segments for w in s["words"]),
|
|
||||||
segments[0]["start"] if segments else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_end(segments: List[dict]) -> Optional[float]:
|
|
||||||
return next(
|
|
||||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
|
||||||
segments[-1]["end"] if segments else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ResultWriter:
|
class ResultWriter:
|
||||||
extension: str
|
extension: str
|
||||||
|
|
||||||
@ -143,8 +129,8 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
line_len = 0
|
line_len = 0
|
||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: List[dict] = []
|
subtitle: list[dict] = []
|
||||||
last: float = get_start(result["segments"]) or 0.0
|
last = result["segments"][0]["words"][0]["start"]
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
words_count = max_words_per_line
|
words_count = max_words_per_line
|
||||||
@ -209,11 +195,9 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
|
|
||||||
yield start, end, "".join(
|
yield start, end, "".join(
|
||||||
[
|
[
|
||||||
(
|
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
if j == i
|
||||||
if j == i
|
else word
|
||||||
else word
|
|
||||||
)
|
|
||||||
for j, word in enumerate(all_words)
|
for j, word in enumerate(all_words)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "20250625"
|
__version__ = "20231117"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user