Use triton==2.0.0 (#1053)

This commit is contained in:
Jong Wook Kim 2023-03-07 19:56:31 -05:00 committed by GitHub
parent 924e1f8e06
commit 38e990d853
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,21 +13,7 @@ def read_version(fname="whisper/version.py"):
requirements = []
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
triton_requirement = "triton==2.0.0"
try:
import re
import subprocess
version_line = (
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
)
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
triton_requirement = "triton==2.0.0.dev20221011"
except (IndexError, OSError, subprocess.SubprocessError):
pass
requirements.append(triton_requirement)
requirements.append("triton==2.0.0")
setup(
name="openai-whisper",