added cut region param for grid search

This commit is contained in:
Elijah Melton 2025-02-03 19:45:56 -08:00
parent e59538ea21
commit e08bd26fce
4 changed files with 88 additions and 427 deletions

479
nb.ipynb generated

File diff suppressed because one or more lines are too long

View File

@ -3,8 +3,8 @@ import whisper
from typing import Tuple
import matplotlib.pyplot as plt
def load_model(model_name: str = "tiny.en", ff: bool = False) -> whisper.Whisper:
return whisper.load_model(model_name, ext_feature_flag=ff)
def load_model(model_name: str = "tiny.en", ff: bool = False, cut_region=None) -> whisper.Whisper:
return whisper.load_model(model_name, ext_feature_flag=ff, cut_region=cut_region)
def transcribe(model: whisper.Whisper, audio_path: str) -> Tuple[str, float]:

View File

@ -106,6 +106,7 @@ def load_model(
download_root: str = None,
in_memory: bool = False,
ext_feature_flag: bool = False,
cut_region: Optional[tuple] = None,
) -> Whisper:
"""
Load a Whisper ASR model
@ -152,7 +153,7 @@ def load_model(
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims, ext_feat_flag=ext_feature_flag)
model = Whisper(dims, ext_feat_flag=ext_feature_flag, cut_region=cut_region)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:

View File

@ -180,20 +180,23 @@ class ResidualAttentionBlock(nn.Module):
return x
class AudioEncoderTokenPruner():
def __init__(self, n_extension: int):
def __init__(self, n_extension: int, cut_region: Tuple[int, int]):
self.n_extension = n_extension
def prune(self, x: Tensor, positional_embedding: Tensor):
audio_length = int((x.shape[1] + 1) // 2)
def prune(self, x: Tensor, positional_embedding: Tensor, cut_region: Tuple[int, int]=[750, 1000]):
# audio_length = int((x.shape[1] + 1) // 2)
# [0-950, -----, 1300-1500]
pos_emb = torch.concat((
positional_embedding[:audio_length + self.n_extension, :],
positional_embedding[-self.n_extension:, :]), dim=0,
positional_embedding[:cut_region[0], :],
torch.zeros_like(positional_embedding[cut_region[0]:cut_region[1], :], device=x.device),
positional_embedding[cut_region[1]:,:]), dim=0,
)
# extend the x's first dimension by n_extension
x = torch.concat((
x[:, :audio_length + self.n_extension, :],
x[:, -self.n_extension:, :]), dim=1,
x[:, :cut_region[0], :],
torch.zeros_like(x[:, cut_region[0]:cut_region[1], :], device=x.device),
x[:, cut_region[1]:,:]), dim=1,
)
x = (x + pos_emb).to(x.dtype)
@ -202,7 +205,7 @@ class AudioEncoderTokenPruner():
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=[750, 1000]
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
@ -216,7 +219,7 @@ class AudioEncoder(nn.Module):
self.ln_post = LayerNorm(n_state)
self.ext_feat_flag = ext_feat_flag
if ext_feat_flag:
self.token_pruner = AudioEncoderTokenPruner(n_extension=200)
self.token_pruner = AudioEncoderTokenPruner(n_extension=200, cut_region=cut_region)
def forward(self, x: Tensor):
"""
@ -287,7 +290,7 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False):
def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=None):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
@ -313,6 +316,10 @@ class Whisper(nn.Module):
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
self.ext_feat_flag = ext_feat_flag
self.cut_region = cut_region
if self.ext_feat_flag and not self.cut_region:
raise ValueError("cut_region must be specified if ext_feat_flag is True")
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(