mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Merge 5ed89d0ca24e99e93376da4626690354c1f1c6e7 into 25639fc17ddc013d56c594bfbf7644f2185fad84
This commit is contained in:
commit
cb76906e06
@ -112,6 +112,7 @@ class DecodingOptions:
|
|||||||
|
|
||||||
# implementation details
|
# implementation details
|
||||||
fp16: bool = True # use fp16 for most of the calculation
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
hotwords: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -598,15 +599,20 @@ class DecodingTask:
|
|||||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||||
tokens = tokens + prefix_tokens
|
tokens = tokens + prefix_tokens
|
||||||
|
|
||||||
if prompt := self.options.prompt:
|
if (prompt := self.options.prompt) or ((self.options.hotwords) and not self.options.prefix):
|
||||||
prompt_tokens = (
|
prompt_tokens = (
|
||||||
self.tokenizer.encode(" " + prompt.strip())
|
self.tokenizer.encode(" " + prompt.strip())
|
||||||
if isinstance(prompt, str)
|
if isinstance(prompt, str)
|
||||||
else prompt
|
else prompt
|
||||||
)
|
)
|
||||||
|
if (hotwords := self.options.hotwords) and not self.options.prefix:
|
||||||
|
hotwords_tokens = self.tokenizer.encode(" " + hotwords.strip())
|
||||||
|
if len(hotwords_tokens) >= self.n_ctx // 2:
|
||||||
|
hotwords_tokens = hotwords_tokens[: self.n_ctx // 2 - 1]
|
||||||
tokens = (
|
tokens = (
|
||||||
[self.tokenizer.sot_prev]
|
[self.tokenizer.sot_prev]
|
||||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
+ (hotwords_tokens if self.options.hotwords is not None else [])
|
||||||
|
+ (prompt_tokens[-(self.n_ctx // 2 - 1) :] if self.options.prompt is not None else [])
|
||||||
+ tokens
|
+ tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -51,6 +51,7 @@ def transcribe(
|
|||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
clip_timestamps: Union[str, List[float]] = "0",
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
hallucination_silence_threshold: Optional[float] = None,
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
|
hotwords: Optional[str] = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -113,6 +114,9 @@ def transcribe(
|
|||||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||||
when a possible hallucination is detected
|
when a possible hallucination is detected
|
||||||
|
|
||||||
|
hotwords: Optional[str]
|
||||||
|
optional hotwords to provide as a prompt for the all window.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
@ -275,6 +279,7 @@ def transcribe(
|
|||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
|
decode_options["hotwords"] = hotwords
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
@ -546,6 +551,8 @@ def cli():
|
|||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||||
|
parser.add_argument("--hotwords", type=str, default=None, help="optional hotwords to provide as a prompt for the all window")
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user