mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Merge branch 'main' into transcribe-argument
This commit is contained in:
commit
888dd61cba
31
.github/workflows/test.yml
vendored
31
.github/workflows/test.yml
vendored
@ -41,16 +41,35 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
include:
|
||||||
pytorch-version: [1.13.1, 2.0.0]
|
- python-version: '3.8'
|
||||||
exclude:
|
pytorch-version: 1.10.1
|
||||||
- python-version: '3.11'
|
numpy-requirement: "'numpy<2'"
|
||||||
|
- python-version: '3.8'
|
||||||
pytorch-version: 1.13.1
|
pytorch-version: 1.13.1
|
||||||
|
numpy-requirement: "'numpy<2'"
|
||||||
|
- python-version: '3.8'
|
||||||
|
pytorch-version: 2.0.1
|
||||||
|
numpy-requirement: "'numpy<2'"
|
||||||
|
- python-version: '3.9'
|
||||||
|
pytorch-version: 2.1.2
|
||||||
|
numpy-requirement: "'numpy<2'"
|
||||||
|
- python-version: '3.10'
|
||||||
|
pytorch-version: 2.2.2
|
||||||
|
numpy-requirement: "'numpy<2'"
|
||||||
|
- python-version: '3.11'
|
||||||
|
pytorch-version: 2.3.1
|
||||||
|
numpy-requirement: "'numpy'"
|
||||||
|
- python-version: '3.12'
|
||||||
|
pytorch-version: 2.4.1
|
||||||
|
numpy-requirement: "'numpy'"
|
||||||
|
- python-version: '3.12'
|
||||||
|
pytorch-version: 2.5.0
|
||||||
|
numpy-requirement: "'numpy'"
|
||||||
steps:
|
steps:
|
||||||
- uses: conda-incubator/setup-miniconda@v2
|
- uses: conda-incubator/setup-miniconda@v2
|
||||||
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
|
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
|
||||||
- run: pip3 install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||||
- run: pip install .["dev"]
|
- run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
|
||||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
||||||
|
|||||||
14
CHANGELOG.md
14
CHANGELOG.md
@ -1,5 +1,19 @@
|
|||||||
# CHANGELOG
|
# CHANGELOG
|
||||||
|
|
||||||
|
## [v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
|
||||||
|
|
||||||
|
* allowing numpy 2 in tests ([#2362](https://github.com/openai/whisper/pull/2362))
|
||||||
|
* large-v3-turbo model ([#2361](https://github.com/openai/whisper/pull/2361))
|
||||||
|
* test on python/pytorch versions up to 3.12 and 2.4.1 ([#2360](https://github.com/openai/whisper/pull/2360))
|
||||||
|
* using sdpa if available ([#2359](https://github.com/openai/whisper/pull/2359))
|
||||||
|
|
||||||
|
## [v20240927](https://github.com/openai/whisper/releases/tag/v20240927)
|
||||||
|
|
||||||
|
* pinning numpy<2 in tests ([#2332](https://github.com/openai/whisper/pull/2332))
|
||||||
|
* Relax triton requirements for compatibility with pytorch 2.4 and newer ([#2307](https://github.com/openai/whisper/pull/2307))
|
||||||
|
* Skip silence around hallucinations ([#1838](https://github.com/openai/whisper/pull/1838))
|
||||||
|
* Fix triton env marker ([#1887](https://github.com/openai/whisper/pull/1887))
|
||||||
|
|
||||||
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
|
## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
|
||||||
|
|
||||||
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))
|
* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802))
|
||||||
|
|||||||
24
README.md
24
README.md
@ -57,17 +57,21 @@ pip install setuptools-rust
|
|||||||
|
|
||||||
## Available models and languages
|
## Available models and languages
|
||||||
|
|
||||||
There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model; actual speed may vary depending on many factors including the available hardware.
|
There are six model sizes, four with English-only versions, offering speed and accuracy tradeoffs.
|
||||||
|
Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model.
|
||||||
|
The relative speeds below are measured by transcribing English speech on a A100, and the real-world speed may vary significantly depending on many factors including the language, the speaking speed, and the available hardware.
|
||||||
|
|
||||||
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
||||||
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
|
|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
|
||||||
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x |
|
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~10x |
|
||||||
| base | 74 M | `base.en` | `base` | ~1 GB | ~16x |
|
| base | 74 M | `base.en` | `base` | ~1 GB | ~7x |
|
||||||
| small | 244 M | `small.en` | `small` | ~2 GB | ~6x |
|
| small | 244 M | `small.en` | `small` | ~2 GB | ~4x |
|
||||||
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
||||||
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
||||||
|
| turbo | 809 M | N/A | `turbo` | ~6 GB | ~8x |
|
||||||
|
|
||||||
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||||
|
Additionally, the `turbo` model is an optimized version of `large-v3` that offers faster transcription speed with a minimal degradation in accuracy.
|
||||||
|
|
||||||
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
|
Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
|
||||||
|
|
||||||
@ -77,11 +81,11 @@ Whisper's performance varies widely depending on the language. The figure below
|
|||||||
|
|
||||||
## Command-line usage
|
## Command-line usage
|
||||||
|
|
||||||
The following command will transcribe speech in audio files, using the `medium` model:
|
The following command will transcribe speech in audio files, using the `turbo` model:
|
||||||
|
|
||||||
whisper audio.flac audio.mp3 audio.wav --model medium
|
whisper audio.flac audio.mp3 audio.wav --model turbo
|
||||||
|
|
||||||
The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
|
The default setting (which selects the `turbo` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
|
||||||
|
|
||||||
whisper japanese.wav --language Japanese
|
whisper japanese.wav --language Japanese
|
||||||
|
|
||||||
@ -103,7 +107,7 @@ Transcription can also be performed within Python:
|
|||||||
```python
|
```python
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
model = whisper.load_model("base")
|
model = whisper.load_model("turbo")
|
||||||
result = model.transcribe("audio.mp3")
|
result = model.transcribe("audio.mp3")
|
||||||
print(result["text"])
|
print(result["text"])
|
||||||
```
|
```
|
||||||
@ -115,14 +119,14 @@ Below is an example usage of `whisper.detect_language()` and `whisper.decode()`
|
|||||||
```python
|
```python
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
model = whisper.load_model("base")
|
model = whisper.load_model("turbo")
|
||||||
|
|
||||||
# load audio and pad/trim it to fit 30 seconds
|
# load audio and pad/trim it to fit 30 seconds
|
||||||
audio = whisper.load_audio("audio.mp3")
|
audio = whisper.load_audio("audio.mp3")
|
||||||
audio = whisper.pad_or_trim(audio)
|
audio = whisper.pad_or_trim(audio)
|
||||||
|
|
||||||
# make log-Mel spectrogram and move to the same device as the model
|
# make log-Mel spectrogram and move to the same device as the model
|
||||||
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
|
||||||
|
|
||||||
# detect the spoken language
|
# detect the spoken language
|
||||||
_, probs = model.detect_language(mel)
|
_, probs = model.detect_language(mel)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challen
|
|||||||
|
|
||||||
### AMI-IHM, AMI-SDM1
|
### AMI-IHM, AMI-SDM1
|
||||||
|
|
||||||
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
|
We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 and 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
|
||||||
|
|
||||||
|
|
||||||
## Long-form English-only datasets
|
## Long-form English-only datasets
|
||||||
|
|||||||
@ -16,13 +16,15 @@ The Whisper models are trained for speech recognition and translation tasks, cap
|
|||||||
| small | 244 M | ✓ | ✓ |
|
| small | 244 M | ✓ | ✓ |
|
||||||
| medium | 769 M | ✓ | ✓ |
|
| medium | 769 M | ✓ | ✓ |
|
||||||
| large | 1550 M | | ✓ |
|
| large | 1550 M | | ✓ |
|
||||||
|
| turbo | 798 M | | ✓ |
|
||||||
|
|
||||||
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
|
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
|
||||||
|
Additionally, we've added a `turbo` model in September 2024 which is optimized for inference speed.
|
||||||
|
|
||||||
|
|
||||||
### Release date
|
### Release date
|
||||||
|
|
||||||
September 2022 (original series), December 2022 (`large-v2`), and November 2023 (`large-v3`)
|
September 2022 (original series), December 2022 (`large-v2`), November 2023 (`large-v3`), September 2024 (`large-v3-turbo`)
|
||||||
|
|
||||||
### Model type
|
### Model type
|
||||||
|
|
||||||
|
|||||||
@ -4,4 +4,4 @@ torch
|
|||||||
tqdm
|
tqdm
|
||||||
more-itertools
|
more-itertools
|
||||||
tiktoken
|
tiktoken
|
||||||
triton>=2.0.0,<3;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
||||||
|
|||||||
2
setup.py
2
setup.py
@ -13,7 +13,7 @@ def read_version(fname="whisper/version.py"):
|
|||||||
|
|
||||||
requirements = []
|
requirements = []
|
||||||
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
||||||
requirements.append("triton>=2.0.0,<3")
|
requirements.append("triton>=2.0.0")
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="openai-whisper",
|
name="openai-whisper",
|
||||||
|
|||||||
@ -27,6 +27,8 @@ _MODELS = {
|
|||||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||||
|
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||||
@ -44,6 +46,8 @@ _ALIGNMENT_HEADS = {
|
|||||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||||
|
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -122,7 +122,7 @@ def log_mel_spectrogram(
|
|||||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||||
|
|
||||||
n_mels: int
|
n_mels: int
|
||||||
The number of Mel-frequency filters, only 80 is supported
|
The number of Mel-frequency filters, only 80 and 128 are supported
|
||||||
|
|
||||||
padding: int
|
padding: int
|
||||||
Number of zero samples to pad to the right
|
Number of zero samples to pad to the right
|
||||||
@ -132,7 +132,7 @@ def log_mel_spectrogram(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor, shape = (80, n_frames)
|
torch.Tensor, shape = (n_mels, n_frames)
|
||||||
A Tensor that contains the Mel spectrogram
|
A Tensor that contains the Mel spectrogram
|
||||||
"""
|
"""
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import gzip
|
import gzip
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, Optional
|
from typing import Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -12,6 +13,14 @@ from .decoding import decode as decode_function
|
|||||||
from .decoding import detect_language as detect_language_function
|
from .decoding import detect_language as detect_language_function
|
||||||
from .transcribe import transcribe as transcribe_function
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
|
SDPA_AVAILABLE = True
|
||||||
|
except (ImportError, RuntimeError, OSError):
|
||||||
|
scaled_dot_product_attention = None
|
||||||
|
SDPA_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelDimensions:
|
class ModelDimensions:
|
||||||
@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_sdpa():
|
||||||
|
prev_state = MultiHeadAttention.use_sdpa
|
||||||
|
try:
|
||||||
|
MultiHeadAttention.use_sdpa = False
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
MultiHeadAttention.use_sdpa = prev_state
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
|
use_sdpa = True
|
||||||
|
|
||||||
def __init__(self, n_state: int, n_head: int):
|
def __init__(self, n_state: int, n_head: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
@ -92,20 +113,30 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
):
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
n_batch, n_ctx, n_state = q.shape
|
n_batch, n_ctx, n_state = q.shape
|
||||||
scale = (n_state // self.n_head) ** -0.25
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
qk = q @ k
|
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
||||||
|
a = scaled_dot_product_attention(
|
||||||
|
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||||
|
)
|
||||||
|
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
qk = None
|
||||||
|
else:
|
||||||
|
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
qk = qk + mask[:n_ctx, :n_ctx]
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
qk = qk.float()
|
qk = qk.float()
|
||||||
|
|
||||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
qk = qk.detach()
|
||||||
|
|
||||||
|
return out, qk
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
|||||||
@ -191,7 +191,9 @@ def find_alignment(
|
|||||||
for i, block in enumerate(model.decoder.blocks)
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
]
|
]
|
||||||
|
|
||||||
with torch.no_grad():
|
from .model import disable_sdpa
|
||||||
|
|
||||||
|
with torch.no_grad(), disable_sdpa():
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
|
|||||||
@ -46,6 +46,7 @@ def transcribe(
|
|||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
|
carry_initial_prompt: bool = False,
|
||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
@ -102,6 +103,11 @@ def transcribe(
|
|||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||||
to make it more likely to predict those word correctly.
|
to make it more likely to predict those word correctly.
|
||||||
|
|
||||||
|
carry_initial_prompt: bool
|
||||||
|
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
||||||
|
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
||||||
|
left-sliced to make space.
|
||||||
|
|
||||||
decode_options: dict
|
decode_options: dict
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
@ -208,6 +214,8 @@ def transcribe(
|
|||||||
if (
|
if (
|
||||||
no_speech_threshold is not None
|
no_speech_threshold is not None
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
and decode_result.no_speech_prob > no_speech_threshold
|
||||||
|
and logprob_threshold is not None
|
||||||
|
and decode_result.avg_logprob < logprob_threshold
|
||||||
):
|
):
|
||||||
needs_fallback = False # silence
|
needs_fallback = False # silence
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
@ -227,9 +235,11 @@ def transcribe(
|
|||||||
all_segments = []
|
all_segments = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
|
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
||||||
if initial_prompt is not None:
|
if initial_prompt is not None:
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
|
remaining_prompt_length -= len(initial_prompt_tokens)
|
||||||
else:
|
else:
|
||||||
initial_prompt_tokens = []
|
initial_prompt_tokens = []
|
||||||
|
|
||||||
@ -275,7 +285,13 @@ def transcribe(
|
|||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
|
if carry_initial_prompt:
|
||||||
|
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
||||||
|
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
||||||
|
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
||||||
|
else:
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
@ -518,7 +534,7 @@ def cli():
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
|
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default=default_device, help="device to use for PyTorch inference")
|
parser.add_argument("--device", default=default_device, help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
@ -536,6 +552,8 @@ def cli():
|
|||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
|
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
||||||
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
__version__ = "20231117"
|
__version__ = "20240930"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user