diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0a4cc36..a115443 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,6 +2,8 @@ import argparse import os import traceback import warnings +import json +import re from typing import TYPE_CHECKING, List, Optional, Tuple, Union import numpy as np @@ -52,6 +54,8 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + censor: bool = False, + censor_path: str = None, **decode_options, ): """ @@ -124,6 +128,8 @@ def transcribe( A dictionary containing the resulting text ("text") and segment-level details ("segments"), and the spoken language ("language"), which is detected when `decode_options["language"]` is None. """ + + dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 if model.device == torch.device("cpu"): if torch.cuda.is_available(): @@ -165,6 +171,21 @@ def transcribe( task=task, ) + forbidden_words = [] + if censor: + if ( + censor_path is None + or not os.path.exists(censor_path) + or not censor_path.endswith(".json") + ): + warnings.warn("Please provide a valid censor directory, censoring disabled.") + censor = False + else: + with open(f'{censor_path}', 'r') as f: + censor_data = json.load(f) + + forbidden_words = censor_data.get(language, []) + if isinstance(clip_timestamps, str): clip_timestamps = [ float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) @@ -243,16 +264,32 @@ def transcribe( else: initial_prompt_tokens = [] + def censor_text(text, forbidden): + + def censor_match(match): + word = match.group(0) + return '*' * len(word) if word.lower() in forbidden_words else word + + censored_text = re.sub(r'\w+|[^\w\s]', censor_match, text) + + return censored_text + def new_segment( *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult ): tokens = tokens.tolist() text_tokens = [token for token in tokens if token < tokenizer.eot] + + if censor: + text = censor_text(tokenizer.decode(text_tokens), forbidden_words) + else: + text = tokenizer.decode(text_tokens) + return { "seek": seek, "start": start, "end": end, - "text": tokenizer.decode(text_tokens), + "text": text, "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, @@ -507,12 +544,19 @@ def transcribe( # update progress bar pbar.update(min(content_frames, seek) - previous_seek) - return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), + text = tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]) + + if censor: + text = censor_text(text, forbidden_words) + + data = dict( + text=text, segments=all_segments, language=language, ) + return data + def cli(): from . import available_models @@ -533,6 +577,8 @@ def cli(): parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") + parser.add_argument("--censor", type=str2bool, default=True, help="(requires --censor_path=\"\") whether to censor out profanity or not") + parser.add_argument("--censor_path", type=str2bool, default=True, help="censored words path. Use json format - {lang: [words]}") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")