Merge 44cc156a7cd777f5d177e34ff8243e80466b336c into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
MotoMatt5040 2025-06-26 10:57:08 -06:00 committed by GitHub
commit abf7955309
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,8 @@ import argparse
import os import os
import traceback import traceback
import warnings import warnings
import json
import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np import numpy as np
@ -52,6 +54,8 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0", clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
censor: bool = False,
censor_path: str = None,
**decode_options, **decode_options,
): ):
""" """
@ -124,6 +128,8 @@ def transcribe(
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
""" """
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"): if model.device == torch.device("cpu"):
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -165,6 +171,21 @@ def transcribe(
task=task, task=task,
) )
forbidden_words = []
if censor:
if (
censor_path is None
or not os.path.exists(censor_path)
or not censor_path.endswith(".json")
):
warnings.warn("Please provide a valid censor directory, censoring disabled.")
censor = False
else:
with open(f'{censor_path}', 'r') as f:
censor_data = json.load(f)
forbidden_words = censor_data.get(language, [])
if isinstance(clip_timestamps, str): if isinstance(clip_timestamps, str):
clip_timestamps = [ clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
@ -243,16 +264,32 @@ def transcribe(
else: else:
initial_prompt_tokens = [] initial_prompt_tokens = []
def censor_text(text, forbidden):
def censor_match(match):
word = match.group(0)
return '*' * len(word) if word.lower() in forbidden_words else word
censored_text = re.sub(r'\w+|[^\w\s]', censor_match, text)
return censored_text
def new_segment( def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
): ):
tokens = tokens.tolist() tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot] text_tokens = [token for token in tokens if token < tokenizer.eot]
if censor:
text = censor_text(tokenizer.decode(text_tokens), forbidden_words)
else:
text = tokenizer.decode(text_tokens)
return { return {
"seek": seek, "seek": seek,
"start": start, "start": start,
"end": end, "end": end,
"text": tokenizer.decode(text_tokens), "text": text,
"tokens": tokens, "tokens": tokens,
"temperature": result.temperature, "temperature": result.temperature,
"avg_logprob": result.avg_logprob, "avg_logprob": result.avg_logprob,
@ -507,8 +544,13 @@ def transcribe(
# update progress bar # update progress bar
pbar.update(min(content_frames, seek) - previous_seek) pbar.update(min(content_frames, seek) - previous_seek)
text = tokenizer.decode(all_tokens[len(initial_prompt_tokens) :])
if censor:
text = censor_text(text, forbidden_words)
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), text=text,
segments=all_segments, segments=all_segments,
language=language, language=language,
) )
@ -533,6 +575,8 @@ def cli():
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--censor", type=str2bool, default=True, help="(requires --censor_path=\"<path>\") whether to censor out profanity or not")
parser.add_argument("--censor_path", type=str2bool, default=True, help="censored words path. Use json format - {lang: [words]}")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")