From 0ce59f338ee29d5f2797b49e7e5412c5fcfbc4e5 Mon Sep 17 00:00:00 2001 From: devpramod-intel Date: Wed, 19 Jul 2023 18:14:08 -0400 Subject: [PATCH 1/2] added support for bfloat16 datatype --- whisper/decoding.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/whisper/decoding.py b/whisper/decoding.py index ecd98a4..18cae1f 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -110,7 +110,13 @@ class DecodingOptions: # implementation details fp16: bool = True # use fp16 for most of the calculation + bf16: bool = False # use bf16 for most of the calculation + def __post_init__(self): + if self.fp16 and self.bf16: + raise ValueError("Both fp16 and bf16 cannot be True at the same time") + if self.bf16: + object.__setattr__(self, "fp16", False) @dataclass(frozen=True) class DecodingResult: @@ -650,7 +656,9 @@ class DecodingTask: audio_features = self.model.encoder(mel) if audio_features.dtype != ( - torch.float16 if self.options.fp16 else torch.float32 + torch.float16 if self.options.fp16 else + torch.bfloat16 if self.options.bf16 else + torch.float32 ): return TypeError( f"audio_features has an incorrect dtype: {audio_features.dtype}" From a0936816d5624f975a039ffeb1ca25d473e26d5b Mon Sep 17 00:00:00 2001 From: Pramod Pai <108906450+devpramod-intel@users.noreply.github.com> Date: Wed, 1 Nov 2023 18:23:04 -0400 Subject: [PATCH 2/2] Update transcribe.py --- whisper/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 6e43a22..7df01a7 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -114,7 +114,7 @@ def transcribe( warnings.warn("FP16 is not supported on CPU; using FP32 instead") dtype = torch.float32 - if dtype == torch.float32: + if dtype == torch.float32 or dtype == torch.bfloat16: decode_options["fp16"] = False # Pad 30-seconds of silence to the input audio, for slicing