mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 16:14:00 +00:00
clean up and make functional
This commit is contained in:
parent
36e49d920f
commit
937313cee9
51
main.py
51
main.py
@ -1,51 +0,0 @@
|
|||||||
import whisper
|
|
||||||
import os
|
|
||||||
|
|
||||||
TEST_DATA_BASE = "test_data/"
|
|
||||||
FIVE_SEC_BASE = os.path.join(TEST_DATA_BASE, "5s/")
|
|
||||||
THIRTY_SEC_BASE = os.path.join(TEST_DATA_BASE, "30s/")
|
|
||||||
TRANSCRIPTS_BASE = "test_transcripts_before/"
|
|
||||||
|
|
||||||
model = whisper.load_model("tiny.en")
|
|
||||||
|
|
||||||
def transcribe_baseline(file_name):
|
|
||||||
return model.transcribe(file_name).get('text', '')
|
|
||||||
|
|
||||||
def get_all_files(base_path, count=1000):
|
|
||||||
return [os.path.join(base_path, f"out{i:03d}.wav") for i in range(count)]
|
|
||||||
|
|
||||||
def write_to_file(file_name, text):
|
|
||||||
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
|
||||||
with open(file_name, "w") as f:
|
|
||||||
f.write(text)
|
|
||||||
|
|
||||||
def calculate_wer(hypothesis, actual):
|
|
||||||
hyp_words = hypothesis.strip().lower().split()
|
|
||||||
act_words = actual.strip().lower().split()
|
|
||||||
|
|
||||||
dp = [[0] * (len(hyp_words) + 1) for _ in range(len(act_words) + 1)]
|
|
||||||
|
|
||||||
for i in range(len(act_words) + 1):
|
|
||||||
dp[i][0] = i
|
|
||||||
for j in range(len(hyp_words) + 1):
|
|
||||||
dp[0][j] = j
|
|
||||||
|
|
||||||
for i in range(1, len(act_words) + 1):
|
|
||||||
for j in range(1, len(hyp_words) + 1):
|
|
||||||
if act_words[i - 1] == hyp_words[j - 1]:
|
|
||||||
dp[i][j] = dp[i - 1][j - 1]
|
|
||||||
else:
|
|
||||||
dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + 1)
|
|
||||||
|
|
||||||
total_words = len(act_words)
|
|
||||||
return dp[len(act_words)][len(hyp_words)] / total_words if total_words else float("inf") if hyp_words else 0.0
|
|
||||||
|
|
||||||
def process_files(files, output_base):
|
|
||||||
for file_name in files:
|
|
||||||
hypothesis = transcribe_baseline(file_name)
|
|
||||||
sample_name = os.path.splitext(os.path.basename(file_name))[0]
|
|
||||||
write_to_file(os.path.join(output_base, f"{sample_name}.txt"), hypothesis)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# process_files(get_all_files(FIVE_SEC_BASE), os.path.join(TRANSCRIPTS_BASE, "5s"))
|
|
||||||
process_files(get_all_files(THIRTY_SEC_BASE), os.path.join(TRANSCRIPTS_BASE, "30s"))
|
|
||||||
BIN
test_figs/enc_sim_1_5s.png
Normal file
BIN
test_figs/enc_sim_1_5s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1010 KiB |
BIN
test_figs/enc_sim_2_5s.png
Normal file
BIN
test_figs/enc_sim_2_5s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.0 MiB |
BIN
test_figs/enc_sim_3_30s.png
Normal file
BIN
test_figs/enc_sim_3_30s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 969 KiB |
BIN
test_figs/enc_vis_4_30s.png
Normal file
BIN
test_figs/enc_vis_4_30s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.3 MiB |
55
util.py
Normal file
55
util.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import timeit
|
||||||
|
import whisper
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_name: str = "tiny.en") -> whisper.Whisper:
|
||||||
|
return whisper.load_model(model_name, ext_feature_flag=False)
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe(model: whisper.Whisper, audio_path: str) -> Tuple[str, float]:
|
||||||
|
start_time = timeit.default_timer()
|
||||||
|
transcription = model.transcribe(audio_path).get("text", "")
|
||||||
|
elapsed_time = timeit.default_timer() - start_time
|
||||||
|
return transcription, elapsed_time
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_wer(hypothesis: str, reference: str) -> float:
|
||||||
|
hyp_words = hypothesis.strip().lower().split()
|
||||||
|
ref_words = reference.strip().lower().split()
|
||||||
|
|
||||||
|
if not ref_words:
|
||||||
|
return float("inf") if hyp_words else 0.0
|
||||||
|
|
||||||
|
dp = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)]
|
||||||
|
|
||||||
|
for i in range(len(ref_words) + 1):
|
||||||
|
dp[i][0] = i
|
||||||
|
for j in range(len(hyp_words) + 1):
|
||||||
|
dp[0][j] = j
|
||||||
|
|
||||||
|
for i in range(1, len(ref_words) + 1):
|
||||||
|
for j in range(1, len(hyp_words) + 1):
|
||||||
|
if ref_words[i - 1] == hyp_words[j - 1]:
|
||||||
|
dp[i][j] = dp[i - 1][j - 1]
|
||||||
|
else:
|
||||||
|
dp[i][j] = min(
|
||||||
|
dp[i - 1][j] + 1, # deletion
|
||||||
|
dp[i][j - 1] + 1, # insertion
|
||||||
|
dp[i - 1][j - 1] + 1, # substitution
|
||||||
|
)
|
||||||
|
|
||||||
|
return dp[len(ref_words)][len(hyp_words)] / len(ref_words)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = load_model()
|
||||||
|
audio_path = "test_data/30s/out000.wav"
|
||||||
|
transcript_path = "test_transcripts_before/30s/out000.txt"
|
||||||
|
|
||||||
|
hypothesis, elapsed_time = transcribe(model, audio_path)
|
||||||
|
with open(transcript_path, "r") as f:
|
||||||
|
reference = f.read()
|
||||||
|
|
||||||
|
wer = calculate_wer(hypothesis, reference)
|
||||||
|
print(f"WER: {wer:.4f}")
|
||||||
@ -105,6 +105,7 @@ def load_model(
|
|||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
download_root: str = None,
|
download_root: str = None,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
|
ext_feature_flag: bool = False,
|
||||||
) -> Whisper:
|
) -> Whisper:
|
||||||
"""
|
"""
|
||||||
Load a Whisper ASR model
|
Load a Whisper ASR model
|
||||||
@ -151,7 +152,7 @@ def load_model(
|
|||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
model = Whisper(dims)
|
model = Whisper(dims, ext_feat_flag=ext_feature_flag)
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
|
|||||||
@ -182,17 +182,19 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||||
|
|
||||||
|
|
||||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||||
)
|
)
|
||||||
self.ln_post = LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
self.ext_feat_flag = ext_feat_flag
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
"""
|
"""
|
||||||
@ -205,9 +207,7 @@ class AudioEncoder(nn.Module):
|
|||||||
|
|
||||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
|
|
||||||
FEAT = False
|
if self.ext_feat_flag:
|
||||||
|
|
||||||
if FEAT:
|
|
||||||
n_extension = 200
|
n_extension = 200
|
||||||
audio_length = int((x.shape[2] + 1) // 2)
|
audio_length = int((x.shape[2] + 1) // 2)
|
||||||
pos_emb = torch.concat((
|
pos_emb = torch.concat((
|
||||||
@ -278,7 +278,7 @@ class TextDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Whisper(nn.Module):
|
class Whisper(nn.Module):
|
||||||
def __init__(self, dims: ModelDimensions):
|
def __init__(self, dims: ModelDimensions, ext_feat_flag: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.encoder = AudioEncoder(
|
self.encoder = AudioEncoder(
|
||||||
@ -287,6 +287,7 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_audio_state,
|
self.dims.n_audio_state,
|
||||||
self.dims.n_audio_head,
|
self.dims.n_audio_head,
|
||||||
self.dims.n_audio_layer,
|
self.dims.n_audio_layer,
|
||||||
|
ext_feat_flag=ext_feat_flag,
|
||||||
)
|
)
|
||||||
self.decoder = TextDecoder(
|
self.decoder = TextDecoder(
|
||||||
self.dims.n_vocab,
|
self.dims.n_vocab,
|
||||||
@ -302,6 +303,7 @@ class Whisper(nn.Module):
|
|||||||
)
|
)
|
||||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||||
|
self.ext_feat_flag = ext_feat_flag
|
||||||
|
|
||||||
def set_alignment_heads(self, dump: bytes):
|
def set_alignment_heads(self, dump: bytes):
|
||||||
array = np.frombuffer(
|
array = np.frombuffer(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user