From 492c05c5f3b2666c8d8f6f8f06b54e1cdfe8761a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20Brul=C3=A9=20Naudet?= Date: Mon, 19 Feb 2024 20:12:26 +0100 Subject: [PATCH] Update utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dear Developers, I'm pleased to inform you that I have completed the documentation update the utils.py file. The updated documentation provides clear explanations of function parameters, return types, and expected behavior. Additionally, it adheres to consistent formatting and organization, ensuring ease of understanding for both current and future developers. Please review the updated documentation at your earliest convenience. If you have any feedback or suggestions for further improvements, please don't hesitate to let me know. Thank you for your attention to this matter. Best regards, Louis Brulé Naudet --- whisper/utils.py | 378 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 369 insertions(+), 9 deletions(-) diff --git a/whisper/utils.py b/whisper/utils.py index 9b9b138..2ea4146 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -21,12 +21,51 @@ else: return string -def exact_div(x, y): +def exact_div(x:int, y:int): + """ + Performs exact division of x by y. + + Parameters + ---------- + x : int + The dividend. + + y : int + The divisor. + + Returns + ------- + quotient : int + The result of the exact division. + + Raises + ------ + AssertionError + If x is not exactly divisible by y. + """ assert x % y == 0 return x // y -def str2bool(string): +def str2bool(string:str) -> bool: + """ + Converts a string representation of a boolean to its boolean equivalent. + + Parameters + ---------- + string : str + The string representation of the boolean. + + Returns + ------- + bool + The boolean value represented by the input string. + + Raises + ------ + ValueError + If the input string does not represent a boolean value. + """ str2val = {"True": True, "False": False} if string in str2val: return str2val[string] @@ -34,15 +73,54 @@ def str2bool(string): raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") -def optional_int(string): +def optional_int(string:str) -> int: + """ + Converts a string to an integer or returns None if the string is "None". + + Parameters + ---------- + string : str + The string to convert. + + Returns + ------- + int or None + The integer value of the string, or None if the string is "None". + """ return None if string == "None" else int(string) -def optional_float(string): +def optional_float(string:str) -> float: + """ + Converts a string to a float or returns None if the string is "None". + + Parameters + ---------- + string : str + The string to convert. + + Returns + ------- + float or None + The float value of the string, or None if the string is "None". + """ return None if string == "None" else float(string) -def compression_ratio(text) -> float: +def compression_ratio(text:str) -> float: + """ + Calculates the compression ratio of a text using zlib compression. + + Parameters + ---------- + text : str + The text to compress. + + Returns + ------- + float + The compression ratio of the text. + """ text_bytes = text.encode("utf-8") return len(text_bytes) / len(zlib.compress(text_bytes)) @@ -50,6 +128,25 @@ def compression_ratio(text) -> float: def format_timestamp( seconds: float, always_include_hours: bool = False, decimal_marker: str = "." ): + """ + Formats a timestamp in seconds into a human-readable string. + + Parameters + ---------- + seconds : float + The timestamp in seconds. + + always_include_hours : bool, optional + Whether to always include hours in the formatted timestamp. Default is False. + + decimal_marker : str, optional + The decimal marker to use. Default is ".". + + Returns + ------- + str + The formatted timestamp string. + """ assert seconds >= 0, "non-negative timestamp expected" milliseconds = round(seconds * 1000.0) @@ -69,6 +166,19 @@ def format_timestamp( def get_start(segments: List[dict]) -> Optional[float]: + """ + Get the start time from a list of segments. + + Parameters + ---------- + segments : List[dict] + A list of segments, each containing a "start" field. + + Returns + ------- + Optional[float] + The start time, or None if no segments are provided. + """ return next( (w["start"] for s in segments for w in s["words"]), segments[0]["start"] if segments else None, @@ -76,6 +186,19 @@ def get_start(segments: List[dict]) -> Optional[float]: def get_end(segments: List[dict]) -> Optional[float]: + """ + Get the end time from a list of segments. + + Parameters + ---------- + segments : List[dict] + A list of segments, each containing a "end" field. + + Returns + ------- + Optional[float] + The end time, or None if no segments are provided. + """ return next( (w["end"] for s in reversed(segments) for w in reversed(s["words"])), segments[-1]["end"] if segments else None, @@ -83,6 +206,14 @@ def get_end(segments: List[dict]) -> Optional[float]: class ResultWriter: + """ + Base class for result writers. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + """ extension: str def __init__(self, output_dir: str): @@ -91,6 +222,24 @@ class ResultWriter: def __call__( self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs ): + """ + Writes the result to a file. + + Parameters + ---------- + result : dict + The result to write. + + audio_path : str + The path to the audio file associated with the result. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ audio_basename = os.path.basename(audio_path) audio_basename = os.path.splitext(audio_basename)[0] output_path = os.path.join( @@ -103,20 +252,75 @@ class ResultWriter: def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs ): + """ + Writes the result to a file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ raise NotImplementedError class WriteTXT(ResultWriter): + """ + Result writer for writing text results to a .txt file. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + """ extension: str = "txt" def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs ): + """ + Writes the result to a .txt file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ for segment in result["segments"]: print(segment["text"].strip(), file=file, flush=True) class SubtitlesWriter(ResultWriter): + """ + Base class for subtitle writers. + + Attributes + ---------- + always_include_hours : bool + Whether to always include hours in the formatted timestamps. + + decimal_marker : str + The decimal marker to use in formatted timestamps. + """ always_include_hours: bool decimal_marker: str @@ -130,6 +334,29 @@ class SubtitlesWriter(ResultWriter): highlight_words: bool = False, max_words_per_line: Optional[int] = None, ): + """ + Iterates over the result to generate subtitles. + + Parameters + ---------- + result : dict + The result to iterate over. + + options : dict, optional + Additional options for iterating the result. Default is None. + + max_line_width : int, optional + The maximum width of each line. Default is None. + + max_line_count : int, optional + The maximum number of lines. Default is None. + + highlight_words : bool, optional + Whether to highlight individual words in the subtitles. Default is False. + + max_words_per_line : int, optional + The maximum number of words per line. Default is None. + """ options = options or {} max_line_width = max_line_width or options.get("max_line_width") max_line_count = max_line_count or options.get("max_line_count") @@ -226,6 +453,19 @@ class SubtitlesWriter(ResultWriter): yield segment_start, segment_end, segment_text def format_timestamp(self, seconds: float): + """ + Formats a timestamp in seconds into a human-readable string. + + Parameters + ---------- + seconds : float + The timestamp in seconds. + + Returns + ------- + str + The formatted timestamp string. + """ return format_timestamp( seconds=seconds, always_include_hours=self.always_include_hours, @@ -234,26 +474,90 @@ class SubtitlesWriter(ResultWriter): class WriteVTT(SubtitlesWriter): + """ + Result writer for writing subtitles to a .vtt file. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + + always_include_hours : bool + Whether to always include hours in the formatted timestamps. + + decimal_marker : str + The decimal marker to use in formatted timestamps. + """ extension: str = "vtt" always_include_hours: bool = False decimal_marker: str = "." def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs - ): + ) -> None: + """ + Writes the result to a .vtt file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ print("WEBVTT\n", file=file) for start, end, text in self.iterate_result(result, options, **kwargs): print(f"{start} --> {end}\n{text}\n", file=file, flush=True) class WriteSRT(SubtitlesWriter): + """ + Result writer for writing subtitles to a .srt file. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + + always_include_hours : bool + Whether to always include hours in the formatted timestamps. + + decimal_marker : str + The decimal marker to use in formatted timestamps. + """ extension: str = "srt" always_include_hours: bool = True decimal_marker: str = "," def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs - ): + ) -> None: + """ + Writes the result to a .srt file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ for i, (start, end, text) in enumerate( self.iterate_result(result, options, **kwargs), start=1 ): @@ -274,7 +578,25 @@ class WriteTSV(ResultWriter): def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs - ): + ) -> None: + """ + Writes the result to a .tsv file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + + Returns + ------- + None + """ print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: print(round(1000 * segment["start"]), file=file, end="\t") @@ -283,17 +605,55 @@ class WriteTSV(ResultWriter): class WriteJSON(ResultWriter): + """ + Result writer for writing data to a .json file. + + Attributes + ---------- + extension : str + The file extension associated with the writer. + """ extension: str = "json" def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs - ): + ) -> None: + """ + Writes the result to a .json file. + + Parameters + ---------- + result : dict + The result to write. + + file : TextIO + The file object to write to. + + options : dict, optional + Additional options for writing the result. Default is None. + """ json.dump(result, file) def get_writer( output_format: str, output_dir: str ) -> Callable[[dict, TextIO, dict], None]: + """ + Returns a result writer based on the specified output format. + + Parameters + ---------- + output_format : str + The desired output format for the writer. + + output_dir : str + The directory where the output files will be saved. + + Returns + ------- + Callable[[dict, TextIO, dict], None] + A function that can be used to write results to files. + """ writers = { "txt": WriteTXT, "vtt": WriteVTT,