diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d0..56bede8 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -112,6 +112,7 @@ class DecodingOptions: # implementation details fp16: bool = True # use fp16 for most of the calculation + hotwords: Optional[str] = None @dataclass(frozen=True) @@ -598,15 +599,20 @@ class DecodingTask: prefix_tokens = prefix_tokens[-max_prefix_len:] tokens = tokens + prefix_tokens - if prompt := self.options.prompt: + if (prompt := self.options.prompt) or ((self.options.hotwords) and not self.options.prefix): prompt_tokens = ( self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt ) + if (hotwords := self.options.hotwords) and not self.options.prefix: + hotwords_tokens = self.tokenizer.encode(" " + hotwords.strip()) + if len(hotwords_tokens) >= self.n_ctx // 2: + hotwords_tokens = hotwords_tokens[: self.n_ctx // 2 - 1] tokens = ( [self.tokenizer.sot_prev] - + prompt_tokens[-(self.n_ctx // 2 - 1) :] + + (hotwords_tokens if self.options.hotwords is not None else []) + + (prompt_tokens[-(self.n_ctx // 2 - 1) :] if self.options.prompt is not None else []) + tokens ) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8e1240b..87f3ccb 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -51,6 +51,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + hotwords: Optional[str] = None, **decode_options, ): """ @@ -113,6 +114,9 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + hotwords: Optional[str] + optional hotwords to provide as a prompt for the all window. + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -275,6 +279,7 @@ def transcribe( segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) + decode_options["hotwords"] = hotwords decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) @@ -546,6 +551,8 @@ def cli(): parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") + parser.add_argument("--hotwords", type=str, default=None, help="optional hotwords to provide as a prompt for the all window") + # fmt: on args = parser.parse_args().__dict__