From b1d213c0c784e04f7f413ae841c4bd352638491d Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 17 Jan 2023 13:43:36 -0800 Subject: [PATCH] allow test_transcribe to run on CPU when CUDA is not available --- .github/workflows/test.yml | 2 +- tests/test_transcribe.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b4f0828..7811428 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,7 @@ jobs: pytorch-version: 1.10.2 steps: - uses: conda-incubator/setup-miniconda@v2 - - run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch + - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch - uses: actions/checkout@v2 - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: pip install pytest diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 836cf40..f5d66c3 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -1,13 +1,15 @@ import os import pytest +import torch import whisper -@pytest.mark.parametrize('model_name', whisper.available_models()) +@pytest.mark.parametrize("model_name", whisper.available_models()) def test_transcribe(model_name: str): - model = whisper.load_model(model_name).cuda() + device = "cuda" if torch.cuda.is_available() else "cpu" + model = whisper.load_model(model_name).to(device) audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") language = "en" if model_name.endswith(".en") else None