Add new option to generate subtitles by a specific number of words (#1729)

* ADD parser for new argument --max_words_count

* ADD max_words_count in words_options
ADD warning for max_line_width compatibility

* ADD logic for max_words_count

* rename to max_words_per_line

* make them kwargs

* allow specifying file path by --model

* black formatting

---------

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
This commit is contained in:
amosal 2023-11-06 10:49:33 +01:00 committed by GitHub
parent b38a1f20f4
commit 6ed314fe41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 47 deletions

View File

@ -378,10 +378,17 @@ def transcribe(
def cli(): def cli():
from . import available_models from . import available_models
def valid_model_name(name):
if name in available_models() or os.path.exists(name):
return name
raise ValueError(
f"model should be one of {available_models()} or path to a model checkpoint"
)
# fmt: off # fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
@ -412,6 +419,7 @@ def cli():
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on # fmt: on
@ -444,17 +452,24 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"] word_options = [
"highlight_words",
"max_line_count",
"max_line_width",
"max_words_per_line",
]
if not args["word_timestamps"]: if not args["word_timestamps"]:
for option in word_options: for option in word_options:
if args[option]: if args[option]:
parser.error(f"--{option} requires --word_timestamps True") parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]: if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width") warnings.warn("--max_line_count has no effect without --max_line_width")
if args["max_words_per_line"] and args["max_line_width"]:
warnings.warn("--max_words_per_line has no effect with --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options} writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args) result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, writer_args) writer(result, audio_path, **writer_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -74,7 +74,9 @@ class ResultWriter:
def __init__(self, output_dir: str): def __init__(self, output_dir: str):
self.output_dir = output_dir self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str, options: dict): def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path) audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0] audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join( output_path = os.path.join(
@ -82,16 +84,20 @@ class ResultWriter:
) )
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options) self.write_result(result, file=f, options=options, **kwargs)
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
raise NotImplementedError raise NotImplementedError
class WriteTXT(ResultWriter): class WriteTXT(ResultWriter):
extension: str = "txt" extension: str = "txt"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for segment in result["segments"]: for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True) print(segment["text"].strip(), file=file, flush=True)
@ -100,12 +106,24 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool always_include_hours: bool
decimal_marker: str decimal_marker: str
def iterate_result(self, result: dict, options: dict): def iterate_result(
raw_max_line_width: Optional[int] = options["max_line_width"] self,
max_line_count: Optional[int] = options["max_line_count"] result: dict,
highlight_words: bool = options["highlight_words"] options: Optional[dict] = None,
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width *,
preserve_segments = max_line_count is None or raw_max_line_width is None max_line_width: Optional[int] = None,
max_line_count: Optional[int] = None,
highlight_words: bool = False,
max_words_per_line: Optional[int] = None,
):
options = options or {}
max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count")
highlight_words = highlight_words or options.get("highlight_words", False)
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
preserve_segments = max_line_count is None or max_line_width is None
max_line_width = max_line_width or 1000
max_words_per_line = max_words_per_line or 1000
def iterate_subtitles(): def iterate_subtitles():
line_len = 0 line_len = 0
@ -114,34 +132,50 @@ class SubtitlesWriter(ResultWriter):
subtitle: list[dict] = [] subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"] last = result["segments"][0]["words"][0]["start"]
for segment in result["segments"]: for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]): chunk_index = 0
timing = original_timing.copy() words_count = max_words_per_line
long_pause = not preserve_segments and timing["start"] - last > 3.0 while chunk_index < len(segment["words"]):
has_room = line_len + len(timing["word"]) <= max_line_width remaining_words = len(segment["words"]) - chunk_index
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments if max_words_per_line > len(segment["words"]) - chunk_index:
if line_len > 0 and has_room and not long_pause and not seg_break: words_count = remaining_words
# line continuation for i, original_timing in enumerate(
line_len += len(timing["word"]) segment["words"][chunk_index : chunk_index + words_count]
else: ):
# new line timing = original_timing.copy()
timing["word"] = timing["word"].strip() long_pause = (
not preserve_segments and timing["start"] - last > 3.0
)
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if ( if (
len(subtitle) > 0 line_len > 0
and max_line_count is not None and has_room
and (long_pause or line_count >= max_line_count) and not long_pause
or seg_break and not seg_break
): ):
# subtitle break # line continuation
yield subtitle line_len += len(timing["word"])
subtitle = [] else:
line_count = 1 # new line
elif line_len > 0: timing["word"] = timing["word"].strip()
# line break if (
line_count += 1 len(subtitle) > 0
timing["word"] = "\n" + timing["word"] and max_line_count is not None
line_len = len(timing["word"].strip()) and (long_pause or line_count >= max_line_count)
subtitle.append(timing) or seg_break
last = timing["start"] ):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
chunk_index += max_words_per_line
if len(subtitle) > 0: if len(subtitle) > 0:
yield subtitle yield subtitle
@ -190,9 +224,11 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False always_include_hours: bool = False
decimal_marker: str = "." decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options): for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@ -201,9 +237,11 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True always_include_hours: bool = True
decimal_marker: str = "," decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for i, (start, end, text) in enumerate( for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1 self.iterate_result(result, options, **kwargs), start=1
): ):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
@ -220,7 +258,9 @@ class WriteTSV(ResultWriter):
extension: str = "tsv" extension: str = "tsv"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("start", "end", "text", sep="\t", file=file) print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]: for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t") print(round(1000 * segment["start"]), file=file, end="\t")
@ -231,7 +271,9 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter): class WriteJSON(ResultWriter):
extension: str = "json" extension: str = "json"
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
json.dump(result, file) json.dump(result, file)
@ -249,9 +291,11 @@ def get_writer(
if output_format == "all": if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()] all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO, options: dict): def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for writer in all_writers: for writer in all_writers:
writer(result, file, options) writer(result, file, options, **kwargs)
return write_all return write_all