mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
* attempt to fix the repetition/hallucination issue identified in #1046 * zero-pad the audio instead of spectrogram * formatting fix * delete debug print
This commit is contained in:
parent
38e990d853
commit
919a713499
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -15,10 +15,8 @@ N_FFT = 400
|
|||||||
N_MELS = 80
|
N_MELS = 80
|
||||||
HOP_LENGTH = 160
|
HOP_LENGTH = 160
|
||||||
CHUNK_LENGTH = 30
|
CHUNK_LENGTH = 30
|
||||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
N_FRAMES = exact_div(
|
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||||
N_SAMPLES, HOP_LENGTH
|
|
||||||
) # 3000: number of frames in a mel spectrogram input
|
|
||||||
|
|
||||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||||
@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def log_mel_spectrogram(
|
def log_mel_spectrogram(
|
||||||
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
n_mels: int = N_MELS,
|
||||||
|
padding: int = 0,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute the log-Mel spectrogram of
|
Compute the log-Mel spectrogram of
|
||||||
@ -113,6 +114,12 @@ def log_mel_spectrogram(
|
|||||||
n_mels: int
|
n_mels: int
|
||||||
The number of Mel-frequency filters, only 80 is supported
|
The number of Mel-frequency filters, only 80 is supported
|
||||||
|
|
||||||
|
padding: int
|
||||||
|
Number of zero samples to pad to the right
|
||||||
|
|
||||||
|
device: Optional[Union[str, torch.device]]
|
||||||
|
If given, the audio tensor is moved to this device before STFT
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor, shape = (80, n_frames)
|
torch.Tensor, shape = (80, n_frames)
|
||||||
@ -123,6 +130,10 @@ def log_mel_spectrogram(
|
|||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
audio = torch.from_numpy(audio)
|
audio = torch.from_numpy(audio)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
audio = audio.to(device)
|
||||||
|
if padding > 0:
|
||||||
|
audio = F.pad(audio, (0, padding))
|
||||||
window = torch.hann_window(N_FFT).to(audio.device)
|
window = torch.hann_window(N_FFT).to(audio.device)
|
||||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||||
magnitudes = stft[..., :-1].abs() ** 2
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from .audio import (
|
|||||||
FRAMES_PER_SECOND,
|
FRAMES_PER_SECOND,
|
||||||
HOP_LENGTH,
|
HOP_LENGTH,
|
||||||
N_FRAMES,
|
N_FRAMES,
|
||||||
|
N_SAMPLES,
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
log_mel_spectrogram,
|
log_mel_spectrogram,
|
||||||
pad_or_trim,
|
pad_or_trim,
|
||||||
@ -116,7 +117,9 @@ def transcribe(
|
|||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
mel = log_mel_spectrogram(audio)
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
|
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||||
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
if not model.is_multilingual:
|
if not model.is_multilingual:
|
||||||
@ -212,14 +215,13 @@ def transcribe(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||||
num_frames = mel.shape[-1]
|
|
||||||
with tqdm.tqdm(
|
with tqdm.tqdm(
|
||||||
total=num_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
while seek < num_frames:
|
while seek < content_frames:
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
mel_segment = mel[:, seek:]
|
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||||
segment_size = min(mel_segment.shape[-1], N_FRAMES)
|
segment_size = min(N_FRAMES, content_frames - seek)
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
@ -246,20 +248,18 @@ def transcribe(
|
|||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
0
|
|
||||||
].add_(1)
|
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||||
if (
|
consecutive.add_(1)
|
||||||
len(consecutive) > 0
|
if len(consecutive) > 0:
|
||||||
): # if the output contains two consecutive timestamp tokens
|
# if the output contains two consecutive timestamp tokens
|
||||||
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
|
slices = consecutive.tolist()
|
||||||
False,
|
if single_timestamp_ending:
|
||||||
True,
|
slices.append(len(tokens))
|
||||||
]:
|
|
||||||
consecutive = consecutive.tolist() + [len(tokens)]
|
|
||||||
|
|
||||||
last_slice = 0
|
last_slice = 0
|
||||||
for current_slice in consecutive:
|
for current_slice in slices:
|
||||||
sliced_tokens = tokens[last_slice:current_slice]
|
sliced_tokens = tokens[last_slice:current_slice]
|
||||||
start_timestamp_pos = (
|
start_timestamp_pos = (
|
||||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||||
@ -278,7 +278,7 @@ def transcribe(
|
|||||||
current_tokens.append(sliced_tokens.tolist())
|
current_tokens.append(sliced_tokens.tolist())
|
||||||
last_slice = current_slice
|
last_slice = current_slice
|
||||||
|
|
||||||
if ended_with_single_timestamp:
|
if single_timestamp_ending:
|
||||||
# single timestamp at the end means no speech after the last timestamp.
|
# single timestamp at the end means no speech after the last timestamp.
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
else:
|
else:
|
||||||
@ -329,7 +329,7 @@ def transcribe(
|
|||||||
word_end_timestamps = [
|
word_end_timestamps = [
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
w["end"] for s in current_segments for w in s["words"]
|
||||||
]
|
]
|
||||||
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
|
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||||
seek_shift = round(
|
seek_shift = round(
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||||
)
|
)
|
||||||
@ -356,7 +356,7 @@ def transcribe(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar.update(min(num_frames, seek) - previous_seek)
|
pbar.update(min(content_frames, seek) - previous_seek)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user