fix(transcribe): fix censor

re has been added to imports and censor_path added to params. The goal is to allow users to create their own censor json file to use rather than have it supplied to them. A check is used to verify the file exists if the censor flag is set, and if it does not or it is not the proper file tye, the censor is disabled. Segments and full text are both censored. The returned dict was set to a variable called "data" to allow this to occur. To do so another way would be text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]) if not censor else censor_text(tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), forbidden_words).... which is much more difficult to read.

BREAKING CHANGE: I have not confirmed issues yet, however it may be possible for the censor to bug if weird formats or improper design is put in place of the json file.

Signed-off-by: matt@aero <motomatt5040@gmail.com>
This commit is contained in:
matt@aero 2025-02-04 17:37:09 -06:00
parent 517a43ecd1
commit 854784880b
No known key found for this signature in database
GPG Key ID: C41658B30DBEB4CE

View File

@ -2,6 +2,8 @@ import argparse
import os
import traceback
import warnings
import json
import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
@ -52,6 +54,8 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
censor: bool = False,
censor_path: str = None,
**decode_options,
):
"""
@ -124,6 +128,8 @@ def transcribe(
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
@ -165,6 +171,21 @@ def transcribe(
task=task,
)
forbidden_words = []
if censor:
if (
censor_path is None
or not os.path.exists(censor_path)
or not censor_path.endswith(".json")
):
warnings.warn("Please provide a valid censor directory, censoring disabled.")
censor = False
else:
with open(f'{censor_path}', 'r') as f:
censor_data = json.load(f)
forbidden_words = censor_data.get(language, [])
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
@ -243,16 +264,32 @@ def transcribe(
else:
initial_prompt_tokens = []
def censor_text(text, forbidden):
def censor_match(match):
word = match.group(0)
return '*' * len(word) if word.lower() in forbidden_words else word
censored_text = re.sub(r'\w+|[^\w\s]', censor_match, text)
return censored_text
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
if censor:
text = censor_text(tokenizer.decode(text_tokens), forbidden_words)
else:
text = tokenizer.decode(text_tokens)
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"text": text,
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
@ -507,12 +544,19 @@ def transcribe(
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
text = tokenizer.decode(all_tokens[len(initial_prompt_tokens) :])
if censor:
text = censor_text(text, forbidden_words)
data = dict(
text=text,
segments=all_segments,
language=language,
)
return data
def cli():
from . import available_models
@ -533,6 +577,8 @@ def cli():
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--censor", type=str2bool, default=True, help="(requires --censor_path=\"<path>\") whether to censor out profanity or not")
parser.add_argument("--censor_path", type=str2bool, default=True, help="censored words path. Use json format - {lang: [words]}")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")