mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Merge 67f5b1d2317f1b7223b0d39c1b7fd2f3ed0c6272 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
e3a432ca64
378
whisper/utils.py
378
whisper/utils.py
@ -21,12 +21,51 @@ else:
|
||||
return string
|
||||
|
||||
|
||||
def exact_div(x, y):
|
||||
def exact_div(x:int, y:int):
|
||||
"""
|
||||
Performs exact division of x by y.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : int
|
||||
The dividend.
|
||||
|
||||
y : int
|
||||
The divisor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
quotient : int
|
||||
The result of the exact division.
|
||||
|
||||
Raises
|
||||
------
|
||||
AssertionError
|
||||
If x is not exactly divisible by y.
|
||||
"""
|
||||
assert x % y == 0
|
||||
return x // y
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
def str2bool(string:str) -> bool:
|
||||
"""
|
||||
Converts a string representation of a boolean to its boolean equivalent.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
string : str
|
||||
The string representation of the boolean.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
The boolean value represented by the input string.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input string does not represent a boolean value.
|
||||
"""
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
return str2val[string]
|
||||
@ -34,15 +73,54 @@ def str2bool(string):
|
||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
|
||||
|
||||
def optional_int(string):
|
||||
def optional_int(string:str) -> int:
|
||||
"""
|
||||
Converts a string to an integer or returns None if the string is "None".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
string : str
|
||||
The string to convert.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int or None
|
||||
The integer value of the string, or None if the string is "None".
|
||||
"""
|
||||
return None if string == "None" else int(string)
|
||||
|
||||
|
||||
def optional_float(string):
|
||||
def optional_float(string:str) -> float:
|
||||
"""
|
||||
Converts a string to a float or returns None if the string is "None".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
string : str
|
||||
The string to convert.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float or None
|
||||
The float value of the string, or None if the string is "None".
|
||||
"""
|
||||
return None if string == "None" else float(string)
|
||||
|
||||
|
||||
def compression_ratio(text) -> float:
|
||||
def compression_ratio(text:str) -> float:
|
||||
"""
|
||||
Calculates the compression ratio of a text using zlib compression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : str
|
||||
The text to compress.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The compression ratio of the text.
|
||||
"""
|
||||
text_bytes = text.encode("utf-8")
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
@ -50,6 +128,25 @@ def compression_ratio(text) -> float:
|
||||
def format_timestamp(
|
||||
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
||||
):
|
||||
"""
|
||||
Formats a timestamp in seconds into a human-readable string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seconds : float
|
||||
The timestamp in seconds.
|
||||
|
||||
always_include_hours : bool, optional
|
||||
Whether to always include hours in the formatted timestamp. Default is False.
|
||||
|
||||
decimal_marker : str, optional
|
||||
The decimal marker to use. Default is ".".
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The formatted timestamp string.
|
||||
"""
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
@ -69,6 +166,19 @@ def format_timestamp(
|
||||
|
||||
|
||||
def get_start(segments: List[dict]) -> Optional[float]:
|
||||
"""
|
||||
Get the start time from a list of segments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segments : List[dict]
|
||||
A list of segments, each containing a "start" field.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[float]
|
||||
The start time, or None if no segments are provided.
|
||||
"""
|
||||
return next(
|
||||
(w["start"] for s in segments for w in s["words"]),
|
||||
segments[0]["start"] if segments else None,
|
||||
@ -76,6 +186,19 @@ def get_start(segments: List[dict]) -> Optional[float]:
|
||||
|
||||
|
||||
def get_end(segments: List[dict]) -> Optional[float]:
|
||||
"""
|
||||
Get the end time from a list of segments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segments : List[dict]
|
||||
A list of segments, each containing a "end" field.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[float]
|
||||
The end time, or None if no segments are provided.
|
||||
"""
|
||||
return next(
|
||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||
segments[-1]["end"] if segments else None,
|
||||
@ -83,6 +206,14 @@ def get_end(segments: List[dict]) -> Optional[float]:
|
||||
|
||||
|
||||
class ResultWriter:
|
||||
"""
|
||||
Base class for result writers.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
extension : str
|
||||
The file extension associated with the writer.
|
||||
"""
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
@ -91,6 +222,24 @@ class ResultWriter:
|
||||
def __call__(
|
||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Writes the result to a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
The result to write.
|
||||
|
||||
audio_path : str
|
||||
The path to the audio file associated with the result.
|
||||
|
||||
options : dict, optional
|
||||
Additional options for writing the result. Default is None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
@ -103,20 +252,75 @@ class ResultWriter:
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Writes the result to a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
The result to write.
|
||||
|
||||
file : TextIO
|
||||
The file object to write to.
|
||||
|
||||
options : dict, optional
|
||||
Additional options for writing the result. Default is None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WriteTXT(ResultWriter):
|
||||
"""
|
||||
Result writer for writing text results to a .txt file.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
extension : str
|
||||
The file extension associated with the writer.
|
||||
"""
|
||||
extension: str = "txt"
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Writes the result to a .txt file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
The result to write.
|
||||
|
||||
file : TextIO
|
||||
The file object to write to.
|
||||
|
||||
options : dict, optional
|
||||
Additional options for writing the result. Default is None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
"""
|
||||
Base class for subtitle writers.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
always_include_hours : bool
|
||||
Whether to always include hours in the formatted timestamps.
|
||||
|
||||
decimal_marker : str
|
||||
The decimal marker to use in formatted timestamps.
|
||||
"""
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
@ -130,6 +334,29 @@ class SubtitlesWriter(ResultWriter):
|
||||
highlight_words: bool = False,
|
||||
max_words_per_line: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Iterates over the result to generate subtitles.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
The result to iterate over.
|
||||
|
||||
options : dict, optional
|
||||
Additional options for iterating the result. Default is None.
|
||||
|
||||
max_line_width : int, optional
|
||||
The maximum width of each line. Default is None.
|
||||
|
||||
max_line_count : int, optional
|
||||
The maximum number of lines. Default is None.
|
||||
|
||||
highlight_words : bool, optional
|
||||
Whether to highlight individual words in the subtitles. Default is False.
|
||||
|
||||
max_words_per_line : int, optional
|
||||
The maximum number of words per line. Default is None.
|
||||
"""
|
||||
options = options or {}
|
||||
max_line_width = max_line_width or options.get("max_line_width")
|
||||
max_line_count = max_line_count or options.get("max_line_count")
|
||||
@ -228,6 +455,19 @@ class SubtitlesWriter(ResultWriter):
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
"""
|
||||
Formats a timestamp in seconds into a human-readable string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seconds : float
|
||||
The timestamp in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The formatted timestamp string.
|
||||
"""
|
||||
return format_timestamp(
|
||||
seconds=seconds,
|
||||
always_include_hours=self.always_include_hours,
|
||||
@ -236,26 +476,90 @@ class SubtitlesWriter(ResultWriter):
|
||||
|
||||
|
||||
class WriteVTT(SubtitlesWriter):
|
||||
"""
|
||||
Result writer for writing subtitles to a .vtt file.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
extension : str
|
||||
The file extension associated with the writer.
|
||||
|
||||
always_include_hours : bool
|
||||
Whether to always include hours in the formatted timestamps.
|
||||
|
||||
decimal_marker : str
|
||||
The decimal marker to use in formatted timestamps.
|
||||
"""
|
||||
extension: str = "vtt"
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Writes the result to a .vtt file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
The result to write.
|
||||
|
||||
file : TextIO
|
||||
The file object to write to.
|
||||
|
||||
options : dict, optional
|
||||
Additional options for writing the result. Default is None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
print("WEBVTT\n", file=file)
|
||||
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteSRT(SubtitlesWriter):
|
||||
"""
|
||||
Result writer for writing subtitles to a .srt file.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
extension : str
|
||||
The file extension associated with the writer.
|
||||
|
||||
always_include_hours : bool
|
||||
Whether to always include hours in the formatted timestamps.
|
||||
|
||||
decimal_marker : str
|
||||
The decimal marker to use in formatted timestamps.
|
||||
"""
|
||||
extension: str = "srt"
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Writes the result to a .srt file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
The result to write.
|
||||
|
||||
file : TextIO
|
||||
The file object to write to.
|
||||
|
||||
options : dict, optional
|
||||
Additional options for writing the result. Default is None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
for i, (start, end, text) in enumerate(
|
||||
self.iterate_result(result, options, **kwargs), start=1
|
||||
):
|
||||
@ -276,7 +580,25 @@ class WriteTSV(ResultWriter):
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Writes the result to a .tsv file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
The result to write.
|
||||
|
||||
file : TextIO
|
||||
The file object to write to.
|
||||
|
||||
options : dict, optional
|
||||
Additional options for writing the result. Default is None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
@ -285,17 +607,55 @@ class WriteTSV(ResultWriter):
|
||||
|
||||
|
||||
class WriteJSON(ResultWriter):
|
||||
"""
|
||||
Result writer for writing data to a .json file.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
extension : str
|
||||
The file extension associated with the writer.
|
||||
"""
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Writes the result to a .json file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
The result to write.
|
||||
|
||||
file : TextIO
|
||||
The file object to write to.
|
||||
|
||||
options : dict, optional
|
||||
Additional options for writing the result. Default is None.
|
||||
"""
|
||||
json.dump(result, file)
|
||||
|
||||
|
||||
def get_writer(
|
||||
output_format: str, output_dir: str
|
||||
) -> Callable[[dict, TextIO, dict], None]:
|
||||
"""
|
||||
Returns a result writer based on the specified output format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_format : str
|
||||
The desired output format for the writer.
|
||||
|
||||
output_dir : str
|
||||
The directory where the output files will be saved.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable[[dict, TextIO, dict], None]
|
||||
A function that can be used to write results to files.
|
||||
"""
|
||||
writers = {
|
||||
"txt": WriteTXT,
|
||||
"vtt": WriteVTT,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user