mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 22:45:52 +00:00
Add HPU checking in decoding and transcribe handle
This commit is contained in:
parent
d36696f808
commit
b0cf21b9b5
@ -157,7 +157,7 @@ def load_model(
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model = Whisper(dims, device=torch.device(device))
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
@ -170,9 +170,8 @@ def load_model(
|
||||
if torch.hpu.is_available():
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
|
||||
model = model.eval().to(device)
|
||||
|
||||
model = wrap_in_hpu_graph(model)
|
||||
model = model.eval().to(torch.device(device))
|
||||
|
||||
return model
|
||||
return model.to(device)
|
||||
|
||||
@ -8,6 +8,7 @@ from torch import Tensor
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .hpu_utils import is_hpu_device
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
from .utils import compression_ratio
|
||||
|
||||
@ -456,7 +457,17 @@ class ApplyTimestampRules(LogitFilter):
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
if is_hpu_device(tokens.device):
|
||||
"""
|
||||
If tokens are on HPU, `sampled_tokens` is cloned to force evaluation.
|
||||
|
||||
On Habana HPUs, tensors may use lazy execution, which can lead to runtime errors if not explicitly
|
||||
evaluated. Cloning `sampled_tokens` ensures it is fully evaluated on the HPU, preventing potential
|
||||
synchronization issues.
|
||||
"""
|
||||
sampled_tokens = tokens[k, self.sample_begin :].clone()
|
||||
else:
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
seq = [t for t in sampled_tokens.tolist()]
|
||||
last_was_timestamp = (
|
||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
|
||||
@ -18,6 +18,7 @@ from .audio import (
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .hpu_utils import is_hpu_device
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (
|
||||
@ -126,6 +127,11 @@ def transcribe(
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if is_hpu_device(model.device):
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on HPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user