Merge pull request #2 from BlueLabelLabs/feat/Add-HPU-support

Add hpu support
This commit is contained in:
PiotrBLL 2024-11-20 13:13:39 +01:00 committed by GitHub
commit 1a42019974
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 4626 additions and 4242 deletions

1
.dockerignore Normal file
View File

@ -0,0 +1 @@
.graph_dumps

34
Dockerfile.hpu Normal file
View File

@ -0,0 +1,34 @@
# Use the official Gaudi Docker image with PyTorch
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
# Set environment variables for Habana
ENV HABANA_VISIBLE_DEVICES=all
ENV OMPI_MCA_btl_vader_single_copy_mechanism=none
ENV PT_HPU_LAZY_ACC_PAR_MODE=0
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=1
# Set timezone to UTC and install essential packages
ENV DEBIAN_FRONTEND="noninteractive" TZ=Etc/UTC
RUN apt-get update && apt-get install -y \
tzdata \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
# Download and install the static build of ffmpeg
RUN mkdir -p /usr/local/bin/ffmpeg && \
cd /usr/local/bin/ffmpeg && \
wget https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz && \
tar -xf ffmpeg-release-amd64-static.tar.xz && \
cp -a ffmpeg-*-static/ffmpeg /usr/bin/ffmpeg && \
cp -a ffmpeg-*-static/ffprobe /usr/bin/ffprobe && \
rm -rf /usr/local/bin/ffmpeg
COPY . /workspace/whisper
WORKDIR /workspace/whisper
# Copy HPU requirements
COPY requirements_hpu.txt /workspace/requirements_hpu.txt
# Install Python packages
RUN pip install --upgrade pip \
&& pip install -r requirements_hpu.txt

View File

@ -93,6 +93,10 @@ Adding `--task translate` will translate the speech into English:
whisper japanese.wav --language Japanese --task translate whisper japanese.wav --language Japanese --task translate
The following command will transcribe speech in audio files, using the Intel® Gaudi® HPU (`--device hpu` option):
whisper audio.flac audio.mp3 audio.wav --model turbo --device hpu
Run the following to view all available options: Run the following to view all available options:
whisper --help whisper --help
@ -140,6 +144,61 @@ result = whisper.decode(model, mel, options)
print(result.text) print(result.text)
``` ```
## Intel® Gaudi® hpu usage
### Build the Docker Image
```bash
docker build -t whisper_hpu:latest -f Dockerfile.hpu .
```
In the `Dockerfile.hpu`, we use the `vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest` base image, make sure to replace it with the appropriate version for your environment if needed.
See the [PyTorch Docker Images for the Intel® Gaudi® Accelerator](https://developer.habana.ai/catalog/pytorch-container/) for more information.
### Run the Container
```bash
docker run -it --runtime=habana whisper_hpu:latest
```
Using a mapping volume (`-v`) is optional, but it allows you to access the Whisper repository from within the container.
You can make this by adding `-v /path/to/your/whisper:/workspace/whisper` to the `docker run` command.
If you decide to use the mapping make sure to replace `/path/to/your/whisper` with the path to the Whisper repository on your local machine.
### Command-line usage with Intel® Gaudi® hpu
To run the `transcribe` process with Intel® Gaudi® HPU, you can use the `--device hpu` option:
```bash
python3 -m whisper.transcribe audio_file.wav --model turbo --device hpu
```
* Note: Change `audio_file.wav` to the path of the audio file you want to transcribe. (Example file: https://www.kaggle.com/datasets/pavanelisetty/sample-audio-files-for-speech-recognition?resource=download)
To run the `transcribe` tests with Intel® Gaudi® HPU, make sure to install the `pytest` package:
```bash
pip install pytest
```
and run the following command:
```bash
PYTHONPATH=. pytest -s tests/test_transcribe.py::test_transcribe_hpu
```
### Python usage with Intel® Gaudi® hpu
To use Intel® Gaudi® hpu within Python, you can specify the device when loading the model:
```python
import whisper
model = whisper.load_model("turbo", device="hpu")
result = model.transcribe("audio.mp3")
print(result["text"])
```
## More examples ## More examples
Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussions/categories/show-and-tell) category in Discussions for sharing more example usages of Whisper and third-party extensions such as web demos, integrations with other tools, ports for different platforms, etc. Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussions/categories/show-and-tell) category in Discussions for sharing more example usages of Whisper and third-party extensions such as web demos, integrations with other tools, ports for different platforms, etc.

