Implement max line width and max line count, and make word highlighting optional (#1184)

* Add highlight_words, max_line_width, max_line_count

* Refactor subtitle generator

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
ryanheise 2023-04-11 10:28:35 +10:00 committed by GitHub
parent 255887f219
commit 43940fc978
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 39 deletions

View File

@ -401,6 +401,9 @@ def cli():
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on # fmt: on
@ -433,9 +436,17 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args) result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path) writer(result, audio_path, writer_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,8 +1,9 @@
import json import json
import os import os
import re
import sys import sys
import zlib import zlib
from typing import Callable, TextIO from typing import Callable, Optional, TextIO
system_encoding = sys.getdefaultencoding() system_encoding = sys.getdefaultencoding()
@ -73,7 +74,7 @@ class ResultWriter:
def __init__(self, output_dir: str): def __init__(self, output_dir: str):
self.output_dir = output_dir self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str): def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path) audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0] audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join( output_path = os.path.join(
@ -81,16 +82,16 @@ class ResultWriter:
) )
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f) self.write_result(result, file=f, options=options)
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: dict):
raise NotImplementedError raise NotImplementedError
class WriteTXT(ResultWriter): class WriteTXT(ResultWriter):
extension: str = "txt" extension: str = "txt"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]: for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True) print(segment["text"].strip(), file=file, flush=True)
@ -99,33 +100,81 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool always_include_hours: bool
decimal_marker: str decimal_marker: str
def iterate_result(self, result: dict): def iterate_result(self, result: dict, options: dict):
for segment in result["segments"]: raw_max_line_width: Optional[int] = options["max_line_width"]
segment_start = self.format_timestamp(segment["start"]) max_line_count: Optional[int] = options["max_line_count"]
segment_end = self.format_timestamp(segment["end"]) highlight_words: bool = options["highlight_words"]
segment_text = segment["text"].strip().replace("-->", "->") max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None
if word_timings := segment.get("words", None): def iterate_subtitles():
all_words = [timing["word"] for timing in word_timings] line_len = 0
all_words[0] = all_words[0].strip() # remove the leading space, if any line_count = 1
last = segment_start # the next subtitle to yield (a list of word timings with whitespace)
for i, this_word in enumerate(word_timings): subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments and timing["start"] - last > 3.0
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
if len(subtitle) > 0:
yield subtitle
if "words" in result["segments"][0]:
for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"]) start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"]) end = self.format_timestamp(this_word["end"])
if last != start: if last != start:
yield last, start, segment_text yield last, start, subtitle_text
yield start, end, "".join( yield start, end, "".join(
[ [
f"<u>{word}</u>" if j == i else word re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words) for j, word in enumerate(all_words)
] ]
) )
last = end last = end
if last != segment_end:
yield last, segment_end, segment_text
else: else:
yield subtitle_start, subtitle_end, subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float): def format_timestamp(self, seconds: float):
@ -141,9 +190,9 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False always_include_hours: bool = False
decimal_marker: str = "." decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result): for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@ -152,8 +201,10 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True always_include_hours: bool = True
decimal_marker: str = "," decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: dict):
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
@ -169,7 +220,7 @@ class WriteTSV(ResultWriter):
extension: str = "tsv" extension: str = "tsv"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file) print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]: for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t") print(round(1000 * segment["start"]), file=file, end="\t")
@ -180,11 +231,13 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter): class WriteJSON(ResultWriter):
extension: str = "json" extension: str = "json"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file) json.dump(result, file)
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = { writers = {
"txt": WriteTXT, "txt": WriteTXT,
"vtt": WriteVTT, "vtt": WriteVTT,
@ -196,9 +249,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
if output_format == "all": if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()] all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO): def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers: for writer in all_writers:
writer(result, file) writer(result, file, options)
return write_all return write_all