mirror of
https://github.com/openai/whisper.git
synced 2025-11-25 23:15:57 +00:00
added cut region param for grid search
This commit is contained in:
parent
e59538ea21
commit
e08bd26fce
4
util.py
4
util.py
@ -3,8 +3,8 @@ import whisper
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
def load_model(model_name: str = "tiny.en", ff: bool = False) -> whisper.Whisper:
|
def load_model(model_name: str = "tiny.en", ff: bool = False, cut_region=None) -> whisper.Whisper:
|
||||||
return whisper.load_model(model_name, ext_feature_flag=ff)
|
return whisper.load_model(model_name, ext_feature_flag=ff, cut_region=cut_region)
|
||||||
|
|
||||||
|
|
||||||
def transcribe(model: whisper.Whisper, audio_path: str) -> Tuple[str, float]:
|
def transcribe(model: whisper.Whisper, audio_path: str) -> Tuple[str, float]:
|
||||||
|
|||||||
@ -106,6 +106,7 @@ def load_model(
|
|||||||
download_root: str = None,
|
download_root: str = None,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
ext_feature_flag: bool = False,
|
ext_feature_flag: bool = False,
|
||||||
|
cut_region: Optional[tuple] = None,
|
||||||
) -> Whisper:
|
) -> Whisper:
|
||||||
"""
|
"""
|
||||||
Load a Whisper ASR model
|
Load a Whisper ASR model
|
||||||
@ -152,7 +153,7 @@ def load_model(
|
|||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
model = Whisper(dims, ext_feat_flag=ext_feature_flag)
|
model = Whisper(dims, ext_feat_flag=ext_feature_flag, cut_region=cut_region)
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
|
|||||||
@ -180,20 +180,23 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class AudioEncoderTokenPruner():
|
class AudioEncoderTokenPruner():
|
||||||
def __init__(self, n_extension: int):
|
def __init__(self, n_extension: int, cut_region: Tuple[int, int]):
|
||||||
self.n_extension = n_extension
|
self.n_extension = n_extension
|
||||||
|
|
||||||
def prune(self, x: Tensor, positional_embedding: Tensor):
|
def prune(self, x: Tensor, positional_embedding: Tensor, cut_region: Tuple[int, int]=[750, 1000]):
|
||||||
audio_length = int((x.shape[1] + 1) // 2)
|
# audio_length = int((x.shape[1] + 1) // 2)
|
||||||
|
# [0-950, -----, 1300-1500]
|
||||||
|
|
||||||
pos_emb = torch.concat((
|
pos_emb = torch.concat((
|
||||||
positional_embedding[:audio_length + self.n_extension, :],
|
positional_embedding[:cut_region[0], :],
|
||||||
positional_embedding[-self.n_extension:, :]), dim=0,
|
torch.zeros_like(positional_embedding[cut_region[0]:cut_region[1], :], device=x.device),
|
||||||
|
positional_embedding[cut_region[1]:,:]), dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# extend the x's first dimension by n_extension
|
|
||||||
x = torch.concat((
|
x = torch.concat((
|
||||||
x[:, :audio_length + self.n_extension, :],
|
x[:, :cut_region[0], :],
|
||||||
x[:, -self.n_extension:, :]), dim=1,
|
torch.zeros_like(x[:, cut_region[0]:cut_region[1], :], device=x.device),
|
||||||
|
x[:, cut_region[1]:,:]), dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = (x + pos_emb).to(x.dtype)
|
x = (x + pos_emb).to(x.dtype)
|
||||||
@ -202,7 +205,7 @@ class AudioEncoderTokenPruner():
|
|||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False
|
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=[750, 1000]
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
@ -216,7 +219,7 @@ class AudioEncoder(nn.Module):
|
|||||||
self.ln_post = LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
self.ext_feat_flag = ext_feat_flag
|
self.ext_feat_flag = ext_feat_flag
|
||||||
if ext_feat_flag:
|
if ext_feat_flag:
|
||||||
self.token_pruner = AudioEncoderTokenPruner(n_extension=200)
|
self.token_pruner = AudioEncoderTokenPruner(n_extension=200, cut_region=cut_region)
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
"""
|
"""
|
||||||
@ -287,7 +290,7 @@ class TextDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Whisper(nn.Module):
|
class Whisper(nn.Module):
|
||||||
def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False):
|
def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.encoder = AudioEncoder(
|
self.encoder = AudioEncoder(
|
||||||
@ -313,6 +316,10 @@ class Whisper(nn.Module):
|
|||||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||||
self.ext_feat_flag = ext_feat_flag
|
self.ext_feat_flag = ext_feat_flag
|
||||||
|
self.cut_region = cut_region
|
||||||
|
|
||||||
|
if self.ext_feat_flag and not self.cut_region:
|
||||||
|
raise ValueError("cut_region must be specified if ext_feat_flag is True")
|
||||||
|
|
||||||
def set_alignment_heads(self, dump: bytes):
|
def set_alignment_heads(self, dump: bytes):
|
||||||
array = np.frombuffer(
|
array = np.frombuffer(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user