View File

@ -36,10 +36,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {
"id": "3CqtR2Fi5-vP"
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
@ -56,10 +53,35 @@
"import torchaudio\n", "import torchaudio\n",
"\n", "\n",
"from tqdm.notebook import tqdm\n", "from tqdm.notebook import tqdm\n",
"\n", "\n"
"\n", ],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Set the `DEVICE` based on available hardware\n",
"### If you're running on a machine with Intel Gaudi support, you can set the DEVICE to \"hpu\"\n",
"### This can be done by setting the device to \"hpu\" if Gaudi is available.\n",
"### Ensure you have the Habana PyTorch extension installed and properly configured."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# DEVICE = \"hpu\"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"" "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
] ],
"metadata": {
"collapsed": false
}
}, },
{ {
"cell_type": "code", "cell_type": "code",

File diff suppressed because one or more lines are too long

View File

@ -5,3 +5,5 @@ tqdm
more-itertools more-itertools
tiktoken tiktoken
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2" triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
scipy
pytest

6
requirements_hpu.txt Normal file
View File

@ -0,0 +1,6 @@
optimum-habana==1.14.1
transformers==4.45.2
huggingface-hub==0.26.2
tiktoken==0.8.0
torch-geometric==2.6.1
numba==0.60.0

View File

@ -6,6 +6,7 @@ import pytest
def pytest_configure(config): def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda") config.addinivalue_line("markers", "requires_cuda")
config.addinivalue_line("markers", "requires_hpu")
@pytest.fixture @pytest.fixture

View File

@ -3,7 +3,8 @@ import pytest
import scipy.ndimage import scipy.ndimage
import torch import torch
from whisper.timing import dtw_cpu, dtw_cuda, median_filter from whisper.hpu_utils import get_x_hpu
from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu
sizes = [ sizes = [
(10, 20), (10, 20),
@ -94,3 +95,15 @@ def test_median_filter_equivalence(shape):
filtered_gpu = median_filter(x.cuda(), filter_width).cpu() filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
assert np.allclose(filtered_cpu, filtered_gpu) assert np.allclose(filtered_cpu, filtered_gpu)
@pytest.mark.requires_hpu
@pytest.mark.parametrize("N, M", sizes)
def test_dtw_hpu_equivalence(N: int, M: int):
x_numpy = np.random.randn(N, M).astype(np.float32)
x_hpu = get_x_hpu(x_numpy)
trace_cpu = dtw_cpu(x_numpy)
trace_hpu = dtw_hpu(x_hpu)
assert np.allclose(trace_cpu, trace_hpu)

View File

@ -10,7 +10,7 @@ from whisper.tokenizer import get_tokenizer
@pytest.mark.parametrize("model_name", whisper.available_models()) @pytest.mark.parametrize("model_name", whisper.available_models())
def test_transcribe(model_name: str): def test_transcribe(model_name: str):
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device) model = whisper.load_model(model_name, device=device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None language = "en" if model_name.endswith(".en") else None
@ -40,3 +40,38 @@ def test_transcribe(model_name: str):
timing_checked = True timing_checked = True
assert timing_checked assert timing_checked
@pytest.mark.requires_hpu
@pytest.mark.parametrize("model_name", whisper.available_models())
def test_transcribe_hpu(model_name: str):
device = "hpu"
model = whisper.load_model(model_name, device=device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
transcription = result["text"].lower()
assert "my fellow americans" in transcription
assert "your country" in transcription
assert "do for you" in transcription
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"]
timing_checked = False
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
timing_checked = True
assert timing_checked

View File

@ -147,14 +147,31 @@ def load_model(
with ( with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp: ) as fp:
checkpoint = torch.load(fp, map_location=device) if device == "hpu":
"""If the device is HPU,
the model should be loaded on CPU first
and then moved to HPU."""
checkpoint = torch.load(fp, map_location="cpu")
else:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims, compute_device=torch.device(device))
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None: if alignment_heads is not None:
model.set_alignment_heads(alignment_heads) model.set_alignment_heads(alignment_heads)
if device == "hpu":
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model)
model = model.eval().to(torch.device(device))
return model
return model.to(device) return model.to(device)

View File

@ -8,6 +8,7 @@ from torch import Tensor
from torch.distributions import Categorical from torch.distributions import Categorical
from .audio import CHUNK_LENGTH from .audio import CHUNK_LENGTH
from .hpu_utils import is_hpu_device
from .tokenizer import Tokenizer, get_tokenizer from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio from .utils import compression_ratio
@ -456,7 +457,17 @@ class ApplyTimestampRules(LogitFilter):
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]): for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :] if is_hpu_device(tokens.device):
"""
If tokens are on HPU, `sampled_tokens` is cloned to force evaluation.
On Habana HPUs, tensors may use lazy execution, which can lead to runtime errors if not explicitly
evaluated. Cloning `sampled_tokens` ensures it is fully evaluated on the HPU, preventing potential
synchronization issues.
"""
sampled_tokens = tokens[k, self.sample_begin :].clone()
else:
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()] seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = ( last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin

View File

@ -0,0 +1,54 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# Define layers
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(32 * 56 * 56, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
# Forward pass through the network
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 32 * 56 * 56) # Flatten the tensor
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Example usage
if __name__ == "__main__":
# Load Habana module for HPU support
from habana_frameworks.torch.utils.library_loader import load_habana_module
import habana_frameworks.torch.hpu as hthpu
load_habana_module()
device = None
# Set device to HPU
if hthpu.is_available():
device = torch.device("hpu")
print("Using HPU")
if not device:
print("HPU is not available")
exit(1)
# Create model instance and move it to the HPU
model = SimpleCNN(num_classes=10).to(device)
# Create a dummy input tensor and move it to the HPU
input_tensor = torch.rand((64, 3, 224, 224), device=device) # Batch size of 64
# Forward pass through the model on HPU
output = model(input_tensor)
print("Output shape:", output.shape) # Should be [64, num_classes]

13
whisper/hpu_utils.py Normal file
View File

@ -0,0 +1,13 @@
import torch
def get_x_hpu(x_numpy):
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
x_hpu = torch.from_numpy(x_numpy).to("hpu")
return x_hpu
def is_hpu_device(device: torch.device):
return device in (torch.device("hpu:0"), torch.device("hpu"))

View File

@ -11,6 +11,7 @@ from torch import Tensor, nn
from .decoding import decode as decode_function from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function from .decoding import detect_language as detect_language_function
from .hpu_utils import is_hpu_device
from .transcribe import transcribe as transcribe_function from .transcribe import transcribe as transcribe_function
try: try:
@ -250,9 +251,12 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module): class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions): def __init__(self, dims: ModelDimensions, compute_device: Optional[torch.device] = None):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.compute_device = compute_device or (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.encoder = AudioEncoder( self.encoder = AudioEncoder(
self.dims.n_mels, self.dims.n_mels,
self.dims.n_audio_ctx, self.dims.n_audio_ctx,
@ -273,7 +277,11 @@ class Whisper(nn.Module):
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
) )
all_heads[self.dims.n_text_layer // 2 :] = True all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) if not is_hpu_device(self.compute_device):
# Convert to sparse format if device is not HPU
all_heads = all_heads.to_sparse()
self.register_buffer("alignment_heads", all_heads, persistent=False)
def set_alignment_heads(self, dump: bytes): def set_alignment_heads(self, dump: bytes):
array = np.frombuffer( array = np.frombuffer(
@ -282,7 +290,11 @@ class Whisper(nn.Module):
mask = torch.from_numpy(array).reshape( mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head self.dims.n_text_layer, self.dims.n_text_head
) )
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) if not is_hpu_device(self.compute_device):
# Convert to sparse format if device is not HPU
mask = mask.to_sparse()
self.register_buffer("alignment_heads", mask, persistent=False)
def embed_audio(self, mel: torch.Tensor): def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel) return self.encoder(mel)

View File

@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .hpu_utils import is_hpu_device
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -138,7 +139,54 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
return backtrace(trace.cpu().numpy()) return backtrace(trace.cpu().numpy())
def dtw_hpu(x, BLOCK_SIZE=1024):
"""
DTW implementation for HPU.
"""
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
x_skew = x_skew.T.contiguous()
# Initialize cost and trace matrices with high values for comparison
cost = torch.ones(N + M + 2, M + 2, device="hpu") * np.inf
cost[0, 0] = 0 # Start point for DTW
trace = torch.zeros_like(cost, dtype=torch.int32, device="hpu")
for k in range(1, N + M + 1):
p0 = cost[k - 1, :M]
p1 = cost[k, :M]
p2 = cost[k, 1:M + 1]
c0 = p0.clone()
c1 = p1.clone()
c2 = p2.clone()
x_row = x_skew[k - 1, :M]
cost_row = x_row + torch.min(torch.min(c0, c1), c2)
cost[k + 1, 1:M + 1] = cost_row
# Track path by storing traces
trace[k + 1, 1:M + 1] = 2 * (c2 <= c0) * (c2 <= c1) + 1 * (c1 <= c0) * (c1 <= c2) + 0 * (c0 <= c1) * (c0 <= c2)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray: def dtw(x: torch.Tensor) -> np.ndarray:
try:
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
if torch.hpu.is_available():
return dtw_hpu(x)
except (ImportError, subprocess.CalledProcessError):
warnings.warn(
"Failed to import Habana modules, likely due to missing Habana libraries; "
)
if x.is_cuda: if x.is_cuda:
try: try:
return dtw_cuda(x) return dtw_cuda(x)
@ -204,7 +252,21 @@ def find_alignment(
hook.remove() hook.remove()
# heads * tokens * frames # heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) # Adjust alignment head indices for HPU
weights = []
if is_hpu_device(model.device):
# Handle dense layout for HPU
alignment_heads_dense = model.alignment_heads.to_dense() if model.alignment_heads.is_sparse else model.alignment_heads
indices = alignment_heads_dense.nonzero(as_tuple=True)
for _l, _h in zip(*indices):
weights.append(QKs[_l][_h])
else:
# Default behavior for non-HPU devices
for _l, _h in model.alignment_heads.indices().T:
weights.append(QKs[_l][_h])
# Stack the weights
weights = torch.stack(weights)
weights = weights[:, :, : num_frames // 2] weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1) weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)

View File

@ -18,6 +18,7 @@ from .audio import (
pad_or_trim, pad_or_trim,
) )
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .hpu_utils import is_hpu_device
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import ( from .utils import (
@ -126,6 +127,11 @@ def transcribe(
warnings.warn("FP16 is not supported on CPU; using FP32 instead") warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32 dtype = torch.float32
if is_hpu_device(model.device):
if dtype == torch.float16:
warnings.warn("FP16 is not supported on HPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32: if dtype == torch.float32:
decode_options["fp16"] = False decode_options["fp16"] = False
@ -508,12 +514,26 @@ def cli():
f"model should be one of {available_models()} or path to a model checkpoint" f"model should be one of {available_models()} or path to a model checkpoint"
) )
def valid_device(device_name):
if device_name == "cuda" and not torch.cuda.is_available():
warnings.warn("CUDA is not available; using CPU instead")
device_name = "cpu"
if device_name == "hpu":
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
if not torch.hpu.is_available():
warnings.warn("HPU is not available; using CPU instead")
device_name = "hpu"
return device_name
# fmt: off # fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use") parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")