added support for bfloat16 datatype

This commit is contained in:
devpramod-intel 2023-07-19 18:14:08 -04:00
parent b91c907694
commit 0ce59f338e

View File

@ -110,7 +110,13 @@ class DecodingOptions:
# implementation details # implementation details
fp16: bool = True # use fp16 for most of the calculation fp16: bool = True # use fp16 for most of the calculation
bf16: bool = False # use bf16 for most of the calculation
def __post_init__(self):
if self.fp16 and self.bf16:
raise ValueError("Both fp16 and bf16 cannot be True at the same time")
if self.bf16:
object.__setattr__(self, "fp16", False)
@dataclass(frozen=True) @dataclass(frozen=True)
class DecodingResult: class DecodingResult:
@ -650,7 +656,9 @@ class DecodingTask:
audio_features = self.model.encoder(mel) audio_features = self.model.encoder(mel)
if audio_features.dtype != ( if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32 torch.float16 if self.options.fp16 else
torch.bfloat16 if self.options.bf16 else
torch.float32
): ):
return TypeError( return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}" f"audio_features has an incorrect dtype: {audio_features.dtype}"