Update utils.py

Dear Developers,

I'm pleased to inform you that I have completed the documentation update the utils.py file.

The updated documentation provides clear explanations of function parameters, return types, and expected behavior. Additionally, it adheres to consistent formatting and organization, ensuring ease of understanding for both current and future developers.

Please review the updated documentation at your earliest convenience. If you have any feedback or suggestions for further improvements, please don't hesitate to let me know.

Thank you for your attention to this matter.

Best regards,
Louis Brulé Naudet
This commit is contained in:
Louis Brulé Naudet 2024-02-19 20:12:26 +01:00
parent ba3f3cd54b
commit 492c05c5f3

View File

@ -21,12 +21,51 @@ else:
return string return string
def exact_div(x, y): def exact_div(x:int, y:int):
"""
Performs exact division of x by y.
Parameters
----------
x : int
The dividend.
y : int
The divisor.
Returns
-------
quotient : int
The result of the exact division.
Raises
------
AssertionError
If x is not exactly divisible by y.
"""
assert x % y == 0 assert x % y == 0
return x // y return x // y
def str2bool(string): def str2bool(string:str) -> bool:
"""
Converts a string representation of a boolean to its boolean equivalent.
Parameters
----------
string : str
The string representation of the boolean.
Returns
-------
bool
The boolean value represented by the input string.
Raises
------
ValueError
If the input string does not represent a boolean value.
"""
str2val = {"True": True, "False": False} str2val = {"True": True, "False": False}
if string in str2val: if string in str2val:
return str2val[string] return str2val[string]
@ -34,15 +73,54 @@ def str2bool(string):
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string): def optional_int(string:str) -> int:
"""
Converts a string to an integer or returns None if the string is "None".
Parameters
----------
string : str
The string to convert.
Returns
-------
int or None
The integer value of the string, or None if the string is "None".
"""
return None if string == "None" else int(string) return None if string == "None" else int(string)
def optional_float(string): def optional_float(string:str) -> float:
"""
Converts a string to a float or returns None if the string is "None".
Parameters
----------
string : str
The string to convert.
Returns
-------
float or None
The float value of the string, or None if the string is "None".
"""
return None if string == "None" else float(string) return None if string == "None" else float(string)
def compression_ratio(text) -> float: def compression_ratio(text:str) -> float:
"""
Calculates the compression ratio of a text using zlib compression.
Parameters
----------
text : str
The text to compress.
Returns
-------
float
The compression ratio of the text.
"""
text_bytes = text.encode("utf-8") text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes)) return len(text_bytes) / len(zlib.compress(text_bytes))
@ -50,6 +128,25 @@ def compression_ratio(text) -> float:
def format_timestamp( def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "." seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
): ):
"""
Formats a timestamp in seconds into a human-readable string.
Parameters
----------
seconds : float
The timestamp in seconds.
always_include_hours : bool, optional
Whether to always include hours in the formatted timestamp. Default is False.
decimal_marker : str, optional
The decimal marker to use. Default is ".".
Returns
-------
str
The formatted timestamp string.
"""
assert seconds >= 0, "non-negative timestamp expected" assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0) milliseconds = round(seconds * 1000.0)
@ -69,6 +166,19 @@ def format_timestamp(
def get_start(segments: List[dict]) -> Optional[float]: def get_start(segments: List[dict]) -> Optional[float]:
"""
Get the start time from a list of segments.
Parameters
----------
segments : List[dict]
A list of segments, each containing a "start" field.
Returns
-------
Optional[float]
The start time, or None if no segments are provided.
"""
return next( return next(
(w["start"] for s in segments for w in s["words"]), (w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None, segments[0]["start"] if segments else None,
@ -76,6 +186,19 @@ def get_start(segments: List[dict]) -> Optional[float]:
def get_end(segments: List[dict]) -> Optional[float]: def get_end(segments: List[dict]) -> Optional[float]:
"""
Get the end time from a list of segments.
Parameters
----------
segments : List[dict]
A list of segments, each containing a "end" field.
Returns
-------
Optional[float]
The end time, or None if no segments are provided.
"""
return next( return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])), (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None, segments[-1]["end"] if segments else None,
@ -83,6 +206,14 @@ def get_end(segments: List[dict]) -> Optional[float]:
class ResultWriter: class ResultWriter:
"""
Base class for result writers.
Attributes
----------
extension : str
The file extension associated with the writer.
"""
extension: str extension: str
def __init__(self, output_dir: str): def __init__(self, output_dir: str):
@ -91,6 +222,24 @@ class ResultWriter:
def __call__( def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
): ):
"""
Writes the result to a file.
Parameters
----------
result : dict
The result to write.
audio_path : str
The path to the audio file associated with the result.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
audio_basename = os.path.basename(audio_path) audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0] audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join( output_path = os.path.join(
@ -103,20 +252,75 @@ class ResultWriter:
def write_result( def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
): ):
"""
Writes the result to a file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
raise NotImplementedError raise NotImplementedError
class WriteTXT(ResultWriter): class WriteTXT(ResultWriter):
"""
Result writer for writing text results to a .txt file.
Attributes
----------
extension : str
The file extension associated with the writer.
"""
extension: str = "txt" extension: str = "txt"
def write_result( def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
): ):
"""
Writes the result to a .txt file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
for segment in result["segments"]: for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True) print(segment["text"].strip(), file=file, flush=True)
class SubtitlesWriter(ResultWriter): class SubtitlesWriter(ResultWriter):
"""
Base class for subtitle writers.
Attributes
----------
always_include_hours : bool
Whether to always include hours in the formatted timestamps.
decimal_marker : str
The decimal marker to use in formatted timestamps.
"""
always_include_hours: bool always_include_hours: bool
decimal_marker: str decimal_marker: str
@ -130,6 +334,29 @@ class SubtitlesWriter(ResultWriter):
highlight_words: bool = False, highlight_words: bool = False,
max_words_per_line: Optional[int] = None, max_words_per_line: Optional[int] = None,
): ):
"""
Iterates over the result to generate subtitles.
Parameters
----------
result : dict
The result to iterate over.
options : dict, optional
Additional options for iterating the result. Default is None.
max_line_width : int, optional
The maximum width of each line. Default is None.
max_line_count : int, optional
The maximum number of lines. Default is None.
highlight_words : bool, optional
Whether to highlight individual words in the subtitles. Default is False.
max_words_per_line : int, optional
The maximum number of words per line. Default is None.
"""
options = options or {} options = options or {}
max_line_width = max_line_width or options.get("max_line_width") max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count") max_line_count = max_line_count or options.get("max_line_count")
@ -226,6 +453,19 @@ class SubtitlesWriter(ResultWriter):
yield segment_start, segment_end, segment_text yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float): def format_timestamp(self, seconds: float):
"""
Formats a timestamp in seconds into a human-readable string.
Parameters
----------
seconds : float
The timestamp in seconds.
Returns
-------
str
The formatted timestamp string.
"""
return format_timestamp( return format_timestamp(
seconds=seconds, seconds=seconds,
always_include_hours=self.always_include_hours, always_include_hours=self.always_include_hours,
@ -234,26 +474,90 @@ class SubtitlesWriter(ResultWriter):
class WriteVTT(SubtitlesWriter): class WriteVTT(SubtitlesWriter):
"""
Result writer for writing subtitles to a .vtt file.
Attributes
----------
extension : str
The file extension associated with the writer.
always_include_hours : bool
Whether to always include hours in the formatted timestamps.
decimal_marker : str
The decimal marker to use in formatted timestamps.
"""
extension: str = "vtt" extension: str = "vtt"
always_include_hours: bool = False always_include_hours: bool = False
decimal_marker: str = "." decimal_marker: str = "."
def write_result( def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
): ) -> None:
"""
Writes the result to a .vtt file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options, **kwargs): for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True) print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter): class WriteSRT(SubtitlesWriter):
"""
Result writer for writing subtitles to a .srt file.
Attributes
----------
extension : str
The file extension associated with the writer.
always_include_hours : bool
Whether to always include hours in the formatted timestamps.
decimal_marker : str
The decimal marker to use in formatted timestamps.
"""
extension: str = "srt" extension: str = "srt"
always_include_hours: bool = True always_include_hours: bool = True
decimal_marker: str = "," decimal_marker: str = ","
def write_result( def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
): ) -> None:
"""
Writes the result to a .srt file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
for i, (start, end, text) in enumerate( for i, (start, end, text) in enumerate(
self.iterate_result(result, options, **kwargs), start=1 self.iterate_result(result, options, **kwargs), start=1
): ):
@ -274,7 +578,25 @@ class WriteTSV(ResultWriter):
def write_result( def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
): ) -> None:
"""
Writes the result to a .tsv file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
print("start", "end", "text", sep="\t", file=file) print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]: for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t") print(round(1000 * segment["start"]), file=file, end="\t")
@ -283,17 +605,55 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter): class WriteJSON(ResultWriter):
"""
Result writer for writing data to a .json file.
Attributes
----------
extension : str
The file extension associated with the writer.
"""
extension: str = "json" extension: str = "json"
def write_result( def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
): ) -> None:
"""
Writes the result to a .json file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
"""
json.dump(result, file) json.dump(result, file)
def get_writer( def get_writer(
output_format: str, output_dir: str output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]: ) -> Callable[[dict, TextIO, dict], None]:
"""
Returns a result writer based on the specified output format.
Parameters
----------
output_format : str
The desired output format for the writer.
output_dir : str
The directory where the output files will be saved.
Returns
-------
Callable[[dict, TextIO, dict], None]
A function that can be used to write results to files.
"""
writers = { writers = {
"txt": WriteTXT, "txt": WriteTXT,
"vtt": WriteVTT, "vtt": WriteVTT,