mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Merge 5ed89d0ca24e99e93376da4626690354c1f1c6e7 into 25639fc17ddc013d56c594bfbf7644f2185fad84
This commit is contained in:
commit
cb76906e06
@ -112,6 +112,7 @@ class DecodingOptions:
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
hotwords: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -598,15 +599,20 @@ class DecodingTask:
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt := self.options.prompt:
|
||||
if (prompt := self.options.prompt) or ((self.options.hotwords) and not self.options.prefix):
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip())
|
||||
if isinstance(prompt, str)
|
||||
else prompt
|
||||
)
|
||||
if (hotwords := self.options.hotwords) and not self.options.prefix:
|
||||
hotwords_tokens = self.tokenizer.encode(" " + hotwords.strip())
|
||||
if len(hotwords_tokens) >= self.n_ctx // 2:
|
||||
hotwords_tokens = hotwords_tokens[: self.n_ctx // 2 - 1]
|
||||
tokens = (
|
||||
[self.tokenizer.sot_prev]
|
||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||
+ (hotwords_tokens if self.options.hotwords is not None else [])
|
||||
+ (prompt_tokens[-(self.n_ctx // 2 - 1) :] if self.options.prompt is not None else [])
|
||||
+ tokens
|
||||
)
|
||||
|
||||
|
||||
@ -51,6 +51,7 @@ def transcribe(
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
hotwords: Optional[str] = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@ -113,6 +114,9 @@ def transcribe(
|
||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||
when a possible hallucination is detected
|
||||
|
||||
hotwords: Optional[str]
|
||||
optional hotwords to provide as a prompt for the all window.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
@ -275,6 +279,7 @@ def transcribe(
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
decode_options["hotwords"] = hotwords
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
@ -546,6 +551,8 @@ def cli():
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||
parser.add_argument("--hotwords", type=str, default=None, help="optional hotwords to provide as a prompt for the all window")
|
||||
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user