Merge branch 'main' into transcribe-argument

This commit is contained in:
Nripesh Niketan 2024-05-29 03:26:44 +04:00 committed by GitHub
commit c2d30994c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 157 additions and 23 deletions

View File

@ -4,3 +4,4 @@ torch
tqdm tqdm
more-itertools more-itertools
tiktoken tiktoken
triton>=2.0.0,<3;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"

View File

@ -1,6 +1,6 @@
import os
import platform import platform
import sys import sys
from pathlib import Path
import pkg_resources import pkg_resources
from setuptools import find_packages, setup from setuptools import find_packages, setup
@ -28,11 +28,10 @@ setup(
url="https://github.com/openai/whisper", url="https://github.com/openai/whisper",
license="MIT", license="MIT",
packages=find_packages(exclude=["tests*"]), packages=find_packages(exclude=["tests*"]),
install_requires=requirements install_requires=[
+ [
str(r) str(r)
for r in pkg_resources.parse_requirements( for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) Path(__file__).with_name("requirements.txt").open()
) )
], ],
entry_points={ entry_points={

View File

@ -299,6 +299,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment]) word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()] word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2 max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries. # hack: truncate long words at sentence boundaries.

View File

@ -2,7 +2,7 @@ import argparse
import os import os
import traceback import traceback
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -23,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import ( from .utils import (
exact_div, exact_div,
format_timestamp, format_timestamp,
get_end,
get_writer, get_writer,
make_safe, make_safe,
optional_float, optional_float,
@ -48,6 +49,8 @@ def transcribe(
word_timestamps: bool = False, word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options, **decode_options,
): ):
""" """
@ -102,6 +105,14 @@ def transcribe(
decode_options: dict decode_options: dict
Keyword arguments to construct `DecodingOptions` instances Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -121,6 +132,7 @@ def transcribe(
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None: if decode_options.get("language", None) is None:
if not model.is_multilingual: if not model.is_multilingual:
@ -147,6 +159,19 @@ def transcribe(
task=task, task=task,
) )
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate": if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.") warnings.warn("Word-level timestamps on translations may not be reliable.")
@ -190,7 +215,8 @@ def transcribe(
return decode_result return decode_result
seek = 0 clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div( input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2 ) # mel frames per output token: 2
@ -229,10 +255,23 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False total=content_frames, unit="frames", disable=verbose is not False
) as pbar: ) as pbar:
last_speech_timestamp = 0.0 last_speech_timestamp = 0.0
while seek < content_frames: # NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek : seek + N_FRAMES] window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek) segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
@ -257,6 +296,30 @@ def transcribe(
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@ -330,17 +393,71 @@ def transcribe(
append_punctuations=append_punctuations, append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp, last_speech_timestamp=last_speech_timestamp,
) )
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"] if not single_timestamp_ending:
] last_word_end = get_end(current_segments)
if len(word_end_timestamps) > 0: if last_word_end is not None and last_word_end > time_offset:
last_speech_timestamp = word_end_timestamps[-1] seek = round(last_word_end * FRAMES_PER_SECOND)
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round( # skip silence before possible hallucinations
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
) )
if seek_shift > 0: if next_segment is not None:
seek = previous_seek + seek_shift hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose: if verbose:
for segment in current_segments: for segment in current_segments:
@ -434,6 +551,8 @@ def cli():
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment") parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__

View File

@ -3,7 +3,7 @@ import os
import re import re
import sys import sys
import zlib import zlib
from typing import Callable, Optional, TextIO from typing import Callable, List, Optional, TextIO
system_encoding = sys.getdefaultencoding() system_encoding = sys.getdefaultencoding()
@ -68,6 +68,20 @@ def format_timestamp(
) )
def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ResultWriter: class ResultWriter:
extension: str extension: str
@ -129,8 +143,8 @@ class SubtitlesWriter(ResultWriter):
line_len = 0 line_len = 0
line_count = 1 line_count = 1
# the next subtitle to yield (a list of word timings with whitespace) # the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = [] subtitle: List[dict] = []
last = result["segments"][0]["words"][0]["start"] last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]: for segment in result["segments"]:
chunk_index = 0 chunk_index = 0
words_count = max_words_per_line words_count = max_words_per_line