In [1]:
import matplotlib.pyplot as plt
from util import transcribe, calculate_wer, load_model

In [2]:
model = load_model(ff=True, cut_region=(750,1000))
transcriptions = transcribe(model, 'test_data/5s/out000.wav')
print(transcriptions)



(' The following content is provided under a Creative Commons license.', 0.5699582919478416)


In [3]:
tests = [f'out{i:03d}' for i in range(2)]
audio_paths_30 = [f"test_data/30s/{t}.wav" for t in tests]
transcript_paths_30 = [f"test_transcripts_before/30s/{t}.txt" for t in tests]
audio_paths_5 = [f"test_data/5s/{t}.wav" for t in tests]
transcript_paths_5 = [f"test_transcripts_before/5s/{t}.txt" for t in tests]

In [4]:
def eval(model, audio_paths, transcript_paths):
 WER = []
 TIME = []

 for audio_path, transcript_path in zip(audio_paths, transcript_paths):
 hypothesis, elapsed_time = transcribe(model, audio_path)
 with open(transcript_path, "r") as f:
 reference = f.read()

 TIME.append(elapsed_time)

 wer = calculate_wer(hypothesis, reference)
 WER.append(wer)
 print(f"Transcription: {hypothesis}")
 print(f"Reference: {reference}")
 print(f"Elapsed time: {elapsed_time:.4f}s")
 print(f"WER: {wer:.4f}")
 print()
 return WER, TIME

In [5]:

models = {
 "30s_with_hueristic": load_model("tiny.en", ff=True),
 "30s_without_hueristic": load_model("tiny.en", ff=False),
 "5s_without_hueristic": load_model("tiny.en", ff=False),
 "5s_with_hueristic": load_model("tiny.en", ff=True),
}

metrics = {}

In [None]:
for model_name, model in models.items():
 print(f"Model: {model_name}")
 if "30s" in model_name:
 WER, TIME = eval(model, audio_paths_30, transcript_paths_30)
 else:
 WER, TIME = eval(model, audio_paths_5, transcript_paths_5)
 metrics[model_name] = {
 "WER": WER,
 "TIME": TIME,
 }


In [None]:

# bar chart
fig, ax = plt.subplots(2, 1, figsize=(10, 10))
for model_name, metric in metrics.items():
 ax[0].bar(model_name, sum(metric["WER"]) / len(metric["WER"]), label=model_name)
 ax[1].bar(model_name, sum(metric["TIME"]) / len(metric["TIME"]), label=model_name)

ax[0].set_title("Average WER")
ax[0].set_ylabel("WER")
ax[0].legend()
ax[1].set_title("Average Time")
ax[1].set_ylabel("Time (s)")
ax[1].legend()
plt.show()