Merge a0936816d5624f975a039ffeb1ca25d473e26d5b into 517a43ecd132a2089d85f4ebc044728a71d49f6e

This commit is contained in:
Pramod Pai 2025-01-10 17:57:59 +09:00 committed by GitHub
commit 52933bf957
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View File

@ -112,7 +112,13 @@ class DecodingOptions:
# implementation details
fp16: bool = True # use fp16 for most of the calculation
bf16: bool = False # use bf16 for most of the calculation
def __post_init__(self):
if self.fp16 and self.bf16:
raise ValueError("Both fp16 and bf16 cannot be True at the same time")
if self.bf16:
object.__setattr__(self, "fp16", False)
@dataclass(frozen=True)
class DecodingResult:
@ -655,7 +661,9 @@ class DecodingTask:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32
torch.float16 if self.options.fp16 else
torch.bfloat16 if self.options.bf16 else
torch.float32
):
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"

View File

@ -132,7 +132,7 @@ def transcribe(
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
if dtype == torch.float32 or dtype == torch.bfloat16:
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing