mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
2025/10/11
加了一个网络服务部署的功能
This commit is contained in:
parent
c0d2f624c0
commit
14e5568890
1
audio.json
Normal file
1
audio.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"text": "\u4eca\u5929\u4e2d\u5348\u5403\u4ec0\u4e48", "segments": [{"id": 0, "seek": 0, "start": 0.0, "end": 2.0, "text": "\u4eca\u5929\u4e2d\u5348\u5403\u4ec0\u4e48", "tokens": [50364, 12074, 5975, 44237, 10123, 10440, 50464], "temperature": 0.0, "avg_logprob": -0.5544378757476807, "compression_ratio": 0.65625, "no_speech_prob": 0.1877238005399704}], "language": "zh"}
|
||||||
95
v2w_service.py
Normal file
95
v2w_service.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import whisper
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
import threading
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['JSON_AS_ASCII'] = False # 确保中文正常显示
|
||||||
|
|
||||||
|
# 全局变量存储模型
|
||||||
|
model = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
global model
|
||||||
|
print("开始加载 Whisper 模型...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 手动指定模型存储路径
|
||||||
|
model_path = "./models" # 您可以修改为任意路径
|
||||||
|
# 根据实际情况,选择使用CPU还是GPU
|
||||||
|
model = whisper.load_model("medium", download_root=model_path, device="cpu")
|
||||||
|
|
||||||
|
load_time = time.time() - start_time
|
||||||
|
print(f"模型加载完成,耗时: {str(datetime.timedelta(seconds=load_time))}")
|
||||||
|
print(f"模型存储路径: {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# 在应用启动时加载模型
|
||||||
|
@app.before_request
|
||||||
|
def before_first_request():
|
||||||
|
global model
|
||||||
|
if model is None:
|
||||||
|
print("首次请求,加载模型中...")
|
||||||
|
load_model()
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/transcribe', methods=['POST'])
|
||||||
|
def transcribe():
|
||||||
|
if model is None:
|
||||||
|
return jsonify({"error": "模型尚未加载完成"}), 503
|
||||||
|
|
||||||
|
if 'audio' not in request.files:
|
||||||
|
return jsonify({"error": "未提供音频文件"}), 400
|
||||||
|
|
||||||
|
audio_file = request.files['audio']
|
||||||
|
audio_path = f"/{audio_file.filename}"
|
||||||
|
audio_file.save(audio_path)
|
||||||
|
|
||||||
|
# 开始转录 - 使用 FP32 避免 NaN
|
||||||
|
start_time = time.time()
|
||||||
|
result = model.transcribe(audio_path, language="zh", fp16=False) # 主动降低精度,使用 FP32 避免 NaN
|
||||||
|
transcription_time = time.time() - start_time
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"text": result["text"],
|
||||||
|
"processing_time": transcription_time
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/transcribe_text', methods=['POST'])
|
||||||
|
def transcribe_text():
|
||||||
|
"""返回纯文本格式的转录结果,方便命令行查看"""
|
||||||
|
if model is None:
|
||||||
|
return "模型尚未加载完成", 503
|
||||||
|
|
||||||
|
if 'audio' not in request.files:
|
||||||
|
return "未提供音频文件", 400
|
||||||
|
|
||||||
|
audio_file = request.files['audio']
|
||||||
|
audio_path = f"/{audio_file.filename}"
|
||||||
|
audio_file.save(audio_path)
|
||||||
|
|
||||||
|
# 开始转录 - 使用 FP32 避免 NaN
|
||||||
|
start_time = time.time()
|
||||||
|
result = model.transcribe(audio_path, language="zh", fp16=False)
|
||||||
|
transcription_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 返回纯文本格式
|
||||||
|
return f"{result['text']}\r\n处理时间: {transcription_time:.2f}秒"
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
return jsonify({
|
||||||
|
"status": "ok",
|
||||||
|
"model_loaded": model is not None
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 在启动应用前预先加载模型
|
||||||
|
print("启动服务前预先加载模型...")
|
||||||
|
load_model()
|
||||||
|
app.run(host='0.0.0.0', port=5000, threaded=True)
|
||||||
40
voice2word.py
Normal file
40
voice2word.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import whisper
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(seconds):
|
||||||
|
"""将秒数格式化为易读的时间字符串"""
|
||||||
|
return str(datetime.timedelta(seconds=seconds))
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_with_timing():
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
print("开始加载 Whisper 模型...")
|
||||||
|
model_load_start = time.time()
|
||||||
|
model = whisper.load_model("medium") #
|
||||||
|
model_load_time = time.time() - model_load_start
|
||||||
|
print(f"模型加载完成,耗时: {format_time(model_load_time)}")
|
||||||
|
|
||||||
|
print("开始语音识别...")
|
||||||
|
transcription_start = time.time()
|
||||||
|
result = model.transcribe("dingzhen.wav", language="zh")
|
||||||
|
transcription_time = time.time() - transcription_start
|
||||||
|
print(f"语音识别完成,耗时: {format_time(transcription_time)}")
|
||||||
|
|
||||||
|
# 输出结果
|
||||||
|
print("\n识别结果:")
|
||||||
|
print(result["text"])
|
||||||
|
|
||||||
|
# 计算总时间
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print(f"\n总运行时间: {format_time(total_time)}")
|
||||||
|
print(f"详细时间:")
|
||||||
|
print(f"- 模型加载: {format_time(model_load_time)} ({model_load_time / total_time:.1%})")
|
||||||
|
print(f"- 语音识别: {format_time(transcription_time)} ({transcription_time / total_time:.1%})")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
transcribe_with_timing()
|
||||||
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
from subprocess import CalledProcessError, run
|
from subprocess import CalledProcessError, run
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@ -102,7 +104,16 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
|
|||||||
"""
|
"""
|
||||||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||||
|
|
||||||
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
# 使用 pathlib 处理路径,支持开发环境和打包环境
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
# 打包后的 exe 环境
|
||||||
|
exe_dir = Path(sys.executable).parent
|
||||||
|
filters_path = exe_dir / "whisper" / "assets" / "mel_filters.npz"
|
||||||
|
else:
|
||||||
|
# 开发环境
|
||||||
|
filters_path = Path(__file__).parent / "assets" / "mel_filters.npz"
|
||||||
|
|
||||||
|
print(f"filters_path: {filters_path}")
|
||||||
with np.load(filters_path, allow_pickle=False) as f:
|
with np.load(filters_path, allow_pickle=False) as f:
|
||||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import string
|
import string
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import cached_property, lru_cache
|
from functools import cached_property, lru_cache
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@ -329,7 +331,16 @@ class Tokenizer:
|
|||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
# 使用 pathlib 处理路径,支持开发环境和打包环境
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
# 打包后的 exe 环境
|
||||||
|
exe_dir = Path(sys.executable).parent
|
||||||
|
vocab_path = exe_dir / "whisper" / "assets" / f"{name}.tiktoken"
|
||||||
|
else:
|
||||||
|
# 开发环境
|
||||||
|
vocab_path = Path(__file__).parent / "assets" / f"{name}.tiktoken"
|
||||||
|
|
||||||
|
print(f"vocab_path: {vocab_path}")
|
||||||
ranks = {
|
ranks = {
|
||||||
base64.b64decode(token): int(rank)
|
base64.b64decode(token): int(rank)
|
||||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||||
|
|||||||
94
whisper_service_deploy/v2w_service.py
Normal file
94
whisper_service_deploy/v2w_service.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import whisper
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
import threading
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['JSON_AS_ASCII'] = False # 确保中文正常显示
|
||||||
|
|
||||||
|
# 全局变量存储模型
|
||||||
|
model = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
global model
|
||||||
|
print("开始加载 Whisper 模型...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 手动指定模型存储路径
|
||||||
|
model_path = "./models" # 您可以修改为任意路径
|
||||||
|
model = whisper.load_model("medium", download_root=model_path)
|
||||||
|
|
||||||
|
load_time = time.time() - start_time
|
||||||
|
print(f"模型加载完成,耗时: {str(datetime.timedelta(seconds=load_time))}")
|
||||||
|
print(f"模型存储路径: {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# 在应用启动时加载模型
|
||||||
|
@app.before_request
|
||||||
|
def before_first_request():
|
||||||
|
global model
|
||||||
|
if model is None:
|
||||||
|
print("首次请求,加载模型中...")
|
||||||
|
load_model()
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/transcribe', methods=['POST'])
|
||||||
|
def transcribe():
|
||||||
|
if model is None:
|
||||||
|
return jsonify({"error": "模型尚未加载完成"}), 503
|
||||||
|
|
||||||
|
if 'audio' not in request.files:
|
||||||
|
return jsonify({"error": "未提供音频文件"}), 400
|
||||||
|
|
||||||
|
audio_file = request.files['audio']
|
||||||
|
audio_path = f"/{audio_file.filename}"
|
||||||
|
audio_file.save(audio_path)
|
||||||
|
|
||||||
|
# 开始转录
|
||||||
|
start_time = time.time()
|
||||||
|
result = model.transcribe(audio_path, language="zh")
|
||||||
|
transcription_time = time.time() - start_time
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"text": result["text"],
|
||||||
|
"processing_time": transcription_time
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/transcribe_text', methods=['POST'])
|
||||||
|
def transcribe_text():
|
||||||
|
"""返回纯文本格式的转录结果,方便命令行查看"""
|
||||||
|
if model is None:
|
||||||
|
return "模型尚未加载完成", 503
|
||||||
|
|
||||||
|
if 'audio' not in request.files:
|
||||||
|
return "未提供音频文件", 400
|
||||||
|
|
||||||
|
audio_file = request.files['audio']
|
||||||
|
audio_path = f"/{audio_file.filename}"
|
||||||
|
audio_file.save(audio_path)
|
||||||
|
|
||||||
|
# 开始转录
|
||||||
|
start_time = time.time()
|
||||||
|
result = model.transcribe(audio_path, language="zh")
|
||||||
|
transcription_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 返回纯文本格式
|
||||||
|
return f"{result['text']}\r\n处理时间: {transcription_time:.2f}秒"
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
return jsonify({
|
||||||
|
"status": "ok",
|
||||||
|
"model_loaded": model is not None
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 在启动应用前预先加载模型
|
||||||
|
print("启动服务前预先加载模型...")
|
||||||
|
load_model()
|
||||||
|
app.run(host='0.0.0.0', port=5000, threaded=True)
|
||||||
161
whisper_service_deploy/whisper/__init__.py
Normal file
161
whisper_service_deploy/whisper/__init__.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import urllib
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
|
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
|
from .model import ModelDimensions, Whisper
|
||||||
|
from .transcribe import transcribe
|
||||||
|
from .version import __version__
|
||||||
|
|
||||||
|
_MODELS = {
|
||||||
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
|
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||||
|
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||||
|
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||||
|
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||||
|
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||||
|
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||||
|
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||||
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||||
|
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||||
|
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||||
|
_ALIGNMENT_HEADS = {
|
||||||
|
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||||
|
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||||
|
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||||
|
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||||
|
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||||
|
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||||
|
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||||
|
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||||
|
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||||
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||||
|
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
|
|
||||||
|
expected_sha256 = url.split("/")[-2]
|
||||||
|
download_target = os.path.join(root, os.path.basename(url))
|
||||||
|
|
||||||
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||||
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||||
|
|
||||||
|
if os.path.isfile(download_target):
|
||||||
|
with open(download_target, "rb") as f:
|
||||||
|
model_bytes = f.read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||||
|
return model_bytes if in_memory else download_target
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||||
|
)
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
|
with tqdm(
|
||||||
|
total=int(source.info().get("Content-Length")),
|
||||||
|
ncols=80,
|
||||||
|
unit="iB",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as loop:
|
||||||
|
while True:
|
||||||
|
buffer = source.read(8192)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
output.write(buffer)
|
||||||
|
loop.update(len(buffer))
|
||||||
|
|
||||||
|
model_bytes = open(download_target, "rb").read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_bytes if in_memory else download_target
|
||||||
|
|
||||||
|
|
||||||
|
def available_models() -> List[str]:
|
||||||
|
"""Returns the names of available models"""
|
||||||
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
name: str,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
download_root: str = None,
|
||||||
|
in_memory: bool = False,
|
||||||
|
) -> Whisper:
|
||||||
|
"""
|
||||||
|
Load a Whisper ASR model
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
one of the official model names listed by `whisper.available_models()`, or
|
||||||
|
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||||
|
device : Union[str, torch.device]
|
||||||
|
the PyTorch device to put the model into
|
||||||
|
download_root: str
|
||||||
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||||
|
in_memory: bool
|
||||||
|
whether to preload the model weights into host memory
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Whisper
|
||||||
|
The Whisper ASR model instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if download_root is None:
|
||||||
|
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||||
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||||
|
|
||||||
|
if name in _MODELS:
|
||||||
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||||
|
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||||
|
elif os.path.isfile(name):
|
||||||
|
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||||
|
alignment_heads = None
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model {name} not found; available models = {available_models()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
|
) as fp:
|
||||||
|
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
|
||||||
|
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
||||||
|
del checkpoint_file
|
||||||
|
|
||||||
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
|
model = Whisper(dims)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
if alignment_heads is not None:
|
||||||
|
model.set_alignment_heads(alignment_heads)
|
||||||
|
|
||||||
|
return model.to(device)
|
||||||
3
whisper_service_deploy/whisper/__main__.py
Normal file
3
whisper_service_deploy/whisper/__main__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .transcribe import cli
|
||||||
|
|
||||||
|
cli()
|
||||||
157
whisper_service_deploy/whisper/audio.py
Normal file
157
whisper_service_deploy/whisper/audio.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from subprocess import CalledProcessError, run
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .utils import exact_div
|
||||||
|
|
||||||
|
# hard-coded audio hyperparameters
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
N_FFT = 400
|
||||||
|
HOP_LENGTH = 160
|
||||||
|
CHUNK_LENGTH = 30
|
||||||
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
|
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||||
|
|
||||||
|
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||||
|
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||||
|
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||||
|
"""
|
||||||
|
Open an audio file and read as mono waveform, resampling as necessary
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file: str
|
||||||
|
The audio file to open
|
||||||
|
|
||||||
|
sr: int
|
||||||
|
The sample rate to resample the audio if necessary
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A NumPy array containing the audio waveform, in float32 dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This launches a subprocess to decode audio while down-mixing
|
||||||
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||||
|
# fmt: off
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-nostdin",
|
||||||
|
"-threads", "0",
|
||||||
|
"-i", file,
|
||||||
|
"-f", "s16le",
|
||||||
|
"-ac", "1",
|
||||||
|
"-acodec", "pcm_s16le",
|
||||||
|
"-ar", str(sr),
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
try:
|
||||||
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
|
except CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||||
|
"""
|
||||||
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||||
|
"""
|
||||||
|
if torch.is_tensor(array):
|
||||||
|
if array.shape[axis] > length:
|
||||||
|
array = array.index_select(
|
||||||
|
dim=axis, index=torch.arange(length, device=array.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if array.shape[axis] < length:
|
||||||
|
pad_widths = [(0, 0)] * array.ndim
|
||||||
|
pad_widths[axis] = (0, length - array.shape[axis])
|
||||||
|
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
||||||
|
else:
|
||||||
|
if array.shape[axis] > length:
|
||||||
|
array = array.take(indices=range(length), axis=axis)
|
||||||
|
|
||||||
|
if array.shape[axis] < length:
|
||||||
|
pad_widths = [(0, 0)] * array.ndim
|
||||||
|
pad_widths[axis] = (0, length - array.shape[axis])
|
||||||
|
array = np.pad(array, pad_widths)
|
||||||
|
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||||
|
Allows decoupling librosa dependency; saved using:
|
||||||
|
|
||||||
|
np.savez_compressed(
|
||||||
|
"mel_filters.npz",
|
||||||
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||||
|
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||||
|
|
||||||
|
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||||
|
with np.load(filters_path, allow_pickle=False) as f:
|
||||||
|
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def log_mel_spectrogram(
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
n_mels: int = 80,
|
||||||
|
padding: int = 0,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute the log-Mel spectrogram of
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
||||||
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||||
|
|
||||||
|
n_mels: int
|
||||||
|
The number of Mel-frequency filters, only 80 and 128 are supported
|
||||||
|
|
||||||
|
padding: int
|
||||||
|
Number of zero samples to pad to the right
|
||||||
|
|
||||||
|
device: Optional[Union[str, torch.device]]
|
||||||
|
If given, the audio tensor is moved to this device before STFT
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor, shape = (n_mels, n_frames)
|
||||||
|
A Tensor that contains the Mel spectrogram
|
||||||
|
"""
|
||||||
|
if not torch.is_tensor(audio):
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio = load_audio(audio)
|
||||||
|
audio = torch.from_numpy(audio)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
audio = audio.to(device)
|
||||||
|
if padding > 0:
|
||||||
|
audio = F.pad(audio, (0, padding))
|
||||||
|
window = torch.hann_window(N_FFT).to(audio.device)
|
||||||
|
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||||
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
|
||||||
|
filters = mel_filters(audio.device, n_mels)
|
||||||
|
mel_spec = filters @ magnitudes
|
||||||
|
|
||||||
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
return log_spec
|
||||||
826
whisper_service_deploy/whisper/decoding.py
Normal file
826
whisper_service_deploy/whisper/decoding.py
Normal file
@ -0,0 +1,826 @@
|
|||||||
|
from dataclasses import dataclass, field, replace
|
||||||
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
from .audio import CHUNK_LENGTH
|
||||||
|
from .tokenizer import Tokenizer, get_tokenizer
|
||||||
|
from .utils import compression_ratio
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .model import Whisper
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def detect_language(
|
||||||
|
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
||||||
|
) -> Tuple[Tensor, List[dict]]:
|
||||||
|
"""
|
||||||
|
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||||
|
of the most probable language tokens and the probability distribution over all language tokens.
|
||||||
|
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
language_tokens : Tensor, shape = (n_audio,)
|
||||||
|
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||||
|
language_probs : List[Dict[str, float]], length = n_audio
|
||||||
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
|
"""
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual, num_languages=model.num_languages
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
tokenizer.language is None
|
||||||
|
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"This model doesn't have language tokens so it can't perform lang id"
|
||||||
|
)
|
||||||
|
|
||||||
|
single = mel.ndim == 2
|
||||||
|
if single:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
|
# skip encoder forward pass if already-encoded audio features were given
|
||||||
|
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||||
|
mel = model.encoder(mel)
|
||||||
|
|
||||||
|
# forward pass using a single token, startoftranscript
|
||||||
|
n_audio = mel.shape[0]
|
||||||
|
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
|
logits = model.logits(x, mel)[:, 0]
|
||||||
|
|
||||||
|
# collect detected languages; suppress all non-language tokens
|
||||||
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
|
mask[list(tokenizer.all_language_tokens)] = False
|
||||||
|
logits[:, mask] = -np.inf
|
||||||
|
language_tokens = logits.argmax(dim=-1)
|
||||||
|
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||||
|
language_probs = [
|
||||||
|
{
|
||||||
|
c: language_token_probs[i, j].item()
|
||||||
|
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||||
|
}
|
||||||
|
for i in range(n_audio)
|
||||||
|
]
|
||||||
|
|
||||||
|
if single:
|
||||||
|
language_tokens = language_tokens[0]
|
||||||
|
language_probs = language_probs[0]
|
||||||
|
|
||||||
|
return language_tokens, language_probs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DecodingOptions:
|
||||||
|
# whether to perform X->X "transcribe" or X->English "translate"
|
||||||
|
task: str = "transcribe"
|
||||||
|
|
||||||
|
# language that the audio is in; uses detected language if None
|
||||||
|
language: Optional[str] = None
|
||||||
|
|
||||||
|
# sampling-related options
|
||||||
|
temperature: float = 0.0
|
||||||
|
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||||
|
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
||||||
|
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
||||||
|
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
||||||
|
|
||||||
|
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
||||||
|
# to select which to return among the beams or best-of-N samples
|
||||||
|
length_penalty: Optional[float] = None
|
||||||
|
|
||||||
|
# text or tokens to feed as the prompt or the prefix; for more info:
|
||||||
|
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||||
|
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
||||||
|
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
||||||
|
|
||||||
|
# list of tokens ids (or comma-separated token ids) to suppress
|
||||||
|
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||||
|
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||||
|
suppress_blank: bool = True # this will suppress blank outputs
|
||||||
|
|
||||||
|
# timestamp sampling options
|
||||||
|
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||||
|
max_initial_timestamp: Optional[float] = 1.0
|
||||||
|
|
||||||
|
# implementation details
|
||||||
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DecodingResult:
|
||||||
|
audio_features: Tensor
|
||||||
|
language: str
|
||||||
|
language_probs: Optional[Dict[str, float]] = None
|
||||||
|
tokens: List[int] = field(default_factory=list)
|
||||||
|
text: str = ""
|
||||||
|
avg_logprob: float = np.nan
|
||||||
|
no_speech_prob: float = np.nan
|
||||||
|
temperature: float = np.nan
|
||||||
|
compression_ratio: float = np.nan
|
||||||
|
|
||||||
|
|
||||||
|
class Inference:
|
||||||
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
|
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices) -> None:
|
||||||
|
"""Update the key-value cache according to the updated beams"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def cleanup_caching(self) -> None:
|
||||||
|
"""Clean up any resources or hooks after decoding is finished"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchInference(Inference):
|
||||||
|
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||||
|
self.model: "Whisper" = model
|
||||||
|
self.initial_token_length = initial_token_length
|
||||||
|
self.kv_cache = {}
|
||||||
|
self.hooks = []
|
||||||
|
|
||||||
|
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
||||||
|
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
||||||
|
self.kv_modules = key_modules + value_modules
|
||||||
|
|
||||||
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
|
if not self.kv_cache:
|
||||||
|
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||||
|
|
||||||
|
if tokens.shape[-1] > self.initial_token_length:
|
||||||
|
# only need to use the last token except in the first forward pass
|
||||||
|
tokens = tokens[:, -1:]
|
||||||
|
|
||||||
|
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||||
|
|
||||||
|
def cleanup_caching(self):
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
self.kv_cache = {}
|
||||||
|
self.hooks = []
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices):
|
||||||
|
if source_indices != list(range(len(source_indices))):
|
||||||
|
for module in self.kv_modules:
|
||||||
|
# update the key/value cache to contain the selected sequences
|
||||||
|
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceRanker:
|
||||||
|
def rank(
|
||||||
|
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Given a list of groups of samples and their cumulative log probabilities,
|
||||||
|
return the indices of the samples in each group to select as the final result
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MaximumLikelihoodRanker(SequenceRanker):
|
||||||
|
"""
|
||||||
|
Select the sample with the highest log probabilities, penalized using either
|
||||||
|
a simple length normalization or Google NMT paper's length penalty
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, length_penalty: Optional[float]):
|
||||||
|
self.length_penalty = length_penalty
|
||||||
|
|
||||||
|
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
||||||
|
def scores(logprobs, lengths):
|
||||||
|
result = []
|
||||||
|
for logprob, length in zip(logprobs, lengths):
|
||||||
|
if self.length_penalty is None:
|
||||||
|
penalty = length
|
||||||
|
else:
|
||||||
|
# from the Google NMT paper
|
||||||
|
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||||
|
result.append(logprob / penalty)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# get the sequence with the highest score
|
||||||
|
lengths = [[len(t) for t in s] for s in tokens]
|
||||||
|
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenDecoder:
|
||||||
|
def reset(self):
|
||||||
|
"""Initialize any stateful variables for decoding a new sequence"""
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Tensor, bool]:
|
||||||
|
"""Specify how to select the next token, based on the current trace and logits
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||||
|
|
||||||
|
logits : Tensor, shape = (n_batch, vocab_size)
|
||||||
|
per-token logits of the probability distribution at the current step
|
||||||
|
|
||||||
|
sum_logprobs : Tensor, shape = (n_batch)
|
||||||
|
cumulative log probabilities for each sequence
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
||||||
|
the tokens, appended with the selected next token
|
||||||
|
|
||||||
|
completed : bool
|
||||||
|
True if all sequences has reached the end of text
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def finalize(
|
||||||
|
self, tokens: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||||
|
"""Finalize search and return the final candidate sequences
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence
|
||||||
|
|
||||||
|
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
||||||
|
cumulative log probabilities for each sequence
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
||||||
|
sequence of Tensors containing candidate token sequences, for each audio input
|
||||||
|
|
||||||
|
sum_logprobs : List[List[float]], length = n_audio
|
||||||
|
sequence of cumulative log probabilities corresponding to the above
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyDecoder(TokenDecoder):
|
||||||
|
def __init__(self, temperature: float, eot: int):
|
||||||
|
self.temperature = temperature
|
||||||
|
self.eot = eot
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Tensor, bool]:
|
||||||
|
if self.temperature == 0:
|
||||||
|
next_tokens = logits.argmax(dim=-1)
|
||||||
|
else:
|
||||||
|
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||||
|
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||||
|
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||||
|
|
||||||
|
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||||
|
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||||
|
|
||||||
|
completed = (tokens[:, -1] == self.eot).all()
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||||
|
# make sure each sequence has at least one EOT token at the end
|
||||||
|
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||||
|
return tokens, sum_logprobs.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchDecoder(TokenDecoder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
beam_size: int,
|
||||||
|
eot: int,
|
||||||
|
inference: Inference,
|
||||||
|
patience: Optional[float] = None,
|
||||||
|
):
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.eot = eot
|
||||||
|
self.inference = inference
|
||||||
|
self.patience = patience or 1.0
|
||||||
|
self.max_candidates: int = round(beam_size * self.patience)
|
||||||
|
self.finished_sequences = None
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.max_candidates > 0
|
||||||
|
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.finished_sequences = None
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Tensor, bool]:
|
||||||
|
if tokens.shape[0] % self.beam_size != 0:
|
||||||
|
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||||
|
|
||||||
|
n_audio = tokens.shape[0] // self.beam_size
|
||||||
|
if self.finished_sequences is None: # for the first update
|
||||||
|
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||||
|
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
next_tokens, source_indices, finished_sequences = [], [], []
|
||||||
|
for i in range(n_audio):
|
||||||
|
scores, sources, finished = {}, {}, {}
|
||||||
|
|
||||||
|
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
||||||
|
for j in range(self.beam_size):
|
||||||
|
idx = i * self.beam_size + j
|
||||||
|
prefix = tokens[idx].tolist()
|
||||||
|
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||||
|
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||||
|
sequence = tuple(prefix + [token.item()])
|
||||||
|
scores[sequence] = new_logprob
|
||||||
|
sources[sequence] = idx
|
||||||
|
|
||||||
|
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
||||||
|
saved = 0
|
||||||
|
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||||
|
if sequence[-1] == self.eot:
|
||||||
|
finished[sequence] = scores[sequence]
|
||||||
|
else:
|
||||||
|
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||||
|
next_tokens.append(sequence)
|
||||||
|
source_indices.append(sources[sequence])
|
||||||
|
|
||||||
|
saved += 1
|
||||||
|
if saved == self.beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
finished_sequences.append(finished)
|
||||||
|
|
||||||
|
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||||
|
self.inference.rearrange_kv_cache(source_indices)
|
||||||
|
|
||||||
|
# add newly finished sequences to self.finished_sequences
|
||||||
|
assert len(self.finished_sequences) == len(finished_sequences)
|
||||||
|
for previously_finished, newly_finished in zip(
|
||||||
|
self.finished_sequences, finished_sequences
|
||||||
|
):
|
||||||
|
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||||
|
if len(previously_finished) >= self.max_candidates:
|
||||||
|
break # the candidate list is full
|
||||||
|
previously_finished[seq] = newly_finished[seq]
|
||||||
|
|
||||||
|
# mark as completed if all audio has enough number of samples
|
||||||
|
completed = all(
|
||||||
|
len(sequences) >= self.max_candidates
|
||||||
|
for sequences in self.finished_sequences
|
||||||
|
)
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||||
|
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||||
|
sum_logprobs = sum_logprobs.cpu()
|
||||||
|
for i, sequences in enumerate(self.finished_sequences):
|
||||||
|
if (
|
||||||
|
len(sequences) < self.beam_size
|
||||||
|
): # when not enough sequences are finished
|
||||||
|
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||||
|
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||||
|
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||||
|
if len(sequences) >= self.beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
tokens: List[List[Tensor]] = [
|
||||||
|
[torch.tensor(seq) for seq in sequences.keys()]
|
||||||
|
for sequences in self.finished_sequences
|
||||||
|
]
|
||||||
|
sum_logprobs: List[List[float]] = [
|
||||||
|
list(sequences.values()) for sequences in self.finished_sequences
|
||||||
|
]
|
||||||
|
return tokens, sum_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
class LogitFilter:
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
||||||
|
"""Apply any filtering or masking to logits in-place
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
logits : Tensor, shape = (n_batch, vocab_size)
|
||||||
|
per-token logits of the probability distribution at the current step
|
||||||
|
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressBlank(LogitFilter):
|
||||||
|
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.sample_begin = sample_begin
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
if tokens.shape[1] == self.sample_begin:
|
||||||
|
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressTokens(LogitFilter):
|
||||||
|
def __init__(self, suppress_tokens: Sequence[int]):
|
||||||
|
self.suppress_tokens = list(suppress_tokens)
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
logits[:, self.suppress_tokens] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyTimestampRules(LogitFilter):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
sample_begin: int,
|
||||||
|
max_initial_timestamp_index: Optional[int],
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.sample_begin = sample_begin
|
||||||
|
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||||
|
if self.tokenizer.no_timestamps is not None:
|
||||||
|
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
||||||
|
|
||||||
|
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||||
|
for k in range(tokens.shape[0]):
|
||||||
|
sampled_tokens = tokens[k, self.sample_begin :]
|
||||||
|
seq = [t for t in sampled_tokens.tolist()]
|
||||||
|
last_was_timestamp = (
|
||||||
|
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
penultimate_was_timestamp = (
|
||||||
|
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
|
||||||
|
if last_was_timestamp:
|
||||||
|
if penultimate_was_timestamp: # has to be non-timestamp
|
||||||
|
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||||
|
else: # cannot be normal text tokens
|
||||||
|
logits[k, : self.tokenizer.eot] = -np.inf
|
||||||
|
|
||||||
|
timestamps = sampled_tokens[
|
||||||
|
sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
||||||
|
]
|
||||||
|
if timestamps.numel() > 0:
|
||||||
|
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||||
|
# also force each segment to have a nonzero length, to prevent infinite looping
|
||||||
|
if last_was_timestamp and not penultimate_was_timestamp:
|
||||||
|
timestamp_last = timestamps[-1]
|
||||||
|
else:
|
||||||
|
timestamp_last = timestamps[-1] + 1
|
||||||
|
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
||||||
|
|
||||||
|
if tokens.shape[1] == self.sample_begin:
|
||||||
|
# suppress generating non-timestamp tokens at the beginning
|
||||||
|
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||||
|
|
||||||
|
# apply the `max_initial_timestamp` option
|
||||||
|
if self.max_initial_timestamp_index is not None:
|
||||||
|
last_allowed = (
|
||||||
|
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||||
|
)
|
||||||
|
logits[:, last_allowed + 1 :] = -np.inf
|
||||||
|
|
||||||
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
for k in range(tokens.shape[0]):
|
||||||
|
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||||
|
if timestamp_logprob > max_text_token_logprob:
|
||||||
|
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class DecodingTask:
|
||||||
|
inference: Inference
|
||||||
|
sequence_ranker: SequenceRanker
|
||||||
|
decoder: TokenDecoder
|
||||||
|
logit_filters: List[LogitFilter]
|
||||||
|
|
||||||
|
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
language = options.language or "en"
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=options.task,
|
||||||
|
)
|
||||||
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
|
|
||||||
|
self.n_group: int = options.beam_size or options.best_of or 1
|
||||||
|
self.n_ctx: int = model.dims.n_text_ctx
|
||||||
|
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||||
|
|
||||||
|
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||||
|
if self.options.without_timestamps:
|
||||||
|
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||||
|
|
||||||
|
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||||
|
self.sample_begin: int = len(self.initial_tokens)
|
||||||
|
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||||
|
|
||||||
|
# inference: implements the forward pass through the decoder, including kv caching
|
||||||
|
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
||||||
|
|
||||||
|
# sequence ranker: implements how to rank a group of sampled sequences
|
||||||
|
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||||
|
|
||||||
|
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||||
|
if options.beam_size is not None:
|
||||||
|
self.decoder = BeamSearchDecoder(
|
||||||
|
options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||||
|
|
||||||
|
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||||
|
self.logit_filters = []
|
||||||
|
if self.options.suppress_blank:
|
||||||
|
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
||||||
|
if self.options.suppress_tokens:
|
||||||
|
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
||||||
|
if not options.without_timestamps:
|
||||||
|
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||||
|
max_initial_timestamp_index = None
|
||||||
|
if options.max_initial_timestamp:
|
||||||
|
max_initial_timestamp_index = round(
|
||||||
|
self.options.max_initial_timestamp / precision
|
||||||
|
)
|
||||||
|
self.logit_filters.append(
|
||||||
|
ApplyTimestampRules(
|
||||||
|
tokenizer, self.sample_begin, max_initial_timestamp_index
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||||
|
if options.beam_size is not None and options.best_of is not None:
|
||||||
|
raise ValueError("beam_size and best_of can't be given together")
|
||||||
|
if options.temperature == 0:
|
||||||
|
if options.best_of is not None:
|
||||||
|
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||||
|
if options.patience is not None and options.beam_size is None:
|
||||||
|
raise ValueError("patience requires beam_size to be given")
|
||||||
|
if options.length_penalty is not None and not (
|
||||||
|
0 <= options.length_penalty <= 1
|
||||||
|
):
|
||||||
|
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
def _get_initial_tokens(self) -> Tuple[int]:
|
||||||
|
tokens = list(self.sot_sequence)
|
||||||
|
|
||||||
|
if prefix := self.options.prefix:
|
||||||
|
prefix_tokens = (
|
||||||
|
self.tokenizer.encode(" " + prefix.strip())
|
||||||
|
if isinstance(prefix, str)
|
||||||
|
else prefix
|
||||||
|
)
|
||||||
|
if self.sample_len is not None:
|
||||||
|
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||||
|
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||||
|
tokens = tokens + prefix_tokens
|
||||||
|
|
||||||
|
if prompt := self.options.prompt:
|
||||||
|
prompt_tokens = (
|
||||||
|
self.tokenizer.encode(" " + prompt.strip())
|
||||||
|
if isinstance(prompt, str)
|
||||||
|
else prompt
|
||||||
|
)
|
||||||
|
tokens = (
|
||||||
|
[self.tokenizer.sot_prev]
|
||||||
|
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||||
|
+ tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuple(tokens)
|
||||||
|
|
||||||
|
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||||
|
suppress_tokens = self.options.suppress_tokens
|
||||||
|
|
||||||
|
if isinstance(suppress_tokens, str):
|
||||||
|
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||||
|
|
||||||
|
if -1 in suppress_tokens:
|
||||||
|
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||||
|
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||||
|
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||||
|
suppress_tokens = [] # interpret empty string as an empty list
|
||||||
|
else:
|
||||||
|
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||||
|
|
||||||
|
suppress_tokens.extend(
|
||||||
|
[
|
||||||
|
self.tokenizer.transcribe,
|
||||||
|
self.tokenizer.translate,
|
||||||
|
self.tokenizer.sot,
|
||||||
|
self.tokenizer.sot_prev,
|
||||||
|
self.tokenizer.sot_lm,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.tokenizer.no_speech is not None:
|
||||||
|
# no-speech probability is collected separately
|
||||||
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
|
|
||||||
|
return tuple(sorted(set(suppress_tokens)))
|
||||||
|
|
||||||
|
def _get_audio_features(self, mel: Tensor):
|
||||||
|
if self.options.fp16:
|
||||||
|
mel = mel.half()
|
||||||
|
|
||||||
|
if mel.shape[-2:] == (
|
||||||
|
self.model.dims.n_audio_ctx,
|
||||||
|
self.model.dims.n_audio_state,
|
||||||
|
):
|
||||||
|
# encoded audio features are given; skip audio encoding
|
||||||
|
audio_features = mel
|
||||||
|
else:
|
||||||
|
audio_features = self.model.encoder(mel)
|
||||||
|
|
||||||
|
if audio_features.dtype != (
|
||||||
|
torch.float16 if self.options.fp16 else torch.float32
|
||||||
|
):
|
||||||
|
return TypeError(
|
||||||
|
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_features
|
||||||
|
|
||||||
|
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
||||||
|
languages = [self.options.language] * audio_features.shape[0]
|
||||||
|
lang_probs = None
|
||||||
|
|
||||||
|
if self.options.language is None or self.options.task == "lang_id":
|
||||||
|
lang_tokens, lang_probs = self.model.detect_language(
|
||||||
|
audio_features, self.tokenizer
|
||||||
|
)
|
||||||
|
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||||
|
if self.options.language is None:
|
||||||
|
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||||
|
|
||||||
|
return languages, lang_probs
|
||||||
|
|
||||||
|
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||||
|
n_batch = tokens.shape[0]
|
||||||
|
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||||
|
no_speech_probs = [np.nan] * n_batch
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i in range(self.sample_len):
|
||||||
|
logits = self.inference.logits(tokens, audio_features)
|
||||||
|
|
||||||
|
if (
|
||||||
|
i == 0 and self.tokenizer.no_speech is not None
|
||||||
|
): # save no_speech_probs
|
||||||
|
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||||
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||||
|
|
||||||
|
# now we need to consider the logits at the last token only
|
||||||
|
logits = logits[:, -1]
|
||||||
|
|
||||||
|
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||||
|
for logit_filter in self.logit_filters:
|
||||||
|
logit_filter.apply(logits, tokens)
|
||||||
|
|
||||||
|
# expand the tokens tensor with the selected next tokens
|
||||||
|
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||||
|
|
||||||
|
if completed or tokens.shape[-1] > self.n_ctx:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.inference.cleanup_caching()
|
||||||
|
|
||||||
|
return tokens, sum_logprobs, no_speech_probs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||||
|
self.decoder.reset()
|
||||||
|
tokenizer: Tokenizer = self.tokenizer
|
||||||
|
n_audio: int = mel.shape[0]
|
||||||
|
|
||||||
|
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||||
|
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||||
|
|
||||||
|
# detect language if requested, overwriting the language token
|
||||||
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||||
|
if self.options.task == "lang_id":
|
||||||
|
return [
|
||||||
|
DecodingResult(
|
||||||
|
audio_features=features, language=language, language_probs=probs
|
||||||
|
)
|
||||||
|
for features, language, probs in zip(
|
||||||
|
audio_features, languages, language_probs
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
||||||
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
|
# call the main sampling loop
|
||||||
|
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||||
|
|
||||||
|
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||||
|
audio_features = audio_features[:: self.n_group]
|
||||||
|
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||||
|
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||||
|
|
||||||
|
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||||
|
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||||
|
|
||||||
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||||
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||||
|
tokens: List[List[Tensor]] = [
|
||||||
|
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||||
|
for s in tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
# select the top-ranked sample in each group
|
||||||
|
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||||
|
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||||
|
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||||
|
|
||||||
|
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||||
|
avg_logprobs: List[float] = [
|
||||||
|
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
||||||
|
]
|
||||||
|
|
||||||
|
fields = (
|
||||||
|
texts,
|
||||||
|
languages,
|
||||||
|
tokens,
|
||||||
|
audio_features,
|
||||||
|
avg_logprobs,
|
||||||
|
no_speech_probs,
|
||||||
|
)
|
||||||
|
if len(set(map(len, fields))) != 1:
|
||||||
|
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
DecodingResult(
|
||||||
|
audio_features=features,
|
||||||
|
language=language,
|
||||||
|
tokens=tokens,
|
||||||
|
text=text,
|
||||||
|
avg_logprob=avg_logprob,
|
||||||
|
no_speech_prob=no_speech_prob,
|
||||||
|
temperature=self.options.temperature,
|
||||||
|
compression_ratio=compression_ratio(text),
|
||||||
|
)
|
||||||
|
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
||||||
|
*fields
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(
|
||||||
|
model: "Whisper",
|
||||||
|
mel: Tensor,
|
||||||
|
options: DecodingOptions = DecodingOptions(),
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||||
|
"""
|
||||||
|
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model: Whisper
|
||||||
|
the Whisper model instance
|
||||||
|
|
||||||
|
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
||||||
|
A tensor containing the Mel spectrogram(s)
|
||||||
|
|
||||||
|
options: DecodingOptions
|
||||||
|
A dataclass that contains all necessary options for decoding 30-second segments
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
result: Union[DecodingResult, List[DecodingResult]]
|
||||||
|
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||||
|
"""
|
||||||
|
if single := mel.ndim == 2:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
options = replace(options, **kwargs)
|
||||||
|
|
||||||
|
result = DecodingTask(model, options).run(mel)
|
||||||
|
|
||||||
|
return result[0] if single else result
|
||||||
345
whisper_service_deploy/whisper/model.py
Normal file
345
whisper_service_deploy/whisper/model.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
import base64
|
||||||
|
import gzip
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .decoding import decode as decode_function
|
||||||
|
from .decoding import detect_language as detect_language_function
|
||||||
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
|
SDPA_AVAILABLE = True
|
||||||
|
except (ImportError, RuntimeError, OSError):
|
||||||
|
scaled_dot_product_attention = None
|
||||||
|
SDPA_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelDimensions:
|
||||||
|
n_mels: int
|
||||||
|
n_audio_ctx: int
|
||||||
|
n_audio_state: int
|
||||||
|
n_audio_head: int
|
||||||
|
n_audio_layer: int
|
||||||
|
n_vocab: int
|
||||||
|
n_text_ctx: int
|
||||||
|
n_text_state: int
|
||||||
|
n_text_head: int
|
||||||
|
n_text_layer: int
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.linear(
|
||||||
|
x,
|
||||||
|
self.weight.to(x.dtype),
|
||||||
|
None if self.bias is None else self.bias.to(x.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(nn.Conv1d):
|
||||||
|
def _conv_forward(
|
||||||
|
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||||
|
) -> Tensor:
|
||||||
|
return super()._conv_forward(
|
||||||
|
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoids(length, channels, max_timescale=10000):
|
||||||
|
"""Returns sinusoids for positional embedding"""
|
||||||
|
assert channels % 2 == 0
|
||||||
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||||
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||||
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||||
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_sdpa():
|
||||||
|
prev_state = MultiHeadAttention.use_sdpa
|
||||||
|
try:
|
||||||
|
MultiHeadAttention.use_sdpa = False
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
MultiHeadAttention.use_sdpa = prev_state
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
use_sdpa = True
|
||||||
|
|
||||||
|
def __init__(self, n_state: int, n_head: int):
|
||||||
|
super().__init__()
|
||||||
|
self.n_head = n_head
|
||||||
|
self.query = Linear(n_state, n_state)
|
||||||
|
self.key = Linear(n_state, n_state, bias=False)
|
||||||
|
self.value = Linear(n_state, n_state)
|
||||||
|
self.out = Linear(n_state, n_state)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
xa: Optional[Tensor] = None,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
q = self.query(x)
|
||||||
|
|
||||||
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||||
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||||
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||||
|
k = self.key(x if xa is None else xa)
|
||||||
|
v = self.value(x if xa is None else xa)
|
||||||
|
else:
|
||||||
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||||
|
k = kv_cache[self.key]
|
||||||
|
v = kv_cache[self.value]
|
||||||
|
|
||||||
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
|
return self.out(wv), qk
|
||||||
|
|
||||||
|
def qkv_attention(
|
||||||
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
n_batch, n_ctx, n_state = q.shape
|
||||||
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
|
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
||||||
|
a = scaled_dot_product_attention(
|
||||||
|
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||||
|
)
|
||||||
|
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
qk = None
|
||||||
|
else:
|
||||||
|
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||||
|
if mask is not None:
|
||||||
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
|
qk = qk.float()
|
||||||
|
|
||||||
|
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||||
|
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
qk = qk.detach()
|
||||||
|
|
||||||
|
return out, qk
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = MultiHeadAttention(n_state, n_head)
|
||||||
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
self.cross_attn = (
|
||||||
|
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||||
|
)
|
||||||
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
|
n_mlp = n_state * 4
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||||
|
)
|
||||||
|
self.mlp_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
xa: Optional[Tensor] = None,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||||
|
if self.cross_attn:
|
||||||
|
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||||
|
x = x + self.mlp(self.mlp_ln(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
|
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||||
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||||
|
|
||||||
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
|
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||||
|
)
|
||||||
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
|
the mel spectrogram of the audio
|
||||||
|
"""
|
||||||
|
x = F.gelu(self.conv1(x))
|
||||||
|
x = F.gelu(self.conv2(x))
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
|
x = (x + self.positional_embedding).to(x.dtype)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.ln_post(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TextDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||||
|
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||||
|
|
||||||
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||||
|
for _ in range(n_layer)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||||
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||||
|
the text tokens
|
||||||
|
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||||
|
the encoded audio features to be attended on
|
||||||
|
"""
|
||||||
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
|
x = (
|
||||||
|
self.token_embedding(x)
|
||||||
|
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
|
)
|
||||||
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
|
|
||||||
|
x = self.ln(x)
|
||||||
|
logits = (
|
||||||
|
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||||
|
).float()
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class Whisper(nn.Module):
|
||||||
|
def __init__(self, dims: ModelDimensions):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
self.encoder = AudioEncoder(
|
||||||
|
self.dims.n_mels,
|
||||||
|
self.dims.n_audio_ctx,
|
||||||
|
self.dims.n_audio_state,
|
||||||
|
self.dims.n_audio_head,
|
||||||
|
self.dims.n_audio_layer,
|
||||||
|
)
|
||||||
|
self.decoder = TextDecoder(
|
||||||
|
self.dims.n_vocab,
|
||||||
|
self.dims.n_text_ctx,
|
||||||
|
self.dims.n_text_state,
|
||||||
|
self.dims.n_text_head,
|
||||||
|
self.dims.n_text_layer,
|
||||||
|
)
|
||||||
|
# use the last half among the decoder layers for time alignment by default;
|
||||||
|
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||||
|
all_heads = torch.zeros(
|
||||||
|
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||||
|
)
|
||||||
|
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||||
|
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||||
|
|
||||||
|
def set_alignment_heads(self, dump: bytes):
|
||||||
|
array = np.frombuffer(
|
||||||
|
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||||
|
).copy()
|
||||||
|
mask = torch.from_numpy(array).reshape(
|
||||||
|
self.dims.n_text_layer, self.dims.n_text_head
|
||||||
|
)
|
||||||
|
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||||
|
|
||||||
|
def embed_audio(self, mel: torch.Tensor):
|
||||||
|
return self.encoder(mel)
|
||||||
|
|
||||||
|
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||||
|
return self.decoder(tokens, audio_features)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_multilingual(self):
|
||||||
|
return self.dims.n_vocab >= 51865
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_languages(self):
|
||||||
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||||
|
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||||
|
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||||
|
intermediate tensors to be reused during later calculations.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
cache : Dict[nn.Module, torch.Tensor]
|
||||||
|
A dictionary object mapping the key/value projection modules to its cache
|
||||||
|
hooks : List[RemovableHandle]
|
||||||
|
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||||
|
"""
|
||||||
|
cache = {**cache} if cache is not None else {}
|
||||||
|
hooks = []
|
||||||
|
|
||||||
|
def save_to_cache(module, _, output):
|
||||||
|
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||||
|
# save as-is, for the first token or cross attention
|
||||||
|
cache[module] = output
|
||||||
|
else:
|
||||||
|
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||||
|
return cache[module]
|
||||||
|
|
||||||
|
def install_hooks(layer: nn.Module):
|
||||||
|
if isinstance(layer, MultiHeadAttention):
|
||||||
|
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||||
|
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||||
|
|
||||||
|
self.decoder.apply(install_hooks)
|
||||||
|
return cache, hooks
|
||||||
|
|
||||||
|
detect_language = detect_language_function
|
||||||
|
transcribe = transcribe_function
|
||||||
|
decode = decode_function
|
||||||
2
whisper_service_deploy/whisper/normalizers/__init__.py
Normal file
2
whisper_service_deploy/whisper/normalizers/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
||||||
|
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
||||||
80
whisper_service_deploy/whisper/normalizers/basic.py
Normal file
80
whisper_service_deploy/whisper/normalizers/basic.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
import regex
|
||||||
|
|
||||||
|
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||||
|
ADDITIONAL_DIACRITICS = {
|
||||||
|
"œ": "oe",
|
||||||
|
"Œ": "OE",
|
||||||
|
"ø": "o",
|
||||||
|
"Ø": "O",
|
||||||
|
"æ": "ae",
|
||||||
|
"Æ": "AE",
|
||||||
|
"ß": "ss",
|
||||||
|
"ẞ": "SS",
|
||||||
|
"đ": "d",
|
||||||
|
"Đ": "D",
|
||||||
|
"ð": "d",
|
||||||
|
"Ð": "D",
|
||||||
|
"þ": "th",
|
||||||
|
"Þ": "th",
|
||||||
|
"ł": "l",
|
||||||
|
"Ł": "L",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||||
|
"""
|
||||||
|
Replace any other markers, symbols, and punctuations with a space,
|
||||||
|
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||||
|
"""
|
||||||
|
return "".join(
|
||||||
|
(
|
||||||
|
c
|
||||||
|
if c in keep
|
||||||
|
else (
|
||||||
|
ADDITIONAL_DIACRITICS[c]
|
||||||
|
if c in ADDITIONAL_DIACRITICS
|
||||||
|
else (
|
||||||
|
""
|
||||||
|
if unicodedata.category(c) == "Mn"
|
||||||
|
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for c in unicodedata.normalize("NFKD", s)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_symbols(s: str):
|
||||||
|
"""
|
||||||
|
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||||
|
"""
|
||||||
|
return "".join(
|
||||||
|
" " if unicodedata.category(c)[0] in "MSP" else c
|
||||||
|
for c in unicodedata.normalize("NFKC", s)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTextNormalizer:
|
||||||
|
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||||
|
self.clean = (
|
||||||
|
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||||
|
)
|
||||||
|
self.split_letters = split_letters
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = s.lower()
|
||||||
|
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||||
|
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||||
|
s = self.clean(s).lower()
|
||||||
|
|
||||||
|
if self.split_letters:
|
||||||
|
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||||
|
|
||||||
|
s = re.sub(
|
||||||
|
r"\s+", " ", s
|
||||||
|
) # replace any successive whitespace characters with a space
|
||||||
|
|
||||||
|
return s
|
||||||
1741
whisper_service_deploy/whisper/normalizers/english.json
Normal file
1741
whisper_service_deploy/whisper/normalizers/english.json
Normal file
File diff suppressed because it is too large
Load Diff
550
whisper_service_deploy/whisper/normalizers/english.py
Normal file
550
whisper_service_deploy/whisper/normalizers/english.py
Normal file
@ -0,0 +1,550 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Iterator, List, Match, Optional, Union
|
||||||
|
|
||||||
|
from more_itertools import windowed
|
||||||
|
|
||||||
|
from .basic import remove_symbols_and_diacritics
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishNumberNormalizer:
|
||||||
|
"""
|
||||||
|
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||||
|
|
||||||
|
- remove any commas
|
||||||
|
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||||
|
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||||
|
- spell out `one` and `ones`
|
||||||
|
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.zeros = {"o", "oh", "zero"}
|
||||||
|
self.ones = {
|
||||||
|
name: i
|
||||||
|
for i, name in enumerate(
|
||||||
|
[
|
||||||
|
"one",
|
||||||
|
"two",
|
||||||
|
"three",
|
||||||
|
"four",
|
||||||
|
"five",
|
||||||
|
"six",
|
||||||
|
"seven",
|
||||||
|
"eight",
|
||||||
|
"nine",
|
||||||
|
"ten",
|
||||||
|
"eleven",
|
||||||
|
"twelve",
|
||||||
|
"thirteen",
|
||||||
|
"fourteen",
|
||||||
|
"fifteen",
|
||||||
|
"sixteen",
|
||||||
|
"seventeen",
|
||||||
|
"eighteen",
|
||||||
|
"nineteen",
|
||||||
|
],
|
||||||
|
start=1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.ones_plural = {
|
||||||
|
"sixes" if name == "six" else name + "s": (value, "s")
|
||||||
|
for name, value in self.ones.items()
|
||||||
|
}
|
||||||
|
self.ones_ordinal = {
|
||||||
|
"zeroth": (0, "th"),
|
||||||
|
"first": (1, "st"),
|
||||||
|
"second": (2, "nd"),
|
||||||
|
"third": (3, "rd"),
|
||||||
|
"fifth": (5, "th"),
|
||||||
|
"twelfth": (12, "th"),
|
||||||
|
**{
|
||||||
|
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||||
|
for name, value in self.ones.items()
|
||||||
|
if value > 3 and value != 5 and value != 12
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||||
|
|
||||||
|
self.tens = {
|
||||||
|
"twenty": 20,
|
||||||
|
"thirty": 30,
|
||||||
|
"forty": 40,
|
||||||
|
"fifty": 50,
|
||||||
|
"sixty": 60,
|
||||||
|
"seventy": 70,
|
||||||
|
"eighty": 80,
|
||||||
|
"ninety": 90,
|
||||||
|
}
|
||||||
|
self.tens_plural = {
|
||||||
|
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||||
|
}
|
||||||
|
self.tens_ordinal = {
|
||||||
|
name.replace("y", "ieth"): (value, "th")
|
||||||
|
for name, value in self.tens.items()
|
||||||
|
}
|
||||||
|
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||||
|
|
||||||
|
self.multipliers = {
|
||||||
|
"hundred": 100,
|
||||||
|
"thousand": 1_000,
|
||||||
|
"million": 1_000_000,
|
||||||
|
"billion": 1_000_000_000,
|
||||||
|
"trillion": 1_000_000_000_000,
|
||||||
|
"quadrillion": 1_000_000_000_000_000,
|
||||||
|
"quintillion": 1_000_000_000_000_000_000,
|
||||||
|
"sextillion": 1_000_000_000_000_000_000_000,
|
||||||
|
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||||
|
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||||
|
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||||
|
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||||
|
}
|
||||||
|
self.multipliers_plural = {
|
||||||
|
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||||
|
}
|
||||||
|
self.multipliers_ordinal = {
|
||||||
|
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||||
|
}
|
||||||
|
self.multipliers_suffixed = {
|
||||||
|
**self.multipliers_plural,
|
||||||
|
**self.multipliers_ordinal,
|
||||||
|
}
|
||||||
|
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||||
|
|
||||||
|
self.preceding_prefixers = {
|
||||||
|
"minus": "-",
|
||||||
|
"negative": "-",
|
||||||
|
"plus": "+",
|
||||||
|
"positive": "+",
|
||||||
|
}
|
||||||
|
self.following_prefixers = {
|
||||||
|
"pound": "£",
|
||||||
|
"pounds": "£",
|
||||||
|
"euro": "€",
|
||||||
|
"euros": "€",
|
||||||
|
"dollar": "$",
|
||||||
|
"dollars": "$",
|
||||||
|
"cent": "¢",
|
||||||
|
"cents": "¢",
|
||||||
|
}
|
||||||
|
self.prefixes = set(
|
||||||
|
list(self.preceding_prefixers.values())
|
||||||
|
+ list(self.following_prefixers.values())
|
||||||
|
)
|
||||||
|
self.suffixers = {
|
||||||
|
"per": {"cent": "%"},
|
||||||
|
"percent": "%",
|
||||||
|
}
|
||||||
|
self.specials = {"and", "double", "triple", "point"}
|
||||||
|
|
||||||
|
self.words = set(
|
||||||
|
[
|
||||||
|
key
|
||||||
|
for mapping in [
|
||||||
|
self.zeros,
|
||||||
|
self.ones,
|
||||||
|
self.ones_suffixed,
|
||||||
|
self.tens,
|
||||||
|
self.tens_suffixed,
|
||||||
|
self.multipliers,
|
||||||
|
self.multipliers_suffixed,
|
||||||
|
self.preceding_prefixers,
|
||||||
|
self.following_prefixers,
|
||||||
|
self.suffixers,
|
||||||
|
self.specials,
|
||||||
|
]
|
||||||
|
for key in mapping
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.literal_words = {"one", "ones"}
|
||||||
|
|
||||||
|
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||||
|
prefix: Optional[str] = None
|
||||||
|
value: Optional[Union[str, int]] = None
|
||||||
|
skip = False
|
||||||
|
|
||||||
|
def to_fraction(s: str):
|
||||||
|
try:
|
||||||
|
return Fraction(s)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def output(result: Union[str, int]):
|
||||||
|
nonlocal prefix, value
|
||||||
|
result = str(result)
|
||||||
|
if prefix is not None:
|
||||||
|
result = prefix + result
|
||||||
|
value = None
|
||||||
|
prefix = None
|
||||||
|
return result
|
||||||
|
|
||||||
|
if len(words) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for prev, current, next in windowed([None] + words + [None], 3):
|
||||||
|
if skip:
|
||||||
|
skip = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||||
|
has_prefix = current[0] in self.prefixes
|
||||||
|
current_without_prefix = current[1:] if has_prefix else current
|
||||||
|
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||||
|
# arabic numbers (potentially with signs and fractions)
|
||||||
|
f = to_fraction(current_without_prefix)
|
||||||
|
assert f is not None
|
||||||
|
if value is not None:
|
||||||
|
if isinstance(value, str) and value.endswith("."):
|
||||||
|
# concatenate decimals / ip address components
|
||||||
|
value = str(value) + str(current)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
prefix = current[0] if has_prefix else prefix
|
||||||
|
if f.denominator == 1:
|
||||||
|
value = f.numerator # store integers as int
|
||||||
|
else:
|
||||||
|
value = current_without_prefix
|
||||||
|
elif current not in self.words:
|
||||||
|
# non-numeric words
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.zeros:
|
||||||
|
value = str(value or "") + "0"
|
||||||
|
elif current in self.ones:
|
||||||
|
ones = self.ones[current]
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
value = ones
|
||||||
|
elif isinstance(value, str) or prev in self.ones:
|
||||||
|
if (
|
||||||
|
prev in self.tens and ones < 10
|
||||||
|
): # replace the last zero with the digit
|
||||||
|
assert value[-1] == "0"
|
||||||
|
value = value[:-1] + str(ones)
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
elif ones < 10:
|
||||||
|
if value % 10 == 0:
|
||||||
|
value += ones
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
else: # eleven to nineteen
|
||||||
|
if value % 100 == 0:
|
||||||
|
value += ones
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
elif current in self.ones_suffixed:
|
||||||
|
# ordinal or cardinal; yield the number right away
|
||||||
|
ones, suffix = self.ones_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(ones) + suffix)
|
||||||
|
elif isinstance(value, str) or prev in self.ones:
|
||||||
|
if prev in self.tens and ones < 10:
|
||||||
|
assert value[-1] == "0"
|
||||||
|
yield output(value[:-1] + str(ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
elif ones < 10:
|
||||||
|
if value % 10 == 0:
|
||||||
|
yield output(str(value + ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
else: # eleven to nineteen
|
||||||
|
if value % 100 == 0:
|
||||||
|
yield output(str(value + ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
value = None
|
||||||
|
elif current in self.tens:
|
||||||
|
tens = self.tens[current]
|
||||||
|
if value is None:
|
||||||
|
value = tens
|
||||||
|
elif isinstance(value, str):
|
||||||
|
value = str(value) + str(tens)
|
||||||
|
else:
|
||||||
|
if value % 100 == 0:
|
||||||
|
value += tens
|
||||||
|
else:
|
||||||
|
value = str(value) + str(tens)
|
||||||
|
elif current in self.tens_suffixed:
|
||||||
|
# ordinal or cardinal; yield the number right away
|
||||||
|
tens, suffix = self.tens_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(tens) + suffix)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
yield output(str(value) + str(tens) + suffix)
|
||||||
|
else:
|
||||||
|
if value % 100 == 0:
|
||||||
|
yield output(str(value + tens) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(tens) + suffix)
|
||||||
|
elif current in self.multipliers:
|
||||||
|
multiplier = self.multipliers[current]
|
||||||
|
if value is None:
|
||||||
|
value = multiplier
|
||||||
|
elif isinstance(value, str) or value == 0:
|
||||||
|
f = to_fraction(value)
|
||||||
|
p = f * multiplier if f is not None else None
|
||||||
|
if f is not None and p.denominator == 1:
|
||||||
|
value = p.numerator
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
value = multiplier
|
||||||
|
else:
|
||||||
|
before = value // 1000 * 1000
|
||||||
|
residual = value % 1000
|
||||||
|
value = before + residual * multiplier
|
||||||
|
elif current in self.multipliers_suffixed:
|
||||||
|
multiplier, suffix = self.multipliers_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(multiplier) + suffix)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
f = to_fraction(value)
|
||||||
|
p = f * multiplier if f is not None else None
|
||||||
|
if f is not None and p.denominator == 1:
|
||||||
|
yield output(str(p.numerator) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
yield output(str(multiplier) + suffix)
|
||||||
|
else: # int
|
||||||
|
before = value // 1000 * 1000
|
||||||
|
residual = value % 1000
|
||||||
|
value = before + residual * multiplier
|
||||||
|
yield output(str(value) + suffix)
|
||||||
|
value = None
|
||||||
|
elif current in self.preceding_prefixers:
|
||||||
|
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
if next in self.words or next_is_numeric:
|
||||||
|
prefix = self.preceding_prefixers[current]
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.following_prefixers:
|
||||||
|
# apply prefix (dollars, cents, etc.) only after a number
|
||||||
|
if value is not None:
|
||||||
|
prefix = self.following_prefixers[current]
|
||||||
|
yield output(value)
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.suffixers:
|
||||||
|
# apply suffix symbols (percent -> '%')
|
||||||
|
if value is not None:
|
||||||
|
suffix = self.suffixers[current]
|
||||||
|
if isinstance(suffix, dict):
|
||||||
|
if next in suffix:
|
||||||
|
yield output(str(value) + suffix[next])
|
||||||
|
skip = True
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.specials:
|
||||||
|
if next not in self.words and not next_is_numeric:
|
||||||
|
# apply special handling only if the next word can be numeric
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "and":
|
||||||
|
# ignore "and" after hundreds, thousands, etc.
|
||||||
|
if prev not in self.multipliers:
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "double" or current == "triple":
|
||||||
|
if next in self.ones or next in self.zeros:
|
||||||
|
repeats = 2 if current == "double" else 3
|
||||||
|
ones = self.ones.get(next, 0)
|
||||||
|
value = str(value or "") + str(ones) * repeats
|
||||||
|
skip = True
|
||||||
|
else:
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "point":
|
||||||
|
if next in self.decimals or next_is_numeric:
|
||||||
|
value = str(value or "") + "."
|
||||||
|
else:
|
||||||
|
# should all have been covered at this point
|
||||||
|
raise ValueError(f"Unexpected token: {current}")
|
||||||
|
else:
|
||||||
|
# all should have been covered at this point
|
||||||
|
raise ValueError(f"Unexpected token: {current}")
|
||||||
|
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
def preprocess(self, s: str):
|
||||||
|
# replace "<number> and a half" with "<number> point five"
|
||||||
|
results = []
|
||||||
|
|
||||||
|
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
if len(segment.strip()) == 0:
|
||||||
|
continue
|
||||||
|
if i == len(segments) - 1:
|
||||||
|
results.append(segment)
|
||||||
|
else:
|
||||||
|
results.append(segment)
|
||||||
|
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||||
|
if last_word in self.decimals or last_word in self.multipliers:
|
||||||
|
results.append("point five")
|
||||||
|
else:
|
||||||
|
results.append("and a half")
|
||||||
|
|
||||||
|
s = " ".join(results)
|
||||||
|
|
||||||
|
# put a space at number/letter boundary
|
||||||
|
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||||
|
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||||
|
|
||||||
|
# but remove spaces which could be a suffix
|
||||||
|
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def postprocess(self, s: str):
|
||||||
|
def combine_cents(m: Match):
|
||||||
|
try:
|
||||||
|
currency = m.group(1)
|
||||||
|
integer = m.group(2)
|
||||||
|
cents = int(m.group(3))
|
||||||
|
return f"{currency}{integer}.{cents:02d}"
|
||||||
|
except ValueError:
|
||||||
|
return m.string
|
||||||
|
|
||||||
|
def extract_cents(m: Match):
|
||||||
|
try:
|
||||||
|
return f"¢{int(m.group(1))}"
|
||||||
|
except ValueError:
|
||||||
|
return m.string
|
||||||
|
|
||||||
|
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||||
|
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||||
|
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||||
|
|
||||||
|
# write "one(s)" instead of "1(s)", just for the readability
|
||||||
|
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = self.preprocess(s)
|
||||||
|
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||||
|
s = self.postprocess(s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishSpellingNormalizer:
|
||||||
|
"""
|
||||||
|
Applies British-American spelling mappings as listed in [1].
|
||||||
|
|
||||||
|
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||||
|
self.mapping = json.load(open(mapping_path))
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishTextNormalizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||||
|
self.replacers = {
|
||||||
|
# common contractions
|
||||||
|
r"\bwon't\b": "will not",
|
||||||
|
r"\bcan't\b": "can not",
|
||||||
|
r"\blet's\b": "let us",
|
||||||
|
r"\bain't\b": "aint",
|
||||||
|
r"\by'all\b": "you all",
|
||||||
|
r"\bwanna\b": "want to",
|
||||||
|
r"\bgotta\b": "got to",
|
||||||
|
r"\bgonna\b": "going to",
|
||||||
|
r"\bi'ma\b": "i am going to",
|
||||||
|
r"\bimma\b": "i am going to",
|
||||||
|
r"\bwoulda\b": "would have",
|
||||||
|
r"\bcoulda\b": "could have",
|
||||||
|
r"\bshoulda\b": "should have",
|
||||||
|
r"\bma'am\b": "madam",
|
||||||
|
# contractions in titles/prefixes
|
||||||
|
r"\bmr\b": "mister ",
|
||||||
|
r"\bmrs\b": "missus ",
|
||||||
|
r"\bst\b": "saint ",
|
||||||
|
r"\bdr\b": "doctor ",
|
||||||
|
r"\bprof\b": "professor ",
|
||||||
|
r"\bcapt\b": "captain ",
|
||||||
|
r"\bgov\b": "governor ",
|
||||||
|
r"\bald\b": "alderman ",
|
||||||
|
r"\bgen\b": "general ",
|
||||||
|
r"\bsen\b": "senator ",
|
||||||
|
r"\brep\b": "representative ",
|
||||||
|
r"\bpres\b": "president ",
|
||||||
|
r"\brev\b": "reverend ",
|
||||||
|
r"\bhon\b": "honorable ",
|
||||||
|
r"\basst\b": "assistant ",
|
||||||
|
r"\bassoc\b": "associate ",
|
||||||
|
r"\blt\b": "lieutenant ",
|
||||||
|
r"\bcol\b": "colonel ",
|
||||||
|
r"\bjr\b": "junior ",
|
||||||
|
r"\bsr\b": "senior ",
|
||||||
|
r"\besq\b": "esquire ",
|
||||||
|
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||||
|
r"'d been\b": " had been",
|
||||||
|
r"'s been\b": " has been",
|
||||||
|
r"'d gone\b": " had gone",
|
||||||
|
r"'s gone\b": " has gone",
|
||||||
|
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||||
|
r"'s got\b": " has got",
|
||||||
|
# general contractions
|
||||||
|
r"n't\b": " not",
|
||||||
|
r"'re\b": " are",
|
||||||
|
r"'s\b": " is",
|
||||||
|
r"'d\b": " would",
|
||||||
|
r"'ll\b": " will",
|
||||||
|
r"'t\b": " not",
|
||||||
|
r"'ve\b": " have",
|
||||||
|
r"'m\b": " am",
|
||||||
|
}
|
||||||
|
self.standardize_numbers = EnglishNumberNormalizer()
|
||||||
|
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = s.lower()
|
||||||
|
|
||||||
|
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||||
|
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||||
|
s = re.sub(self.ignore_patterns, "", s)
|
||||||
|
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
||||||
|
|
||||||
|
for pattern, replacement in self.replacers.items():
|
||||||
|
s = re.sub(pattern, replacement, s)
|
||||||
|
|
||||||
|
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||||
|
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||||
|
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
||||||
|
|
||||||
|
s = self.standardize_numbers(s)
|
||||||
|
s = self.standardize_spellings(s)
|
||||||
|
|
||||||
|
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||||
|
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||||
|
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||||
|
|
||||||
|
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
||||||
|
|
||||||
|
return s
|
||||||
1
whisper_service_deploy/whisper/version.py
Normal file
1
whisper_service_deploy/whisper/version.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
__version__ = "20250625"
|
||||||
Loading…
x
Reference in New Issue
Block a user