add progress bar for transcribe loop (#100)

* add progress bar to transcribe loop

* improved warning message for English-only models

* add --condition_on_previous_text

* progressbar renames

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
This commit is contained in:
fatih 2022-09-26 13:24:13 +03:00 committed by GitHub
parent 5d8d3e75a4
commit 9e7e418ff1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union, TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
@ -87,7 +88,7 @@ def transcribe(
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment) _, probs = model.detect_language(segment)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
print(f"Detected language: {LANGUAGES[decode_options['language']]}") print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
mel = mel.unsqueeze(0) mel = mel.unsqueeze(0)
language = decode_options["language"] language = decode_options["language"]
@ -160,72 +161,81 @@ def transcribe(
if verbose: if verbose:
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}") print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
while seek < mel.shape[-1]: # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) num_frames = mel.shape[-1]
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype) previous_seek_value = seek
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
decode_options["prompt"] = all_tokens[prompt_reset_since:] with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose) as pbar:
result = decode_with_fallback(segment)[0] while seek < num_frames:
tokens = torch.tensor(result.tokens) timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
if no_speech_threshold is not None: decode_options["prompt"] = all_tokens[prompt_reset_since:]
# no voice activity check result = decode_with_fallback(segment)[0]
should_skip = result.no_speech_prob > no_speech_threshold tokens = torch.tensor(result.tokens)
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip: if no_speech_threshold is not None:
seek += segment.shape[-1] # fast-forward to the next segment boundary # no voice activity check
continue should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) if should_skip:
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) seek += segment.shape[-1] # fast-forward to the next segment boundary
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens continue
last_slice = 0
for current_slice in consecutive: timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
sliced_tokens = tokens[last_slice:current_slice] consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
start_timestamp_position = ( if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
sliced_tokens[0].item() - tokenizer.timestamp_begin last_slice = 0
) for current_slice in consecutive:
end_timestamp_position = ( sliced_tokens = tokens[last_slice:current_slice]
sliced_tokens[-1].item() - tokenizer.timestamp_begin start_timestamp_position = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
) )
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
add_segment( add_segment(
start=timestamp_offset + start_timestamp_position * time_precision, start=timestamp_offset,
end=timestamp_offset + end_timestamp_position * time_precision, end=timestamp_offset + duration,
text_tokens=sliced_tokens[1:-1], text_tokens=tokens,
result=result, result=result,
) )
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
add_segment( seek += segment.shape[-1]
start=timestamp_offset, all_tokens.extend(tokens.tolist())
end=timestamp_offset + duration,
text_tokens=tokens,
result=result,
)
seek += segment.shape[-1] if not condition_on_previous_text or result.temperature > 0.5:
all_tokens.extend(tokens.tolist()) # do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
if not condition_on_previous_text or result.temperature > 0.5: # update progress bar
# do not feed the prompt tokens if a high temperature was used pbar.update(min(num_frames, seek) - previous_seek_value)
prompt_reset_since = len(all_tokens) previous_seek_value = seek
return dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language) return dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)