mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
allowing nonzero initial temperature
This commit is contained in:
parent
30dc5c581b
commit
7cb4cc21bf
@ -94,7 +94,7 @@ class DecodingOptions:
|
|||||||
|
|
||||||
# timestamp sampling options
|
# timestamp sampling options
|
||||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||||
max_initial_timestamp: Optional[float] = 0.0 # the initial timestamp cannot be later than this
|
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
||||||
|
|
||||||
# implementation details
|
# implementation details
|
||||||
fp16: bool = True # use fp16 for most of the calculation
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
|||||||
@ -92,41 +92,37 @@ def transcribe(
|
|||||||
if verbose is not None:
|
if verbose is not None:
|
||||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||||
|
|
||||||
mel = mel.unsqueeze(0)
|
|
||||||
language = decode_options["language"]
|
language = decode_options["language"]
|
||||||
task = decode_options.get("task", "transcribe")
|
task = decode_options.get("task", "transcribe")
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||||
|
|
||||||
def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]:
|
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||||
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
||||||
kwargs = {**decode_options}
|
decode_result = None
|
||||||
t = temperatures[0]
|
|
||||||
if t == 0:
|
|
||||||
best_of = kwargs.pop("best_of", None)
|
|
||||||
else:
|
|
||||||
best_of = kwargs.get("best_of", None)
|
|
||||||
|
|
||||||
options = DecodingOptions(**kwargs, temperature=t)
|
for t in temperatures:
|
||||||
results = model.decode(segment, options)
|
kwargs = {**decode_options}
|
||||||
|
if t > 0:
|
||||||
|
# disable beam_size and patience when t > 0
|
||||||
|
kwargs.pop("beam_size", None)
|
||||||
|
kwargs.pop("patience", None)
|
||||||
|
else:
|
||||||
|
# disable best_of when t == 0
|
||||||
|
kwargs.pop("best_of", None)
|
||||||
|
|
||||||
kwargs.pop("beam_size", None) # no beam search for t > 0
|
options = DecodingOptions(**kwargs, temperature=t)
|
||||||
kwargs.pop("patience", None) # no patience for t > 0
|
decode_result = model.decode(segment, options)
|
||||||
kwargs["best_of"] = best_of # enable best_of for t > 0
|
|
||||||
for t in temperatures[1:]:
|
|
||||||
needs_fallback = [
|
|
||||||
compression_ratio_threshold is not None
|
|
||||||
and result.compression_ratio > compression_ratio_threshold
|
|
||||||
or logprob_threshold is not None
|
|
||||||
and result.avg_logprob < logprob_threshold
|
|
||||||
for result in results
|
|
||||||
]
|
|
||||||
if any(needs_fallback):
|
|
||||||
options = DecodingOptions(**kwargs, temperature=t)
|
|
||||||
retries = model.decode(segment[needs_fallback], options)
|
|
||||||
for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
|
|
||||||
results[original_index] = retries[retry_index]
|
|
||||||
|
|
||||||
return results
|
needs_fallback = False
|
||||||
|
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
||||||
|
needs_fallback = True # too repetitive
|
||||||
|
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
||||||
|
needs_fallback = True # average log probability is too low
|
||||||
|
|
||||||
|
if not needs_fallback:
|
||||||
|
break
|
||||||
|
|
||||||
|
return decode_result
|
||||||
|
|
||||||
seek = 0
|
seek = 0
|
||||||
input_stride = exact_div(
|
input_stride = exact_div(
|
||||||
@ -175,11 +171,11 @@ def transcribe(
|
|||||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||||
while seek < num_frames:
|
while seek < num_frames:
|
||||||
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
|
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||||
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
result = decode_with_fallback(segment)[0]
|
result: DecodingResult = decode_with_fallback(segment)
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
if no_speech_threshold is not None:
|
if no_speech_threshold is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user