diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 02952df..c040441 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -260,7 +260,7 @@ def cli(): parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") + parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") diff --git a/whisper/utils.py b/whisper/utils.py index 8315e7f..8293641 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -102,6 +102,25 @@ class WriteSRT(ResultWriter): ) +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + extension: str = "tsv" + + def write_result(self, result: dict, file: TextIO): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment['start']), file=file, end="\t") + print(round(1000 * segment['end']), file=file, end="\t") + print(segment['text'].strip().replace("\t", " "), file=file, flush=True) + + class WriteJSON(ResultWriter): extension: str = "json" @@ -114,6 +133,7 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], "txt": WriteTXT, "vtt": WriteVTT, "srt": WriteSRT, + "tsv": WriteTSV, "json": WriteJSON, } @@ -127,3 +147,4 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], return write_all return writers[output_format](output_dir) +