mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
fix(transcribe): fix censor
re has been added to imports and censor_path added to params. The goal is to allow users to create their own censor json file to use rather than have it supplied to them. A check is used to verify the file exists if the censor flag is set, and if it does not or it is not the proper file tye, the censor is disabled. Segments and full text are both censored. The returned dict was set to a variable called "data" to allow this to occur. To do so another way would be text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]) if not censor else censor_text(tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), forbidden_words).... which is much more difficult to read. BREAKING CHANGE: I have not confirmed issues yet, however it may be possible for the censor to bug if weird formats or improper design is put in place of the json file. Signed-off-by: matt@aero <motomatt5040@gmail.com>
This commit is contained in:
parent
517a43ecd1
commit
854784880b
@ -2,6 +2,8 @@ import argparse
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -52,6 +54,8 @@ def transcribe(
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
censor: bool = False,
|
||||
censor_path: str = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@ -124,6 +128,8 @@ def transcribe(
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
|
||||
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
@ -165,6 +171,21 @@ def transcribe(
|
||||
task=task,
|
||||
)
|
||||
|
||||
forbidden_words = []
|
||||
if censor:
|
||||
if (
|
||||
censor_path is None
|
||||
or not os.path.exists(censor_path)
|
||||
or not censor_path.endswith(".json")
|
||||
):
|
||||
warnings.warn("Please provide a valid censor directory, censoring disabled.")
|
||||
censor = False
|
||||
else:
|
||||
with open(f'{censor_path}', 'r') as f:
|
||||
censor_data = json.load(f)
|
||||
|
||||
forbidden_words = censor_data.get(language, [])
|
||||
|
||||
if isinstance(clip_timestamps, str):
|
||||
clip_timestamps = [
|
||||
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||
@ -243,16 +264,32 @@ def transcribe(
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def censor_text(text, forbidden):
|
||||
|
||||
def censor_match(match):
|
||||
word = match.group(0)
|
||||
return '*' * len(word) if word.lower() in forbidden_words else word
|
||||
|
||||
censored_text = re.sub(r'\w+|[^\w\s]', censor_match, text)
|
||||
|
||||
return censored_text
|
||||
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
tokens = tokens.tolist()
|
||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||
|
||||
if censor:
|
||||
text = censor_text(tokenizer.decode(text_tokens), forbidden_words)
|
||||
else:
|
||||
text = tokenizer.decode(text_tokens)
|
||||
|
||||
return {
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"text": text,
|
||||
"tokens": tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
@ -507,12 +544,19 @@ def transcribe(
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
text = tokenizer.decode(all_tokens[len(initial_prompt_tokens) :])
|
||||
|
||||
if censor:
|
||||
text = censor_text(text, forbidden_words)
|
||||
|
||||
data = dict(
|
||||
text=text,
|
||||
segments=all_segments,
|
||||
language=language,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
@ -533,6 +577,8 @@ def cli():
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
parser.add_argument("--censor", type=str2bool, default=True, help="(requires --censor_path=\"<path>\") whether to censor out profanity or not")
|
||||
parser.add_argument("--censor_path", type=str2bool, default=True, help="censored words path. Use json format - {lang: [words]}")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
Loading…
x
Reference in New Issue
Block a user