Add HPU checking in decoding and transcribe handle

This commit is contained in:
PiotrBLL 2024-11-11 22:21:12 +01:00
parent d36696f808
commit b0cf21b9b5
3 changed files with 20 additions and 4 deletions

View File

@ -157,7 +157,7 @@ def load_model(
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims, device=torch.device(device))
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None: if alignment_heads is not None:
@ -170,9 +170,8 @@ def load_model(
if torch.hpu.is_available(): if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = model.eval().to(device)
model = wrap_in_hpu_graph(model) model = wrap_in_hpu_graph(model)
model = model.eval().to(torch.device(device))
return model return model
return model.to(device) return model.to(device)

View File

@ -8,6 +8,7 @@ from torch import Tensor
from torch.distributions import Categorical from torch.distributions import Categorical
from .audio import CHUNK_LENGTH from .audio import CHUNK_LENGTH
from .hpu_utils import is_hpu_device
from .tokenizer import Tokenizer, get_tokenizer from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio from .utils import compression_ratio
@ -456,6 +457,16 @@ class ApplyTimestampRules(LogitFilter):
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]): for k in range(tokens.shape[0]):
if is_hpu_device(tokens.device):
"""
If tokens are on HPU, `sampled_tokens` is cloned to force evaluation.
On Habana HPUs, tensors may use lazy execution, which can lead to runtime errors if not explicitly
evaluated. Cloning `sampled_tokens` ensures it is fully evaluated on the HPU, preventing potential
synchronization issues.
"""
sampled_tokens = tokens[k, self.sample_begin :].clone()
else:
sampled_tokens = tokens[k, self.sample_begin :] sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()] seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = ( last_was_timestamp = (

View File

@ -18,6 +18,7 @@ from .audio import (
pad_or_trim, pad_or_trim,
) )
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .hpu_utils import is_hpu_device
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import ( from .utils import (
@ -126,6 +127,11 @@ def transcribe(
warnings.warn("FP16 is not supported on CPU; using FP32 instead") warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32 dtype = torch.float32
if is_hpu_device(model.device):
if dtype == torch.float16:
warnings.warn("FP16 is not supported on HPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32: if dtype == torch.float32:
decode_options["fp16"] = False decode_options["fp16"] = False