mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Add HPU checking in decoding and transcribe handle
This commit is contained in:
parent
d36696f808
commit
b0cf21b9b5
@ -157,7 +157,7 @@ def load_model(
|
|||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
model = Whisper(dims)
|
model = Whisper(dims, device=torch.device(device))
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
@ -170,9 +170,8 @@ def load_model(
|
|||||||
if torch.hpu.is_available():
|
if torch.hpu.is_available():
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
|
|
||||||
model = model.eval().to(device)
|
|
||||||
|
|
||||||
model = wrap_in_hpu_graph(model)
|
model = wrap_in_hpu_graph(model)
|
||||||
|
model = model.eval().to(torch.device(device))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from torch import Tensor
|
|||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
from .audio import CHUNK_LENGTH
|
from .audio import CHUNK_LENGTH
|
||||||
|
from .hpu_utils import is_hpu_device
|
||||||
from .tokenizer import Tokenizer, get_tokenizer
|
from .tokenizer import Tokenizer, get_tokenizer
|
||||||
from .utils import compression_ratio
|
from .utils import compression_ratio
|
||||||
|
|
||||||
@ -456,7 +457,17 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
|
|
||||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||||
for k in range(tokens.shape[0]):
|
for k in range(tokens.shape[0]):
|
||||||
sampled_tokens = tokens[k, self.sample_begin :]
|
if is_hpu_device(tokens.device):
|
||||||
|
"""
|
||||||
|
If tokens are on HPU, `sampled_tokens` is cloned to force evaluation.
|
||||||
|
|
||||||
|
On Habana HPUs, tensors may use lazy execution, which can lead to runtime errors if not explicitly
|
||||||
|
evaluated. Cloning `sampled_tokens` ensures it is fully evaluated on the HPU, preventing potential
|
||||||
|
synchronization issues.
|
||||||
|
"""
|
||||||
|
sampled_tokens = tokens[k, self.sample_begin :].clone()
|
||||||
|
else:
|
||||||
|
sampled_tokens = tokens[k, self.sample_begin :]
|
||||||
seq = [t for t in sampled_tokens.tolist()]
|
seq = [t for t in sampled_tokens.tolist()]
|
||||||
last_was_timestamp = (
|
last_was_timestamp = (
|
||||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from .audio import (
|
|||||||
pad_or_trim,
|
pad_or_trim,
|
||||||
)
|
)
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
|
from .hpu_utils import is_hpu_device
|
||||||
from .timing import add_word_timestamps
|
from .timing import add_word_timestamps
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -126,6 +127,11 @@ def transcribe(
|
|||||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
|
if is_hpu_device(model.device):
|
||||||
|
if dtype == torch.float16:
|
||||||
|
warnings.warn("FP16 is not supported on HPU; using FP32 instead")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user