mirror of
https://github.com/openai/whisper.git
synced 2025-09-13 19:20:10 +00:00
Compare commits
No commits in common. "main" and "v20230307" have entirely different histories.
3
.gitattributes
vendored
3
.gitattributes
vendored
@ -1,3 +0,0 @@
|
|||||||
# Override jupyter in Github language stats for more accurate estimate of repo code languages
|
|
||||||
# reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
|
|
||||||
*.ipynb linguist-generated
|
|
13
.github/dependabot.yml
vendored
13
.github/dependabot.yml
vendored
@ -1,13 +0,0 @@
|
|||||||
# Keep GitHub Actions up to date with GitHub's Dependabot...
|
|
||||||
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
|
|
||||||
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
|
|
||||||
version: 2
|
|
||||||
updates:
|
|
||||||
- package-ecosystem: github-actions
|
|
||||||
directory: /
|
|
||||||
groups:
|
|
||||||
github-actions:
|
|
||||||
patterns:
|
|
||||||
- "*" # Group all Actions updates into a single larger pull request
|
|
||||||
schedule:
|
|
||||||
interval: weekly
|
|
12
.github/workflows/python-publish.yml
vendored
12
.github/workflows/python-publish.yml
vendored
@ -8,23 +8,23 @@ jobs:
|
|||||||
deploy:
|
deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v2
|
||||||
- uses: actions-ecosystem/action-regex-match@v2
|
- uses: actions-ecosystem/action-regex-match@v2
|
||||||
id: regex-match
|
id: regex-match
|
||||||
with:
|
with:
|
||||||
text: ${{ github.event.head_commit.message }}
|
text: ${{ github.event.head_commit.message }}
|
||||||
regex: '^Release ([^ ]+)'
|
regex: '^Release ([^ ]+)'
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.8'
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install setuptools wheel twine build
|
pip install setuptools wheel twine
|
||||||
- name: Release
|
- name: Release
|
||||||
if: ${{ steps.regex-match.outputs.match != '' }}
|
if: ${{ steps.regex-match.outputs.match != '' }}
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v1
|
||||||
with:
|
with:
|
||||||
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
||||||
- name: Build and publish
|
- name: Build and publish
|
||||||
@ -33,5 +33,5 @@ jobs:
|
|||||||
TWINE_USERNAME: __token__
|
TWINE_USERNAME: __token__
|
||||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
python -m build --sdist
|
python setup.py sdist
|
||||||
twine upload dist/*
|
twine upload dist/*
|
||||||
|
73
.github/workflows/test.yml
vendored
73
.github/workflows/test.yml
vendored
@ -6,74 +6,23 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Fetch base branch
|
|
||||||
run: git fetch origin ${{ github.base_ref }}
|
|
||||||
- uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: "3.9"
|
|
||||||
architecture: x64
|
|
||||||
- name: Get pip cache dir
|
|
||||||
id: pip-cache
|
|
||||||
run: |
|
|
||||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
|
||||||
- name: pip/pre-commit cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
${{ steps.pip-cache.outputs.dir }}
|
|
||||||
~/.cache/pre-commit
|
|
||||||
key: ${{ runner.os }}-pip-pre-commit-${{ hashFiles('**/.pre-commit-config.yaml') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pip-pre-commit
|
|
||||||
- name: pre-commit
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pre-commit
|
|
||||||
pre-commit install --install-hooks
|
|
||||||
pre-commit run --all-files
|
|
||||||
whisper-test:
|
whisper-test:
|
||||||
needs: pre-commit
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
python-version: ['3.8', '3.9', '3.10']
|
||||||
- python-version: '3.8'
|
pytorch-version: [1.10.2, 1.13.1]
|
||||||
pytorch-version: 1.10.1
|
exclude:
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.8'
|
|
||||||
pytorch-version: 1.13.1
|
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.8'
|
|
||||||
pytorch-version: 2.0.1
|
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.9'
|
|
||||||
pytorch-version: 2.1.2
|
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.10'
|
- python-version: '3.10'
|
||||||
pytorch-version: 2.2.2
|
pytorch-version: 1.10.2
|
||||||
numpy-requirement: "'numpy<2'"
|
|
||||||
- python-version: '3.11'
|
|
||||||
pytorch-version: 2.3.1
|
|
||||||
numpy-requirement: "'numpy'"
|
|
||||||
- python-version: '3.12'
|
|
||||||
pytorch-version: 2.4.1
|
|
||||||
numpy-requirement: "'numpy'"
|
|
||||||
- python-version: '3.12'
|
|
||||||
pytorch-version: 2.5.1
|
|
||||||
numpy-requirement: "'numpy'"
|
|
||||||
- python-version: '3.13'
|
|
||||||
pytorch-version: 2.5.1
|
|
||||||
numpy-requirement: "'numpy'"
|
|
||||||
steps:
|
steps:
|
||||||
- uses: conda-incubator/setup-miniconda@v3
|
- uses: conda-incubator/setup-miniconda@v2
|
||||||
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
|
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v2
|
||||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||||
- run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
|
- run: pip install .["dev"]
|
||||||
|
- run: black --check --diff -t py38 --include '(\.pyi?)$' .
|
||||||
|
- run: isort --check --diff .
|
||||||
|
- run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
|
||||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
repos:
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v5.0.0
|
|
||||||
hooks:
|
|
||||||
- id: check-json
|
|
||||||
- id: end-of-file-fixer
|
|
||||||
types: [file, python]
|
|
||||||
- id: trailing-whitespace
|
|
||||||
types: [file, python]
|
|
||||||
- id: mixed-line-ending
|
|
||||||
- id: check-added-large-files
|
|
||||||
args: [--maxkb=4096]
|
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 25.1.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
- repo: https://github.com/pycqa/isort
|
|
||||||
rev: 6.0.0
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
name: isort (python)
|
|
||||||
args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"]
|
|
||||||
- repo: https://github.com/pycqa/flake8.git
|
|
||||||
rev: 7.1.1
|
|
||||||
hooks:
|
|
||||||
- id: flake8
|
|
||||||
types: [python]
|
|
||||||
args: ["--max-line-length", "88", "--ignore", "E203,E501,W503,W504"]
|
|
91
CHANGELOG.md
91
CHANGELOG.md
@ -1,95 +1,6 @@
|
|||||||
# CHANGELOG
|
# CHANGELOG
|
||||||
|
|
||||||
## [v20250625](https://github.com/openai/whisper/releases/tag/v20250625)
|
## [v20230307](https://github.com/openai/whisper/releases/tag/v202303067)
|
||||||
|
|
||||||
* Fix: Update torch.load to use weights_only=True to prevent security w… ([#2451](https://github.com/openai/whisper/pull/2451))
|
|
||||||
* Fix: Ensure DTW cost tensor is on the same device as input tensor ([#2561](https://github.com/openai/whisper/pull/2561))
|
|
||||||
* docs: updated README to specify translation model limitation ([#2547](https://github.com/openai/whisper/pull/2547))
|
|
||||||
* Fixed triton kernel update to support latest triton versions ([#2588](https://github.com/openai/whisper/pull/2588))
|
|
||||||
* Fix: GitHub display errors for Jupyter notebooks ([#2589](https://github.com/openai/whisper/pull/2589))
|
|
||||||
* Bump the github-actions group with 3 updates ([#2592](https://github.com/openai/whisper/pull/2592))
|
|
||||||
* Keep GitHub Actions up to date with GitHub's Dependabot ([#2486](https://github.com/openai/whisper/pull/2486))
|
|
||||||
* pre-commit: Upgrade black v25.1.0 and isort v6.0.0 ([#2514](https://github.com/openai/whisper/pull/2514))
|
|
||||||
* GitHub Actions: Add Python 3.13 to the testing ([#2487](https://github.com/openai/whisper/pull/2487))
|
|
||||||
* PEP 621: Migrate from setup.py to pyproject.toml ([#2435](https://github.com/openai/whisper/pull/2435))
|
|
||||||
* pre-commit autoupdate && pre-commit run --all-files ([#2484](https://github.com/openai/whisper/pull/2484))
|
|
||||||
* Upgrade GitHub Actions ([#2430](https://github.com/openai/whisper/pull/2430))
|
|
||||||
* Bugfix: Illogical "Avoid computing higher temperatures on no_speech" ([#1903](https://github.com/openai/whisper/pull/1903))
|
|
||||||
* Updating README and doc strings to reflect that n_mels can now be 128 ([#2049](https://github.com/openai/whisper/pull/2049))
|
|
||||||
* fix typo data/README.md ([#2433](https://github.com/openai/whisper/pull/2433))
|
|
||||||
* Update README.md ([#2379](https://github.com/openai/whisper/pull/2379))
|
|
||||||
* Add option to carry initial_prompt with the sliding window ([#2343](https://github.com/openai/whisper/pull/2343))
|
|
||||||
* more pytorch versions in tests ([#2408](https://github.com/openai/whisper/pull/2408))
|
|
||||||
|
|
||||||
## [v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
|
|
||||||
|
|
||||||
* allowing numpy 2 in tests ([#2362](https://github.com/openai/whisper/pull/2362))
|
|
||||||
* large-v3-turbo model ([#2361](https://github.com/openai/whisper/pull/2361))
|
|
||||||
* test on python/pytorch versions up to 3.12 and 2.4.1 ([#2360](https://github.com/openai/whisper/pull/2360))
|
|
||||||
* using sdpa if available ([#2359](https://github.com/openai/whisper/pull/2359))
|
|
||||||
|
|
||||||
## [v20240927](https://github.com/openai/whisper/releases/tag/v20240927)
|
|
||||||
|
|
||||||
* pinning numpy<2 in tests ([#2332](https://github.com/openai/whisper/pull/2332))
|
|
||||||
* Relax triton requirements for compatibility with pytorch 2.4 and newer ([#2307](https://github.com/openai/whisper/pull/2307))
|
|
||||||
* Skip silence around hallucinations ([#1838](https://github.com/openai/whisper/pull/1838))
|
|
||||||
* Fix triton env marker ([#1887](https://github.com/openai/whisper/pull/1887))
|
|
||||||
|
|
||||||
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
|
|
||||||
|
|
||||||
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))
|
|
||||||
|
|
||||||
## [v20231106](https://github.com/openai/whisper/releases/tag/v20231106)
|
|
||||||
|
|
||||||
* large-v3 ([#1761](https://github.com/openai/whisper/pull/1761))
|
|
||||||
|
|
||||||
## [v20231105](https://github.com/openai/whisper/releases/tag/v20231105)
|
|
||||||
|
|
||||||
* remove tiktoken pin ([#1759](https://github.com/openai/whisper/pull/1759))
|
|
||||||
* docs: Disambiguation of the term "relative speed" in the README ([#1751](https://github.com/openai/whisper/pull/1751))
|
|
||||||
* allow_pickle=False while loading of mel matrix IN audio.py ([#1511](https://github.com/openai/whisper/pull/1511))
|
|
||||||
* handling transcribe exceptions. ([#1682](https://github.com/openai/whisper/pull/1682))
|
|
||||||
* Add new option to generate subtitles by a specific number of words ([#1729](https://github.com/openai/whisper/pull/1729))
|
|
||||||
* Fix exception when an audio file with no speech is provided ([#1396](https://github.com/openai/whisper/pull/1396))
|
|
||||||
|
|
||||||
## [v20230918](https://github.com/openai/whisper/releases/tag/v20230918)
|
|
||||||
|
|
||||||
* Add .pre-commit-config.yaml ([#1528](https://github.com/openai/whisper/pull/1528))
|
|
||||||
* fix doc of TextDecoder ([#1526](https://github.com/openai/whisper/pull/1526))
|
|
||||||
* Update model-card.md ([#1643](https://github.com/openai/whisper/pull/1643))
|
|
||||||
* word timing tweaks ([#1559](https://github.com/openai/whisper/pull/1559))
|
|
||||||
* Avoid rearranging all caches ([#1483](https://github.com/openai/whisper/pull/1483))
|
|
||||||
* Improve timestamp heuristics. ([#1461](https://github.com/openai/whisper/pull/1461))
|
|
||||||
* fix condition_on_previous_text ([#1224](https://github.com/openai/whisper/pull/1224))
|
|
||||||
* Fix numba depreceation notice ([#1233](https://github.com/openai/whisper/pull/1233))
|
|
||||||
* Updated README.md to provide more insight on BLEU and specific appendices ([#1236](https://github.com/openai/whisper/pull/1236))
|
|
||||||
* Avoid computing higher temperatures on no_speech segments ([#1279](https://github.com/openai/whisper/pull/1279))
|
|
||||||
* Dropped unused execute bit from mel_filters.npz. ([#1254](https://github.com/openai/whisper/pull/1254))
|
|
||||||
* Drop ffmpeg-python dependency and call ffmpeg directly. ([#1242](https://github.com/openai/whisper/pull/1242))
|
|
||||||
* Python 3.11 ([#1171](https://github.com/openai/whisper/pull/1171))
|
|
||||||
* Update decoding.py ([#1219](https://github.com/openai/whisper/pull/1219))
|
|
||||||
* Update decoding.py ([#1155](https://github.com/openai/whisper/pull/1155))
|
|
||||||
* Update README.md to reference tiktoken ([#1105](https://github.com/openai/whisper/pull/1105))
|
|
||||||
* Implement max line width and max line count, and make word highlighting optional ([#1184](https://github.com/openai/whisper/pull/1184))
|
|
||||||
* Squash long words at window and sentence boundaries. ([#1114](https://github.com/openai/whisper/pull/1114))
|
|
||||||
* python-publish.yml: bump actions version to fix node warning ([#1211](https://github.com/openai/whisper/pull/1211))
|
|
||||||
* Update tokenizer.py ([#1163](https://github.com/openai/whisper/pull/1163))
|
|
||||||
|
|
||||||
## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
|
|
||||||
|
|
||||||
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
|
|
||||||
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
|
|
||||||
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
|
|
||||||
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
|
|
||||||
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
|
|
||||||
|
|
||||||
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
|
|
||||||
|
|
||||||
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
|
|
||||||
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
|
|
||||||
* fix typo in CHANGELOG.md
|
|
||||||
|
|
||||||
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
|
|
||||||
|
|
||||||
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
|
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
|
||||||
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
|
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
|
||||||
|
@ -2,4 +2,6 @@ include requirements.txt
|
|||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include whisper/assets/*
|
include whisper/assets/*
|
||||||
|
include whisper/assets/gpt2/*
|
||||||
|
include whisper/assets/multilingual/*
|
||||||
include whisper/normalizers/english.json
|
include whisper/normalizers/english.json
|
||||||
|
55
README.md
55
README.md
@ -17,7 +17,7 @@ A Transformer sequence-to-sequence model is trained on various speech processing
|
|||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.11 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [OpenAI's tiktoken](https://github.com/openai/tiktoken) for their fast tokenizer implementation. You can download and install (or update to) the latest release of Whisper with the following command:
|
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
|
||||||
|
|
||||||
pip install -U openai-whisper
|
pip install -U openai-whisper
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ choco install ffmpeg
|
|||||||
scoop install ffmpeg
|
scoop install ffmpeg
|
||||||
```
|
```
|
||||||
|
|
||||||
You may need [`rust`](http://rust-lang.org) installed as well, in case [tiktoken](https://github.com/openai/tiktoken) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
|
You may need [`rust`](http://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install setuptools-rust
|
pip install setuptools-rust
|
||||||
@ -57,55 +57,42 @@ pip install setuptools-rust
|
|||||||
|
|
||||||
## Available models and languages
|
## Available models and languages
|
||||||
|
|
||||||
There are six model sizes, four with English-only versions, offering speed and accuracy tradeoffs.
|
There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and relative speed.
|
||||||
Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model.
|
|
||||||
The relative speeds below are measured by transcribing English speech on a A100, and the real-world speed may vary significantly depending on many factors including the language, the speaking speed, and the available hardware.
|
|
||||||
|
|
||||||
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
||||||
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
|
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
|
||||||
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~10x |
|
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x |
|
||||||
| base | 74 M | `base.en` | `base` | ~1 GB | ~7x |
|
| base | 74 M | `base.en` | `base` | ~1 GB | ~16x |
|
||||||
| small | 244 M | `small.en` | `small` | ~2 GB | ~4x |
|
| small | 244 M | `small.en` | `small` | ~2 GB | ~6x |
|
||||||
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
||||||
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
||||||
| turbo | 809 M | N/A | `turbo` | ~6 GB | ~8x |
|
|
||||||
|
|
||||||
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||||
Additionally, the `turbo` model is an optimized version of `large-v3` that offers faster transcription speed with a minimal degradation in accuracy.
|
|
||||||
|
|
||||||
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
|
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller, the better.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
## Command-line usage
|
## Command-line usage
|
||||||
|
|
||||||
The following command will transcribe speech in audio files, using the `turbo` model:
|
The following command will transcribe speech in audio files, using the `medium` model:
|
||||||
|
|
||||||
```bash
|
whisper audio.flac audio.mp3 audio.wav --model medium
|
||||||
whisper audio.flac audio.mp3 audio.wav --model turbo
|
|
||||||
```
|
|
||||||
|
|
||||||
The default setting (which selects the `turbo` model) works well for transcribing English. However, **the `turbo` model is not trained for translation tasks**. If you need to **translate non-English speech into English**, use one of the **multilingual models** (`tiny`, `base`, `small`, `medium`, `large`) instead of `turbo`.
|
The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
|
||||||
|
|
||||||
For example, to transcribe an audio file containing non-English speech, you can specify the language:
|
whisper japanese.wav --language Japanese
|
||||||
|
|
||||||
```bash
|
Adding `--task translate` will translate the speech into English:
|
||||||
whisper japanese.wav --language Japanese
|
|
||||||
```
|
|
||||||
|
|
||||||
To **translate** speech into English, use:
|
whisper japanese.wav --language Japanese --task translate
|
||||||
|
|
||||||
```bash
|
|
||||||
whisper japanese.wav --model medium --language Japanese --task translate
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Note:** The `turbo` model will return the original language even if `--task translate` is specified. Use `medium` or `large` for the best translation results.
|
|
||||||
|
|
||||||
Run the following to view all available options:
|
Run the following to view all available options:
|
||||||
|
|
||||||
```bash
|
whisper --help
|
||||||
whisper --help
|
|
||||||
```
|
|
||||||
|
|
||||||
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
|
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
|
||||||
|
|
||||||
@ -117,7 +104,7 @@ Transcription can also be performed within Python:
|
|||||||
```python
|
```python
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
model = whisper.load_model("turbo")
|
model = whisper.load_model("base")
|
||||||
result = model.transcribe("audio.mp3")
|
result = model.transcribe("audio.mp3")
|
||||||
print(result["text"])
|
print(result["text"])
|
||||||
```
|
```
|
||||||
@ -129,14 +116,14 @@ Below is an example usage of `whisper.detect_language()` and `whisper.decode()`
|
|||||||
```python
|
```python
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
model = whisper.load_model("turbo")
|
model = whisper.load_model("base")
|
||||||
|
|
||||||
# load audio and pad/trim it to fit 30 seconds
|
# load audio and pad/trim it to fit 30 seconds
|
||||||
audio = whisper.load_audio("audio.mp3")
|
audio = whisper.load_audio("audio.mp3")
|
||||||
audio = whisper.pad_or_trim(audio)
|
audio = whisper.pad_or_trim(audio)
|
||||||
|
|
||||||
# make log-Mel spectrogram and move to the same device as the model
|
# make log-Mel spectrogram and move to the same device as the model
|
||||||
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
|
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
||||||
|
|
||||||
# detect the spoken language
|
# detect the spoken language
|
||||||
_, probs = model.detect_language(mel)
|
_, probs = model.detect_language(mel)
|
||||||
|
@ -45,7 +45,7 @@ We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challen
|
|||||||
|
|
||||||
### AMI-IHM, AMI-SDM1
|
### AMI-IHM, AMI-SDM1
|
||||||
|
|
||||||
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 and 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
|
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
|
||||||
|
|
||||||
|
|
||||||
## Long-form English-only datasets
|
## Long-form English-only datasets
|
||||||
|
File diff suppressed because it is too large
Load Diff
Before Width: | Height: | Size: 272 KiB After Width: | Height: | Size: 100 KiB |
@ -16,15 +16,13 @@ The Whisper models are trained for speech recognition and translation tasks, cap
|
|||||||
| small | 244 M | ✓ | ✓ |
|
| small | 244 M | ✓ | ✓ |
|
||||||
| medium | 769 M | ✓ | ✓ |
|
| medium | 769 M | ✓ | ✓ |
|
||||||
| large | 1550 M | | ✓ |
|
| large | 1550 M | | ✓ |
|
||||||
| turbo | 798 M | | ✓ |
|
|
||||||
|
|
||||||
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
|
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661).
|
||||||
Additionally, we've added a `turbo` model in September 2024 which is optimized for inference speed.
|
|
||||||
|
|
||||||
|
|
||||||
### Release date
|
### Release date
|
||||||
|
|
||||||
September 2022 (original series), December 2022 (`large-v2`), November 2023 (`large-v3`), September 2024 (`large-v3-turbo`)
|
September 2022 (original series) and December 2022 (`large-v2`)
|
||||||
|
|
||||||
### Model type
|
### Model type
|
||||||
|
|
||||||
@ -39,7 +37,7 @@ Sequence-to-sequence ASR (automatic speech recognition) and speech translation m
|
|||||||
|
|
||||||
### Evaluated Use
|
### Evaluated Use
|
||||||
|
|
||||||
The primary intended users of these models are AI researchers studying the robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research.
|
The primary intended users of these models are AI researchers studying robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research.
|
||||||
|
|
||||||
The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them.
|
The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them.
|
||||||
|
|
||||||
@ -55,17 +53,17 @@ As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we s
|
|||||||
|
|
||||||
## Performance and Limitations
|
## Performance and Limitations
|
||||||
|
|
||||||
Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, and technical language, as well as zero-shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level.
|
Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, technical language, as well as zero shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level.
|
||||||
|
|
||||||
However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself.
|
However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself.
|
||||||
|
|
||||||
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include a higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356).
|
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356).
|
||||||
|
|
||||||
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis of these limitations is provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse in lower-resource and/or lower-discoverability languages.
|
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
|
||||||
|
|
||||||
|
|
||||||
## Broader Implications
|
## Broader Implications
|
||||||
|
|
||||||
We anticipate that Whisper models’ transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box – their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications.
|
We anticipate that Whisper models’ transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box – their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications.
|
||||||
|
|
||||||
There are also potential dual-use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects.
|
There are also potential dual use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects.
|
||||||
|
@ -949,8 +949,7 @@
|
|||||||
"style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
|
"style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
|
||||||
"value": " 164/164 [05:08<00:00, 1.86s/it]"
|
"value": " 164/164 [05:08<00:00, 1.86s/it]"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"state": {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -4219,8 +4219,7 @@
|
|||||||
"_view_name": "StyleView",
|
"_view_name": "StyleView",
|
||||||
"description_width": ""
|
"description_width": ""
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"state": {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1,50 +1,3 @@
|
|||||||
[build-system]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
requires = [ "setuptools>=61.2" ]
|
|
||||||
|
|
||||||
[project]
|
|
||||||
name = "openai-whisper"
|
|
||||||
description = "Robust Speech Recognition via Large-Scale Weak Supervision"
|
|
||||||
readme.content-type = "text/markdown"
|
|
||||||
readme.file = "README.md"
|
|
||||||
license = { text = "MIT" }
|
|
||||||
authors = [ { name = "OpenAI" } ]
|
|
||||||
requires-python = ">=3.8"
|
|
||||||
classifiers = [
|
|
||||||
"Programming Language :: Python :: 3 :: Only",
|
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
|
||||||
"Programming Language :: Python :: 3.12",
|
|
||||||
"Programming Language :: Python :: 3.13",
|
|
||||||
]
|
|
||||||
dynamic = [ "version" ]
|
|
||||||
dependencies = [
|
|
||||||
"more-itertools",
|
|
||||||
"numba",
|
|
||||||
"numpy",
|
|
||||||
"tiktoken",
|
|
||||||
"torch",
|
|
||||||
"tqdm",
|
|
||||||
"triton>=2; (platform_machine=='x86_64' and sys_platform=='linux') or sys_platform=='linux2'",
|
|
||||||
]
|
|
||||||
optional-dependencies.dev = [ "black", "flake8", "isort", "pytest", "scipy" ]
|
|
||||||
urls = { Homepage = "https://github.com/openai/whisper" }
|
|
||||||
scripts.whisper = "whisper.transcribe:cli"
|
|
||||||
|
|
||||||
[tool.setuptools]
|
|
||||||
py-modules = [ "whisper" ]
|
|
||||||
include-package-data = true
|
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
|
||||||
version = { attr = "whisper.version.__version__" }
|
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
|
||||||
exclude = [ "tests*" ]
|
|
||||||
namespaces = false
|
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
@ -52,3 +5,4 @@ profile = "black"
|
|||||||
include_trailing_comma = true
|
include_trailing_comma = true
|
||||||
line_length = 88
|
line_length = 88
|
||||||
multi_line_output = 3
|
multi_line_output = 3
|
||||||
|
|
||||||
|
@ -3,5 +3,5 @@ numpy
|
|||||||
torch
|
torch
|
||||||
tqdm
|
tqdm
|
||||||
more-itertools
|
more-itertools
|
||||||
tiktoken
|
transformers>=4.19.0
|
||||||
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
ffmpeg-python==0.2.0
|
||||||
|
43
setup.py
Normal file
43
setup.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
def read_version(fname="whisper/version.py"):
|
||||||
|
exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
|
||||||
|
return locals()["__version__"]
|
||||||
|
|
||||||
|
|
||||||
|
requirements = []
|
||||||
|
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
||||||
|
requirements.append("triton==2.0.0")
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="openai-whisper",
|
||||||
|
py_modules=["whisper"],
|
||||||
|
version=read_version(),
|
||||||
|
description="Robust Speech Recognition via Large-Scale Weak Supervision",
|
||||||
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
readme="README.md",
|
||||||
|
python_requires=">=3.8",
|
||||||
|
author="OpenAI",
|
||||||
|
url="https://github.com/openai/whisper",
|
||||||
|
license="MIT",
|
||||||
|
packages=find_packages(exclude=["tests*"]),
|
||||||
|
install_requires=requirements
|
||||||
|
+ [
|
||||||
|
str(r)
|
||||||
|
for r in pkg_resources.parse_requirements(
|
||||||
|
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||||
|
)
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": ["whisper=whisper.transcribe:cli"],
|
||||||
|
},
|
||||||
|
include_package_data=True,
|
||||||
|
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
|
||||||
|
)
|
@ -1,17 +1,7 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from whisper.tokenizer import get_tokenizer
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("multilingual", [True, False])
|
def test_tokenizer():
|
||||||
def test_tokenizer(multilingual):
|
|
||||||
tokenizer = get_tokenizer(multilingual=False)
|
|
||||||
assert tokenizer.sot in tokenizer.sot_sequence
|
|
||||||
assert len(tokenizer.all_language_codes) == len(tokenizer.all_language_tokens)
|
|
||||||
assert all(c < tokenizer.timestamp_begin for c in tokenizer.all_language_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multilingual_tokenizer():
|
|
||||||
gpt2_tokenizer = get_tokenizer(multilingual=False)
|
gpt2_tokenizer = get_tokenizer(multilingual=False)
|
||||||
multilingual_tokenizer = get_tokenizer(multilingual=True)
|
multilingual_tokenizer = get_tokenizer(multilingual=True)
|
||||||
|
|
||||||
@ -22,13 +12,3 @@ def test_multilingual_tokenizer():
|
|||||||
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
||||||
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
||||||
assert len(gpt2_tokens) > len(multilingual_tokens)
|
assert len(gpt2_tokens) > len(multilingual_tokens)
|
||||||
|
|
||||||
|
|
||||||
def test_split_on_unicode():
|
|
||||||
multilingual_tokenizer = get_tokenizer(multilingual=True)
|
|
||||||
|
|
||||||
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
|
|
||||||
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
|
|
||||||
|
|
||||||
assert words == [" elle", " est", " l", "'", "\ufffd", "é", "rit", "oire"]
|
|
||||||
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
|
|
||||||
|
@ -4,7 +4,6 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import whisper
|
import whisper
|
||||||
from whisper.tokenizer import get_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||||
@ -18,18 +17,12 @@ def test_transcribe(model_name: str):
|
|||||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||||
)
|
)
|
||||||
assert result["language"] == "en"
|
assert result["language"] == "en"
|
||||||
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
|
||||||
|
|
||||||
transcription = result["text"].lower()
|
transcription = result["text"].lower()
|
||||||
assert "my fellow americans" in transcription
|
assert "my fellow americans" in transcription
|
||||||
assert "your country" in transcription
|
assert "your country" in transcription
|
||||||
assert "do for you" in transcription
|
assert "do for you" in transcription
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
|
||||||
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
|
||||||
assert tokenizer.decode(all_tokens) == result["text"]
|
|
||||||
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
|
||||||
|
|
||||||
timing_checked = False
|
timing_checked = False
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for timing in segment["words"]:
|
for timing in segment["words"]:
|
||||||
@ -37,6 +30,7 @@ def test_transcribe(model_name: str):
|
|||||||
if timing["word"].strip(" ,") == "Americans":
|
if timing["word"].strip(" ,") == "Americans":
|
||||||
assert timing["start"] <= 1.8
|
assert timing["start"] <= 1.8
|
||||||
assert timing["end"] >= 1.8
|
assert timing["end"] >= 1.8
|
||||||
|
print(timing)
|
||||||
timing_checked = True
|
timing_checked = True
|
||||||
|
|
||||||
assert timing_checked
|
assert timing_checked
|
||||||
|
@ -25,10 +25,7 @@ _MODELS = {
|
|||||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
|
||||||
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
|
||||||
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||||
@ -44,10 +41,7 @@ _ALIGNMENT_HEADS = {
|
|||||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
|
||||||
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
|
||||||
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -147,8 +141,7 @@ def load_model(
|
|||||||
with (
|
with (
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
) as fp:
|
) as fp:
|
||||||
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
|
checkpoint = torch.load(fp, map_location=device)
|
||||||
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
|
||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
|
File diff suppressed because it is too large
Load Diff
50001
whisper/assets/gpt2/merges.txt
Normal file
50001
whisper/assets/gpt2/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
1
whisper/assets/gpt2/special_tokens_map.json
Normal file
1
whisper/assets/gpt2/special_tokens_map.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
1
whisper/assets/gpt2/tokenizer_config.json
Normal file
1
whisper/assets/gpt2/tokenizer_config.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
1
whisper/assets/gpt2/vocab.json
Normal file
1
whisper/assets/gpt2/vocab.json
Normal file
File diff suppressed because one or more lines are too long
BIN
whisper/assets/mel_filters.npz
Normal file → Executable file
BIN
whisper/assets/mel_filters.npz
Normal file → Executable file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
1
whisper/assets/multilingual/added_tokens.json
Normal file
1
whisper/assets/multilingual/added_tokens.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"<|endoftext|>": 50257}
|
50000
whisper/assets/multilingual/merges.txt
Normal file
50000
whisper/assets/multilingual/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
1
whisper/assets/multilingual/special_tokens_map.json
Normal file
1
whisper/assets/multilingual/special_tokens_map.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
1
whisper/assets/multilingual/tokenizer_config.json
Normal file
1
whisper/assets/multilingual/tokenizer_config.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
1
whisper/assets/multilingual/vocab.json
Normal file
1
whisper/assets/multilingual/vocab.json
Normal file
File diff suppressed because one or more lines are too long
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from subprocess import CalledProcessError, run
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import ffmpeg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -12,6 +12,7 @@ from .utils import exact_div
|
|||||||
# hard-coded audio hyperparameters
|
# hard-coded audio hyperparameters
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
N_FFT = 400
|
N_FFT = 400
|
||||||
|
N_MELS = 80
|
||||||
HOP_LENGTH = 160
|
HOP_LENGTH = 160
|
||||||
CHUNK_LENGTH = 30
|
CHUNK_LENGTH = 30
|
||||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
@ -38,25 +39,15 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
-------
|
-------
|
||||||
A NumPy array containing the audio waveform, in float32 dtype.
|
A NumPy array containing the audio waveform, in float32 dtype.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This launches a subprocess to decode audio while down-mixing
|
|
||||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
|
||||||
# fmt: off
|
|
||||||
cmd = [
|
|
||||||
"ffmpeg",
|
|
||||||
"-nostdin",
|
|
||||||
"-threads", "0",
|
|
||||||
"-i", file,
|
|
||||||
"-f", "s16le",
|
|
||||||
"-ac", "1",
|
|
||||||
"-acodec", "pcm_s16le",
|
|
||||||
"-ar", str(sr),
|
|
||||||
"-"
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
try:
|
try:
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||||
except CalledProcessError as e:
|
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||||
|
out, _ = (
|
||||||
|
ffmpeg.input(file, threads=0)
|
||||||
|
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
||||||
|
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||||
|
)
|
||||||
|
except ffmpeg.Error as e:
|
||||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
@ -89,7 +80,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||||
Allows decoupling librosa dependency; saved using:
|
Allows decoupling librosa dependency; saved using:
|
||||||
@ -97,19 +88,18 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
|
|||||||
np.savez_compressed(
|
np.savez_compressed(
|
||||||
"mel_filters.npz",
|
"mel_filters.npz",
|
||||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||||
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||||
|
with np.load(
|
||||||
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||||
with np.load(filters_path, allow_pickle=False) as f:
|
) as f:
|
||||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||||
|
|
||||||
|
|
||||||
def log_mel_spectrogram(
|
def log_mel_spectrogram(
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
n_mels: int = 80,
|
n_mels: int = N_MELS,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
):
|
):
|
||||||
@ -122,7 +112,7 @@ def log_mel_spectrogram(
|
|||||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||||
|
|
||||||
n_mels: int
|
n_mels: int
|
||||||
The number of Mel-frequency filters, only 80 and 128 are supported
|
The number of Mel-frequency filters, only 80 is supported
|
||||||
|
|
||||||
padding: int
|
padding: int
|
||||||
Number of zero samples to pad to the right
|
Number of zero samples to pad to the right
|
||||||
@ -132,7 +122,7 @@ def log_mel_spectrogram(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor, shape = (n_mels, n_frames)
|
torch.Tensor, shape = (80, n_frames)
|
||||||
A Tensor that contains the Mel spectrogram
|
A Tensor that contains the Mel spectrogram
|
||||||
"""
|
"""
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -32,9 +32,7 @@ def detect_language(
|
|||||||
list of dictionaries containing the probability distribution over all languages.
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
"""
|
"""
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(model.is_multilingual)
|
||||||
model.is_multilingual, num_languages=model.num_languages
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
tokenizer.language is None
|
tokenizer.language is None
|
||||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||||
@ -148,10 +146,6 @@ class PyTorchInference(Inference):
|
|||||||
self.kv_cache = {}
|
self.kv_cache = {}
|
||||||
self.hooks = []
|
self.hooks = []
|
||||||
|
|
||||||
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
|
||||||
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
|
||||||
self.kv_modules = key_modules + value_modules
|
|
||||||
|
|
||||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
if not self.kv_cache:
|
if not self.kv_cache:
|
||||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||||
@ -170,10 +164,9 @@ class PyTorchInference(Inference):
|
|||||||
self.hooks = []
|
self.hooks = []
|
||||||
|
|
||||||
def rearrange_kv_cache(self, source_indices):
|
def rearrange_kv_cache(self, source_indices):
|
||||||
if source_indices != list(range(len(source_indices))):
|
for module, tensor in self.kv_cache.items():
|
||||||
for module in self.kv_modules:
|
# update the key/value cache to contain the selected sequences
|
||||||
# update the key/value cache to contain the selected sequences
|
self.kv_cache[module] = tensor[source_indices].detach()
|
||||||
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceRanker:
|
class SequenceRanker:
|
||||||
@ -476,12 +469,7 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
]
|
]
|
||||||
if timestamps.numel() > 0:
|
if timestamps.numel() > 0:
|
||||||
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||||
# also force each segment to have a nonzero length, to prevent infinite looping
|
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
|
||||||
if last_was_timestamp and not penultimate_was_timestamp:
|
|
||||||
timestamp_last = timestamps[-1]
|
|
||||||
else:
|
|
||||||
timestamp_last = timestamps[-1] + 1
|
|
||||||
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
|
||||||
|
|
||||||
if tokens.shape[1] == self.sample_begin:
|
if tokens.shape[1] == self.sample_begin:
|
||||||
# suppress generating non-timestamp tokens at the beginning
|
# suppress generating non-timestamp tokens at the beginning
|
||||||
@ -516,10 +504,7 @@ class DecodingTask:
|
|||||||
|
|
||||||
language = options.language or "en"
|
language = options.language or "en"
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
model.is_multilingual,
|
model.is_multilingual, language=language, task=options.task
|
||||||
num_languages=model.num_languages,
|
|
||||||
language=language,
|
|
||||||
task=options.task,
|
|
||||||
)
|
)
|
||||||
self.tokenizer: Tokenizer = tokenizer
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
self.options: DecodingOptions = self._verify_options(options)
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
@ -678,6 +663,7 @@ class DecodingTask:
|
|||||||
return languages, lang_probs
|
return languages, lang_probs
|
||||||
|
|
||||||
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||||
|
assert audio_features.shape[0] == tokens.shape[0]
|
||||||
n_batch = tokens.shape[0]
|
n_batch = tokens.shape[0]
|
||||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||||
no_speech_probs = [np.nan] * n_batch
|
no_speech_probs = [np.nan] * n_batch
|
||||||
@ -730,7 +716,8 @@ class DecodingTask:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
||||||
|
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
||||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
# call the main sampling loop
|
# call the main sampling loop
|
||||||
@ -791,10 +778,7 @@ class DecodingTask:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(
|
def decode(
|
||||||
model: "Whisper",
|
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
|
||||||
mel: Tensor,
|
|
||||||
options: DecodingOptions = DecodingOptions(),
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[DecodingResult, List[DecodingResult]]:
|
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||||
"""
|
"""
|
||||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||||
@ -818,9 +802,6 @@ def decode(
|
|||||||
if single := mel.ndim == 2:
|
if single := mel.ndim == 2:
|
||||||
mel = mel.unsqueeze(0)
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
if kwargs:
|
|
||||||
options = replace(options, **kwargs)
|
|
||||||
|
|
||||||
result = DecodingTask(model, options).run(mel)
|
result = DecodingTask(model, options).run(mel)
|
||||||
|
|
||||||
return result[0] if single else result
|
return result[0] if single else result
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import gzip
|
import gzip
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, Optional, Tuple
|
from typing import Dict, Iterable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -13,14 +12,6 @@ from .decoding import decode as decode_function
|
|||||||
from .decoding import detect_language as detect_language_function
|
from .decoding import detect_language as detect_language_function
|
||||||
from .transcribe import transcribe as transcribe_function
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
|
||||||
|
|
||||||
SDPA_AVAILABLE = True
|
|
||||||
except (ImportError, RuntimeError, OSError):
|
|
||||||
scaled_dot_product_attention = None
|
|
||||||
SDPA_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelDimensions:
|
class ModelDimensions:
|
||||||
@ -68,19 +59,7 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def disable_sdpa():
|
|
||||||
prev_state = MultiHeadAttention.use_sdpa
|
|
||||||
try:
|
|
||||||
MultiHeadAttention.use_sdpa = False
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
MultiHeadAttention.use_sdpa = prev_state
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
use_sdpa = True
|
|
||||||
|
|
||||||
def __init__(self, n_state: int, n_head: int):
|
def __init__(self, n_state: int, n_head: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
@ -113,30 +92,20 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
):
|
||||||
n_batch, n_ctx, n_state = q.shape
|
n_batch, n_ctx, n_state = q.shape
|
||||||
scale = (n_state // self.n_head) ** -0.25
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
qk = q @ k
|
||||||
a = scaled_dot_product_attention(
|
if mask is not None:
|
||||||
q, k, v, is_causal=mask is not None and n_ctx > 1
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
)
|
qk = qk.float()
|
||||||
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
|
||||||
qk = None
|
|
||||||
else:
|
|
||||||
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
|
||||||
if mask is not None:
|
|
||||||
qk = qk + mask[:n_ctx, :n_ctx]
|
|
||||||
qk = qk.float()
|
|
||||||
|
|
||||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||||
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||||
qk = qk.detach()
|
|
||||||
|
|
||||||
return out, qk
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
@ -228,7 +197,7 @@ class TextDecoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||||
the text tokens
|
the text tokens
|
||||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
||||||
the encoded audio features to be attended on
|
the encoded audio features to be attended on
|
||||||
"""
|
"""
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
@ -267,8 +236,7 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_text_head,
|
self.dims.n_text_head,
|
||||||
self.dims.n_text_layer,
|
self.dims.n_text_layer,
|
||||||
)
|
)
|
||||||
# use the last half among the decoder layers for time alignment by default;
|
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
||||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
|
||||||
all_heads = torch.zeros(
|
all_heads = torch.zeros(
|
||||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||||
)
|
)
|
||||||
@ -301,11 +269,7 @@ class Whisper(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multilingual(self):
|
def is_multilingual(self):
|
||||||
return self.dims.n_vocab >= 51865
|
return self.dims.n_vocab == 51865
|
||||||
|
|
||||||
@property
|
|
||||||
def num_languages(self):
|
|
||||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
|
||||||
|
|
||||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
|
@ -30,19 +30,15 @@ def remove_symbols_and_diacritics(s: str, keep=""):
|
|||||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||||
"""
|
"""
|
||||||
return "".join(
|
return "".join(
|
||||||
(
|
c
|
||||||
c
|
if c in keep
|
||||||
if c in keep
|
else ADDITIONAL_DIACRITICS[c]
|
||||||
else (
|
if c in ADDITIONAL_DIACRITICS
|
||||||
ADDITIONAL_DIACRITICS[c]
|
else ""
|
||||||
if c in ADDITIONAL_DIACRITICS
|
if unicodedata.category(c) == "Mn"
|
||||||
else (
|
else " "
|
||||||
""
|
if unicodedata.category(c)[0] in "MSP"
|
||||||
if unicodedata.category(c) == "Mn"
|
else c
|
||||||
else " " if unicodedata.category(c)[0] in "MSP" else c
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for c in unicodedata.normalize("NFKD", s)
|
for c in unicodedata.normalize("NFKD", s)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import itertools
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -54,7 +53,7 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@numba.jit(nopython=True)
|
@numba.jit
|
||||||
def backtrace(trace: np.ndarray):
|
def backtrace(trace: np.ndarray):
|
||||||
i = trace.shape[0] - 1
|
i = trace.shape[0] - 1
|
||||||
j = trace.shape[1] - 1
|
j = trace.shape[1] - 1
|
||||||
@ -117,7 +116,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
|||||||
x_skew = x_skew.T.contiguous()
|
x_skew = x_skew.T.contiguous()
|
||||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||||
cost[0, 0] = 0
|
cost[0, 0] = 0
|
||||||
cost = cost.to(x.device)
|
cost = cost.cuda()
|
||||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||||
|
|
||||||
dtw_kernel[(1,)](
|
dtw_kernel[(1,)](
|
||||||
@ -170,9 +169,6 @@ def find_alignment(
|
|||||||
medfilt_width: int = 7,
|
medfilt_width: int = 7,
|
||||||
qk_scale: float = 1.0,
|
qk_scale: float = 1.0,
|
||||||
) -> List[WordTiming]:
|
) -> List[WordTiming]:
|
||||||
if len(text_tokens) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tokens = torch.tensor(
|
tokens = torch.tensor(
|
||||||
[
|
[
|
||||||
*tokenizer.sot_sequence,
|
*tokenizer.sot_sequence,
|
||||||
@ -191,9 +187,7 @@ def find_alignment(
|
|||||||
for i, block in enumerate(model.decoder.blocks)
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
]
|
]
|
||||||
|
|
||||||
from .model import disable_sdpa
|
with torch.no_grad():
|
||||||
|
|
||||||
with torch.no_grad(), disable_sdpa():
|
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
@ -204,7 +198,7 @@ def find_alignment(
|
|||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = (weights * qk_scale).softmax(dim=-1)
|
weights = (weights * qk_scale).softmax(dim=-1)
|
||||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||||
@ -216,13 +210,6 @@ def find_alignment(
|
|||||||
text_indices, time_indices = dtw(-matrix)
|
text_indices, time_indices = dtw(-matrix)
|
||||||
|
|
||||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||||
if len(word_tokens) <= 1:
|
|
||||||
# return on eot only
|
|
||||||
# >>> np.pad([], (1, 0))
|
|
||||||
# array([0.])
|
|
||||||
# This results in crashes when we lookup jump_times with float, like
|
|
||||||
# IndexError: arrays used as indices must be of integer (or boolean) type
|
|
||||||
return []
|
|
||||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||||
|
|
||||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
@ -234,6 +221,19 @@ def find_alignment(
|
|||||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# hack: ensure the first and second word is not longer than twice the median word duration.
|
||||||
|
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||||
|
word_durations = end_times - start_times
|
||||||
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
|
if len(word_durations) > 0:
|
||||||
|
median_duration = np.median(word_durations)
|
||||||
|
max_duration = median_duration * 2
|
||||||
|
if len(word_durations) >= 2 and word_durations[1] > max_duration:
|
||||||
|
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||||
|
end_times[0] = start_times[1] = boundary
|
||||||
|
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
|
||||||
|
start_times[0] = max(0, end_times[0] - max_duration)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
WordTiming(word, tokens, start, end, probability)
|
WordTiming(word, tokens, start, end, probability)
|
||||||
for word, tokens, start, end, probability in zip(
|
for word, tokens, start, end, probability in zip(
|
||||||
@ -285,104 +285,39 @@ def add_word_timestamps(
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
last_speech_timestamp: float,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if len(segments) == 0:
|
if len(segments) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
text_tokens_per_segment = [
|
text_tokens = [t for segment in segments for t in segment["tokens"]]
|
||||||
[token for token in segment["tokens"] if token < tokenizer.eot]
|
|
||||||
for segment in segments
|
|
||||||
]
|
|
||||||
|
|
||||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
|
||||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||||
word_durations = np.array([t.end - t.start for t in alignment])
|
|
||||||
word_durations = word_durations[word_durations.nonzero()]
|
|
||||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
|
||||||
median_duration = min(0.7, float(median_duration))
|
|
||||||
max_duration = median_duration * 2
|
|
||||||
|
|
||||||
# hack: truncate long words at sentence boundaries.
|
|
||||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
|
||||||
if len(word_durations) > 0:
|
|
||||||
sentence_end_marks = ".。!!??"
|
|
||||||
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
|
||||||
for i in range(1, len(alignment)):
|
|
||||||
if alignment[i].end - alignment[i].start > max_duration:
|
|
||||||
if alignment[i].word in sentence_end_marks:
|
|
||||||
alignment[i].end = alignment[i].start + max_duration
|
|
||||||
elif alignment[i - 1].word in sentence_end_marks:
|
|
||||||
alignment[i].start = alignment[i].end - max_duration
|
|
||||||
|
|
||||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||||
|
|
||||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||||
word_index = 0
|
segment_lengths = [len(s["tokens"]) for s in segments]
|
||||||
|
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
|
||||||
|
|
||||||
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
for segment in segments:
|
||||||
saved_tokens = 0
|
segment["words"] = []
|
||||||
words = []
|
|
||||||
|
|
||||||
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
|
||||||
timing = alignment[word_index]
|
for i, timing in enumerate(alignment):
|
||||||
|
if timing.word:
|
||||||
if timing.word:
|
segment = segments[token_sources[word_boundaries[i]]]
|
||||||
words.append(
|
start = round(time_offset + timing.start, 2)
|
||||||
dict(
|
end = round(time_offset + timing.end, 2)
|
||||||
word=timing.word,
|
segment["words"].append(
|
||||||
start=round(time_offset + timing.start, 2),
|
dict(
|
||||||
end=round(time_offset + timing.end, 2),
|
word=timing.word,
|
||||||
probability=timing.probability,
|
start=start,
|
||||||
)
|
end=end,
|
||||||
|
probability=timing.probability,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
saved_tokens += len(timing.tokens)
|
for segment in segments:
|
||||||
word_index += 1
|
if len(words := segment["words"]) > 0:
|
||||||
|
# adjust the segment-level timestamps based on the word-level timestamps
|
||||||
# hack: truncate long words at segment boundaries.
|
segment["start"] = words[0]["start"]
|
||||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
segment["end"] = words[-1]["end"]
|
||||||
if len(words) > 0:
|
|
||||||
# ensure the first and second word after a pause is not longer than
|
|
||||||
# twice the median word duration.
|
|
||||||
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
|
||||||
words[0]["end"] - words[0]["start"] > max_duration
|
|
||||||
or (
|
|
||||||
len(words) > 1
|
|
||||||
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
|
||||||
)
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
len(words) > 1
|
|
||||||
and words[1]["end"] - words[1]["start"] > max_duration
|
|
||||||
):
|
|
||||||
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
|
||||||
words[0]["end"] = words[1]["start"] = boundary
|
|
||||||
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
|
||||||
|
|
||||||
# prefer the segment-level start timestamp if the first word is too long.
|
|
||||||
if (
|
|
||||||
segment["start"] < words[0]["end"]
|
|
||||||
and segment["start"] - 0.5 > words[0]["start"]
|
|
||||||
):
|
|
||||||
words[0]["start"] = max(
|
|
||||||
0, min(words[0]["end"] - median_duration, segment["start"])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
segment["start"] = words[0]["start"]
|
|
||||||
|
|
||||||
# prefer the segment-level end timestamp if the last word is too long.
|
|
||||||
if (
|
|
||||||
segment["end"] > words[-1]["start"]
|
|
||||||
and segment["end"] + 0.5 < words[-1]["end"]
|
|
||||||
):
|
|
||||||
words[-1]["end"] = max(
|
|
||||||
words[-1]["start"] + median_duration, segment["end"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
segment["end"] = words[-1]["end"]
|
|
||||||
|
|
||||||
last_speech_timestamp = segment["end"]
|
|
||||||
|
|
||||||
segment["words"] = words
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import base64
|
|
||||||
import os
|
import os
|
||||||
import string
|
import string
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from functools import cached_property, lru_cache
|
from functools import cached_property, lru_cache
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import tiktoken
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
LANGUAGES = {
|
LANGUAGES = {
|
||||||
"en": "english",
|
"en": "english",
|
||||||
@ -107,7 +108,6 @@ LANGUAGES = {
|
|||||||
"ba": "bashkir",
|
"ba": "bashkir",
|
||||||
"jw": "javanese",
|
"jw": "javanese",
|
||||||
"su": "sundanese",
|
"su": "sundanese",
|
||||||
"yue": "cantonese",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# language code lookup by name, with a few language aliases
|
# language code lookup by name, with a few language aliases
|
||||||
@ -124,89 +124,77 @@ TO_LANGUAGE_CODE = {
|
|||||||
"moldovan": "ro",
|
"moldovan": "ro",
|
||||||
"sinhalese": "si",
|
"sinhalese": "si",
|
||||||
"castilian": "es",
|
"castilian": "es",
|
||||||
"mandarin": "zh",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
||||||
|
|
||||||
encoding: tiktoken.Encoding
|
tokenizer: "GPT2TokenizerFast"
|
||||||
num_languages: int
|
language: Optional[str]
|
||||||
language: Optional[str] = None
|
sot_sequence: Tuple[int]
|
||||||
task: Optional[str] = None
|
|
||||||
sot_sequence: Tuple[int] = ()
|
|
||||||
special_tokens: Dict[str, int] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
for special in self.encoding.special_tokens_set:
|
|
||||||
special_token = self.encoding.encode_single_token(special)
|
|
||||||
self.special_tokens[special] = special_token
|
|
||||||
|
|
||||||
sot: int = self.special_tokens["<|startoftranscript|>"]
|
|
||||||
translate: int = self.special_tokens["<|translate|>"]
|
|
||||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
|
||||||
|
|
||||||
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
|
||||||
sot_sequence = [sot]
|
|
||||||
if self.language is not None:
|
|
||||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
|
||||||
if self.task is not None:
|
|
||||||
task_token: int = transcribe if self.task == "transcribe" else translate
|
|
||||||
sot_sequence.append(task_token)
|
|
||||||
|
|
||||||
self.sot_sequence = tuple(sot_sequence)
|
|
||||||
|
|
||||||
def encode(self, text, **kwargs):
|
def encode(self, text, **kwargs):
|
||||||
return self.encoding.encode(text, **kwargs)
|
return self.tokenizer.encode(text, **kwargs)
|
||||||
|
|
||||||
def decode(self, token_ids: List[int], **kwargs) -> str:
|
def decode(
|
||||||
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
|
||||||
return self.encoding.decode(token_ids, **kwargs)
|
):
|
||||||
|
return self.tokenizer.decode(token_ids, **kwargs)
|
||||||
|
|
||||||
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
def decode_with_timestamps(self, tokens) -> str:
|
||||||
"""
|
"""
|
||||||
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
||||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||||
"""
|
"""
|
||||||
return self.encoding.decode(token_ids, **kwargs)
|
outputs = [[]]
|
||||||
|
for token in tokens:
|
||||||
|
if token >= self.timestamp_begin:
|
||||||
|
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
||||||
|
outputs.append(timestamp)
|
||||||
|
outputs.append([])
|
||||||
|
else:
|
||||||
|
outputs[-1].append(token)
|
||||||
|
return "".join(
|
||||||
|
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||||
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def eot(self) -> int:
|
def eot(self) -> int:
|
||||||
return self.encoding.eot_token
|
return self.tokenizer.eos_token_id
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def transcribe(self) -> int:
|
def transcribe(self) -> int:
|
||||||
return self.special_tokens["<|transcribe|>"]
|
return self._get_single_token_id("<|transcribe|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def translate(self) -> int:
|
def translate(self) -> int:
|
||||||
return self.special_tokens["<|translate|>"]
|
return self._get_single_token_id("<|translate|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot(self) -> int:
|
def sot(self) -> int:
|
||||||
return self.special_tokens["<|startoftranscript|>"]
|
return self._get_single_token_id("<|startoftranscript|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot_lm(self) -> int:
|
def sot_lm(self) -> int:
|
||||||
return self.special_tokens["<|startoflm|>"]
|
return self._get_single_token_id("<|startoflm|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot_prev(self) -> int:
|
def sot_prev(self) -> int:
|
||||||
return self.special_tokens["<|startofprev|>"]
|
return self._get_single_token_id("<|startofprev|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def no_speech(self) -> int:
|
def no_speech(self) -> int:
|
||||||
return self.special_tokens["<|nospeech|>"]
|
return self._get_single_token_id("<|nospeech|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def no_timestamps(self) -> int:
|
def no_timestamps(self) -> int:
|
||||||
return self.special_tokens["<|notimestamps|>"]
|
return self._get_single_token_id("<|notimestamps|>")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def timestamp_begin(self) -> int:
|
def timestamp_begin(self) -> int:
|
||||||
return self.special_tokens["<|0.00|>"]
|
return self.tokenizer.all_special_ids[-1] + 1
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def language_token(self) -> int:
|
def language_token(self) -> int:
|
||||||
@ -214,25 +202,32 @@ class Tokenizer:
|
|||||||
if self.language is None:
|
if self.language is None:
|
||||||
raise ValueError("This tokenizer does not have language token configured")
|
raise ValueError("This tokenizer does not have language token configured")
|
||||||
|
|
||||||
return self.to_language_token(self.language)
|
additional_tokens = dict(
|
||||||
|
zip(
|
||||||
|
self.tokenizer.additional_special_tokens,
|
||||||
|
self.tokenizer.additional_special_tokens_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
candidate = f"<|{self.language}|>"
|
||||||
|
if candidate in additional_tokens:
|
||||||
|
return additional_tokens[candidate]
|
||||||
|
|
||||||
def to_language_token(self, language):
|
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||||
if token := self.special_tokens.get(f"<|{language}|>", None):
|
|
||||||
return token
|
|
||||||
|
|
||||||
raise KeyError(f"Language {language} not found in tokenizer.")
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def all_language_tokens(self) -> Tuple[int]:
|
def all_language_tokens(self) -> Tuple[int]:
|
||||||
result = []
|
result = []
|
||||||
for token, token_id in self.special_tokens.items():
|
for token, token_id in zip(
|
||||||
|
self.tokenizer.additional_special_tokens,
|
||||||
|
self.tokenizer.additional_special_tokens_ids,
|
||||||
|
):
|
||||||
if token.strip("<|>") in LANGUAGES:
|
if token.strip("<|>") in LANGUAGES:
|
||||||
result.append(token_id)
|
result.append(token_id)
|
||||||
return tuple(result)[: self.num_languages]
|
return tuple(result)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def all_language_codes(self) -> Tuple[str]:
|
def all_language_codes(self) -> Tuple[str]:
|
||||||
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||||
@ -263,19 +258,24 @@ class Tokenizer:
|
|||||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||||
|
|
||||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||||
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||||
for symbol in symbols + list(miscellaneous):
|
for symbol in symbols + list(miscellaneous):
|
||||||
for tokens in [
|
for tokens in [
|
||||||
self.encoding.encode(symbol),
|
self.tokenizer.encode(symbol),
|
||||||
self.encoding.encode(" " + symbol),
|
self.tokenizer.encode(" " + symbol),
|
||||||
]:
|
]:
|
||||||
if len(tokens) == 1 or symbol in miscellaneous:
|
if len(tokens) == 1 or symbol in miscellaneous:
|
||||||
result.add(tokens[0])
|
result.add(tokens[0])
|
||||||
|
|
||||||
return tuple(sorted(result))
|
return tuple(sorted(result))
|
||||||
|
|
||||||
|
def _get_single_token_id(self, text) -> int:
|
||||||
|
tokens = self.tokenizer.encode(text)
|
||||||
|
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||||
|
return tokens[0]
|
||||||
|
|
||||||
def split_to_word_tokens(self, tokens: List[int]):
|
def split_to_word_tokens(self, tokens: List[int]):
|
||||||
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
||||||
# These languages don't typically use spaces, so it is difficult to split words
|
# These languages don't typically use spaces, so it is difficult to split words
|
||||||
# without morpheme analysis. Here, we instead split words at any
|
# without morpheme analysis. Here, we instead split words at any
|
||||||
# position where the tokens are decoded as valid unicode points
|
# position where the tokens are decoded as valid unicode points
|
||||||
@ -284,27 +284,17 @@ class Tokenizer:
|
|||||||
return self.split_tokens_on_spaces(tokens)
|
return self.split_tokens_on_spaces(tokens)
|
||||||
|
|
||||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||||
decoded_full = self.decode_with_timestamps(tokens)
|
|
||||||
replacement_char = "\ufffd"
|
|
||||||
|
|
||||||
words = []
|
words = []
|
||||||
word_tokens = []
|
word_tokens = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
unicode_offset = 0
|
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
decoded = self.decode_with_timestamps(current_tokens)
|
decoded = self.decode_with_timestamps(current_tokens)
|
||||||
|
if "\ufffd" not in decoded:
|
||||||
if (
|
|
||||||
replacement_char not in decoded
|
|
||||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
|
||||||
== replacement_char
|
|
||||||
):
|
|
||||||
words.append(decoded)
|
words.append(decoded)
|
||||||
word_tokens.append(current_tokens)
|
word_tokens.append(current_tokens)
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
unicode_offset += len(decoded)
|
|
||||||
|
|
||||||
return words, word_tokens
|
return words, word_tokens
|
||||||
|
|
||||||
@ -328,48 +318,32 @@ class Tokenizer:
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
def build_tokenizer(name: str = "gpt2"):
|
||||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
ranks = {
|
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
||||||
base64.b64decode(token): int(rank)
|
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
||||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
|
||||||
}
|
|
||||||
n_vocab = len(ranks)
|
|
||||||
special_tokens = {}
|
|
||||||
|
|
||||||
specials = [
|
specials = [
|
||||||
"<|endoftext|>",
|
|
||||||
"<|startoftranscript|>",
|
"<|startoftranscript|>",
|
||||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||||
"<|translate|>",
|
"<|translate|>",
|
||||||
"<|transcribe|>",
|
"<|transcribe|>",
|
||||||
"<|startoflm|>",
|
"<|startoflm|>",
|
||||||
"<|startofprev|>",
|
"<|startofprev|>",
|
||||||
"<|nospeech|>",
|
"<|nospeech|>",
|
||||||
"<|notimestamps|>",
|
"<|notimestamps|>",
|
||||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for token in specials:
|
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
||||||
special_tokens[token] = n_vocab
|
return tokenizer
|
||||||
n_vocab += 1
|
|
||||||
|
|
||||||
return tiktoken.Encoding(
|
|
||||||
name=os.path.basename(vocab_path),
|
|
||||||
explicit_n_vocab=n_vocab,
|
|
||||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
|
||||||
mergeable_ranks=ranks,
|
|
||||||
special_tokens=special_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_tokenizer(
|
def get_tokenizer(
|
||||||
multilingual: bool,
|
multilingual: bool,
|
||||||
*,
|
*,
|
||||||
num_languages: int = 99,
|
|
||||||
language: Optional[str] = None,
|
|
||||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||||
|
language: Optional[str] = None,
|
||||||
) -> Tokenizer:
|
) -> Tokenizer:
|
||||||
if language is not None:
|
if language is not None:
|
||||||
language = language.lower()
|
language = language.lower()
|
||||||
@ -380,16 +354,27 @@ def get_tokenizer(
|
|||||||
raise ValueError(f"Unsupported language: {language}")
|
raise ValueError(f"Unsupported language: {language}")
|
||||||
|
|
||||||
if multilingual:
|
if multilingual:
|
||||||
encoding_name = "multilingual"
|
tokenizer_name = "multilingual"
|
||||||
language = language or "en"
|
|
||||||
task = task or "transcribe"
|
task = task or "transcribe"
|
||||||
|
language = language or "en"
|
||||||
else:
|
else:
|
||||||
encoding_name = "gpt2"
|
tokenizer_name = "gpt2"
|
||||||
language = None
|
|
||||||
task = None
|
task = None
|
||||||
|
language = None
|
||||||
|
|
||||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
tokenizer = build_tokenizer(name=tokenizer_name)
|
||||||
|
all_special_ids: List[int] = tokenizer.all_special_ids
|
||||||
|
sot: int = all_special_ids[1]
|
||||||
|
translate: int = all_special_ids[-6]
|
||||||
|
transcribe: int = all_special_ids[-5]
|
||||||
|
|
||||||
|
langs = tuple(LANGUAGES.keys())
|
||||||
|
sot_sequence = [sot]
|
||||||
|
if language is not None:
|
||||||
|
sot_sequence.append(sot + 1 + langs.index(language))
|
||||||
|
if task is not None:
|
||||||
|
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||||
|
|
||||||
return Tokenizer(
|
return Tokenizer(
|
||||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import traceback
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -23,7 +22,6 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
exact_div,
|
exact_div,
|
||||||
format_timestamp,
|
format_timestamp,
|
||||||
get_end,
|
|
||||||
get_writer,
|
get_writer,
|
||||||
make_safe,
|
make_safe,
|
||||||
optional_float,
|
optional_float,
|
||||||
@ -46,12 +44,9 @@ def transcribe(
|
|||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
carry_initial_prompt: bool = False,
|
|
||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
clip_timestamps: Union[str, List[float]] = "0",
|
|
||||||
hallucination_silence_threshold: Optional[float] = None,
|
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -103,22 +98,9 @@ def transcribe(
|
|||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||||
to make it more likely to predict those word correctly.
|
to make it more likely to predict those word correctly.
|
||||||
|
|
||||||
carry_initial_prompt: bool
|
|
||||||
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
|
||||||
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
|
||||||
left-sliced to make space.
|
|
||||||
|
|
||||||
decode_options: dict
|
decode_options: dict
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
clip_timestamps: Union[str, List[float]]
|
|
||||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
|
||||||
The last end timestamp defaults to the end of the file.
|
|
||||||
|
|
||||||
hallucination_silence_threshold: Optional[float]
|
|
||||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
|
||||||
when a possible hallucination is detected
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
@ -136,9 +118,8 @@ def transcribe(
|
|||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
if not model.is_multilingual:
|
if not model.is_multilingual:
|
||||||
@ -158,25 +139,7 @@ def transcribe(
|
|||||||
|
|
||||||
language: str = decode_options["language"]
|
language: str = decode_options["language"]
|
||||||
task: str = decode_options.get("task", "transcribe")
|
task: str = decode_options.get("task", "transcribe")
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||||
model.is_multilingual,
|
|
||||||
num_languages=model.num_languages,
|
|
||||||
language=language,
|
|
||||||
task=task,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(clip_timestamps, str):
|
|
||||||
clip_timestamps = [
|
|
||||||
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
|
||||||
]
|
|
||||||
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
|
||||||
if len(seek_points) == 0:
|
|
||||||
seek_points.append(0)
|
|
||||||
if len(seek_points) % 2 == 1:
|
|
||||||
seek_points.append(content_frames)
|
|
||||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
|
||||||
|
|
||||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
|
||||||
|
|
||||||
if word_timestamps and task == "translate":
|
if word_timestamps and task == "translate":
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
@ -211,20 +174,13 @@ def transcribe(
|
|||||||
and decode_result.avg_logprob < logprob_threshold
|
and decode_result.avg_logprob < logprob_threshold
|
||||||
):
|
):
|
||||||
needs_fallback = True # average log probability is too low
|
needs_fallback = True # average log probability is too low
|
||||||
if (
|
|
||||||
no_speech_threshold is not None
|
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
|
||||||
and logprob_threshold is not None
|
|
||||||
and decode_result.avg_logprob < logprob_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = False # silence
|
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
break
|
break
|
||||||
|
|
||||||
return decode_result
|
return decode_result
|
||||||
|
|
||||||
clip_idx = 0
|
seek = 0
|
||||||
seek = seek_clips[clip_idx][0]
|
|
||||||
input_stride = exact_div(
|
input_stride = exact_div(
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
N_FRAMES, model.dims.n_audio_ctx
|
||||||
) # mel frames per output token: 2
|
) # mel frames per output token: 2
|
||||||
@ -235,25 +191,23 @@ def transcribe(
|
|||||||
all_segments = []
|
all_segments = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
|
||||||
if initial_prompt is not None:
|
if initial_prompt is not None:
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
remaining_prompt_length -= len(initial_prompt_tokens)
|
|
||||||
else:
|
else:
|
||||||
initial_prompt_tokens = []
|
initial_prompt_tokens = []
|
||||||
|
|
||||||
def new_segment(
|
def new_segment(
|
||||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||||
):
|
):
|
||||||
tokens = tokens.tolist()
|
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
|
||||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
|
||||||
return {
|
return {
|
||||||
|
"id": len(all_segments),
|
||||||
"seek": seek,
|
"seek": seek,
|
||||||
"start": start,
|
"start": start,
|
||||||
"end": end,
|
"end": end,
|
||||||
"text": tokenizer.decode(text_tokens),
|
"text": tokenizer.decode(text_tokens),
|
||||||
"tokens": tokens,
|
"tokens": text_tokens,
|
||||||
"temperature": result.temperature,
|
"temperature": result.temperature,
|
||||||
"avg_logprob": result.avg_logprob,
|
"avg_logprob": result.avg_logprob,
|
||||||
"compression_ratio": result.compression_ratio,
|
"compression_ratio": result.compression_ratio,
|
||||||
@ -264,34 +218,14 @@ def transcribe(
|
|||||||
with tqdm.tqdm(
|
with tqdm.tqdm(
|
||||||
total=content_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
last_speech_timestamp = 0.0
|
while seek < content_frames:
|
||||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
|
||||||
# A later commit should turn this into a simpler nested loop.
|
|
||||||
# for seek_clip_start, seek_clip_end in seek_clips:
|
|
||||||
# while seek < seek_clip_end
|
|
||||||
while clip_idx < len(seek_clips):
|
|
||||||
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
|
||||||
if seek < seek_clip_start:
|
|
||||||
seek = seek_clip_start
|
|
||||||
if seek >= seek_clip_end:
|
|
||||||
clip_idx += 1
|
|
||||||
if clip_idx < len(seek_clips):
|
|
||||||
seek = seek_clips[clip_idx][0]
|
|
||||||
continue
|
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||||
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
segment_size = min(N_FRAMES, content_frames - seek)
|
||||||
mel_segment = mel[:, seek : seek + segment_size]
|
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
if carry_initial_prompt:
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
|
||||||
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
|
||||||
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
|
||||||
else:
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
||||||
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
@ -311,30 +245,7 @@ def transcribe(
|
|||||||
|
|
||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
|
current_tokens = []
|
||||||
# anomalous words are very long/short/improbable
|
|
||||||
def word_anomaly_score(word: dict) -> float:
|
|
||||||
probability = word.get("probability", 0.0)
|
|
||||||
duration = word["end"] - word["start"]
|
|
||||||
score = 0.0
|
|
||||||
if probability < 0.15:
|
|
||||||
score += 1.0
|
|
||||||
if duration < 0.133:
|
|
||||||
score += (0.133 - duration) * 15
|
|
||||||
if duration > 2.0:
|
|
||||||
score += duration - 2.0
|
|
||||||
return score
|
|
||||||
|
|
||||||
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
|
||||||
if segment is None or not segment["words"]:
|
|
||||||
return False
|
|
||||||
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
|
||||||
words = words[:8]
|
|
||||||
score = sum(word_anomaly_score(w) for w in words)
|
|
||||||
return score >= 3 or score + 0.01 >= len(words)
|
|
||||||
|
|
||||||
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
|
||||||
return next((s for s in segments if s["words"]), None)
|
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
@ -364,6 +275,7 @@ def transcribe(
|
|||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
current_tokens.append(sliced_tokens.tolist())
|
||||||
last_slice = current_slice
|
last_slice = current_slice
|
||||||
|
|
||||||
if single_timestamp_ending:
|
if single_timestamp_ending:
|
||||||
@ -375,6 +287,7 @@ def transcribe(
|
|||||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
seek += last_timestamp_pos * input_stride
|
seek += last_timestamp_pos * input_stride
|
||||||
|
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||||
else:
|
else:
|
||||||
duration = segment_duration
|
duration = segment_duration
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
@ -396,8 +309,13 @@ def transcribe(
|
|||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
current_tokens.append(tokens.tolist())
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
|
|
||||||
|
if not condition_on_previous_text or result.temperature > 0.5:
|
||||||
|
# do not feed the prompt tokens if a high temperature was used
|
||||||
|
prompt_reset_since = len(all_tokens)
|
||||||
|
|
||||||
if word_timestamps:
|
if word_timestamps:
|
||||||
add_word_timestamps(
|
add_word_timestamps(
|
||||||
segments=current_segments,
|
segments=current_segments,
|
||||||
@ -407,73 +325,16 @@ def transcribe(
|
|||||||
num_frames=segment_size,
|
num_frames=segment_size,
|
||||||
prepend_punctuations=prepend_punctuations,
|
prepend_punctuations=prepend_punctuations,
|
||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
|
||||||
)
|
)
|
||||||
|
word_end_timestamps = [
|
||||||
if not single_timestamp_ending:
|
w["end"] for s in current_segments for w in s["words"]
|
||||||
last_word_end = get_end(current_segments)
|
]
|
||||||
if last_word_end is not None and last_word_end > time_offset:
|
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
seek_shift = round(
|
||||||
|
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||||
# skip silence before possible hallucinations
|
)
|
||||||
if hallucination_silence_threshold is not None:
|
if seek_shift > 0:
|
||||||
threshold = hallucination_silence_threshold
|
seek = previous_seek + seek_shift
|
||||||
if not single_timestamp_ending:
|
|
||||||
last_word_end = get_end(current_segments)
|
|
||||||
if last_word_end is not None and last_word_end > time_offset:
|
|
||||||
remaining_duration = window_end_time - last_word_end
|
|
||||||
if remaining_duration > threshold:
|
|
||||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
|
||||||
else:
|
|
||||||
seek = previous_seek + segment_size
|
|
||||||
|
|
||||||
# if first segment might be a hallucination, skip leading silence
|
|
||||||
first_segment = next_words_segment(current_segments)
|
|
||||||
if first_segment is not None and is_segment_anomaly(first_segment):
|
|
||||||
gap = first_segment["start"] - time_offset
|
|
||||||
if gap > threshold:
|
|
||||||
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# skip silence before any possible hallucination that is surrounded
|
|
||||||
# by silence or more hallucinations
|
|
||||||
hal_last_end = last_speech_timestamp
|
|
||||||
for si in range(len(current_segments)):
|
|
||||||
segment = current_segments[si]
|
|
||||||
if not segment["words"]:
|
|
||||||
continue
|
|
||||||
if is_segment_anomaly(segment):
|
|
||||||
next_segment = next_words_segment(
|
|
||||||
current_segments[si + 1 :]
|
|
||||||
)
|
|
||||||
if next_segment is not None:
|
|
||||||
hal_next_start = next_segment["words"][0]["start"]
|
|
||||||
else:
|
|
||||||
hal_next_start = time_offset + segment_duration
|
|
||||||
silence_before = (
|
|
||||||
segment["start"] - hal_last_end > threshold
|
|
||||||
or segment["start"] < threshold
|
|
||||||
or segment["start"] - time_offset < 2.0
|
|
||||||
)
|
|
||||||
silence_after = (
|
|
||||||
hal_next_start - segment["end"] > threshold
|
|
||||||
or is_segment_anomaly(next_segment)
|
|
||||||
or window_end_time - segment["end"] < 2.0
|
|
||||||
)
|
|
||||||
if silence_before and silence_after:
|
|
||||||
seek = round(
|
|
||||||
max(time_offset + 1, segment["start"])
|
|
||||||
* FRAMES_PER_SECOND
|
|
||||||
)
|
|
||||||
if content_duration - segment["end"] < threshold:
|
|
||||||
seek = content_frames
|
|
||||||
current_segments[si:] = []
|
|
||||||
break
|
|
||||||
hal_last_end = segment["end"]
|
|
||||||
|
|
||||||
last_word_end = get_end(current_segments)
|
|
||||||
if last_word_end is not None:
|
|
||||||
last_speech_timestamp = last_word_end
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
@ -487,23 +348,13 @@ def transcribe(
|
|||||||
segment["text"] = ""
|
segment["text"] = ""
|
||||||
segment["tokens"] = []
|
segment["tokens"] = []
|
||||||
segment["words"] = []
|
segment["words"] = []
|
||||||
|
current_tokens[i] = []
|
||||||
|
|
||||||
all_segments.extend(
|
all_segments.extend(current_segments)
|
||||||
[
|
|
||||||
{"id": i, **segment}
|
|
||||||
for i, segment in enumerate(
|
|
||||||
current_segments, start=len(all_segments)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
all_tokens.extend(
|
all_tokens.extend(
|
||||||
[token for segment in current_segments for token in segment["tokens"]]
|
[token for segment in current_tokens for token in segment]
|
||||||
)
|
)
|
||||||
|
|
||||||
if not condition_on_previous_text or result.temperature > 0.5:
|
|
||||||
# do not feed the prompt tokens if a high temperature was used
|
|
||||||
prompt_reset_since = len(all_tokens)
|
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar.update(min(content_frames, seek) - previous_seek)
|
pbar.update(min(content_frames, seek) - previous_seek)
|
||||||
|
|
||||||
@ -517,17 +368,10 @@ def transcribe(
|
|||||||
def cli():
|
def cli():
|
||||||
from . import available_models
|
from . import available_models
|
||||||
|
|
||||||
def valid_model_name(name):
|
|
||||||
if name in available_models() or os.path.exists(name):
|
|
||||||
return name
|
|
||||||
raise ValueError(
|
|
||||||
f"model should be one of {available_models()} or path to a model checkpoint"
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
@ -545,8 +389,6 @@ def cli():
|
|||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
|
||||||
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
@ -557,13 +399,7 @@ def cli():
|
|||||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
|
||||||
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
|
||||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@ -595,28 +431,9 @@ def cli():
|
|||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
word_options = [
|
|
||||||
"highlight_words",
|
|
||||||
"max_line_count",
|
|
||||||
"max_line_width",
|
|
||||||
"max_words_per_line",
|
|
||||||
]
|
|
||||||
if not args["word_timestamps"]:
|
|
||||||
for option in word_options:
|
|
||||||
if args[option]:
|
|
||||||
parser.error(f"--{option} requires --word_timestamps True")
|
|
||||||
if args["max_line_count"] and not args["max_line_width"]:
|
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
|
||||||
if args["max_words_per_line"] and args["max_line_width"]:
|
|
||||||
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
try:
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
writer(result, audio_path)
|
||||||
writer(result, audio_path, **writer_args)
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
|
|||||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||||
|
|
||||||
kernel = triton.JITFunction(kernel.fn)
|
kernel = triton.JITFunction(kernel.fn)
|
||||||
new_kernel = kernel.src.replace(
|
kernel.src = kernel.src.replace(
|
||||||
" LOAD_ALL_ROWS_HERE",
|
" LOAD_ALL_ROWS_HERE",
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
@ -69,8 +69,7 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
kernel.src = kernel.src.replace(
|
||||||
new_kernel = new_kernel.replace(
|
|
||||||
" BUBBLESORT_HERE",
|
" BUBBLESORT_HERE",
|
||||||
"\n\n".join(
|
"\n\n".join(
|
||||||
[
|
[
|
||||||
@ -91,14 +90,7 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||||
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
|
||||||
|
|
||||||
if hasattr(kernel, "_unsafe_update_src") is True:
|
|
||||||
kernel._unsafe_update_src(new_kernel)
|
|
||||||
kernel.hash = None
|
|
||||||
else:
|
|
||||||
kernel.src = new_kernel
|
|
||||||
|
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
|
189
whisper/utils.py
189
whisper/utils.py
@ -1,9 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Callable, List, Optional, TextIO
|
from typing import Callable, TextIO
|
||||||
|
|
||||||
system_encoding = sys.getdefaultencoding()
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
@ -68,29 +67,13 @@ def format_timestamp(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_start(segments: List[dict]) -> Optional[float]:
|
|
||||||
return next(
|
|
||||||
(w["start"] for s in segments for w in s["words"]),
|
|
||||||
segments[0]["start"] if segments else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_end(segments: List[dict]) -> Optional[float]:
|
|
||||||
return next(
|
|
||||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
|
||||||
segments[-1]["end"] if segments else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ResultWriter:
|
class ResultWriter:
|
||||||
extension: str
|
extension: str
|
||||||
|
|
||||||
def __init__(self, output_dir: str):
|
def __init__(self, output_dir: str):
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, result: dict, audio_path: str):
|
||||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
|
||||||
):
|
|
||||||
audio_basename = os.path.basename(audio_path)
|
audio_basename = os.path.basename(audio_path)
|
||||||
audio_basename = os.path.splitext(audio_basename)[0]
|
audio_basename = os.path.splitext(audio_basename)[0]
|
||||||
output_path = os.path.join(
|
output_path = os.path.join(
|
||||||
@ -98,20 +81,16 @@ class ResultWriter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
self.write_result(result, file=f, options=options, **kwargs)
|
self.write_result(result, file=f)
|
||||||
|
|
||||||
def write_result(
|
def write_result(self, result: dict, file: TextIO):
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
||||||
):
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class WriteTXT(ResultWriter):
|
class WriteTXT(ResultWriter):
|
||||||
extension: str = "txt"
|
extension: str = "txt"
|
||||||
|
|
||||||
def write_result(
|
def write_result(self, result: dict, file: TextIO):
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
||||||
):
|
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(segment["text"].strip(), file=file, flush=True)
|
print(segment["text"].strip(), file=file, flush=True)
|
||||||
|
|
||||||
@ -120,111 +99,33 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
always_include_hours: bool
|
always_include_hours: bool
|
||||||
decimal_marker: str
|
decimal_marker: str
|
||||||
|
|
||||||
def iterate_result(
|
def iterate_result(self, result: dict):
|
||||||
self,
|
for segment in result["segments"]:
|
||||||
result: dict,
|
segment_start = self.format_timestamp(segment["start"])
|
||||||
options: Optional[dict] = None,
|
segment_end = self.format_timestamp(segment["end"])
|
||||||
*,
|
segment_text = segment["text"].strip().replace("-->", "->")
|
||||||
max_line_width: Optional[int] = None,
|
|
||||||
max_line_count: Optional[int] = None,
|
|
||||||
highlight_words: bool = False,
|
|
||||||
max_words_per_line: Optional[int] = None,
|
|
||||||
):
|
|
||||||
options = options or {}
|
|
||||||
max_line_width = max_line_width or options.get("max_line_width")
|
|
||||||
max_line_count = max_line_count or options.get("max_line_count")
|
|
||||||
highlight_words = highlight_words or options.get("highlight_words", False)
|
|
||||||
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
|
||||||
preserve_segments = max_line_count is None or max_line_width is None
|
|
||||||
max_line_width = max_line_width or 1000
|
|
||||||
max_words_per_line = max_words_per_line or 1000
|
|
||||||
|
|
||||||
def iterate_subtitles():
|
if word_timings := segment.get("words", None):
|
||||||
line_len = 0
|
all_words = [timing["word"] for timing in word_timings]
|
||||||
line_count = 1
|
all_words[0] = all_words[0].strip() # remove the leading space, if any
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
last = segment_start
|
||||||
subtitle: List[dict] = []
|
for i, this_word in enumerate(word_timings):
|
||||||
last: float = get_start(result["segments"]) or 0.0
|
start = self.format_timestamp(this_word["start"])
|
||||||
for segment in result["segments"]:
|
end = self.format_timestamp(this_word["end"])
|
||||||
chunk_index = 0
|
if last != start:
|
||||||
words_count = max_words_per_line
|
yield last, start, segment_text
|
||||||
while chunk_index < len(segment["words"]):
|
|
||||||
remaining_words = len(segment["words"]) - chunk_index
|
|
||||||
if max_words_per_line > len(segment["words"]) - chunk_index:
|
|
||||||
words_count = remaining_words
|
|
||||||
for i, original_timing in enumerate(
|
|
||||||
segment["words"][chunk_index : chunk_index + words_count]
|
|
||||||
):
|
|
||||||
timing = original_timing.copy()
|
|
||||||
long_pause = (
|
|
||||||
not preserve_segments and timing["start"] - last > 3.0
|
|
||||||
)
|
|
||||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
|
||||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
|
||||||
if (
|
|
||||||
line_len > 0
|
|
||||||
and has_room
|
|
||||||
and not long_pause
|
|
||||||
and not seg_break
|
|
||||||
):
|
|
||||||
# line continuation
|
|
||||||
line_len += len(timing["word"])
|
|
||||||
else:
|
|
||||||
# new line
|
|
||||||
timing["word"] = timing["word"].strip()
|
|
||||||
if (
|
|
||||||
len(subtitle) > 0
|
|
||||||
and max_line_count is not None
|
|
||||||
and (long_pause or line_count >= max_line_count)
|
|
||||||
or seg_break
|
|
||||||
):
|
|
||||||
# subtitle break
|
|
||||||
yield subtitle
|
|
||||||
subtitle = []
|
|
||||||
line_count = 1
|
|
||||||
elif line_len > 0:
|
|
||||||
# line break
|
|
||||||
line_count += 1
|
|
||||||
timing["word"] = "\n" + timing["word"]
|
|
||||||
line_len = len(timing["word"].strip())
|
|
||||||
subtitle.append(timing)
|
|
||||||
last = timing["start"]
|
|
||||||
chunk_index += max_words_per_line
|
|
||||||
if len(subtitle) > 0:
|
|
||||||
yield subtitle
|
|
||||||
|
|
||||||
if len(result["segments"]) > 0 and "words" in result["segments"][0]:
|
yield start, end, "".join(
|
||||||
for subtitle in iterate_subtitles():
|
[
|
||||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
f"<u>{word}</u>" if j == i else word
|
||||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
for j, word in enumerate(all_words)
|
||||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
]
|
||||||
if highlight_words:
|
)
|
||||||
last = subtitle_start
|
last = end
|
||||||
all_words = [timing["word"] for timing in subtitle]
|
|
||||||
for i, this_word in enumerate(subtitle):
|
|
||||||
start = self.format_timestamp(this_word["start"])
|
|
||||||
end = self.format_timestamp(this_word["end"])
|
|
||||||
if last != start:
|
|
||||||
yield last, start, subtitle_text
|
|
||||||
|
|
||||||
yield start, end, "".join(
|
if last != segment_end:
|
||||||
[
|
yield last, segment_end, segment_text
|
||||||
(
|
else:
|
||||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
|
||||||
if j == i
|
|
||||||
else word
|
|
||||||
)
|
|
||||||
for j, word in enumerate(all_words)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
last = end
|
|
||||||
else:
|
|
||||||
yield subtitle_start, subtitle_end, subtitle_text
|
|
||||||
else:
|
|
||||||
for segment in result["segments"]:
|
|
||||||
segment_start = self.format_timestamp(segment["start"])
|
|
||||||
segment_end = self.format_timestamp(segment["end"])
|
|
||||||
segment_text = segment["text"].strip().replace("-->", "->")
|
|
||||||
yield segment_start, segment_end, segment_text
|
yield segment_start, segment_end, segment_text
|
||||||
|
|
||||||
def format_timestamp(self, seconds: float):
|
def format_timestamp(self, seconds: float):
|
||||||
@ -240,11 +141,9 @@ class WriteVTT(SubtitlesWriter):
|
|||||||
always_include_hours: bool = False
|
always_include_hours: bool = False
|
||||||
decimal_marker: str = "."
|
decimal_marker: str = "."
|
||||||
|
|
||||||
def write_result(
|
def write_result(self, result: dict, file: TextIO):
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
||||||
):
|
|
||||||
print("WEBVTT\n", file=file)
|
print("WEBVTT\n", file=file)
|
||||||
for start, end, text in self.iterate_result(result, options, **kwargs):
|
for start, end, text in self.iterate_result(result):
|
||||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
@ -253,12 +152,8 @@ class WriteSRT(SubtitlesWriter):
|
|||||||
always_include_hours: bool = True
|
always_include_hours: bool = True
|
||||||
decimal_marker: str = ","
|
decimal_marker: str = ","
|
||||||
|
|
||||||
def write_result(
|
def write_result(self, result: dict, file: TextIO):
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
|
||||||
):
|
|
||||||
for i, (start, end, text) in enumerate(
|
|
||||||
self.iterate_result(result, options, **kwargs), start=1
|
|
||||||
):
|
|
||||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
@ -274,9 +169,7 @@ class WriteTSV(ResultWriter):
|
|||||||
|
|
||||||
extension: str = "tsv"
|
extension: str = "tsv"
|
||||||
|
|
||||||
def write_result(
|
def write_result(self, result: dict, file: TextIO):
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
||||||
):
|
|
||||||
print("start", "end", "text", sep="\t", file=file)
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||||
@ -287,15 +180,11 @@ class WriteTSV(ResultWriter):
|
|||||||
class WriteJSON(ResultWriter):
|
class WriteJSON(ResultWriter):
|
||||||
extension: str = "json"
|
extension: str = "json"
|
||||||
|
|
||||||
def write_result(
|
def write_result(self, result: dict, file: TextIO):
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
||||||
):
|
|
||||||
json.dump(result, file)
|
json.dump(result, file)
|
||||||
|
|
||||||
|
|
||||||
def get_writer(
|
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
||||||
output_format: str, output_dir: str
|
|
||||||
) -> Callable[[dict, TextIO, dict], None]:
|
|
||||||
writers = {
|
writers = {
|
||||||
"txt": WriteTXT,
|
"txt": WriteTXT,
|
||||||
"vtt": WriteVTT,
|
"vtt": WriteVTT,
|
||||||
@ -307,11 +196,9 @@ def get_writer(
|
|||||||
if output_format == "all":
|
if output_format == "all":
|
||||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||||
|
|
||||||
def write_all(
|
def write_all(result: dict, file: TextIO):
|
||||||
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
||||||
):
|
|
||||||
for writer in all_writers:
|
for writer in all_writers:
|
||||||
writer(result, file, options, **kwargs)
|
writer(result, file)
|
||||||
|
|
||||||
return write_all
|
return write_all
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "20250625"
|
__version__ = "20230307"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user