fix: transcribe verbosity (#140)

This commit is contained in:
Nick Konovalchuk 2022-09-26 21:46:21 +03:00 committed by GitHub
parent 9c8183a179
commit b4308c4782
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,7 +20,7 @@ def transcribe(
model: "Whisper", model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor],
*, *,
verbose: bool = False, verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4, compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0, logprob_threshold: Optional[float] = -1.0,
@ -40,7 +40,8 @@ def transcribe(
The path to the audio file to open, or the audio waveform The path to the audio file to open, or the audio waveform
verbose: bool verbose: bool
Whether to display the text being decoded to the console Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]] temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
@ -88,7 +89,8 @@ def transcribe(
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment) _, probs = model.detect_language(segment)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
mel = mel.unsqueeze(0) mel = mel.unsqueeze(0)
language = decode_options["language"] language = decode_options["language"]
@ -170,7 +172,7 @@ def transcribe(
num_frames = mel.shape[-1] num_frames = mel.shape[-1]
previous_seek_value = seek previous_seek_value = seek
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose) as pbar: with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames: while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype) segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)