mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 08:11:11 +00:00
Merge 44cc156a7cd777f5d177e34ff8243e80466b336c into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
abf7955309
@ -2,6 +2,8 @@ import argparse
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -52,6 +54,8 @@ def transcribe(
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
censor: bool = False,
|
||||
censor_path: str = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@ -124,6 +128,8 @@ def transcribe(
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
|
||||
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
@ -165,6 +171,21 @@ def transcribe(
|
||||
task=task,
|
||||
)
|
||||
|
||||
forbidden_words = []
|
||||
if censor:
|
||||
if (
|
||||
censor_path is None
|
||||
or not os.path.exists(censor_path)
|
||||
or not censor_path.endswith(".json")
|
||||
):
|
||||
warnings.warn("Please provide a valid censor directory, censoring disabled.")
|
||||
censor = False
|
||||
else:
|
||||
with open(f'{censor_path}', 'r') as f:
|
||||
censor_data = json.load(f)
|
||||
|
||||
forbidden_words = censor_data.get(language, [])
|
||||
|
||||
if isinstance(clip_timestamps, str):
|
||||
clip_timestamps = [
|
||||
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||
@ -243,16 +264,32 @@ def transcribe(
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def censor_text(text, forbidden):
|
||||
|
||||
def censor_match(match):
|
||||
word = match.group(0)
|
||||
return '*' * len(word) if word.lower() in forbidden_words else word
|
||||
|
||||
censored_text = re.sub(r'\w+|[^\w\s]', censor_match, text)
|
||||
|
||||
return censored_text
|
||||
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
tokens = tokens.tolist()
|
||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||
|
||||
if censor:
|
||||
text = censor_text(tokenizer.decode(text_tokens), forbidden_words)
|
||||
else:
|
||||
text = tokenizer.decode(text_tokens)
|
||||
|
||||
return {
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"text": text,
|
||||
"tokens": tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
@ -507,8 +544,13 @@ def transcribe(
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
text = tokenizer.decode(all_tokens[len(initial_prompt_tokens) :])
|
||||
|
||||
if censor:
|
||||
text = censor_text(text, forbidden_words)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
text=text,
|
||||
segments=all_segments,
|
||||
language=language,
|
||||
)
|
||||
@ -533,6 +575,8 @@ def cli():
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
parser.add_argument("--censor", type=str2bool, default=True, help="(requires --censor_path=\"<path>\") whether to censor out profanity or not")
|
||||
parser.add_argument("--censor_path", type=str2bool, default=True, help="censored words path. Use json format - {lang: [words]}")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user