fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060)

This commit is contained in:
Jong Wook Kim 2023-03-08 18:34:07 -05:00 committed by GitHub
parent aac47c9834
commit 38f2f4d99d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 11 deletions

View File

@ -17,6 +17,7 @@ def test_transcribe(model_name: str):
audio_path, language=language, temperature=0.0, word_timestamps=True audio_path, language=language, temperature=0.0, word_timestamps=True
) )
assert result["language"] == "en" assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
transcription = result["text"].lower() transcription = result["text"].lower()
assert "my fellow americans" in transcription assert "my fellow americans" in transcription

View File

@ -290,7 +290,7 @@ def add_word_timestamps(
if len(segments) == 0: if len(segments) == 0:
return return
text_tokens = [t for segment in segments for t in segment["tokens"]] text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations) merge_punctuations(alignment, prepend_punctuations, append_punctuations)

View File

@ -200,14 +200,14 @@ def transcribe(
def new_segment( def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
): ):
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot] tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return { return {
"id": len(all_segments),
"seek": seek, "seek": seek,
"start": start, "start": start,
"end": end, "end": end,
"text": tokenizer.decode(text_tokens), "text": tokenizer.decode(text_tokens),
"tokens": text_tokens, "tokens": tokens,
"temperature": result.temperature, "temperature": result.temperature,
"avg_logprob": result.avg_logprob, "avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio, "compression_ratio": result.compression_ratio,
@ -245,7 +245,6 @@ def transcribe(
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@ -275,7 +274,6 @@ def transcribe(
result=result, result=result,
) )
) )
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice last_slice = current_slice
if single_timestamp_ending: if single_timestamp_ending:
@ -287,7 +285,6 @@ def transcribe(
tokens[last_slice - 1].item() - tokenizer.timestamp_begin tokens[last_slice - 1].item() - tokenizer.timestamp_begin
) )
seek += last_timestamp_pos * input_stride seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else: else:
duration = segment_duration duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@ -309,7 +306,6 @@ def transcribe(
result=result, result=result,
) )
) )
current_tokens.append(tokens.tolist())
seek += segment_size seek += segment_size
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:
@ -348,11 +344,17 @@ def transcribe(
segment["text"] = "" segment["text"] = ""
segment["tokens"] = [] segment["tokens"] = []
segment["words"] = [] segment["words"] = []
current_tokens[i] = []
all_segments.extend(current_segments) all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend( all_tokens.extend(
[token for segment in current_tokens for token in segment] [token for segment in current_segments for token in segment["tokens"]]
) )
# update progress bar # update progress bar