diff --git a/whisper/__init__.py b/whisper/__init__.py index e8ed005..3c027d7 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -157,7 +157,7 @@ def load_model( del checkpoint_file dims = ModelDimensions(**checkpoint["dims"]) - model = Whisper(dims) + model = Whisper(dims, device=torch.device(device)) model.load_state_dict(checkpoint["model_state_dict"]) if alignment_heads is not None: @@ -170,9 +170,8 @@ def load_model( if torch.hpu.is_available(): from habana_frameworks.torch.hpu import wrap_in_hpu_graph - model = model.eval().to(device) - model = wrap_in_hpu_graph(model) + model = model.eval().to(torch.device(device)) return model return model.to(device) diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d0..07949aa 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.distributions import Categorical from .audio import CHUNK_LENGTH +from .hpu_utils import is_hpu_device from .tokenizer import Tokenizer, get_tokenizer from .utils import compression_ratio @@ -456,7 +457,17 @@ class ApplyTimestampRules(LogitFilter): # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly for k in range(tokens.shape[0]): - sampled_tokens = tokens[k, self.sample_begin :] + if is_hpu_device(tokens.device): + """ + If tokens are on HPU, `sampled_tokens` is cloned to force evaluation. + + On Habana HPUs, tensors may use lazy execution, which can lead to runtime errors if not explicitly + evaluated. Cloning `sampled_tokens` ensures it is fully evaluated on the HPU, preventing potential + synchronization issues. + """ + sampled_tokens = tokens[k, self.sample_begin :].clone() + else: + sampled_tokens = tokens[k, self.sample_begin :] seq = [t for t in sampled_tokens.tolist()] last_was_timestamp = ( len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin diff --git a/whisper/transcribe.py b/whisper/transcribe.py index ff321a9..7da349e 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -18,6 +18,7 @@ from .audio import ( pad_or_trim, ) from .decoding import DecodingOptions, DecodingResult +from .hpu_utils import is_hpu_device from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import ( @@ -126,6 +127,11 @@ def transcribe( warnings.warn("FP16 is not supported on CPU; using FP32 instead") dtype = torch.float32 + if is_hpu_device(model.device): + if dtype == torch.float16: + warnings.warn("FP16 is not supported on HPU; using FP32 instead") + dtype = torch.float32 + if dtype == torch.float32: decode_options["fp16"] = False