mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
d15213d561
13
.github/dependabot.yml
vendored
Normal file
13
.github/dependabot.yml
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
# Keep GitHub Actions up to date with GitHub's Dependabot...
|
||||
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
|
||||
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: github-actions
|
||||
directory: /
|
||||
groups:
|
||||
github-actions:
|
||||
patterns:
|
||||
- "*" # Group all Actions updates into a single larger pull request
|
||||
schedule:
|
||||
interval: weekly
|
||||
12
.github/workflows/python-publish.yml
vendored
12
.github/workflows/python-publish.yml
vendored
@ -8,23 +8,23 @@ jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-ecosystem/action-regex-match@v2
|
||||
id: regex-match
|
||||
with:
|
||||
text: ${{ github.event.head_commit.message }}
|
||||
regex: '^Release ([^ ]+)'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.8'
|
||||
python-version: '3.12'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
pip install setuptools wheel twine build
|
||||
- name: Release
|
||||
if: ${{ steps.regex-match.outputs.match != '' }}
|
||||
uses: softprops/action-gh-release@v1
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
||||
- name: Build and publish
|
||||
@ -33,5 +33,5 @@ jobs:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
python setup.py sdist
|
||||
python -m build --sdist
|
||||
twine upload dist/*
|
||||
|
||||
49
.github/workflows/test.yml
vendored
49
.github/workflows/test.yml
vendored
@ -11,19 +11,19 @@ jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Fetch base branch
|
||||
run: git fetch origin ${{ github.base_ref }}
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.9"
|
||||
architecture: x64
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip/pre-commit cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
${{ steps.pip-cache.outputs.dir }}
|
||||
@ -33,24 +33,47 @@ jobs:
|
||||
${{ runner.os }}-pip-pre-commit
|
||||
- name: pre-commit
|
||||
run: |
|
||||
pip install -U pre-commit
|
||||
pip install --upgrade pre-commit
|
||||
pre-commit install --install-hooks
|
||||
pre-commit run --all-files
|
||||
whisper-test:
|
||||
needs: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
pytorch-version: [1.13.1, 2.0.0]
|
||||
exclude:
|
||||
- python-version: '3.11'
|
||||
include:
|
||||
- python-version: '3.8'
|
||||
pytorch-version: 1.10.1
|
||||
numpy-requirement: "'numpy<2'"
|
||||
- python-version: '3.8'
|
||||
pytorch-version: 1.13.1
|
||||
numpy-requirement: "'numpy<2'"
|
||||
- python-version: '3.8'
|
||||
pytorch-version: 2.0.1
|
||||
numpy-requirement: "'numpy<2'"
|
||||
- python-version: '3.9'
|
||||
pytorch-version: 2.1.2
|
||||
numpy-requirement: "'numpy<2'"
|
||||
- python-version: '3.10'
|
||||
pytorch-version: 2.2.2
|
||||
numpy-requirement: "'numpy<2'"
|
||||
- python-version: '3.11'
|
||||
pytorch-version: 2.3.1
|
||||
numpy-requirement: "'numpy'"
|
||||
- python-version: '3.12'
|
||||
pytorch-version: 2.4.1
|
||||
numpy-requirement: "'numpy'"
|
||||
- python-version: '3.12'
|
||||
pytorch-version: 2.5.1
|
||||
numpy-requirement: "'numpy'"
|
||||
- python-version: '3.13'
|
||||
pytorch-version: 2.5.1
|
||||
numpy-requirement: "'numpy'"
|
||||
steps:
|
||||
- uses: conda-incubator/setup-miniconda@v2
|
||||
- uses: conda-incubator/setup-miniconda@v3
|
||||
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
|
||||
- run: pip3 install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||
- run: pip install .["dev"]
|
||||
- run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
|
||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-json
|
||||
- id: end-of-file-fixer
|
||||
@ -11,17 +11,17 @@ repos:
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=4096]
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.7.0
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"]
|
||||
- repo: https://github.com/pycqa/flake8.git
|
||||
rev: 6.0.0
|
||||
rev: 7.1.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
types: [python]
|
||||
|
||||
52
CHANGELOG.md
52
CHANGELOG.md
@ -1,5 +1,57 @@
|
||||
# CHANGELOG
|
||||
|
||||
## [v20250625](https://github.com/openai/whisper/releases/tag/v20250625)
|
||||
|
||||
* Fix: Update torch.load to use weights_only=True to prevent security w… ([#2451](https://github.com/openai/whisper/pull/2451))
|
||||
* Fix: Ensure DTW cost tensor is on the same device as input tensor ([#2561](https://github.com/openai/whisper/pull/2561))
|
||||
* docs: updated README to specify translation model limitation ([#2547](https://github.com/openai/whisper/pull/2547))
|
||||
* Fixed triton kernel update to support latest triton versions ([#2588](https://github.com/openai/whisper/pull/2588))
|
||||
* Fix: GitHub display errors for Jupyter notebooks ([#2589](https://github.com/openai/whisper/pull/2589))
|
||||
* Bump the github-actions group with 3 updates ([#2592](https://github.com/openai/whisper/pull/2592))
|
||||
* Keep GitHub Actions up to date with GitHub's Dependabot ([#2486](https://github.com/openai/whisper/pull/2486))
|
||||
* pre-commit: Upgrade black v25.1.0 and isort v6.0.0 ([#2514](https://github.com/openai/whisper/pull/2514))
|
||||
* GitHub Actions: Add Python 3.13 to the testing ([#2487](https://github.com/openai/whisper/pull/2487))
|
||||
* PEP 621: Migrate from setup.py to pyproject.toml ([#2435](https://github.com/openai/whisper/pull/2435))
|
||||
* pre-commit autoupdate && pre-commit run --all-files ([#2484](https://github.com/openai/whisper/pull/2484))
|
||||
* Upgrade GitHub Actions ([#2430](https://github.com/openai/whisper/pull/2430))
|
||||
* Bugfix: Illogical "Avoid computing higher temperatures on no_speech" ([#1903](https://github.com/openai/whisper/pull/1903))
|
||||
* Updating README and doc strings to reflect that n_mels can now be 128 ([#2049](https://github.com/openai/whisper/pull/2049))
|
||||
* fix typo data/README.md ([#2433](https://github.com/openai/whisper/pull/2433))
|
||||
* Update README.md ([#2379](https://github.com/openai/whisper/pull/2379))
|
||||
* Add option to carry initial_prompt with the sliding window ([#2343](https://github.com/openai/whisper/pull/2343))
|
||||
* more pytorch versions in tests ([#2408](https://github.com/openai/whisper/pull/2408))
|
||||
|
||||
## [v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
|
||||
|
||||
* allowing numpy 2 in tests ([#2362](https://github.com/openai/whisper/pull/2362))
|
||||
* large-v3-turbo model ([#2361](https://github.com/openai/whisper/pull/2361))
|
||||
* test on python/pytorch versions up to 3.12 and 2.4.1 ([#2360](https://github.com/openai/whisper/pull/2360))
|
||||
* using sdpa if available ([#2359](https://github.com/openai/whisper/pull/2359))
|
||||
|
||||
## [v20240927](https://github.com/openai/whisper/releases/tag/v20240927)
|
||||
|
||||
* pinning numpy<2 in tests ([#2332](https://github.com/openai/whisper/pull/2332))
|
||||
* Relax triton requirements for compatibility with pytorch 2.4 and newer ([#2307](https://github.com/openai/whisper/pull/2307))
|
||||
* Skip silence around hallucinations ([#1838](https://github.com/openai/whisper/pull/1838))
|
||||
* Fix triton env marker ([#1887](https://github.com/openai/whisper/pull/1887))
|
||||
|
||||
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
|
||||
|
||||
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))
|
||||
|
||||
## [v20231106](https://github.com/openai/whisper/releases/tag/v20231106)
|
||||
|
||||
* large-v3 ([#1761](https://github.com/openai/whisper/pull/1761))
|
||||
|
||||
## [v20231105](https://github.com/openai/whisper/releases/tag/v20231105)
|
||||
|
||||
* remove tiktoken pin ([#1759](https://github.com/openai/whisper/pull/1759))
|
||||
* docs: Disambiguation of the term "relative speed" in the README ([#1751](https://github.com/openai/whisper/pull/1751))
|
||||
* allow_pickle=False while loading of mel matrix IN audio.py ([#1511](https://github.com/openai/whisper/pull/1511))
|
||||
* handling transcribe exceptions. ([#1682](https://github.com/openai/whisper/pull/1682))
|
||||
* Add new option to generate subtitles by a specific number of words ([#1729](https://github.com/openai/whisper/pull/1729))
|
||||
* Fix exception when an audio file with no speech is provided ([#1396](https://github.com/openai/whisper/pull/1396))
|
||||
|
||||
## [v20230918](https://github.com/openai/whisper/releases/tag/v20230918)
|
||||
|
||||
* Add .pre-commit-config.yaml ([#1528](https://github.com/openai/whisper/pull/1528))
|
||||
|
||||
51
README.md
51
README.md
@ -57,42 +57,55 @@ pip install setuptools-rust
|
||||
|
||||
## Available models and languages
|
||||
|
||||
There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and relative speed.
|
||||
|
||||
There are six model sizes, four with English-only versions, offering speed and accuracy tradeoffs.
|
||||
Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model.
|
||||
The relative speeds below are measured by transcribing English speech on a A100, and the real-world speed may vary significantly depending on many factors including the language, the speaking speed, and the available hardware.
|
||||
|
||||
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
||||
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
|
||||
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x |
|
||||
| base | 74 M | `base.en` | `base` | ~1 GB | ~16x |
|
||||
| small | 244 M | `small.en` | `small` | ~2 GB | ~6x |
|
||||
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~10x |
|
||||
| base | 74 M | `base.en` | `base` | ~1 GB | ~7x |
|
||||
| small | 244 M | `small.en` | `small` | ~2 GB | ~4x |
|
||||
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
||||
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
||||
| turbo | 809 M | N/A | `turbo` | ~6 GB | ~8x |
|
||||
|
||||
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||
Additionally, the `turbo` model is an optimized version of `large-v3` that offers faster transcription speed with a minimal degradation in accuracy.
|
||||
|
||||
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model (The smaller the numbers, the better the performance). Additional WER scores corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4. Meanwhile, more BLEU (Bilingual Evaluation Understudy) scores can be found in Appendix D.3. Both are found in [the paper](https://arxiv.org/abs/2212.04356).
|
||||
|
||||

|
||||
|
||||
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
|
||||
|
||||

|
||||
|
||||
## Command-line usage
|
||||
|
||||
The following command will transcribe speech in audio files, using the `medium` model:
|
||||
The following command will transcribe speech in audio files, using the `turbo` model:
|
||||
|
||||
whisper audio.flac audio.mp3 audio.wav --model medium
|
||||
```bash
|
||||
whisper audio.flac audio.mp3 audio.wav --model turbo
|
||||
```
|
||||
|
||||
The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
|
||||
The default setting (which selects the `turbo` model) works well for transcribing English. However, **the `turbo` model is not trained for translation tasks**. If you need to **translate non-English speech into English**, use one of the **multilingual models** (`tiny`, `base`, `small`, `medium`, `large`) instead of `turbo`.
|
||||
|
||||
whisper japanese.wav --language Japanese
|
||||
For example, to transcribe an audio file containing non-English speech, you can specify the language:
|
||||
|
||||
Adding `--task translate` will translate the speech into English:
|
||||
```bash
|
||||
whisper japanese.wav --language Japanese
|
||||
```
|
||||
|
||||
whisper japanese.wav --language Japanese --task translate
|
||||
To **translate** speech into English, use:
|
||||
|
||||
```bash
|
||||
whisper japanese.wav --model medium --language Japanese --task translate
|
||||
```
|
||||
|
||||
> **Note:** The `turbo` model will return the original language even if `--task translate` is specified. Use `medium` or `large` for the best translation results.
|
||||
|
||||
Run the following to view all available options:
|
||||
|
||||
whisper --help
|
||||
```bash
|
||||
whisper --help
|
||||
```
|
||||
|
||||
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
|
||||
|
||||
@ -104,7 +117,7 @@ Transcription can also be performed within Python:
|
||||
```python
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("base")
|
||||
model = whisper.load_model("turbo")
|
||||
result = model.transcribe("audio.mp3")
|
||||
print(result["text"])
|
||||
```
|
||||
@ -116,14 +129,14 @@ Below is an example usage of `whisper.detect_language()` and `whisper.decode()`
|
||||
```python
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("base")
|
||||
model = whisper.load_model("turbo")
|
||||
|
||||
# load audio and pad/trim it to fit 30 seconds
|
||||
audio = whisper.load_audio("audio.mp3")
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
|
||||
# make log-Mel spectrogram and move to the same device as the model
|
||||
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
||||
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
|
||||
|
||||
# detect the spoken language
|
||||
_, probs = model.detect_language(mel)
|
||||
|
||||
@ -45,7 +45,7 @@ We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challen
|
||||
|
||||
### AMI-IHM, AMI-SDM1
|
||||
|
||||
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
|
||||
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 and 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
|
||||
|
||||
|
||||
## Long-form English-only datasets
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
Before Width: | Height: | Size: 100 KiB After Width: | Height: | Size: 272 KiB |
@ -16,13 +16,15 @@ The Whisper models are trained for speech recognition and translation tasks, cap
|
||||
| small | 244 M | ✓ | ✓ |
|
||||
| medium | 769 M | ✓ | ✓ |
|
||||
| large | 1550 M | | ✓ |
|
||||
| turbo | 798 M | | ✓ |
|
||||
|
||||
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661).
|
||||
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
|
||||
Additionally, we've added a `turbo` model in September 2024 which is optimized for inference speed.
|
||||
|
||||
|
||||
### Release date
|
||||
|
||||
September 2022 (original series) and December 2022 (`large-v2`)
|
||||
September 2022 (original series), December 2022 (`large-v2`), November 2023 (`large-v3`), September 2024 (`large-v3-turbo`)
|
||||
|
||||
### Model type
|
||||
|
||||
|
||||
3
notebooks/LibriSpeech.ipynb
generated
3
notebooks/LibriSpeech.ipynb
generated
@ -949,7 +949,8 @@
|
||||
"style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
|
||||
"value": " 164/164 [05:08<00:00, 1.86s/it]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"state": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
3
notebooks/Multilingual_ASR.ipynb
generated
3
notebooks/Multilingual_ASR.ipynb
generated
@ -4219,7 +4219,8 @@
|
||||
"_view_name": "StyleView",
|
||||
"description_width": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"state": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@ -1,3 +1,50 @@
|
||||
[build-system]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
requires = [ "setuptools>=61.2" ]
|
||||
|
||||
[project]
|
||||
name = "openai-whisper"
|
||||
description = "Robust Speech Recognition via Large-Scale Weak Supervision"
|
||||
readme.content-type = "text/markdown"
|
||||
readme.file = "README.md"
|
||||
license = { text = "MIT" }
|
||||
authors = [ { name = "OpenAI" } ]
|
||||
requires-python = ">=3.8"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
dynamic = [ "version" ]
|
||||
dependencies = [
|
||||
"more-itertools",
|
||||
"numba",
|
||||
"numpy",
|
||||
"tiktoken",
|
||||
"torch",
|
||||
"tqdm",
|
||||
"triton>=2; (platform_machine=='x86_64' and sys_platform=='linux') or sys_platform=='linux2'",
|
||||
]
|
||||
optional-dependencies.dev = [ "black", "flake8", "isort", "pytest", "scipy" ]
|
||||
urls = { Homepage = "https://github.com/openai/whisper" }
|
||||
scripts.whisper = "whisper.transcribe:cli"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = [ "whisper" ]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = { attr = "whisper.version.__version__" }
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
exclude = [ "tests*" ]
|
||||
namespaces = false
|
||||
|
||||
[tool.black]
|
||||
|
||||
[tool.isort]
|
||||
@ -5,4 +52,3 @@ profile = "black"
|
||||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
multi_line_output = 3
|
||||
|
||||
|
||||
@ -3,4 +3,5 @@ numpy
|
||||
torch
|
||||
tqdm
|
||||
more-itertools
|
||||
tiktoken==0.3.3
|
||||
tiktoken
|
||||
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
||||
|
||||
43
setup.py
43
setup.py
@ -1,43 +0,0 @@
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def read_version(fname="whisper/version.py"):
|
||||
exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
|
||||
return locals()["__version__"]
|
||||
|
||||
|
||||
requirements = []
|
||||
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
||||
requirements.append("triton==2.0.0")
|
||||
|
||||
setup(
|
||||
name="openai-whisper",
|
||||
py_modules=["whisper"],
|
||||
version=read_version(),
|
||||
description="Robust Speech Recognition via Large-Scale Weak Supervision",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
readme="README.md",
|
||||
python_requires=">=3.8",
|
||||
author="OpenAI",
|
||||
url="https://github.com/openai/whisper",
|
||||
license="MIT",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=requirements
|
||||
+ [
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
)
|
||||
],
|
||||
entry_points={
|
||||
"console_scripts": ["whisper=whisper.transcribe:cli"],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
|
||||
)
|
||||
@ -1,7 +1,17 @@
|
||||
import pytest
|
||||
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def test_tokenizer():
|
||||
@pytest.mark.parametrize("multilingual", [True, False])
|
||||
def test_tokenizer(multilingual):
|
||||
tokenizer = get_tokenizer(multilingual=False)
|
||||
assert tokenizer.sot in tokenizer.sot_sequence
|
||||
assert len(tokenizer.all_language_codes) == len(tokenizer.all_language_tokens)
|
||||
assert all(c < tokenizer.timestamp_begin for c in tokenizer.all_language_tokens)
|
||||
|
||||
|
||||
def test_multilingual_tokenizer():
|
||||
gpt2_tokenizer = get_tokenizer(multilingual=False)
|
||||
multilingual_tokenizer = get_tokenizer(multilingual=True)
|
||||
|
||||
@ -20,5 +30,5 @@ def test_split_on_unicode():
|
||||
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
|
||||
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
|
||||
|
||||
assert words == [" elle", " est", " l", "'", "<EFBFBD>", "é", "rit", "oire"]
|
||||
assert words == [" elle", " est", " l", "'", "\ufffd", "é", "rit", "oire"]
|
||||
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
|
||||
|
||||
@ -25,7 +25,7 @@ def test_transcribe(model_name: str):
|
||||
assert "your country" in transcription
|
||||
assert "do for you" in transcription
|
||||
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
||||
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
||||
assert tokenizer.decode(all_tokens) == result["text"]
|
||||
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
||||
|
||||
@ -28,7 +28,10 @@ _MODELS = {
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
@ -44,7 +47,10 @@ _ALIGNMENT_HEADS = {
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
}
|
||||
|
||||
|
||||
@ -150,7 +156,8 @@ def load_model(
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
|
||||
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
|
||||
Binary file not shown.
@ -12,7 +12,6 @@ from .utils import exact_div
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
@ -90,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
@ -98,18 +97,19 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
np.savez_compressed(
|
||||
"mel_filters.npz",
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
with np.load(
|
||||
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
) as f:
|
||||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||
|
||||
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
with np.load(filters_path, allow_pickle=False) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = N_MELS,
|
||||
n_mels: int = 80,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
@ -122,7 +122,7 @@ def log_mel_spectrogram(
|
||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
The number of Mel-frequency filters, only 80 and 128 are supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
@ -132,7 +132,7 @@ def log_mel_spectrogram(
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
torch.Tensor, shape = (n_mels, n_frames)
|
||||
A Tensor that contains the Mel spectrogram
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
|
||||
@ -32,7 +32,9 @@ def detect_language(
|
||||
list of dictionaries containing the probability distribution over all languages.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, num_languages=model.num_languages
|
||||
)
|
||||
if (
|
||||
tokenizer.language is None
|
||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||
@ -514,7 +516,10 @@ class DecodingTask:
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, language=language, task=options.task
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language=language,
|
||||
task=options.task,
|
||||
)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import base64
|
||||
import gzip
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,6 +13,14 @@ from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
|
||||
try:
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
SDPA_AVAILABLE = True
|
||||
except (ImportError, RuntimeError, OSError):
|
||||
scaled_dot_product_attention = None
|
||||
SDPA_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_sdpa():
|
||||
prev_state = MultiHeadAttention.use_sdpa
|
||||
try:
|
||||
MultiHeadAttention.use_sdpa = False
|
||||
yield
|
||||
finally:
|
||||
MultiHeadAttention.use_sdpa = prev_state
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
use_sdpa = True
|
||||
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
@ -92,20 +113,30 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
qk = q @ k
|
||||
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
||||
a = scaled_dot_product_attention(
|
||||
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||
)
|
||||
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = None
|
||||
else:
|
||||
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = qk.detach()
|
||||
|
||||
return out, qk
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
@ -236,7 +267,8 @@ class Whisper(nn.Module):
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
||||
# use the last half among the decoder layers for time alignment by default;
|
||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
@ -269,7 +301,11 @@ class Whisper(nn.Module):
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab == 51865
|
||||
return self.dims.n_vocab >= 51865
|
||||
|
||||
@property
|
||||
def num_languages(self):
|
||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
|
||||
@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
(
|
||||
c
|
||||
if c in keep
|
||||
else ADDITIONAL_DIACRITICS[c]
|
||||
else (
|
||||
ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else ""
|
||||
else (
|
||||
""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " "
|
||||
if unicodedata.category(c)[0] in "MSP"
|
||||
else c
|
||||
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||
)
|
||||
)
|
||||
)
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
cost = cost.to(x.device)
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
@ -191,7 +191,9 @@ def find_alignment(
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
from .model import disable_sdpa
|
||||
|
||||
with torch.no_grad(), disable_sdpa():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
@ -299,6 +301,7 @@ def add_word_timestamps(
|
||||
word_durations = np.array([t.end - t.start for t in alignment])
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||
median_duration = min(0.7, float(median_duration))
|
||||
max_duration = median_duration * 2
|
||||
|
||||
# hack: truncate long words at sentence boundaries.
|
||||
|
||||
@ -107,6 +107,7 @@ LANGUAGES = {
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
@ -123,6 +124,7 @@ TO_LANGUAGE_CODE = {
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
"mandarin": "zh",
|
||||
}
|
||||
|
||||
|
||||
@ -131,6 +133,7 @@ class Tokenizer:
|
||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||
|
||||
encoding: tiktoken.Encoding
|
||||
num_languages: int
|
||||
language: Optional[str] = None
|
||||
task: Optional[str] = None
|
||||
sot_sequence: Tuple[int] = ()
|
||||
@ -145,7 +148,7 @@ class Tokenizer:
|
||||
translate: int = self.special_tokens["<|translate|>"]
|
||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
||||
sot_sequence = [sot]
|
||||
if self.language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||
@ -211,10 +214,13 @@ class Tokenizer:
|
||||
if self.language is None:
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
if token := self.special_tokens.get(f"<|{self.language}|>", None):
|
||||
return self.to_language_token(self.language)
|
||||
|
||||
def to_language_token(self, language):
|
||||
if token := self.special_tokens.get(f"<|{language}|>", None):
|
||||
return token
|
||||
|
||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||
raise KeyError(f"Language {language} not found in tokenizer.")
|
||||
|
||||
@cached_property
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
@ -222,7 +228,7 @@ class Tokenizer:
|
||||
for token, token_id in self.special_tokens.items():
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)
|
||||
return tuple(result)[: self.num_languages]
|
||||
|
||||
@cached_property
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
@ -269,7 +275,7 @@ class Tokenizer:
|
||||
return tuple(sorted(result))
|
||||
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
||||
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
# without morpheme analysis. Here, we instead split words at any
|
||||
# position where the tokens are decoded as valid unicode points
|
||||
@ -322,7 +328,7 @@ class Tokenizer:
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2"):
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
@ -334,7 +340,7 @@ def get_encoding(name: str = "gpt2"):
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
@ -361,6 +367,7 @@ def get_encoding(name: str = "gpt2"):
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
num_languages: int = 99,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
) -> Tokenizer:
|
||||
@ -381,6 +388,8 @@ def get_tokenizer(
|
||||
language = None
|
||||
task = None
|
||||
|
||||
encoding = get_encoding(name=encoding_name)
|
||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||
|
||||
return Tokenizer(encoding=encoding, language=language, task=task)
|
||||
return Tokenizer(
|
||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||
)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import argparse
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from importlib.util import find_spec
|
||||
|
||||
import numpy as np
|
||||
@ -23,6 +24,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_end,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
@ -45,9 +47,12 @@ def transcribe(
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
carry_initial_prompt: bool = False,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@ -99,9 +104,22 @@ def transcribe(
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
carry_initial_prompt: bool
|
||||
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
||||
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
||||
left-sliced to make space.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||
The last end timestamp defaults to the end of the file.
|
||||
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||
when a possible hallucination is detected
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
@ -121,8 +139,9 @@ def transcribe(
|
||||
decode_options["fp16"] = False
|
||||
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-1] - N_FRAMES
|
||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if not model.is_multilingual:
|
||||
@ -142,7 +161,25 @@ def transcribe(
|
||||
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language=language,
|
||||
task=task,
|
||||
)
|
||||
|
||||
if isinstance(clip_timestamps, str):
|
||||
clip_timestamps = [
|
||||
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||
]
|
||||
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||
if len(seek_points) == 0:
|
||||
seek_points.append(0)
|
||||
if len(seek_points) % 2 == 1:
|
||||
seek_points.append(content_frames)
|
||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||
|
||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
@ -180,6 +217,8 @@ def transcribe(
|
||||
if (
|
||||
no_speech_threshold is not None
|
||||
and decode_result.no_speech_prob > no_speech_threshold
|
||||
and logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = False # silence
|
||||
if not needs_fallback:
|
||||
@ -187,7 +226,8 @@ def transcribe(
|
||||
|
||||
return decode_result
|
||||
|
||||
seek = 0
|
||||
clip_idx = 0
|
||||
seek = seek_clips[clip_idx][0]
|
||||
input_stride = exact_div(
|
||||
N_FRAMES, model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
@ -198,9 +238,11 @@ def transcribe(
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
||||
if initial_prompt is not None:
|
||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
remaining_prompt_length -= len(initial_prompt_tokens)
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
@ -226,14 +268,33 @@ def transcribe(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
last_speech_timestamp = 0.0
|
||||
while seek < content_frames:
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
# A later commit should turn this into a simpler nested loop.
|
||||
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||
# while seek < seek_clip_end
|
||||
while clip_idx < len(seek_clips):
|
||||
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||
if seek < seek_clip_start:
|
||||
seek = seek_clip_start
|
||||
if seek >= seek_clip_end:
|
||||
clip_idx += 1
|
||||
if clip_idx < len(seek_clips):
|
||||
seek = seek_clips[clip_idx][0]
|
||||
continue
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||
segment_size = min(N_FRAMES, content_frames - seek)
|
||||
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||
mel_segment = mel[:, seek : seek + segment_size]
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
if carry_initial_prompt:
|
||||
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
||||
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
||||
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
||||
else:
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
@ -254,6 +315,30 @@ def transcribe(
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
# anomalous words are very long/short/improbable
|
||||
def word_anomaly_score(word: dict) -> float:
|
||||
probability = word.get("probability", 0.0)
|
||||
duration = word["end"] - word["start"]
|
||||
score = 0.0
|
||||
if probability < 0.15:
|
||||
score += 1.0
|
||||
if duration < 0.133:
|
||||
score += (0.133 - duration) * 15
|
||||
if duration > 2.0:
|
||||
score += duration - 2.0
|
||||
return score
|
||||
|
||||
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||
if segment is None or not segment["words"]:
|
||||
return False
|
||||
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||
words = words[:8]
|
||||
score = sum(word_anomaly_score(w) for w in words)
|
||||
return score >= 3 or score + 0.01 >= len(words)
|
||||
|
||||
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||
return next((s for s in segments if s["words"]), None)
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
@ -327,17 +412,71 @@ def transcribe(
|
||||
append_punctuations=append_punctuations,
|
||||
last_speech_timestamp=last_speech_timestamp,
|
||||
)
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if len(word_end_timestamps) > 0:
|
||||
last_speech_timestamp = word_end_timestamps[-1]
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
|
||||
# skip silence before possible hallucinations
|
||||
if hallucination_silence_threshold is not None:
|
||||
threshold = hallucination_silence_threshold
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
remaining_duration = window_end_time - last_word_end
|
||||
if remaining_duration > threshold:
|
||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
else:
|
||||
seek = previous_seek + segment_size
|
||||
|
||||
# if first segment might be a hallucination, skip leading silence
|
||||
first_segment = next_words_segment(current_segments)
|
||||
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||
gap = first_segment["start"] - time_offset
|
||||
if gap > threshold:
|
||||
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||
continue
|
||||
|
||||
# skip silence before any possible hallucination that is surrounded
|
||||
# by silence or more hallucinations
|
||||
hal_last_end = last_speech_timestamp
|
||||
for si in range(len(current_segments)):
|
||||
segment = current_segments[si]
|
||||
if not segment["words"]:
|
||||
continue
|
||||
if is_segment_anomaly(segment):
|
||||
next_segment = next_words_segment(
|
||||
current_segments[si + 1 :]
|
||||
)
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
if next_segment is not None:
|
||||
hal_next_start = next_segment["words"][0]["start"]
|
||||
else:
|
||||
hal_next_start = time_offset + segment_duration
|
||||
silence_before = (
|
||||
segment["start"] - hal_last_end > threshold
|
||||
or segment["start"] < threshold
|
||||
or segment["start"] - time_offset < 2.0
|
||||
)
|
||||
silence_after = (
|
||||
hal_next_start - segment["end"] > threshold
|
||||
or is_segment_anomaly(next_segment)
|
||||
or window_end_time - segment["end"] < 2.0
|
||||
)
|
||||
if silence_before and silence_after:
|
||||
seek = round(
|
||||
max(time_offset + 1, segment["start"])
|
||||
* FRAMES_PER_SECOND
|
||||
)
|
||||
if content_duration - segment["end"] < threshold:
|
||||
seek = content_frames
|
||||
current_segments[si:] = []
|
||||
break
|
||||
hal_last_end = segment["end"]
|
||||
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None:
|
||||
last_speech_timestamp = last_word_end
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
@ -381,10 +520,17 @@ def transcribe(
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
def valid_model_name(name):
|
||||
if name in available_models() or os.path.exists(name):
|
||||
return name
|
||||
raise ValueError(
|
||||
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default=None, help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
@ -402,6 +548,8 @@ def cli():
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
||||
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
@ -415,7 +563,10 @@ def cli():
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
@ -447,17 +598,28 @@ def cli():
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
||||
word_options = [
|
||||
"highlight_words",
|
||||
"max_line_count",
|
||||
"max_line_width",
|
||||
"max_words_per_line",
|
||||
]
|
||||
if not args["word_timestamps"]:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
if args["max_words_per_line"] and args["max_line_width"]:
|
||||
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
for audio_path in args.pop("audio"):
|
||||
try:
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path, writer_args)
|
||||
writer(result, audio_path, **writer_args)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
|
||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||
|
||||
kernel = triton.JITFunction(kernel.fn)
|
||||
kernel.src = kernel.src.replace(
|
||||
new_kernel = kernel.src.replace(
|
||||
" LOAD_ALL_ROWS_HERE",
|
||||
"\n".join(
|
||||
[
|
||||
@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
|
||||
]
|
||||
),
|
||||
)
|
||||
kernel.src = kernel.src.replace(
|
||||
|
||||
new_kernel = new_kernel.replace(
|
||||
" BUBBLESORT_HERE",
|
||||
"\n\n".join(
|
||||
[
|
||||
@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
|
||||
]
|
||||
),
|
||||
)
|
||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||
|
||||
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||
|
||||
if hasattr(kernel, "_unsafe_update_src") is True:
|
||||
kernel._unsafe_update_src(new_kernel)
|
||||
kernel.hash = None
|
||||
else:
|
||||
kernel.src = new_kernel
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
110
whisper/utils.py
110
whisper/utils.py
@ -3,7 +3,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, Optional, TextIO
|
||||
from typing import Callable, List, Optional, TextIO
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
@ -68,13 +68,29 @@ def format_timestamp(
|
||||
)
|
||||
|
||||
|
||||
def get_start(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["start"] for s in segments for w in s["words"]),
|
||||
segments[0]["start"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
def get_end(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||
segments[-1]["end"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
class ResultWriter:
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(self, result: dict, audio_path: str, options: dict):
|
||||
def __call__(
|
||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
@ -82,16 +98,20 @@ class ResultWriter:
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f, options=options)
|
||||
self.write_result(result, file=f, options=options, **kwargs)
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WriteTXT(ResultWriter):
|
||||
extension: str = "txt"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
@ -100,26 +120,53 @@ class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
def iterate_result(self, result: dict, options: dict):
|
||||
raw_max_line_width: Optional[int] = options["max_line_width"]
|
||||
max_line_count: Optional[int] = options["max_line_count"]
|
||||
highlight_words: bool = options["highlight_words"]
|
||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||
def iterate_result(
|
||||
self,
|
||||
result: dict,
|
||||
options: Optional[dict] = None,
|
||||
*,
|
||||
max_line_width: Optional[int] = None,
|
||||
max_line_count: Optional[int] = None,
|
||||
highlight_words: bool = False,
|
||||
max_words_per_line: Optional[int] = None,
|
||||
):
|
||||
options = options or {}
|
||||
max_line_width = max_line_width or options.get("max_line_width")
|
||||
max_line_count = max_line_count or options.get("max_line_count")
|
||||
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||
preserve_segments = max_line_count is None or max_line_width is None
|
||||
max_line_width = max_line_width or 1000
|
||||
max_words_per_line = max_words_per_line or 1000
|
||||
|
||||
def iterate_subtitles():
|
||||
line_len = 0
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: list[dict] = []
|
||||
last = result["segments"][0]["words"][0]["start"]
|
||||
subtitle: List[dict] = []
|
||||
last: float = get_start(result["segments"]) or 0.0
|
||||
for segment in result["segments"]:
|
||||
for i, original_timing in enumerate(segment["words"]):
|
||||
chunk_index = 0
|
||||
words_count = max_words_per_line
|
||||
while chunk_index < len(segment["words"]):
|
||||
remaining_words = len(segment["words"]) - chunk_index
|
||||
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||
words_count = remaining_words
|
||||
for i, original_timing in enumerate(
|
||||
segment["words"][chunk_index : chunk_index + words_count]
|
||||
):
|
||||
timing = original_timing.copy()
|
||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
||||
long_pause = (
|
||||
not preserve_segments and timing["start"] - last > 3.0
|
||||
)
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
||||
if (
|
||||
line_len > 0
|
||||
and has_room
|
||||
and not long_pause
|
||||
and not seg_break
|
||||
):
|
||||
# line continuation
|
||||
line_len += len(timing["word"])
|
||||
else:
|
||||
@ -142,10 +189,11 @@ class SubtitlesWriter(ResultWriter):
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
last = timing["start"]
|
||||
chunk_index += max_words_per_line
|
||||
if len(subtitle) > 0:
|
||||
yield subtitle
|
||||
|
||||
if "words" in result["segments"][0]:
|
||||
if len(result["segments"]) > 0 and "words" in result["segments"][0]:
|
||||
for subtitle in iterate_subtitles():
|
||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
||||
@ -161,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
(
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
)
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
@ -190,9 +240,11 @@ class WriteVTT(SubtitlesWriter):
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
print("WEBVTT\n", file=file)
|
||||
for start, end, text in self.iterate_result(result, options):
|
||||
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
@ -201,9 +253,11 @@ class WriteSRT(SubtitlesWriter):
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for i, (start, end, text) in enumerate(
|
||||
self.iterate_result(result, options), start=1
|
||||
self.iterate_result(result, options, **kwargs), start=1
|
||||
):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
@ -220,7 +274,9 @@ class WriteTSV(ResultWriter):
|
||||
|
||||
extension: str = "tsv"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
@ -231,7 +287,9 @@ class WriteTSV(ResultWriter):
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
json.dump(result, file)
|
||||
|
||||
|
||||
@ -249,9 +307,11 @@ def get_writer(
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
|
||||
def write_all(result: dict, file: TextIO, options: dict):
|
||||
def write_all(
|
||||
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for writer in all_writers:
|
||||
writer(result, file, options)
|
||||
writer(result, file, options, **kwargs)
|
||||
|
||||
return write_all
|
||||
|
||||
|
||||
@ -1 +1 @@
|
||||
__version__ = "20230918"
|
||||
__version__ = "20250625"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user