Merge 67f5b1d2317f1b7223b0d39c1b7fd2f3ed0c6272 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
Louis Brulé Naudet 2025-06-27 02:27:48 +00:00 committed by GitHub
commit e3a432ca64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,12 +21,51 @@ else:
return string
def exact_div(x, y):
def exact_div(x:int, y:int):
"""
Performs exact division of x by y.
Parameters
----------
x : int
The dividend.
y : int
The divisor.
Returns
-------
quotient : int
The result of the exact division.
Raises
------
AssertionError
If x is not exactly divisible by y.
"""
assert x % y == 0
return x // y
def str2bool(string):
def str2bool(string:str) -> bool:
"""
Converts a string representation of a boolean to its boolean equivalent.
Parameters
----------
string : str
The string representation of the boolean.
Returns
-------
bool
The boolean value represented by the input string.
Raises
------
ValueError
If the input string does not represent a boolean value.
"""
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
@ -34,15 +73,54 @@ def str2bool(string):
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
def optional_int(string:str) -> int:
"""
Converts a string to an integer or returns None if the string is "None".
Parameters
----------
string : str
The string to convert.
Returns
-------
int or None
The integer value of the string, or None if the string is "None".
"""
return None if string == "None" else int(string)
def optional_float(string):
def optional_float(string:str) -> float:
"""
Converts a string to a float or returns None if the string is "None".
Parameters
----------
string : str
The string to convert.
Returns
-------
float or None
The float value of the string, or None if the string is "None".
"""
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
def compression_ratio(text:str) -> float:
"""
Calculates the compression ratio of a text using zlib compression.
Parameters
----------
text : str
The text to compress.
Returns
-------
float
The compression ratio of the text.
"""
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
@ -50,6 +128,25 @@ def compression_ratio(text) -> float:
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
"""
Formats a timestamp in seconds into a human-readable string.
Parameters
----------
seconds : float
The timestamp in seconds.
always_include_hours : bool, optional
Whether to always include hours in the formatted timestamp. Default is False.
decimal_marker : str, optional
The decimal marker to use. Default is ".".
Returns
-------
str
The formatted timestamp string.
"""
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
@ -69,6 +166,19 @@ def format_timestamp(
def get_start(segments: List[dict]) -> Optional[float]:
"""
Get the start time from a list of segments.
Parameters
----------
segments : List[dict]
A list of segments, each containing a "start" field.
Returns
-------
Optional[float]
The start time, or None if no segments are provided.
"""
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
@ -76,6 +186,19 @@ def get_start(segments: List[dict]) -> Optional[float]:
def get_end(segments: List[dict]) -> Optional[float]:
"""
Get the end time from a list of segments.
Parameters
----------
segments : List[dict]
A list of segments, each containing a "end" field.
Returns
-------
Optional[float]
The end time, or None if no segments are provided.
"""
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
@ -83,6 +206,14 @@ def get_end(segments: List[dict]) -> Optional[float]:
class ResultWriter:
"""
Base class for result writers.
Attributes
----------
extension : str
The file extension associated with the writer.
"""
extension: str
def __init__(self, output_dir: str):
@ -91,6 +222,24 @@ class ResultWriter:
def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
):
"""
Writes the result to a file.
Parameters
----------
result : dict
The result to write.
audio_path : str
The path to the audio file associated with the result.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
@ -103,20 +252,75 @@ class ResultWriter:
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
"""
Writes the result to a file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
raise NotImplementedError
class WriteTXT(ResultWriter):
"""
Result writer for writing text results to a .txt file.
Attributes
----------
extension : str
The file extension associated with the writer.
"""
extension: str = "txt"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
"""
Writes the result to a .txt file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
class SubtitlesWriter(ResultWriter):
"""
Base class for subtitle writers.
Attributes
----------
always_include_hours : bool
Whether to always include hours in the formatted timestamps.
decimal_marker : str
The decimal marker to use in formatted timestamps.
"""
always_include_hours: bool
decimal_marker: str
@ -130,6 +334,29 @@ class SubtitlesWriter(ResultWriter):
highlight_words: bool = False,
max_words_per_line: Optional[int] = None,
):
"""
Iterates over the result to generate subtitles.
Parameters
----------
result : dict
The result to iterate over.
options : dict, optional
Additional options for iterating the result. Default is None.
max_line_width : int, optional
The maximum width of each line. Default is None.
max_line_count : int, optional
The maximum number of lines. Default is None.
highlight_words : bool, optional
Whether to highlight individual words in the subtitles. Default is False.
max_words_per_line : int, optional
The maximum number of words per line. Default is None.
"""
options = options or {}
max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count")
@ -228,6 +455,19 @@ class SubtitlesWriter(ResultWriter):
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
"""
Formats a timestamp in seconds into a human-readable string.
Parameters
----------
seconds : float
The timestamp in seconds.
Returns
-------
str
The formatted timestamp string.
"""
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
@ -236,26 +476,90 @@ class SubtitlesWriter(ResultWriter):
class WriteVTT(SubtitlesWriter):
"""
Result writer for writing subtitles to a .vtt file.
Attributes
----------
extension : str
The file extension associated with the writer.
always_include_hours : bool
Whether to always include hours in the formatted timestamps.
decimal_marker : str
The decimal marker to use in formatted timestamps.
"""
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
) -> None:
"""
Writes the result to a .vtt file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter):
"""
Result writer for writing subtitles to a .srt file.
Attributes
----------
extension : str
The file extension associated with the writer.
always_include_hours : bool
Whether to always include hours in the formatted timestamps.
decimal_marker : str
The decimal marker to use in formatted timestamps.
"""
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
) -> None:
"""
Writes the result to a .srt file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
for i, (start, end, text) in enumerate(
self.iterate_result(result, options, **kwargs), start=1
):
@ -276,7 +580,25 @@ class WriteTSV(ResultWriter):
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
) -> None:
"""
Writes the result to a .tsv file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
Returns
-------
None
"""
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
@ -285,17 +607,55 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter):
"""
Result writer for writing data to a .json file.
Attributes
----------
extension : str
The file extension associated with the writer.
"""
extension: str = "json"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
) -> None:
"""
Writes the result to a .json file.
Parameters
----------
result : dict
The result to write.
file : TextIO
The file object to write to.
options : dict, optional
Additional options for writing the result. Default is None.
"""
json.dump(result, file)
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
"""
Returns a result writer based on the specified output format.
Parameters
----------
output_format : str
The desired output format for the writer.
output_dir : str
The directory where the output files will be saved.
Returns
-------
Callable[[dict, TextIO, dict], None]
A function that can be used to write results to files.
"""
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,