mirror of
https://github.com/openai/whisper.git
synced 2025-11-25 15:06:10 +00:00
added cut region param for grid search
This commit is contained in:
parent
e59538ea21
commit
e08bd26fce
4
util.py
4
util.py
@ -3,8 +3,8 @@ import whisper
|
||||
from typing import Tuple
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def load_model(model_name: str = "tiny.en", ff: bool = False) -> whisper.Whisper:
|
||||
return whisper.load_model(model_name, ext_feature_flag=ff)
|
||||
def load_model(model_name: str = "tiny.en", ff: bool = False, cut_region=None) -> whisper.Whisper:
|
||||
return whisper.load_model(model_name, ext_feature_flag=ff, cut_region=cut_region)
|
||||
|
||||
|
||||
def transcribe(model: whisper.Whisper, audio_path: str) -> Tuple[str, float]:
|
||||
|
||||
@ -106,6 +106,7 @@ def load_model(
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
ext_feature_flag: bool = False,
|
||||
cut_region: Optional[tuple] = None,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
@ -152,7 +153,7 @@ def load_model(
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims, ext_feat_flag=ext_feature_flag)
|
||||
model = Whisper(dims, ext_feat_flag=ext_feature_flag, cut_region=cut_region)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
|
||||
@ -180,20 +180,23 @@ class ResidualAttentionBlock(nn.Module):
|
||||
return x
|
||||
|
||||
class AudioEncoderTokenPruner():
|
||||
def __init__(self, n_extension: int):
|
||||
def __init__(self, n_extension: int, cut_region: Tuple[int, int]):
|
||||
self.n_extension = n_extension
|
||||
|
||||
def prune(self, x: Tensor, positional_embedding: Tensor):
|
||||
audio_length = int((x.shape[1] + 1) // 2)
|
||||
def prune(self, x: Tensor, positional_embedding: Tensor, cut_region: Tuple[int, int]=[750, 1000]):
|
||||
# audio_length = int((x.shape[1] + 1) // 2)
|
||||
# [0-950, -----, 1300-1500]
|
||||
|
||||
pos_emb = torch.concat((
|
||||
positional_embedding[:audio_length + self.n_extension, :],
|
||||
positional_embedding[-self.n_extension:, :]), dim=0,
|
||||
positional_embedding[:cut_region[0], :],
|
||||
torch.zeros_like(positional_embedding[cut_region[0]:cut_region[1], :], device=x.device),
|
||||
positional_embedding[cut_region[1]:,:]), dim=0,
|
||||
)
|
||||
|
||||
# extend the x's first dimension by n_extension
|
||||
x = torch.concat((
|
||||
x[:, :audio_length + self.n_extension, :],
|
||||
x[:, -self.n_extension:, :]), dim=1,
|
||||
x[:, :cut_region[0], :],
|
||||
torch.zeros_like(x[:, cut_region[0]:cut_region[1], :], device=x.device),
|
||||
x[:, cut_region[1]:,:]), dim=1,
|
||||
)
|
||||
|
||||
x = (x + pos_emb).to(x.dtype)
|
||||
@ -202,7 +205,7 @@ class AudioEncoderTokenPruner():
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=[750, 1000]
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
@ -216,7 +219,7 @@ class AudioEncoder(nn.Module):
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
self.ext_feat_flag = ext_feat_flag
|
||||
if ext_feat_flag:
|
||||
self.token_pruner = AudioEncoderTokenPruner(n_extension=200)
|
||||
self.token_pruner = AudioEncoderTokenPruner(n_extension=200, cut_region=cut_region)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
@ -287,7 +290,7 @@ class TextDecoder(nn.Module):
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False):
|
||||
def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=None):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
@ -313,6 +316,10 @@ class Whisper(nn.Module):
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
self.ext_feat_flag = ext_feat_flag
|
||||
self.cut_region = cut_region
|
||||
|
||||
if self.ext_feat_flag and not self.cut_region:
|
||||
raise ValueError("cut_region must be specified if ext_feat_flag is True")
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user