From 9cf2f995bdcc265fb25a163c24004d7aa2107503 Mon Sep 17 00:00:00 2001 From: jax Date: Fri, 8 Mar 2024 11:30:43 +0800 Subject: [PATCH 1/2] add hotwords feature --- whisper/decoding.py | 10 +++++++++- whisper/transcribe.py | 7 +++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d0..994cc04 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -112,6 +112,7 @@ class DecodingOptions: # implementation details fp16: bool = True # use fp16 for most of the calculation + hotwords: Optional[str] = None @dataclass(frozen=True) @@ -598,15 +599,22 @@ class DecodingTask: prefix_tokens = prefix_tokens[-max_prefix_len:] tokens = tokens + prefix_tokens - if prompt := self.options.prompt: + if (prompt := self.options.prompt) or ((self.options.hotwords) and not self.options.prefix): prompt_tokens = ( self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt ) + if (hotwords := self.options.hotwords) and not self.options.prefix: + print(f"hotwords: {hotwords}") + hotwords_tokens = self.tokenizer.encode(" " + hotwords.strip()) + if len(hotwords_tokens) >= self.n_ctx // 2: + hotwords_tokens = hotwords_tokens[: self.n_ctx // 2 - 1] tokens = ( [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + + (hotwords_tokens if self.options.hotwords is not None else []) + + (prompt_tokens[-(self.n_ctx // 2 - 1) :] if self.options.prompt is not None else []) + tokens ) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a2..7238a16 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -51,6 +51,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + hotwords: Optional[str] = None, **decode_options, ): """ @@ -113,6 +114,9 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + hotwords: Optional[str] + optional hotwords to provide as a prompt for the all window. + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -275,6 +279,7 @@ def transcribe( segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) + decode_options["hotwords"] = hotwords decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) @@ -546,6 +551,8 @@ def cli(): parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") + parser.add_argument("--hotwords", type=str, default=None, help="optional hotwords to provide as a prompt for the all window") + # fmt: on args = parser.parse_args().__dict__ From 5ed89d0ca24e99e93376da4626690354c1f1c6e7 Mon Sep 17 00:00:00 2001 From: jax Date: Fri, 8 Mar 2024 12:07:59 +0800 Subject: [PATCH 2/2] remove log --- whisper/decoding.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/whisper/decoding.py b/whisper/decoding.py index 994cc04..56bede8 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -606,13 +606,11 @@ class DecodingTask: else prompt ) if (hotwords := self.options.hotwords) and not self.options.prefix: - print(f"hotwords: {hotwords}") hotwords_tokens = self.tokenizer.encode(" " + hotwords.strip()) if len(hotwords_tokens) >= self.n_ctx // 2: hotwords_tokens = hotwords_tokens[: self.n_ctx // 2 - 1] tokens = ( [self.tokenizer.sot_prev] - + prompt_tokens[-(self.n_ctx // 2 - 1) :] + (hotwords_tokens if self.options.hotwords is not None else []) + (prompt_tokens[-(self.n_ctx // 2 - 1) :] if self.options.prompt is not None else []) + tokens