mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 08:11:11 +00:00
clean up and make functional
This commit is contained in:
parent
36e49d920f
commit
937313cee9
51
main.py
51
main.py
@ -1,51 +0,0 @@
|
||||
import whisper
|
||||
import os
|
||||
|
||||
TEST_DATA_BASE = "test_data/"
|
||||
FIVE_SEC_BASE = os.path.join(TEST_DATA_BASE, "5s/")
|
||||
THIRTY_SEC_BASE = os.path.join(TEST_DATA_BASE, "30s/")
|
||||
TRANSCRIPTS_BASE = "test_transcripts_before/"
|
||||
|
||||
model = whisper.load_model("tiny.en")
|
||||
|
||||
def transcribe_baseline(file_name):
|
||||
return model.transcribe(file_name).get('text', '')
|
||||
|
||||
def get_all_files(base_path, count=1000):
|
||||
return [os.path.join(base_path, f"out{i:03d}.wav") for i in range(count)]
|
||||
|
||||
def write_to_file(file_name, text):
|
||||
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(text)
|
||||
|
||||
def calculate_wer(hypothesis, actual):
|
||||
hyp_words = hypothesis.strip().lower().split()
|
||||
act_words = actual.strip().lower().split()
|
||||
|
||||
dp = [[0] * (len(hyp_words) + 1) for _ in range(len(act_words) + 1)]
|
||||
|
||||
for i in range(len(act_words) + 1):
|
||||
dp[i][0] = i
|
||||
for j in range(len(hyp_words) + 1):
|
||||
dp[0][j] = j
|
||||
|
||||
for i in range(1, len(act_words) + 1):
|
||||
for j in range(1, len(hyp_words) + 1):
|
||||
if act_words[i - 1] == hyp_words[j - 1]:
|
||||
dp[i][j] = dp[i - 1][j - 1]
|
||||
else:
|
||||
dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + 1)
|
||||
|
||||
total_words = len(act_words)
|
||||
return dp[len(act_words)][len(hyp_words)] / total_words if total_words else float("inf") if hyp_words else 0.0
|
||||
|
||||
def process_files(files, output_base):
|
||||
for file_name in files:
|
||||
hypothesis = transcribe_baseline(file_name)
|
||||
sample_name = os.path.splitext(os.path.basename(file_name))[0]
|
||||
write_to_file(os.path.join(output_base, f"{sample_name}.txt"), hypothesis)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# process_files(get_all_files(FIVE_SEC_BASE), os.path.join(TRANSCRIPTS_BASE, "5s"))
|
||||
process_files(get_all_files(THIRTY_SEC_BASE), os.path.join(TRANSCRIPTS_BASE, "30s"))
|
||||
BIN
test_figs/enc_sim_1_5s.png
Normal file
BIN
test_figs/enc_sim_1_5s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1010 KiB |
BIN
test_figs/enc_sim_2_5s.png
Normal file
BIN
test_figs/enc_sim_2_5s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.0 MiB |
BIN
test_figs/enc_sim_3_30s.png
Normal file
BIN
test_figs/enc_sim_3_30s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 969 KiB |
BIN
test_figs/enc_vis_4_30s.png
Normal file
BIN
test_figs/enc_vis_4_30s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.3 MiB |
55
util.py
Normal file
55
util.py
Normal file
@ -0,0 +1,55 @@
|
||||
import timeit
|
||||
import whisper
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def load_model(model_name: str = "tiny.en") -> whisper.Whisper:
|
||||
return whisper.load_model(model_name, ext_feature_flag=False)
|
||||
|
||||
|
||||
def transcribe(model: whisper.Whisper, audio_path: str) -> Tuple[str, float]:
|
||||
start_time = timeit.default_timer()
|
||||
transcription = model.transcribe(audio_path).get("text", "")
|
||||
elapsed_time = timeit.default_timer() - start_time
|
||||
return transcription, elapsed_time
|
||||
|
||||
|
||||
def calculate_wer(hypothesis: str, reference: str) -> float:
|
||||
hyp_words = hypothesis.strip().lower().split()
|
||||
ref_words = reference.strip().lower().split()
|
||||
|
||||
if not ref_words:
|
||||
return float("inf") if hyp_words else 0.0
|
||||
|
||||
dp = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)]
|
||||
|
||||
for i in range(len(ref_words) + 1):
|
||||
dp[i][0] = i
|
||||
for j in range(len(hyp_words) + 1):
|
||||
dp[0][j] = j
|
||||
|
||||
for i in range(1, len(ref_words) + 1):
|
||||
for j in range(1, len(hyp_words) + 1):
|
||||
if ref_words[i - 1] == hyp_words[j - 1]:
|
||||
dp[i][j] = dp[i - 1][j - 1]
|
||||
else:
|
||||
dp[i][j] = min(
|
||||
dp[i - 1][j] + 1, # deletion
|
||||
dp[i][j - 1] + 1, # insertion
|
||||
dp[i - 1][j - 1] + 1, # substitution
|
||||
)
|
||||
|
||||
return dp[len(ref_words)][len(hyp_words)] / len(ref_words)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = load_model()
|
||||
audio_path = "test_data/30s/out000.wav"
|
||||
transcript_path = "test_transcripts_before/30s/out000.txt"
|
||||
|
||||
hypothesis, elapsed_time = transcribe(model, audio_path)
|
||||
with open(transcript_path, "r") as f:
|
||||
reference = f.read()
|
||||
|
||||
wer = calculate_wer(hypothesis, reference)
|
||||
print(f"WER: {wer:.4f}")
|
||||
@ -105,6 +105,7 @@ def load_model(
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
ext_feature_flag: bool = False,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
@ -151,7 +152,7 @@ def load_model(
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model = Whisper(dims, ext_feat_flag=ext_feature_flag)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
|
||||
@ -182,17 +182,19 @@ class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
self.ext_feat_flag = ext_feat_flag
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
@ -205,9 +207,7 @@ class AudioEncoder(nn.Module):
|
||||
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
|
||||
FEAT = False
|
||||
|
||||
if FEAT:
|
||||
if self.ext_feat_flag:
|
||||
n_extension = 200
|
||||
audio_length = int((x.shape[2] + 1) // 2)
|
||||
pos_emb = torch.concat((
|
||||
@ -278,7 +278,7 @@ class TextDecoder(nn.Module):
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
@ -287,6 +287,7 @@ class Whisper(nn.Module):
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
ext_feat_flag=ext_feat_flag,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
@ -302,6 +303,7 @@ class Whisper(nn.Module):
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
self.ext_feat_flag = ext_feat_flag
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user