mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Merge 67f5b1d2317f1b7223b0d39c1b7fd2f3ed0c6272 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
e3a432ca64
378
whisper/utils.py
378
whisper/utils.py
@ -21,12 +21,51 @@ else:
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def exact_div(x, y):
|
def exact_div(x:int, y:int):
|
||||||
|
"""
|
||||||
|
Performs exact division of x by y.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : int
|
||||||
|
The dividend.
|
||||||
|
|
||||||
|
y : int
|
||||||
|
The divisor.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
quotient : int
|
||||||
|
The result of the exact division.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
AssertionError
|
||||||
|
If x is not exactly divisible by y.
|
||||||
|
"""
|
||||||
assert x % y == 0
|
assert x % y == 0
|
||||||
return x // y
|
return x // y
|
||||||
|
|
||||||
|
|
||||||
def str2bool(string):
|
def str2bool(string:str) -> bool:
|
||||||
|
"""
|
||||||
|
Converts a string representation of a boolean to its boolean equivalent.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
string : str
|
||||||
|
The string representation of the boolean.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
The boolean value represented by the input string.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the input string does not represent a boolean value.
|
||||||
|
"""
|
||||||
str2val = {"True": True, "False": False}
|
str2val = {"True": True, "False": False}
|
||||||
if string in str2val:
|
if string in str2val:
|
||||||
return str2val[string]
|
return str2val[string]
|
||||||
@ -34,15 +73,54 @@ def str2bool(string):
|
|||||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||||
|
|
||||||
|
|
||||||
def optional_int(string):
|
def optional_int(string:str) -> int:
|
||||||
|
"""
|
||||||
|
Converts a string to an integer or returns None if the string is "None".
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
string : str
|
||||||
|
The string to convert.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int or None
|
||||||
|
The integer value of the string, or None if the string is "None".
|
||||||
|
"""
|
||||||
return None if string == "None" else int(string)
|
return None if string == "None" else int(string)
|
||||||
|
|
||||||
|
|
||||||
def optional_float(string):
|
def optional_float(string:str) -> float:
|
||||||
|
"""
|
||||||
|
Converts a string to a float or returns None if the string is "None".
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
string : str
|
||||||
|
The string to convert.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float or None
|
||||||
|
The float value of the string, or None if the string is "None".
|
||||||
|
"""
|
||||||
return None if string == "None" else float(string)
|
return None if string == "None" else float(string)
|
||||||
|
|
||||||
|
|
||||||
def compression_ratio(text) -> float:
|
def compression_ratio(text:str) -> float:
|
||||||
|
"""
|
||||||
|
Calculates the compression ratio of a text using zlib compression.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
text : str
|
||||||
|
The text to compress.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
The compression ratio of the text.
|
||||||
|
"""
|
||||||
text_bytes = text.encode("utf-8")
|
text_bytes = text.encode("utf-8")
|
||||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||||
|
|
||||||
@ -50,6 +128,25 @@ def compression_ratio(text) -> float:
|
|||||||
def format_timestamp(
|
def format_timestamp(
|
||||||
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Formats a timestamp in seconds into a human-readable string.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
seconds : float
|
||||||
|
The timestamp in seconds.
|
||||||
|
|
||||||
|
always_include_hours : bool, optional
|
||||||
|
Whether to always include hours in the formatted timestamp. Default is False.
|
||||||
|
|
||||||
|
decimal_marker : str, optional
|
||||||
|
The decimal marker to use. Default is ".".
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
The formatted timestamp string.
|
||||||
|
"""
|
||||||
assert seconds >= 0, "non-negative timestamp expected"
|
assert seconds >= 0, "non-negative timestamp expected"
|
||||||
milliseconds = round(seconds * 1000.0)
|
milliseconds = round(seconds * 1000.0)
|
||||||
|
|
||||||
@ -69,6 +166,19 @@ def format_timestamp(
|
|||||||
|
|
||||||
|
|
||||||
def get_start(segments: List[dict]) -> Optional[float]:
|
def get_start(segments: List[dict]) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Get the start time from a list of segments.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
segments : List[dict]
|
||||||
|
A list of segments, each containing a "start" field.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Optional[float]
|
||||||
|
The start time, or None if no segments are provided.
|
||||||
|
"""
|
||||||
return next(
|
return next(
|
||||||
(w["start"] for s in segments for w in s["words"]),
|
(w["start"] for s in segments for w in s["words"]),
|
||||||
segments[0]["start"] if segments else None,
|
segments[0]["start"] if segments else None,
|
||||||
@ -76,6 +186,19 @@ def get_start(segments: List[dict]) -> Optional[float]:
|
|||||||
|
|
||||||
|
|
||||||
def get_end(segments: List[dict]) -> Optional[float]:
|
def get_end(segments: List[dict]) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Get the end time from a list of segments.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
segments : List[dict]
|
||||||
|
A list of segments, each containing a "end" field.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Optional[float]
|
||||||
|
The end time, or None if no segments are provided.
|
||||||
|
"""
|
||||||
return next(
|
return next(
|
||||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||||
segments[-1]["end"] if segments else None,
|
segments[-1]["end"] if segments else None,
|
||||||
@ -83,6 +206,14 @@ def get_end(segments: List[dict]) -> Optional[float]:
|
|||||||
|
|
||||||
|
|
||||||
class ResultWriter:
|
class ResultWriter:
|
||||||
|
"""
|
||||||
|
Base class for result writers.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
extension : str
|
||||||
|
The file extension associated with the writer.
|
||||||
|
"""
|
||||||
extension: str
|
extension: str
|
||||||
|
|
||||||
def __init__(self, output_dir: str):
|
def __init__(self, output_dir: str):
|
||||||
@ -91,6 +222,24 @@ class ResultWriter:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Writes the result to a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict
|
||||||
|
The result to write.
|
||||||
|
|
||||||
|
audio_path : str
|
||||||
|
The path to the audio file associated with the result.
|
||||||
|
|
||||||
|
options : dict, optional
|
||||||
|
Additional options for writing the result. Default is None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
audio_basename = os.path.basename(audio_path)
|
audio_basename = os.path.basename(audio_path)
|
||||||
audio_basename = os.path.splitext(audio_basename)[0]
|
audio_basename = os.path.splitext(audio_basename)[0]
|
||||||
output_path = os.path.join(
|
output_path = os.path.join(
|
||||||
@ -103,20 +252,75 @@ class ResultWriter:
|
|||||||
def write_result(
|
def write_result(
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Writes the result to a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict
|
||||||
|
The result to write.
|
||||||
|
|
||||||
|
file : TextIO
|
||||||
|
The file object to write to.
|
||||||
|
|
||||||
|
options : dict, optional
|
||||||
|
Additional options for writing the result. Default is None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class WriteTXT(ResultWriter):
|
class WriteTXT(ResultWriter):
|
||||||
|
"""
|
||||||
|
Result writer for writing text results to a .txt file.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
extension : str
|
||||||
|
The file extension associated with the writer.
|
||||||
|
"""
|
||||||
extension: str = "txt"
|
extension: str = "txt"
|
||||||
|
|
||||||
def write_result(
|
def write_result(
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Writes the result to a .txt file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict
|
||||||
|
The result to write.
|
||||||
|
|
||||||
|
file : TextIO
|
||||||
|
The file object to write to.
|
||||||
|
|
||||||
|
options : dict, optional
|
||||||
|
Additional options for writing the result. Default is None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(segment["text"].strip(), file=file, flush=True)
|
print(segment["text"].strip(), file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
class SubtitlesWriter(ResultWriter):
|
class SubtitlesWriter(ResultWriter):
|
||||||
|
"""
|
||||||
|
Base class for subtitle writers.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
always_include_hours : bool
|
||||||
|
Whether to always include hours in the formatted timestamps.
|
||||||
|
|
||||||
|
decimal_marker : str
|
||||||
|
The decimal marker to use in formatted timestamps.
|
||||||
|
"""
|
||||||
always_include_hours: bool
|
always_include_hours: bool
|
||||||
decimal_marker: str
|
decimal_marker: str
|
||||||
|
|
||||||
@ -130,6 +334,29 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
highlight_words: bool = False,
|
highlight_words: bool = False,
|
||||||
max_words_per_line: Optional[int] = None,
|
max_words_per_line: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Iterates over the result to generate subtitles.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict
|
||||||
|
The result to iterate over.
|
||||||
|
|
||||||
|
options : dict, optional
|
||||||
|
Additional options for iterating the result. Default is None.
|
||||||
|
|
||||||
|
max_line_width : int, optional
|
||||||
|
The maximum width of each line. Default is None.
|
||||||
|
|
||||||
|
max_line_count : int, optional
|
||||||
|
The maximum number of lines. Default is None.
|
||||||
|
|
||||||
|
highlight_words : bool, optional
|
||||||
|
Whether to highlight individual words in the subtitles. Default is False.
|
||||||
|
|
||||||
|
max_words_per_line : int, optional
|
||||||
|
The maximum number of words per line. Default is None.
|
||||||
|
"""
|
||||||
options = options or {}
|
options = options or {}
|
||||||
max_line_width = max_line_width or options.get("max_line_width")
|
max_line_width = max_line_width or options.get("max_line_width")
|
||||||
max_line_count = max_line_count or options.get("max_line_count")
|
max_line_count = max_line_count or options.get("max_line_count")
|
||||||
@ -228,6 +455,19 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
yield segment_start, segment_end, segment_text
|
yield segment_start, segment_end, segment_text
|
||||||
|
|
||||||
def format_timestamp(self, seconds: float):
|
def format_timestamp(self, seconds: float):
|
||||||
|
"""
|
||||||
|
Formats a timestamp in seconds into a human-readable string.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
seconds : float
|
||||||
|
The timestamp in seconds.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
The formatted timestamp string.
|
||||||
|
"""
|
||||||
return format_timestamp(
|
return format_timestamp(
|
||||||
seconds=seconds,
|
seconds=seconds,
|
||||||
always_include_hours=self.always_include_hours,
|
always_include_hours=self.always_include_hours,
|
||||||
@ -236,26 +476,90 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
|
|
||||||
|
|
||||||
class WriteVTT(SubtitlesWriter):
|
class WriteVTT(SubtitlesWriter):
|
||||||
|
"""
|
||||||
|
Result writer for writing subtitles to a .vtt file.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
extension : str
|
||||||
|
The file extension associated with the writer.
|
||||||
|
|
||||||
|
always_include_hours : bool
|
||||||
|
Whether to always include hours in the formatted timestamps.
|
||||||
|
|
||||||
|
decimal_marker : str
|
||||||
|
The decimal marker to use in formatted timestamps.
|
||||||
|
"""
|
||||||
extension: str = "vtt"
|
extension: str = "vtt"
|
||||||
always_include_hours: bool = False
|
always_include_hours: bool = False
|
||||||
decimal_marker: str = "."
|
decimal_marker: str = "."
|
||||||
|
|
||||||
def write_result(
|
def write_result(
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
):
|
) -> None:
|
||||||
|
"""
|
||||||
|
Writes the result to a .vtt file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict
|
||||||
|
The result to write.
|
||||||
|
|
||||||
|
file : TextIO
|
||||||
|
The file object to write to.
|
||||||
|
|
||||||
|
options : dict, optional
|
||||||
|
Additional options for writing the result. Default is None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
print("WEBVTT\n", file=file)
|
print("WEBVTT\n", file=file)
|
||||||
for start, end, text in self.iterate_result(result, options, **kwargs):
|
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
class WriteSRT(SubtitlesWriter):
|
class WriteSRT(SubtitlesWriter):
|
||||||
|
"""
|
||||||
|
Result writer for writing subtitles to a .srt file.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
extension : str
|
||||||
|
The file extension associated with the writer.
|
||||||
|
|
||||||
|
always_include_hours : bool
|
||||||
|
Whether to always include hours in the formatted timestamps.
|
||||||
|
|
||||||
|
decimal_marker : str
|
||||||
|
The decimal marker to use in formatted timestamps.
|
||||||
|
"""
|
||||||
extension: str = "srt"
|
extension: str = "srt"
|
||||||
always_include_hours: bool = True
|
always_include_hours: bool = True
|
||||||
decimal_marker: str = ","
|
decimal_marker: str = ","
|
||||||
|
|
||||||
def write_result(
|
def write_result(
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
):
|
) -> None:
|
||||||
|
"""
|
||||||
|
Writes the result to a .srt file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict
|
||||||
|
The result to write.
|
||||||
|
|
||||||
|
file : TextIO
|
||||||
|
The file object to write to.
|
||||||
|
|
||||||
|
options : dict, optional
|
||||||
|
Additional options for writing the result. Default is None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
for i, (start, end, text) in enumerate(
|
for i, (start, end, text) in enumerate(
|
||||||
self.iterate_result(result, options, **kwargs), start=1
|
self.iterate_result(result, options, **kwargs), start=1
|
||||||
):
|
):
|
||||||
@ -276,7 +580,25 @@ class WriteTSV(ResultWriter):
|
|||||||
|
|
||||||
def write_result(
|
def write_result(
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
):
|
) -> None:
|
||||||
|
"""
|
||||||
|
Writes the result to a .tsv file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict
|
||||||
|
The result to write.
|
||||||
|
|
||||||
|
file : TextIO
|
||||||
|
The file object to write to.
|
||||||
|
|
||||||
|
options : dict, optional
|
||||||
|
Additional options for writing the result. Default is None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
print("start", "end", "text", sep="\t", file=file)
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||||
@ -285,17 +607,55 @@ class WriteTSV(ResultWriter):
|
|||||||
|
|
||||||
|
|
||||||
class WriteJSON(ResultWriter):
|
class WriteJSON(ResultWriter):
|
||||||
|
"""
|
||||||
|
Result writer for writing data to a .json file.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
extension : str
|
||||||
|
The file extension associated with the writer.
|
||||||
|
"""
|
||||||
extension: str = "json"
|
extension: str = "json"
|
||||||
|
|
||||||
def write_result(
|
def write_result(
|
||||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
):
|
) -> None:
|
||||||
|
"""
|
||||||
|
Writes the result to a .json file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict
|
||||||
|
The result to write.
|
||||||
|
|
||||||
|
file : TextIO
|
||||||
|
The file object to write to.
|
||||||
|
|
||||||
|
options : dict, optional
|
||||||
|
Additional options for writing the result. Default is None.
|
||||||
|
"""
|
||||||
json.dump(result, file)
|
json.dump(result, file)
|
||||||
|
|
||||||
|
|
||||||
def get_writer(
|
def get_writer(
|
||||||
output_format: str, output_dir: str
|
output_format: str, output_dir: str
|
||||||
) -> Callable[[dict, TextIO, dict], None]:
|
) -> Callable[[dict, TextIO, dict], None]:
|
||||||
|
"""
|
||||||
|
Returns a result writer based on the specified output format.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output_format : str
|
||||||
|
The desired output format for the writer.
|
||||||
|
|
||||||
|
output_dir : str
|
||||||
|
The directory where the output files will be saved.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Callable[[dict, TextIO, dict], None]
|
||||||
|
A function that can be used to write results to files.
|
||||||
|
"""
|
||||||
writers = {
|
writers = {
|
||||||
"txt": WriteTXT,
|
"txt": WriteTXT,
|
||||||
"vtt": WriteVTT,
|
"vtt": WriteVTT,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user