Merge 5ed89d0ca24e99e93376da4626690354c1f1c6e7 into 25639fc17ddc013d56c594bfbf7644f2185fad84

This commit is contained in:
jax 2024-10-01 14:18:29 +05:30 committed by GitHub
commit cb76906e06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 2 deletions

View File

@ -112,6 +112,7 @@ class DecodingOptions:
# implementation details # implementation details
fp16: bool = True # use fp16 for most of the calculation fp16: bool = True # use fp16 for most of the calculation
hotwords: Optional[str] = None
@dataclass(frozen=True) @dataclass(frozen=True)
@ -598,15 +599,20 @@ class DecodingTask:
prefix_tokens = prefix_tokens[-max_prefix_len:] prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens tokens = tokens + prefix_tokens
if prompt := self.options.prompt: if (prompt := self.options.prompt) or ((self.options.hotwords) and not self.options.prefix):
prompt_tokens = ( prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) self.tokenizer.encode(" " + prompt.strip())
if isinstance(prompt, str) if isinstance(prompt, str)
else prompt else prompt
) )
if (hotwords := self.options.hotwords) and not self.options.prefix:
hotwords_tokens = self.tokenizer.encode(" " + hotwords.strip())
if len(hotwords_tokens) >= self.n_ctx // 2:
hotwords_tokens = hotwords_tokens[: self.n_ctx // 2 - 1]
tokens = ( tokens = (
[self.tokenizer.sot_prev] [self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :] + (hotwords_tokens if self.options.hotwords is not None else [])
+ (prompt_tokens[-(self.n_ctx // 2 - 1) :] if self.options.prompt is not None else [])
+ tokens + tokens
) )

View File

@ -51,6 +51,7 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0", clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
**decode_options, **decode_options,
): ):
""" """
@ -113,6 +114,9 @@ def transcribe(
When word_timestamps is True, skip silent periods longer than this threshold (in seconds) When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected when a possible hallucination is detected
hotwords: Optional[str]
optional hotwords to provide as a prompt for the all window.
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -275,6 +279,7 @@ def transcribe(
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
decode_options["hotwords"] = hotwords
decode_options["prompt"] = all_tokens[prompt_reset_since:] decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
@ -546,6 +551,8 @@ def cli():
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
parser.add_argument("--hotwords", type=str, default=None, help="optional hotwords to provide as a prompt for the all window")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__