mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
add progress bar for transcribe loop (#100)
* add progress bar to transcribe loop * improved warning message for English-only models * add --condition_on_previous_text * progressbar renames Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
This commit is contained in:
parent
5d8d3e75a4
commit
9e7e418ff1
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
@ -87,7 +88,7 @@ def transcribe(
|
|||||||
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||||
_, probs = model.detect_language(segment)
|
_, probs = model.detect_language(segment)
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
decode_options["language"] = max(probs, key=probs.get)
|
||||||
print(f"Detected language: {LANGUAGES[decode_options['language']]}")
|
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||||
|
|
||||||
mel = mel.unsqueeze(0)
|
mel = mel.unsqueeze(0)
|
||||||
language = decode_options["language"]
|
language = decode_options["language"]
|
||||||
@ -160,7 +161,12 @@ def transcribe(
|
|||||||
if verbose:
|
if verbose:
|
||||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||||
|
|
||||||
while seek < mel.shape[-1]:
|
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||||
|
num_frames = mel.shape[-1]
|
||||||
|
previous_seek_value = seek
|
||||||
|
|
||||||
|
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose) as pbar:
|
||||||
|
while seek < num_frames:
|
||||||
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
|
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||||
@ -227,6 +233,10 @@ def transcribe(
|
|||||||
# do not feed the prompt tokens if a high temperature was used
|
# do not feed the prompt tokens if a high temperature was used
|
||||||
prompt_reset_since = len(all_tokens)
|
prompt_reset_since = len(all_tokens)
|
||||||
|
|
||||||
|
# update progress bar
|
||||||
|
pbar.update(min(num_frames, seek) - previous_seek_value)
|
||||||
|
previous_seek_value = seek
|
||||||
|
|
||||||
return dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
|
return dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user