clean up and make functional

This commit is contained in:
Elijah Melton 2025-02-03 17:11:57 -08:00
parent 36e49d920f
commit 937313cee9
8 changed files with 64 additions and 57 deletions

51
main.py
View File

@ -1,51 +0,0 @@
import whisper
import os
TEST_DATA_BASE = "test_data/"
FIVE_SEC_BASE = os.path.join(TEST_DATA_BASE, "5s/")
THIRTY_SEC_BASE = os.path.join(TEST_DATA_BASE, "30s/")
TRANSCRIPTS_BASE = "test_transcripts_before/"
model = whisper.load_model("tiny.en")
def transcribe_baseline(file_name):
return model.transcribe(file_name).get('text', '')
def get_all_files(base_path, count=1000):
return [os.path.join(base_path, f"out{i:03d}.wav") for i in range(count)]
def write_to_file(file_name, text):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, "w") as f:
f.write(text)
def calculate_wer(hypothesis, actual):
hyp_words = hypothesis.strip().lower().split()
act_words = actual.strip().lower().split()
dp = [[0] * (len(hyp_words) + 1) for _ in range(len(act_words) + 1)]
for i in range(len(act_words) + 1):
dp[i][0] = i
for j in range(len(hyp_words) + 1):
dp[0][j] = j
for i in range(1, len(act_words) + 1):
for j in range(1, len(hyp_words) + 1):
if act_words[i - 1] == hyp_words[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + 1)
total_words = len(act_words)
return dp[len(act_words)][len(hyp_words)] / total_words if total_words else float("inf") if hyp_words else 0.0
def process_files(files, output_base):
for file_name in files:
hypothesis = transcribe_baseline(file_name)
sample_name = os.path.splitext(os.path.basename(file_name))[0]
write_to_file(os.path.join(output_base, f"{sample_name}.txt"), hypothesis)
if __name__ == "__main__":
# process_files(get_all_files(FIVE_SEC_BASE), os.path.join(TRANSCRIPTS_BASE, "5s"))
process_files(get_all_files(THIRTY_SEC_BASE), os.path.join(TRANSCRIPTS_BASE, "30s"))

BIN
test_figs/enc_sim_1_5s.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1010 KiB

BIN
test_figs/enc_sim_2_5s.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

BIN
test_figs/enc_sim_3_30s.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 969 KiB

BIN
test_figs/enc_vis_4_30s.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

55
util.py Normal file
View File

@ -0,0 +1,55 @@
import timeit
import whisper
from typing import Tuple
def load_model(model_name: str = "tiny.en") -> whisper.Whisper:
return whisper.load_model(model_name, ext_feature_flag=False)
def transcribe(model: whisper.Whisper, audio_path: str) -> Tuple[str, float]:
start_time = timeit.default_timer()
transcription = model.transcribe(audio_path).get("text", "")
elapsed_time = timeit.default_timer() - start_time
return transcription, elapsed_time
def calculate_wer(hypothesis: str, reference: str) -> float:
hyp_words = hypothesis.strip().lower().split()
ref_words = reference.strip().lower().split()
if not ref_words:
return float("inf") if hyp_words else 0.0
dp = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)]
for i in range(len(ref_words) + 1):
dp[i][0] = i
for j in range(len(hyp_words) + 1):
dp[0][j] = j
for i in range(1, len(ref_words) + 1):
for j in range(1, len(hyp_words) + 1):
if ref_words[i - 1] == hyp_words[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(
dp[i - 1][j] + 1, # deletion
dp[i][j - 1] + 1, # insertion
dp[i - 1][j - 1] + 1, # substitution
)
return dp[len(ref_words)][len(hyp_words)] / len(ref_words)
if __name__ == "__main__":
model = load_model()
audio_path = "test_data/30s/out000.wav"
transcript_path = "test_transcripts_before/30s/out000.txt"
hypothesis, elapsed_time = transcribe(model, audio_path)
with open(transcript_path, "r") as f:
reference = f.read()
wer = calculate_wer(hypothesis, reference)
print(f"WER: {wer:.4f}")

View File

@ -105,6 +105,7 @@ def load_model(
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
download_root: str = None, download_root: str = None,
in_memory: bool = False, in_memory: bool = False,
ext_feature_flag: bool = False,
) -> Whisper: ) -> Whisper:
""" """
Load a Whisper ASR model Load a Whisper ASR model
@ -151,7 +152,7 @@ def load_model(
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims, ext_feat_flag=ext_feature_flag)
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None: if alignment_heads is not None:

View File

@ -182,17 +182,19 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module): class AudioEncoder(nn.Module):
def __init__( def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False
): ):
super().__init__() super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
) )
self.ln_post = LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
self.ext_feat_flag = ext_feat_flag
def forward(self, x: Tensor): def forward(self, x: Tensor):
""" """
@ -205,9 +207,7 @@ class AudioEncoder(nn.Module):
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
FEAT = False if self.ext_feat_flag:
if FEAT:
n_extension = 200 n_extension = 200
audio_length = int((x.shape[2] + 1) // 2) audio_length = int((x.shape[2] + 1) // 2)
pos_emb = torch.concat(( pos_emb = torch.concat((
@ -278,7 +278,7 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module): class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions): def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.encoder = AudioEncoder( self.encoder = AudioEncoder(
@ -287,6 +287,7 @@ class Whisper(nn.Module):
self.dims.n_audio_state, self.dims.n_audio_state,
self.dims.n_audio_head, self.dims.n_audio_head,
self.dims.n_audio_layer, self.dims.n_audio_layer,
ext_feat_flag=ext_feat_flag,
) )
self.decoder = TextDecoder( self.decoder = TextDecoder(
self.dims.n_vocab, self.dims.n_vocab,
@ -302,6 +303,7 @@ class Whisper(nn.Module):
) )
all_heads[self.dims.n_text_layer // 2 :] = True all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
self.ext_feat_flag = ext_feat_flag
def set_alignment_heads(self, dump: bytes): def set_alignment_heads(self, dump: bytes):
array = np.frombuffer( array = np.frombuffer(