mirror of
https://github.com/openai/whisper.git
synced 2025-11-29 08:28:53 +00:00
added support for bfloat16 datatype
This commit is contained in:
parent
b91c907694
commit
0ce59f338e
@ -110,7 +110,13 @@ class DecodingOptions:
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
bf16: bool = False # use bf16 for most of the calculation
|
||||
|
||||
def __post_init__(self):
|
||||
if self.fp16 and self.bf16:
|
||||
raise ValueError("Both fp16 and bf16 cannot be True at the same time")
|
||||
if self.bf16:
|
||||
object.__setattr__(self, "fp16", False)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingResult:
|
||||
@ -650,7 +656,9 @@ class DecodingTask:
|
||||
audio_features = self.model.encoder(mel)
|
||||
|
||||
if audio_features.dtype != (
|
||||
torch.float16 if self.options.fp16 else torch.float32
|
||||
torch.float16 if self.options.fp16 else
|
||||
torch.bfloat16 if self.options.bf16 else
|
||||
torch.float32
|
||||
):
|
||||
return TypeError(
|
||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user