mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Skip silence around hallucinations (#1838)
* Add clip_timestamps option * Add hallucination_silence_threshold option * Fix typing for python < 3.9 --------- Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
parent
8bc8860694
commit
ba3f3cd54b
@ -299,6 +299,7 @@ def add_word_timestamps(
|
|||||||
word_durations = np.array([t.end - t.start for t in alignment])
|
word_durations = np.array([t.end - t.start for t in alignment])
|
||||||
word_durations = word_durations[word_durations.nonzero()]
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||||
|
median_duration = min(0.7, float(median_duration))
|
||||||
max_duration = median_duration * 2
|
max_duration = median_duration * 2
|
||||||
|
|
||||||
# hack: truncate long words at sentence boundaries.
|
# hack: truncate long words at sentence boundaries.
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -23,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
exact_div,
|
exact_div,
|
||||||
format_timestamp,
|
format_timestamp,
|
||||||
|
get_end,
|
||||||
get_writer,
|
get_writer,
|
||||||
make_safe,
|
make_safe,
|
||||||
optional_float,
|
optional_float,
|
||||||
@ -48,6 +49,8 @@ def transcribe(
|
|||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -102,6 +105,14 @@ def transcribe(
|
|||||||
decode_options: dict
|
decode_options: dict
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
|
clip_timestamps: Union[str, List[float]]
|
||||||
|
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||||
|
The last end timestamp defaults to the end of the file.
|
||||||
|
|
||||||
|
hallucination_silence_threshold: Optional[float]
|
||||||
|
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||||
|
when a possible hallucination is detected
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
@ -121,6 +132,7 @@ def transcribe(
|
|||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
if not model.is_multilingual:
|
if not model.is_multilingual:
|
||||||
@ -147,6 +159,19 @@ def transcribe(
|
|||||||
task=task,
|
task=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(clip_timestamps, str):
|
||||||
|
clip_timestamps = [
|
||||||
|
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||||
|
]
|
||||||
|
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||||
|
if len(seek_points) == 0:
|
||||||
|
seek_points.append(0)
|
||||||
|
if len(seek_points) % 2 == 1:
|
||||||
|
seek_points.append(content_frames)
|
||||||
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
|
||||||
if word_timestamps and task == "translate":
|
if word_timestamps and task == "translate":
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
|
|
||||||
@ -190,7 +215,8 @@ def transcribe(
|
|||||||
|
|
||||||
return decode_result
|
return decode_result
|
||||||
|
|
||||||
seek = 0
|
clip_idx = 0
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
input_stride = exact_div(
|
input_stride = exact_div(
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
N_FRAMES, model.dims.n_audio_ctx
|
||||||
) # mel frames per output token: 2
|
) # mel frames per output token: 2
|
||||||
@ -229,10 +255,23 @@ def transcribe(
|
|||||||
total=content_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
last_speech_timestamp = 0.0
|
last_speech_timestamp = 0.0
|
||||||
while seek < content_frames:
|
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||||
|
# A later commit should turn this into a simpler nested loop.
|
||||||
|
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||||
|
# while seek < seek_clip_end
|
||||||
|
while clip_idx < len(seek_clips):
|
||||||
|
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||||
|
if seek < seek_clip_start:
|
||||||
|
seek = seek_clip_start
|
||||||
|
if seek >= seek_clip_end:
|
||||||
|
clip_idx += 1
|
||||||
|
if clip_idx < len(seek_clips):
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
|
continue
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||||
segment_size = min(N_FRAMES, content_frames - seek)
|
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||||
|
mel_segment = mel[:, seek : seek + segment_size]
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
@ -257,6 +296,30 @@ def transcribe(
|
|||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
|
|
||||||
|
# anomalous words are very long/short/improbable
|
||||||
|
def word_anomaly_score(word: dict) -> float:
|
||||||
|
probability = word.get("probability", 0.0)
|
||||||
|
duration = word["end"] - word["start"]
|
||||||
|
score = 0.0
|
||||||
|
if probability < 0.15:
|
||||||
|
score += 1.0
|
||||||
|
if duration < 0.133:
|
||||||
|
score += (0.133 - duration) * 15
|
||||||
|
if duration > 2.0:
|
||||||
|
score += duration - 2.0
|
||||||
|
return score
|
||||||
|
|
||||||
|
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||||
|
if segment is None or not segment["words"]:
|
||||||
|
return False
|
||||||
|
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||||
|
words = words[:8]
|
||||||
|
score = sum(word_anomaly_score(w) for w in words)
|
||||||
|
return score >= 3 or score + 0.01 >= len(words)
|
||||||
|
|
||||||
|
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||||
|
return next((s for s in segments if s["words"]), None)
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
|
|
||||||
@ -330,17 +393,71 @@ def transcribe(
|
|||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
last_speech_timestamp=last_speech_timestamp,
|
||||||
)
|
)
|
||||||
word_end_timestamps = [
|
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
if not single_timestamp_ending:
|
||||||
]
|
last_word_end = get_end(current_segments)
|
||||||
if len(word_end_timestamps) > 0:
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
last_speech_timestamp = word_end_timestamps[-1]
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
|
||||||
seek_shift = round(
|
# skip silence before possible hallucinations
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
if hallucination_silence_threshold is not None:
|
||||||
)
|
threshold = hallucination_silence_threshold
|
||||||
if seek_shift > 0:
|
if not single_timestamp_ending:
|
||||||
seek = previous_seek + seek_shift
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
|
remaining_duration = window_end_time - last_word_end
|
||||||
|
if remaining_duration > threshold:
|
||||||
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
|
else:
|
||||||
|
seek = previous_seek + segment_size
|
||||||
|
|
||||||
|
# if first segment might be a hallucination, skip leading silence
|
||||||
|
first_segment = next_words_segment(current_segments)
|
||||||
|
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||||
|
gap = first_segment["start"] - time_offset
|
||||||
|
if gap > threshold:
|
||||||
|
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip silence before any possible hallucination that is surrounded
|
||||||
|
# by silence or more hallucinations
|
||||||
|
hal_last_end = last_speech_timestamp
|
||||||
|
for si in range(len(current_segments)):
|
||||||
|
segment = current_segments[si]
|
||||||
|
if not segment["words"]:
|
||||||
|
continue
|
||||||
|
if is_segment_anomaly(segment):
|
||||||
|
next_segment = next_words_segment(
|
||||||
|
current_segments[si + 1 :]
|
||||||
|
)
|
||||||
|
if next_segment is not None:
|
||||||
|
hal_next_start = next_segment["words"][0]["start"]
|
||||||
|
else:
|
||||||
|
hal_next_start = time_offset + segment_duration
|
||||||
|
silence_before = (
|
||||||
|
segment["start"] - hal_last_end > threshold
|
||||||
|
or segment["start"] < threshold
|
||||||
|
or segment["start"] - time_offset < 2.0
|
||||||
|
)
|
||||||
|
silence_after = (
|
||||||
|
hal_next_start - segment["end"] > threshold
|
||||||
|
or is_segment_anomaly(next_segment)
|
||||||
|
or window_end_time - segment["end"] < 2.0
|
||||||
|
)
|
||||||
|
if silence_before and silence_after:
|
||||||
|
seek = round(
|
||||||
|
max(time_offset + 1, segment["start"])
|
||||||
|
* FRAMES_PER_SECOND
|
||||||
|
)
|
||||||
|
if content_duration - segment["end"] < threshold:
|
||||||
|
seek = content_frames
|
||||||
|
current_segments[si:] = []
|
||||||
|
break
|
||||||
|
hal_last_end = segment["end"]
|
||||||
|
|
||||||
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None:
|
||||||
|
last_speech_timestamp = last_word_end
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
@ -427,6 +544,8 @@ def cli():
|
|||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||||
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||||
|
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Callable, Optional, TextIO
|
from typing import Callable, List, Optional, TextIO
|
||||||
|
|
||||||
system_encoding = sys.getdefaultencoding()
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
@ -68,6 +68,20 @@ def format_timestamp(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_start(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["start"] for s in segments for w in s["words"]),
|
||||||
|
segments[0]["start"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_end(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||||
|
segments[-1]["end"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResultWriter:
|
class ResultWriter:
|
||||||
extension: str
|
extension: str
|
||||||
|
|
||||||
@ -129,8 +143,8 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
line_len = 0
|
line_len = 0
|
||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: List[dict] = []
|
||||||
last = result["segments"][0]["words"][0]["start"]
|
last: float = get_start(result["segments"]) or 0.0
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
words_count = max_words_per_line
|
words_count = max_words_per_line
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user