{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from util import transcribe, calculate_wer, load_model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/emm12/repos/whisper/whisper/transcribe.py:132: UserWarning: FP16 is not supported on CPU; using FP32 instead\n", " warnings.warn(\"FP16 is not supported on CPU; using FP32 instead\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(' The following content is provided under a Creative Commons license.', 0.5699582919478416)\n" ] } ], "source": [ "model = load_model(ff=True, cut_region=(750,1000))\n", "transcriptions = transcribe(model, 'test_data/5s/out000.wav')\n", "print(transcriptions)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "tests = [f'out{i:03d}' for i in range(2)]\n", "audio_paths_30 = [f\"test_data/30s/{t}.wav\" for t in tests]\n", "transcript_paths_30 = [f\"test_transcripts_before/30s/{t}.txt\" for t in tests]\n", "audio_paths_5 = [f\"test_data/5s/{t}.wav\" for t in tests]\n", "transcript_paths_5 = [f\"test_transcripts_before/5s/{t}.txt\" for t in tests]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def eval(model, audio_paths, transcript_paths):\n", " WER = []\n", " TIME = []\n", "\n", " for audio_path, transcript_path in zip(audio_paths, transcript_paths):\n", " hypothesis, elapsed_time = transcribe(model, audio_path)\n", " with open(transcript_path, \"r\") as f:\n", " reference = f.read()\n", "\n", " TIME.append(elapsed_time)\n", "\n", " wer = calculate_wer(hypothesis, reference)\n", " WER.append(wer)\n", " print(f\"Transcription: {hypothesis}\")\n", " print(f\"Reference: {reference}\")\n", " print(f\"Elapsed time: {elapsed_time:.4f}s\")\n", " print(f\"WER: {wer:.4f}\")\n", " print()\n", " return WER, TIME" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "\n", "models = {\n", " \"30s_with_hueristic\": load_model(\"tiny.en\", ff=True),\n", " \"30s_without_hueristic\": load_model(\"tiny.en\", ff=False),\n", " \"5s_without_hueristic\": load_model(\"tiny.en\", ff=False),\n", " \"5s_with_hueristic\": load_model(\"tiny.en\", ff=True),\n", "}\n", "\n", "metrics = {}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for model_name, model in models.items():\n", " print(f\"Model: {model_name}\")\n", " if \"30s\" in model_name:\n", " WER, TIME = eval(model, audio_paths_30, transcript_paths_30)\n", " else:\n", " WER, TIME = eval(model, audio_paths_5, transcript_paths_5)\n", " metrics[model_name] = {\n", " \"WER\": WER,\n", " \"TIME\": TIME,\n", " }\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "# bar chart\n", "fig, ax = plt.subplots(2, 1, figsize=(10, 10))\n", "for model_name, metric in metrics.items():\n", " ax[0].bar(model_name, sum(metric[\"WER\"]) / len(metric[\"WER\"]), label=model_name)\n", " ax[1].bar(model_name, sum(metric[\"TIME\"]) / len(metric[\"TIME\"]), label=model_name)\n", "\n", "ax[0].set_title(\"Average WER\")\n", "ax[0].set_ylabel(\"WER\")\n", "ax[0].legend()\n", "ax[1].set_title(\"Average Time\")\n", "ax[1].set_ylabel(\"Time (s)\")\n", "ax[1].legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 2 }