diff --git a/whisper/utils.py b/whisper/utils.py index 9b9b138..2ea4146 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -21,12 +21,51 @@ else: return string -def exact_div(x, y): +def exact_div(x:int, y:int): + """ + Performs exact division of x by y. + + Parameters + ---------- + x : int + The dividend. + + y : int + The divisor. + + Returns + ------- + quotient : int + The result of the exact division. + + Raises + ------ + AssertionError + If x is not exactly divisible by y. + """ assert x % y == 0 return x // y -def str2bool(string): +def str2bool(string:str) -> bool: + """ + Converts a string representation of a boolean to its boolean equivalent. + + Parameters + ---------- + string : str + The string representation of the boolean. + + Returns + ------- + bool + The boolean value represented by the input string. + + Raises + ------ + ValueError + If the input string does not represent a boolean value. + """ str2val = {"True": True, "False": False} if string in str2val: return str2val[string] @@ -34,15 +73,54 @@ def str2bool(string): raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") -def optional_int(string): +def optional_int(string:str) -> int: + """ + Converts a string to an integer or returns None if the string is "None". + + Parameters + ---------- + string : str + The string to convert. + + Returns + ------- + int or None + The integer value of the string, or None if the string is "None". + """ return None if string == "None" else int(string) -def optional_float(string): +def optional_float(string:str) -> float: + """ + Converts a string to a float or returns None if the string is "None". + + Parameters + ---------- + string : str + The string to convert. + + Returns + ------- + float or None + The float value of the string, or None if the string is "None". + """ return None if string == "None" else float(string) -def compression_ratio(text) -> float: +def compression_ratio(text:str) -> float: + """ + Calculates the compression ratio of a text using zlib compression. + + Parameters + ---------- + text : str + The text to compress. + + Returns + ------- + float + The compression ratio of the text. + """ text_bytes = text.encode("utf-8") return len(text_bytes) / len(zlib.compress(text_bytes)) @@ -50,6 +128,25 @@ def compression_ratio(text) -> float: def format_timestamp( seconds: float, always_include_hours: bool = False, decimal_marker: str = "." ): + """ + Formats a timestamp in seconds into a human-readable string. + + Parameters + ---------- + seconds : float + The timestamp in seconds. + + always_include_hours : bool, optional + Whether to always include hours in the formatted timestamp. Default is False. + + decimal_marker : str, optional + The decimal marker to use. Default is ".". + + Returns + ------- + str + The formatted timestamp string. + """ assert seconds >= 0, "non-negative timestamp expected" milliseconds = round(seconds * 1000.0) @@ -69,6 +166,19 @@ def format_timestamp( def get_start(segments: List[dict]) -> Optional[float]: + """ + Get the start time from a list of segments. + + Parameters + ---------- + segments : List[dict] + A list of segments, each containing a "start" field. + + Returns + ------- + Optional[float] + The start time, or None if no segments are provided. + """ return next( (w["start"] for s in segments for w in s["words"]), segments[0]["start"] if segments else None, @@ -76,6 +186,19 @@ def get_start(segments: List[dict]) -> Optional[float]: def get_end(segments: List[dict]) -> Optional[float]: + """ + Get the end time from a list of segments. + + Parameters + ---------- + segments : List[dict] + A list of segments, each containing a "end" field. + + Returns + ------- + Optional[float] + The end time, or None if no segments are provided. + """ return next( (w["end"] for s in reversed(segments) for w in reversed(s["words"])), segments[-1]["end"] if segments else None, @@ -83,6 +206,14 @@ def get_end(segments: List[dict]) -> Optional[float]: class ResultWriter: + """ + Base class for result writers. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + """ extension: str def __init__(self, output_dir: str): @@ -91,6 +222,24 @@ class ResultWriter: def __call__( self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs ): + """ + Writes the result to a file. + + Parameters + ---------- + result : dict + The result to write. + + audio_path : str + The path to the audio file associated with the result. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ audio_basename = os.path.basename(audio_path) audio_basename = os.path.splitext(audio_basename)[0] output_path = os.path.join( @@ -103,20 +252,75 @@ class ResultWriter: def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs ): + """ + Writes the result to a file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ raise NotImplementedError class WriteTXT(ResultWriter): + """ + Result writer for writing text results to a .txt file. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + """ extension: str = "txt" def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs ): + """ + Writes the result to a .txt file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ for segment in result["segments"]: print(segment["text"].strip(), file=file, flush=True) class SubtitlesWriter(ResultWriter): + """ + Base class for subtitle writers. + + Attributes + ---------- + always_include_hours : bool + Whether to always include hours in the formatted timestamps. + + decimal_marker : str + The decimal marker to use in formatted timestamps. + """ always_include_hours: bool decimal_marker: str @@ -130,6 +334,29 @@ class SubtitlesWriter(ResultWriter): highlight_words: bool = False, max_words_per_line: Optional[int] = None, ): + """ + Iterates over the result to generate subtitles. + + Parameters + ---------- + result : dict + The result to iterate over. + + options : dict, optional + Additional options for iterating the result. Default is None. + + max_line_width : int, optional + The maximum width of each line. Default is None. + + max_line_count : int, optional + The maximum number of lines. Default is None. + + highlight_words : bool, optional + Whether to highlight individual words in the subtitles. Default is False. + + max_words_per_line : int, optional + The maximum number of words per line. Default is None. + """ options = options or {} max_line_width = max_line_width or options.get("max_line_width") max_line_count = max_line_count or options.get("max_line_count") @@ -226,6 +453,19 @@ class SubtitlesWriter(ResultWriter): yield segment_start, segment_end, segment_text def format_timestamp(self, seconds: float): + """ + Formats a timestamp in seconds into a human-readable string. + + Parameters + ---------- + seconds : float + The timestamp in seconds. + + Returns + ------- + str + The formatted timestamp string. + """ return format_timestamp( seconds=seconds, always_include_hours=self.always_include_hours, @@ -234,26 +474,90 @@ class SubtitlesWriter(ResultWriter): class WriteVTT(SubtitlesWriter): + """ + Result writer for writing subtitles to a .vtt file. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + + always_include_hours : bool + Whether to always include hours in the formatted timestamps. + + decimal_marker : str + The decimal marker to use in formatted timestamps. + """ extension: str = "vtt" always_include_hours: bool = False decimal_marker: str = "." def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs - ): + ) -> None: + """ + Writes the result to a .vtt file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ print("WEBVTT\n", file=file) for start, end, text in self.iterate_result(result, options, **kwargs): print(f"{start} --> {end}\n{text}\n", file=file, flush=True) class WriteSRT(SubtitlesWriter): + """ + Result writer for writing subtitles to a .srt file. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + + always_include_hours : bool + Whether to always include hours in the formatted timestamps. + + decimal_marker : str + The decimal marker to use in formatted timestamps. + """ extension: str = "srt" always_include_hours: bool = True decimal_marker: str = "," def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs - ): + ) -> None: + """ + Writes the result to a .srt file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ for i, (start, end, text) in enumerate( self.iterate_result(result, options, **kwargs), start=1 ): @@ -274,7 +578,25 @@ class WriteTSV(ResultWriter): def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs - ): + ) -> None: + """ + Writes the result to a .tsv file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: print(round(1000 * segment["start"]), file=file, end="\t") @@ -283,17 +605,55 @@ class WriteTSV(ResultWriter): class WriteJSON(ResultWriter): + """ + Result writer for writing data to a .json file. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + """ extension: str = "json" def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs - ): + ) -> None: + """ + Writes the result to a .json file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + """ json.dump(result, file) def get_writer( output_format: str, output_dir: str ) -> Callable[[dict, TextIO, dict], None]: + """ + Returns a result writer based on the specified output format. + + Parameters + ---------- + output_format : str + The desired output format for the writer. + + output_dir : str + The directory where the output files will be saved. + + Returns + ------- + Callable[[dict, TextIO, dict], None] + A function that can be used to write results to files. + """ writers = { "txt": WriteTXT, "vtt": WriteVTT,