whisper/nb.ipynb
2025-02-03 19:45:56 -08:00

195 lines
4.9 KiB
Plaintext
Generated

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from util import transcribe, calculate_wer, load_model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/emm12/repos/whisper/whisper/transcribe.py:132: UserWarning: FP16 is not supported on CPU; using FP32 instead\n",
" warnings.warn(\"FP16 is not supported on CPU; using FP32 instead\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(' The following content is provided under a Creative Commons license.', 0.5699582919478416)\n"
]
}
],
"source": [
"model = load_model(ff=True, cut_region=(750,1000))\n",
"transcriptions = transcribe(model, 'test_data/5s/out000.wav')\n",
"print(transcriptions)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"tests = [f'out{i:03d}' for i in range(2)]\n",
"audio_paths_30 = [f\"test_data/30s/{t}.wav\" for t in tests]\n",
"transcript_paths_30 = [f\"test_transcripts_before/30s/{t}.txt\" for t in tests]\n",
"audio_paths_5 = [f\"test_data/5s/{t}.wav\" for t in tests]\n",
"transcript_paths_5 = [f\"test_transcripts_before/5s/{t}.txt\" for t in tests]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def eval(model, audio_paths, transcript_paths):\n",
" WER = []\n",
" TIME = []\n",
"\n",
" for audio_path, transcript_path in zip(audio_paths, transcript_paths):\n",
" hypothesis, elapsed_time = transcribe(model, audio_path)\n",
" with open(transcript_path, \"r\") as f:\n",
" reference = f.read()\n",
"\n",
" TIME.append(elapsed_time)\n",
"\n",
" wer = calculate_wer(hypothesis, reference)\n",
" WER.append(wer)\n",
" print(f\"Transcription: {hypothesis}\")\n",
" print(f\"Reference: {reference}\")\n",
" print(f\"Elapsed time: {elapsed_time:.4f}s\")\n",
" print(f\"WER: {wer:.4f}\")\n",
" print()\n",
" return WER, TIME"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"\n",
"models = {\n",
" \"30s_with_hueristic\": load_model(\"tiny.en\", ff=True),\n",
" \"30s_without_hueristic\": load_model(\"tiny.en\", ff=False),\n",
" \"5s_without_hueristic\": load_model(\"tiny.en\", ff=False),\n",
" \"5s_with_hueristic\": load_model(\"tiny.en\", ff=True),\n",
"}\n",
"\n",
"metrics = {}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for model_name, model in models.items():\n",
" print(f\"Model: {model_name}\")\n",
" if \"30s\" in model_name:\n",
" WER, TIME = eval(model, audio_paths_30, transcript_paths_30)\n",
" else:\n",
" WER, TIME = eval(model, audio_paths_5, transcript_paths_5)\n",
" metrics[model_name] = {\n",
" \"WER\": WER,\n",
" \"TIME\": TIME,\n",
" }\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# bar chart\n",
"fig, ax = plt.subplots(2, 1, figsize=(10, 10))\n",
"for model_name, metric in metrics.items():\n",
" ax[0].bar(model_name, sum(metric[\"WER\"]) / len(metric[\"WER\"]), label=model_name)\n",
" ax[1].bar(model_name, sum(metric[\"TIME\"]) / len(metric[\"TIME\"]), label=model_name)\n",
"\n",
"ax[0].set_title(\"Average WER\")\n",
"ax[0].set_ylabel(\"WER\")\n",
"ax[0].legend()\n",
"ax[1].set_title(\"Average Time\")\n",
"ax[1].set_ylabel(\"Time (s)\")\n",
"ax[1].legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}