mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060)
This commit is contained in:
parent
aac47c9834
commit
38f2f4d99d
@ -17,6 +17,7 @@ def test_transcribe(model_name: str):
|
|||||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||||
)
|
)
|
||||||
assert result["language"] == "en"
|
assert result["language"] == "en"
|
||||||
|
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||||
|
|
||||||
transcription = result["text"].lower()
|
transcription = result["text"].lower()
|
||||||
assert "my fellow americans" in transcription
|
assert "my fellow americans" in transcription
|
||||||
|
|||||||
@ -290,7 +290,7 @@ def add_word_timestamps(
|
|||||||
if len(segments) == 0:
|
if len(segments) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
text_tokens = [t for segment in segments for t in segment["tokens"]]
|
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
|
||||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||||
|
|
||||||
|
|||||||
@ -200,14 +200,14 @@ def transcribe(
|
|||||||
def new_segment(
|
def new_segment(
|
||||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||||
):
|
):
|
||||||
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
|
tokens = tokens.tolist()
|
||||||
|
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||||
return {
|
return {
|
||||||
"id": len(all_segments),
|
|
||||||
"seek": seek,
|
"seek": seek,
|
||||||
"start": start,
|
"start": start,
|
||||||
"end": end,
|
"end": end,
|
||||||
"text": tokenizer.decode(text_tokens),
|
"text": tokenizer.decode(text_tokens),
|
||||||
"tokens": text_tokens,
|
"tokens": tokens,
|
||||||
"temperature": result.temperature,
|
"temperature": result.temperature,
|
||||||
"avg_logprob": result.avg_logprob,
|
"avg_logprob": result.avg_logprob,
|
||||||
"compression_ratio": result.compression_ratio,
|
"compression_ratio": result.compression_ratio,
|
||||||
@ -245,7 +245,6 @@ def transcribe(
|
|||||||
|
|
||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
current_tokens = []
|
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
@ -275,7 +274,6 @@ def transcribe(
|
|||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
current_tokens.append(sliced_tokens.tolist())
|
|
||||||
last_slice = current_slice
|
last_slice = current_slice
|
||||||
|
|
||||||
if single_timestamp_ending:
|
if single_timestamp_ending:
|
||||||
@ -287,7 +285,6 @@ def transcribe(
|
|||||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
seek += last_timestamp_pos * input_stride
|
seek += last_timestamp_pos * input_stride
|
||||||
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
|
||||||
else:
|
else:
|
||||||
duration = segment_duration
|
duration = segment_duration
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
@ -309,7 +306,6 @@ def transcribe(
|
|||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
current_tokens.append(tokens.tolist())
|
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
|
|
||||||
if not condition_on_previous_text or result.temperature > 0.5:
|
if not condition_on_previous_text or result.temperature > 0.5:
|
||||||
@ -348,11 +344,17 @@ def transcribe(
|
|||||||
segment["text"] = ""
|
segment["text"] = ""
|
||||||
segment["tokens"] = []
|
segment["tokens"] = []
|
||||||
segment["words"] = []
|
segment["words"] = []
|
||||||
current_tokens[i] = []
|
|
||||||
|
|
||||||
all_segments.extend(current_segments)
|
all_segments.extend(
|
||||||
|
[
|
||||||
|
{"id": i, **segment}
|
||||||
|
for i, segment in enumerate(
|
||||||
|
current_segments, start=len(all_segments)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
all_tokens.extend(
|
all_tokens.extend(
|
||||||
[token for segment in current_tokens for token in segment]
|
[token for segment in current_segments for token in segment["tokens"]]
|
||||||
)
|
)
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user