diff --git a/whisper/transcribe.py b/whisper/transcribe.py index ed6d820..84feb12 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -401,6 +401,9 @@ def cli(): parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") + parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") + parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") + parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") # fmt: on @@ -433,9 +436,17 @@ def cli(): model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) + word_options = ["highlight_words", "max_line_count", "max_line_width"] + if not args["word_timestamps"]: + for option in word_options: + if args[option]: + parser.error(f"--{option} requires --word_timestamps True") + if args["max_line_count"] and not args["max_line_width"]: + warnings.warn("--max_line_count has no effect without --max_line_width") + writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) - writer(result, audio_path) + writer(result, audio_path, writer_args) if __name__ == "__main__": diff --git a/whisper/utils.py b/whisper/utils.py index 490bdd1..ba5a10c 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -1,8 +1,9 @@ import json import os +import re import sys import zlib -from typing import Callable, TextIO +from typing import Callable, Optional, TextIO system_encoding = sys.getdefaultencoding() @@ -73,7 +74,7 @@ class ResultWriter: def __init__(self, output_dir: str): self.output_dir = output_dir - def __call__(self, result: dict, audio_path: str): + def __call__(self, result: dict, audio_path: str, options: dict): audio_basename = os.path.basename(audio_path) audio_basename = os.path.splitext(audio_basename)[0] output_path = os.path.join( @@ -81,16 +82,16 @@ class ResultWriter: ) with open(output_path, "w", encoding="utf-8") as f: - self.write_result(result, file=f) + self.write_result(result, file=f, options=options) - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): raise NotImplementedError class WriteTXT(ResultWriter): extension: str = "txt" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): for segment in result["segments"]: print(segment["text"].strip(), file=file, flush=True) @@ -99,33 +100,81 @@ class SubtitlesWriter(ResultWriter): always_include_hours: bool decimal_marker: str - def iterate_result(self, result: dict): - for segment in result["segments"]: - segment_start = self.format_timestamp(segment["start"]) - segment_end = self.format_timestamp(segment["end"]) - segment_text = segment["text"].strip().replace("-->", "->") + def iterate_result(self, result: dict, options: dict): + raw_max_line_width: Optional[int] = options["max_line_width"] + max_line_count: Optional[int] = options["max_line_count"] + highlight_words: bool = options["highlight_words"] + max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width + preserve_segments = max_line_count is None or raw_max_line_width is None - if word_timings := segment.get("words", None): - all_words = [timing["word"] for timing in word_timings] - all_words[0] = all_words[0].strip() # remove the leading space, if any - last = segment_start - for i, this_word in enumerate(word_timings): - start = self.format_timestamp(this_word["start"]) - end = self.format_timestamp(this_word["end"]) - if last != start: - yield last, start, segment_text + def iterate_subtitles(): + line_len = 0 + line_count = 1 + # the next subtitle to yield (a list of word timings with whitespace) + subtitle: list[dict] = [] + last = result["segments"][0]["words"][0]["start"] + for segment in result["segments"]: + for i, original_timing in enumerate(segment["words"]): + timing = original_timing.copy() + long_pause = not preserve_segments and timing["start"] - last > 3.0 + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if line_len > 0 and has_room and not long_pause and not seg_break: + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle + subtitle = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + last = timing["start"] + if len(subtitle) > 0: + yield subtitle - yield start, end, "".join( - [ - f"{word}" if j == i else word - for j, word in enumerate(all_words) - ] - ) - last = end + if "words" in result["segments"][0]: + for subtitle in iterate_subtitles(): + subtitle_start = self.format_timestamp(subtitle[0]["start"]) + subtitle_end = self.format_timestamp(subtitle[-1]["end"]) + subtitle_text = "".join([word["word"] for word in subtitle]) + if highlight_words: + last = subtitle_start + all_words = [timing["word"] for timing in subtitle] + for i, this_word in enumerate(subtitle): + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, subtitle_text - if last != segment_end: - yield last, segment_end, segment_text - else: + yield start, end, "".join( + [ + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + for j, word in enumerate(all_words) + ] + ) + last = end + else: + yield subtitle_start, subtitle_end, subtitle_text + else: + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment["text"].strip().replace("-->", "->") yield segment_start, segment_end, segment_text def format_timestamp(self, seconds: float): @@ -141,9 +190,9 @@ class WriteVTT(SubtitlesWriter): always_include_hours: bool = False decimal_marker: str = "." - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): print("WEBVTT\n", file=file) - for start, end, text in self.iterate_result(result): + for start, end, text in self.iterate_result(result, options): print(f"{start} --> {end}\n{text}\n", file=file, flush=True) @@ -152,8 +201,10 @@ class WriteSRT(SubtitlesWriter): always_include_hours: bool = True decimal_marker: str = "," - def write_result(self, result: dict, file: TextIO): - for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): + def write_result(self, result: dict, file: TextIO, options: dict): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options), start=1 + ): print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) @@ -169,7 +220,7 @@ class WriteTSV(ResultWriter): extension: str = "tsv" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: print(round(1000 * segment["start"]), file=file, end="\t") @@ -180,11 +231,13 @@ class WriteTSV(ResultWriter): class WriteJSON(ResultWriter): extension: str = "json" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: dict): json.dump(result, file) -def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: +def get_writer( + output_format: str, output_dir: str +) -> Callable[[dict, TextIO, dict], None]: writers = { "txt": WriteTXT, "vtt": WriteVTT, @@ -196,9 +249,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: TextIO): + def write_all(result: dict, file: TextIO, options: dict): for writer in all_writers: - writer(result, file) + writer(result, file, options) return write_all