Support longer audio files reducing memory usage with chunking

This commit is contained in:
Gustavo Garcia 2024-07-01 19:46:15 +02:00
parent ba3f3cd54b
commit 20e323895d
3 changed files with 359 additions and 339 deletions

View File

@ -7,13 +7,13 @@ from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
def test_audio(): def test_audio():
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
audio = load_audio(audio_path) audio = next(load_audio(audio_path))
assert audio.ndim == 1 assert audio.ndim == 1
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12 assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
assert 0 < audio.std() < 1 assert 0 < audio.std() < 1
mel_from_audio = log_mel_spectrogram(audio) mel_from_audio = next(log_mel_spectrogram(audio))
mel_from_file = log_mel_spectrogram(audio_path) mel_from_file = next(log_mel_spectrogram(audio_path))
assert np.allclose(mel_from_audio, mel_from_file) assert np.allclose(mel_from_audio, mel_from_file)
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0 assert mel_from_audio.max() - mel_from_audio.min() <= 2.0

View File

@ -1,7 +1,8 @@
import os import os
import subprocess
from functools import lru_cache from functools import lru_cache
from subprocess import CalledProcessError, run from subprocess import CalledProcessError, run
from typing import Optional, Union from typing import Generator, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -21,6 +22,7 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
MAX_CHUNK_DURATION = 2 * 60 * 60 # 2 hour maximum chunk duration
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """
@ -55,11 +57,15 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
] ]
# fmt: on # fmt: on
try: try:
out = run(cmd, capture_output=True, check=True).stdout process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 while True:
out = process.stdout.read(MAX_CHUNK_DURATION * sr * 2)
if not out:
break
yield np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@ -108,7 +114,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
def log_mel_spectrogram( def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor, Generator[np.ndarray, None, None]],
n_mels: int = 80, n_mels: int = 80,
padding: int = 0, padding: int = 0,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
@ -135,13 +141,26 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames) torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram
""" """
if not torch.is_tensor(audio): if isinstance(audio, str):
if isinstance(audio, str): audio = load_audio(audio)
audio = load_audio(audio) elif isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio) audio = [audio]
elif isinstance(audio, torch.Tensor):
audio = [audio]
if device is not None: for chunk in audio:
audio = audio.to(device) if not isinstance(chunk, torch.Tensor):
chunk = torch.from_numpy(chunk)
if device is not None:
chunk = chunk.to(device)
yield _log_mel_spectrogram(chunk, n_mels, padding)
def _log_mel_spectrogram(
audio: torch.Tensor,
n_mels: int = 80,
padding: int = 0,
):
if padding > 0: if padding > 0:
audio = F.pad(audio, (0, padding)) audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device) window = torch.hann_window(N_FFT).to(audio.device)

View File

@ -2,7 +2,7 @@ import argparse
import os import os
import traceback import traceback
import warnings import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -37,7 +37,7 @@ if TYPE_CHECKING:
def transcribe( def transcribe(
model: "Whisper", model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor, Generator[np.ndarray, None, None]],
*, *,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
@ -129,367 +129,368 @@ def transcribe(
if dtype == torch.float32: if dtype == torch.float32:
decode_options["fp16"] = False decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
return decode_result
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = [] all_tokens = []
all_segments = [] all_segments = []
prompt_reset_since = 0 punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if initial_prompt is not None: # Pad 30-seconds of silence to the input audio, for slicing
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) mels = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
all_tokens.extend(initial_prompt_tokens) for mel in mels:
else: content_frames = mel.shape[-1] - N_FRAMES
initial_prompt_tokens = [] content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
def new_segment( if decode_options.get("language", None) is None:
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult if not model.is_multilingual:
): decode_options["language"] = "en"
tokens = tokens.tolist() else:
text_tokens = [token for token in tokens if token < tokenizer.eot] if verbose:
return { print(
"seek": seek, "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
"start": start, )
"end": end, mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
"text": tokenizer.decode(text_tokens), _, probs = model.detect_language(mel_segment)
"tokens": tokens, decode_options["language"] = max(probs, key=probs.get)
"temperature": result.temperature, if verbose is not None:
"avg_logprob": result.avg_logprob, print(
"compression_ratio": result.compression_ratio, f"Detected language: {LANGUAGES[decode_options['language']].title()}"
"no_speech_prob": result.no_speech_prob, )
}
# show the progress bar when verbose is False (if True, transcribed text will be printed) language: str = decode_options["language"]
with tqdm.tqdm( task: str = decode_options.get("task", "transcribe")
total=content_frames, unit="frames", disable=verbose is not False tokenizer = get_tokenizer(
) as pbar: model.is_multilingual,
last_speech_timestamp = 0.0 num_languages=model.num_languages,
# NOTE: This loop is obscurely flattened to make the diff readable. language=language,
# A later commit should turn this into a simpler nested loop. task=task,
# for seek_clip_start, seek_clip_end in seek_clips: )
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:] if isinstance(clip_timestamps, str):
result: DecodingResult = decode_with_fallback(mel_segment) clip_timestamps = [
tokens = torch.tensor(result.tokens) float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
if no_speech_threshold is not None: if word_timestamps and task == "translate":
# no voice activity check warnings.warn("Word-level timestamps on translations may not be reliable.")
should_skip = result.no_speech_prob > no_speech_threshold
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if ( if (
logprob_threshold is not None logprob_threshold is not None
and result.avg_logprob > logprob_threshold and decode_result.avg_logprob < logprob_threshold
): ):
# don't skip if the logprob is high enough, despite the no_speech_prob needs_fallback = True # average log probability is too low
should_skip = False if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
if should_skip: return decode_result
seek += segment_size # fast-forward to the next segment boundary
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
previous_seek = seek decode_options["prompt"] = all_tokens[prompt_reset_since:]
current_segments = [] result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
# anomalous words are very long/short/improbable if no_speech_threshold is not None:
def word_anomaly_score(word: dict) -> float: # no voice activity check
probability = word.get("probability", 0.0) should_skip = result.no_speech_prob > no_speech_threshold
duration = word["end"] - word["start"] if (
score = 0.0 logprob_threshold is not None
if probability < 0.15: and result.avg_logprob > logprob_threshold
score += 1.0 ):
if duration < 0.133: # don't skip if the logprob is high enough, despite the no_speech_prob
score += (0.133 - duration) * 15 should_skip = False
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool: if should_skip:
if segment is None or not segment["words"]: seek += segment_size # fast-forward to the next segment boundary
return False continue
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]: previous_seek = seek
return next((s for s in segments if s["words"]), None) current_segments = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # anomalous words are very long/short/improbable
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] def is_segment_anomaly(segment: Optional[dict]) -> bool:
consecutive.add_(1) if segment is None or not segment["words"]:
if len(consecutive) > 0: return False
# if the output contains two consecutive timestamp tokens words = [w for w in segment["words"] if w["word"] not in punctuation]
slices = consecutive.tolist() words = words[:8]
if single_timestamp_ending: score = sum(word_anomaly_score(w) for w in words)
slices.append(len(tokens)) return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append( current_segments.append(
new_segment( new_segment(
start=time_offset + start_timestamp_pos * time_precision, start=time_offset,
end=time_offset + end_timestamp_pos * time_precision, end=time_offset + duration,
tokens=sliced_tokens, tokens=tokens,
result=result, result=result,
) )
) )
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp if word_timestamps:
last_timestamp_pos = ( add_word_timestamps(
tokens[last_slice - 1].item() - tokenizer.timestamp_begin segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
) )
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending: if not single_timestamp_ending:
last_word_end = get_end(current_segments) last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset: if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end seek = round(last_word_end * FRAMES_PER_SECOND)
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence # skip silence before possible hallucinations
first_segment = next_words_segment(current_segments) if hallucination_silence_threshold is not None:
if first_segment is not None and is_segment_anomaly(first_segment): threshold = hallucination_silence_threshold
gap = first_segment["start"] - time_offset if not single_timestamp_ending:
if gap > threshold: last_word_end = get_end(current_segments)
seek = previous_seek + round(gap * FRAMES_PER_SECOND) if last_word_end is not None and last_word_end > time_offset:
continue remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# skip silence before any possible hallucination that is surrounded # if first segment might be a hallucination, skip leading silence
# by silence or more hallucinations first_segment = next_words_segment(current_segments)
hal_last_end = last_speech_timestamp if first_segment is not None and is_segment_anomaly(first_segment):
for si in range(len(current_segments)): gap = first_segment["start"] - time_offset
segment = current_segments[si] if gap > threshold:
if not segment["words"]: seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue continue
if is_segment_anomaly(segment):
next_segment = next_words_segment( # skip silence before any possible hallucination that is surrounded
current_segments[si + 1 :] # by silence or more hallucinations
) hal_last_end = last_speech_timestamp
if next_segment is not None: for si in range(len(current_segments)):
hal_next_start = next_segment["words"][0]["start"] segment = current_segments[si]
else: if not segment["words"]:
hal_next_start = time_offset + segment_duration continue
silence_before = ( if is_segment_anomaly(segment):
segment["start"] - hal_last_end > threshold next_segment = next_words_segment(
or segment["start"] < threshold current_segments[si + 1 :]
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
) )
if content_duration - segment["end"] < threshold: if next_segment is not None:
seek = content_frames hal_next_start = next_segment["words"][0]["start"]
current_segments[si:] = [] else:
break hal_next_start = time_offset + segment_duration
hal_last_end = segment["end"] silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments) last_word_end = get_end(current_segments)
if last_word_end is not None: if last_word_end is not None:
last_speech_timestamp = last_word_end last_speech_timestamp = last_word_end
if verbose: if verbose:
for segment in current_segments: for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"] start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line)) print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it # if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments): for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "": if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = "" segment["text"] = ""
segment["tokens"] = [] segment["tokens"] = []
segment["words"] = [] segment["words"] = []
all_segments.extend( all_segments.extend(
[ [
{"id": i, **segment} {"id": i, **segment}
for i, segment in enumerate( for i, segment in enumerate(
current_segments, start=len(all_segments) current_segments, start=len(all_segments)
) )
] ]
) )
all_tokens.extend( all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]] [token for segment in current_segments for token in segment["tokens"]]
) )
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used # do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_tokens)
# update progress bar # update progress bar
pbar.update(min(content_frames, seek) - previous_seek) pbar.update(min(content_frames, seek) - previous_seek)
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),