diff --git a/whisper/transcribe.py b/whisper/transcribe.py
index ed6d820..84feb12 100644
--- a/whisper/transcribe.py
+++ b/whisper/transcribe.py
@@ -401,6 +401,9 @@ def cli():
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
+ parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
+ parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
+ parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
@@ -433,9 +436,17 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
+ word_options = ["highlight_words", "max_line_count", "max_line_width"]
+ if not args["word_timestamps"]:
+ for option in word_options:
+ if args[option]:
+ parser.error(f"--{option} requires --word_timestamps True")
+ if args["max_line_count"] and not args["max_line_width"]:
+ warnings.warn("--max_line_count has no effect without --max_line_width")
+ writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
- writer(result, audio_path)
+ writer(result, audio_path, writer_args)
if __name__ == "__main__":
diff --git a/whisper/utils.py b/whisper/utils.py
index 490bdd1..ba5a10c 100644
--- a/whisper/utils.py
+++ b/whisper/utils.py
@@ -1,8 +1,9 @@
import json
import os
+import re
import sys
import zlib
-from typing import Callable, TextIO
+from typing import Callable, Optional, TextIO
system_encoding = sys.getdefaultencoding()
@@ -73,7 +74,7 @@ class ResultWriter:
def __init__(self, output_dir: str):
self.output_dir = output_dir
- def __call__(self, result: dict, audio_path: str):
+ def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
@@ -81,16 +82,16 @@ class ResultWriter:
)
with open(output_path, "w", encoding="utf-8") as f:
- self.write_result(result, file=f)
+ self.write_result(result, file=f, options=options)
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: dict):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
@@ -99,33 +100,81 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
- def iterate_result(self, result: dict):
- for segment in result["segments"]:
- segment_start = self.format_timestamp(segment["start"])
- segment_end = self.format_timestamp(segment["end"])
- segment_text = segment["text"].strip().replace("-->", "->")
+ def iterate_result(self, result: dict, options: dict):
+ raw_max_line_width: Optional[int] = options["max_line_width"]
+ max_line_count: Optional[int] = options["max_line_count"]
+ highlight_words: bool = options["highlight_words"]
+ max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
+ preserve_segments = max_line_count is None or raw_max_line_width is None
- if word_timings := segment.get("words", None):
- all_words = [timing["word"] for timing in word_timings]
- all_words[0] = all_words[0].strip() # remove the leading space, if any
- last = segment_start
- for i, this_word in enumerate(word_timings):
- start = self.format_timestamp(this_word["start"])
- end = self.format_timestamp(this_word["end"])
- if last != start:
- yield last, start, segment_text
+ def iterate_subtitles():
+ line_len = 0
+ line_count = 1
+ # the next subtitle to yield (a list of word timings with whitespace)
+ subtitle: list[dict] = []
+ last = result["segments"][0]["words"][0]["start"]
+ for segment in result["segments"]:
+ for i, original_timing in enumerate(segment["words"]):
+ timing = original_timing.copy()
+ long_pause = not preserve_segments and timing["start"] - last > 3.0
+ has_room = line_len + len(timing["word"]) <= max_line_width
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
+ if line_len > 0 and has_room and not long_pause and not seg_break:
+ # line continuation
+ line_len += len(timing["word"])
+ else:
+ # new line
+ timing["word"] = timing["word"].strip()
+ if (
+ len(subtitle) > 0
+ and max_line_count is not None
+ and (long_pause or line_count >= max_line_count)
+ or seg_break
+ ):
+ # subtitle break
+ yield subtitle
+ subtitle = []
+ line_count = 1
+ elif line_len > 0:
+ # line break
+ line_count += 1
+ timing["word"] = "\n" + timing["word"]
+ line_len = len(timing["word"].strip())
+ subtitle.append(timing)
+ last = timing["start"]
+ if len(subtitle) > 0:
+ yield subtitle
- yield start, end, "".join(
- [
- f"{word}" if j == i else word
- for j, word in enumerate(all_words)
- ]
- )
- last = end
+ if "words" in result["segments"][0]:
+ for subtitle in iterate_subtitles():
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
+ subtitle_text = "".join([word["word"] for word in subtitle])
+ if highlight_words:
+ last = subtitle_start
+ all_words = [timing["word"] for timing in subtitle]
+ for i, this_word in enumerate(subtitle):
+ start = self.format_timestamp(this_word["start"])
+ end = self.format_timestamp(this_word["end"])
+ if last != start:
+ yield last, start, subtitle_text
- if last != segment_end:
- yield last, segment_end, segment_text
- else:
+ yield start, end, "".join(
+ [
+ re.sub(r"^(\s*)(.*)$", r"\1\2", word)
+ if j == i
+ else word
+ for j, word in enumerate(all_words)
+ ]
+ )
+ last = end
+ else:
+ yield subtitle_start, subtitle_end, subtitle_text
+ else:
+ for segment in result["segments"]:
+ segment_start = self.format_timestamp(segment["start"])
+ segment_end = self.format_timestamp(segment["end"])
+ segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
@@ -141,9 +190,9 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False
decimal_marker: str = "."
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file)
- for start, end, text in self.iterate_result(result):
+ for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -152,8 +201,10 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True
decimal_marker: str = ","
- def write_result(self, result: dict, file: TextIO):
- for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
+ def write_result(self, result: dict, file: TextIO, options: dict):
+ for i, (start, end, text) in enumerate(
+ self.iterate_result(result, options), start=1
+ ):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -169,7 +220,7 @@ class WriteTSV(ResultWriter):
extension: str = "tsv"
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
@@ -180,11 +231,13 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter):
extension: str = "json"
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file)
-def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
+def get_writer(
+ output_format: str, output_dir: str
+) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
@@ -196,9 +249,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
- def write_all(result: dict, file: TextIO):
+ def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers:
- writer(result, file)
+ writer(result, file, options)
return write_all