mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
add comments
This commit is contained in:
parent
ba3f3cd54b
commit
ed4b0d14a2
367
notebooks/LibriSpeech.ipynb
generated
367
notebooks/LibriSpeech.ipynb
generated
@ -17,11 +17,11 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "ZsJUxc0aRsAf"
|
"id": "ZsJUxc0aRsAf"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"! pip install git+https://github.com/openai/whisper.git\n",
|
"! pip install git+https://github.com/openai/whisper.git\n",
|
||||||
"! pip install jiwer"
|
"! pip install jiwer"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
@ -40,7 +40,6 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "3CqtR2Fi5-vP"
|
"id": "3CqtR2Fi5-vP"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
@ -59,7 +58,8 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
@ -67,7 +67,6 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "GuCCB2KYOJCE"
|
"id": "GuCCB2KYOJCE"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"class LibriSpeech(torch.utils.data.Dataset):\n",
|
"class LibriSpeech(torch.utils.data.Dataset):\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
@ -92,7 +91,8 @@
|
|||||||
" mel = whisper.log_mel_spectrogram(audio)\n",
|
" mel = whisper.log_mel_spectrogram(audio)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" return (mel, text)"
|
" return (mel, text)"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
@ -100,11 +100,11 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "-YcRU5jqNqo2"
|
"id": "-YcRU5jqNqo2"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"dataset = LibriSpeech(\"test-clean\")\n",
|
"dataset = LibriSpeech(\"test-clean\")\n",
|
||||||
"loader = torch.utils.data.DataLoader(dataset, batch_size=16)"
|
"loader = torch.utils.data.DataLoader(dataset, batch_size=16)"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
@ -127,32 +127,24 @@
|
|||||||
"id": "_PokfNJtOYNu",
|
"id": "_PokfNJtOYNu",
|
||||||
"outputId": "2c53ec44-bc93-4107-b4fa-214e3f71fe8e"
|
"outputId": "2c53ec44-bc93-4107-b4fa-214e3f71fe8e"
|
||||||
},
|
},
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Model is English-only and has 71,825,408 parameters.\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"model = whisper.load_model(\"base.en\")\n",
|
"model = whisper.load_model(\"base.en\")\n",
|
||||||
"print(\n",
|
"print(\n",
|
||||||
" f\"Model is {'multilingual' if model.is_multilingual else 'English-only'} \"\n",
|
" f\"Model is {'multilingual' if model.is_multilingual else 'English-only'} \"\n",
|
||||||
" f\"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters.\"\n",
|
" f\"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters.\"\n",
|
||||||
")"
|
")"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# predict without timestamps for short-form transcription\n",
|
"# predict without timestamps for short-form transcription\n",
|
||||||
"options = whisper.DecodingOptions(language=\"en\", without_timestamps=True)"
|
"options = whisper.DecodingOptions(language=\"en\", without_timestamps=True)"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
@ -178,22 +170,6 @@
|
|||||||
"id": "7OWTn_KvNk59",
|
"id": "7OWTn_KvNk59",
|
||||||
"outputId": "a813a792-3c91-4144-f11f-054fd6778023"
|
"outputId": "a813a792-3c91-4144-f11f-054fd6778023"
|
||||||
},
|
},
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "9df048b46f764cf68cbe0045b8ff73a8",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/164 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"hypotheses = []\n",
|
"hypotheses = []\n",
|
||||||
"references = []\n",
|
"references = []\n",
|
||||||
@ -202,7 +178,8 @@
|
|||||||
" results = model.decode(mels, options)\n",
|
" results = model.decode(mels, options)\n",
|
||||||
" hypotheses.extend([result.text for result in results])\n",
|
" hypotheses.extend([result.text for result in results])\n",
|
||||||
" references.extend(texts)"
|
" references.extend(texts)"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
@ -215,132 +192,11 @@
|
|||||||
"id": "4nTyynELQ42j",
|
"id": "4nTyynELQ42j",
|
||||||
"outputId": "1c72d25a-3e87-4c60-a8d1-1da9d2f73bd7"
|
"outputId": "1c72d25a-3e87-4c60-a8d1-1da9d2f73bd7"
|
||||||
},
|
},
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<div>\n",
|
|
||||||
"<style scoped>\n",
|
|
||||||
" .dataframe tbody tr th:only-of-type {\n",
|
|
||||||
" vertical-align: middle;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe tbody tr th {\n",
|
|
||||||
" vertical-align: top;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe thead th {\n",
|
|
||||||
" text-align: right;\n",
|
|
||||||
" }\n",
|
|
||||||
"</style>\n",
|
|
||||||
"<table border=\"1\" class=\"dataframe\">\n",
|
|
||||||
" <thead>\n",
|
|
||||||
" <tr style=\"text-align: right;\">\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th>hypothesis</th>\n",
|
|
||||||
" <th>reference</th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </thead>\n",
|
|
||||||
" <tbody>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>0</th>\n",
|
|
||||||
" <td>He hoped there would be stew for dinner, turni...</td>\n",
|
|
||||||
" <td>HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>1</th>\n",
|
|
||||||
" <td>Stuffered into you, his belly counseled him.</td>\n",
|
|
||||||
" <td>STUFF IT INTO YOU HIS BELLY COUNSELLED HIM</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2</th>\n",
|
|
||||||
" <td>After early nightfall the yellow lamps would l...</td>\n",
|
|
||||||
" <td>AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>3</th>\n",
|
|
||||||
" <td>Hello Bertie, any good in your mind?</td>\n",
|
|
||||||
" <td>HELLO BERTIE ANY GOOD IN YOUR MIND</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>4</th>\n",
|
|
||||||
" <td>Number 10. Fresh Nelly is waiting on you. Good...</td>\n",
|
|
||||||
" <td>NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>...</th>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2615</th>\n",
|
|
||||||
" <td>Oh, to shoot my soul's full meaning into futur...</td>\n",
|
|
||||||
" <td>OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2616</th>\n",
|
|
||||||
" <td>Then I, long tried by natural ills, received t...</td>\n",
|
|
||||||
" <td>THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2617</th>\n",
|
|
||||||
" <td>I love thee freely as men strive for right. I ...</td>\n",
|
|
||||||
" <td>I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2618</th>\n",
|
|
||||||
" <td>I love thee with the passion put to use, in my...</td>\n",
|
|
||||||
" <td>I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2619</th>\n",
|
|
||||||
" <td>I love thee with the love I seemed to lose wit...</td>\n",
|
|
||||||
" <td>I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </tbody>\n",
|
|
||||||
"</table>\n",
|
|
||||||
"<p>2620 rows × 2 columns</p>\n",
|
|
||||||
"</div>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
" hypothesis \\\n",
|
|
||||||
"0 He hoped there would be stew for dinner, turni... \n",
|
|
||||||
"1 Stuffered into you, his belly counseled him. \n",
|
|
||||||
"2 After early nightfall the yellow lamps would l... \n",
|
|
||||||
"3 Hello Bertie, any good in your mind? \n",
|
|
||||||
"4 Number 10. Fresh Nelly is waiting on you. Good... \n",
|
|
||||||
"... ... \n",
|
|
||||||
"2615 Oh, to shoot my soul's full meaning into futur... \n",
|
|
||||||
"2616 Then I, long tried by natural ills, received t... \n",
|
|
||||||
"2617 I love thee freely as men strive for right. I ... \n",
|
|
||||||
"2618 I love thee with the passion put to use, in my... \n",
|
|
||||||
"2619 I love thee with the love I seemed to lose wit... \n",
|
|
||||||
"\n",
|
|
||||||
" reference \n",
|
|
||||||
"0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n",
|
|
||||||
"1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n",
|
|
||||||
"2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n",
|
|
||||||
"3 HELLO BERTIE ANY GOOD IN YOUR MIND \n",
|
|
||||||
"4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n",
|
|
||||||
"... ... \n",
|
|
||||||
"2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n",
|
|
||||||
"2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n",
|
|
||||||
"2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n",
|
|
||||||
"2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n",
|
|
||||||
"2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n",
|
|
||||||
"\n",
|
|
||||||
"[2620 rows x 2 columns]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 8,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))\n",
|
"data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))\n",
|
||||||
"data"
|
"data"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
@ -359,13 +215,13 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "dl-KBDflMhrg"
|
"id": "dl-KBDflMhrg"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import jiwer\n",
|
"import jiwer\n",
|
||||||
"from whisper.normalizers import EnglishTextNormalizer\n",
|
"from whisper.normalizers import EnglishTextNormalizer\n",
|
||||||
"\n",
|
"\n",
|
||||||
"normalizer = EnglishTextNormalizer()"
|
"normalizer = EnglishTextNormalizer()"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
@ -378,183 +234,12 @@
|
|||||||
"id": "6-O048q4WI4o",
|
"id": "6-O048q4WI4o",
|
||||||
"outputId": "f2089bc9-f535-441e-f192-26e52ae82b5e"
|
"outputId": "f2089bc9-f535-441e-f192-26e52ae82b5e"
|
||||||
},
|
},
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<div>\n",
|
|
||||||
"<style scoped>\n",
|
|
||||||
" .dataframe tbody tr th:only-of-type {\n",
|
|
||||||
" vertical-align: middle;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe tbody tr th {\n",
|
|
||||||
" vertical-align: top;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe thead th {\n",
|
|
||||||
" text-align: right;\n",
|
|
||||||
" }\n",
|
|
||||||
"</style>\n",
|
|
||||||
"<table border=\"1\" class=\"dataframe\">\n",
|
|
||||||
" <thead>\n",
|
|
||||||
" <tr style=\"text-align: right;\">\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th>hypothesis</th>\n",
|
|
||||||
" <th>reference</th>\n",
|
|
||||||
" <th>hypothesis_clean</th>\n",
|
|
||||||
" <th>reference_clean</th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </thead>\n",
|
|
||||||
" <tbody>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>0</th>\n",
|
|
||||||
" <td>He hoped there would be stew for dinner, turni...</td>\n",
|
|
||||||
" <td>HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...</td>\n",
|
|
||||||
" <td>he hoped there would be stew for dinner turnip...</td>\n",
|
|
||||||
" <td>he hoped there would be stew for dinner turnip...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>1</th>\n",
|
|
||||||
" <td>Stuffered into you, his belly counseled him.</td>\n",
|
|
||||||
" <td>STUFF IT INTO YOU HIS BELLY COUNSELLED HIM</td>\n",
|
|
||||||
" <td>stuffered into you his belly counseled him</td>\n",
|
|
||||||
" <td>stuff it into you his belly counseled him</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2</th>\n",
|
|
||||||
" <td>After early nightfall the yellow lamps would l...</td>\n",
|
|
||||||
" <td>AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...</td>\n",
|
|
||||||
" <td>after early nightfall the yellow lamps would l...</td>\n",
|
|
||||||
" <td>after early nightfall the yellow lamps would l...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>3</th>\n",
|
|
||||||
" <td>Hello Bertie, any good in your mind?</td>\n",
|
|
||||||
" <td>HELLO BERTIE ANY GOOD IN YOUR MIND</td>\n",
|
|
||||||
" <td>hello bertie any good in your mind</td>\n",
|
|
||||||
" <td>hello bertie any good in your mind</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>4</th>\n",
|
|
||||||
" <td>Number 10. Fresh Nelly is waiting on you. Good...</td>\n",
|
|
||||||
" <td>NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...</td>\n",
|
|
||||||
" <td>number 10 fresh nelly is waiting on you good n...</td>\n",
|
|
||||||
" <td>number 10 fresh nelly is waiting on you good n...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>...</th>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2615</th>\n",
|
|
||||||
" <td>Oh, to shoot my soul's full meaning into futur...</td>\n",
|
|
||||||
" <td>OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...</td>\n",
|
|
||||||
" <td>0 to shoot my soul is full meaning into future...</td>\n",
|
|
||||||
" <td>0 to shoot my soul is full meaning into future...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2616</th>\n",
|
|
||||||
" <td>Then I, long tried by natural ills, received t...</td>\n",
|
|
||||||
" <td>THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...</td>\n",
|
|
||||||
" <td>then i long tried by natural ills received the...</td>\n",
|
|
||||||
" <td>then i long tried by natural ills received the...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2617</th>\n",
|
|
||||||
" <td>I love thee freely as men strive for right. I ...</td>\n",
|
|
||||||
" <td>I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...</td>\n",
|
|
||||||
" <td>i love thee freely as men strive for right i l...</td>\n",
|
|
||||||
" <td>i love thee freely as men strive for right i l...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2618</th>\n",
|
|
||||||
" <td>I love thee with the passion put to use, in my...</td>\n",
|
|
||||||
" <td>I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...</td>\n",
|
|
||||||
" <td>i love thee with the passion put to use in my ...</td>\n",
|
|
||||||
" <td>i love thee with the passion put to use in my ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2619</th>\n",
|
|
||||||
" <td>I love thee with the love I seemed to lose wit...</td>\n",
|
|
||||||
" <td>I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ...</td>\n",
|
|
||||||
" <td>i love thee with the love i seemed to lose wit...</td>\n",
|
|
||||||
" <td>i love thee with a love i seemed to lose with ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </tbody>\n",
|
|
||||||
"</table>\n",
|
|
||||||
"<p>2620 rows × 4 columns</p>\n",
|
|
||||||
"</div>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
" hypothesis \\\n",
|
|
||||||
"0 He hoped there would be stew for dinner, turni... \n",
|
|
||||||
"1 Stuffered into you, his belly counseled him. \n",
|
|
||||||
"2 After early nightfall the yellow lamps would l... \n",
|
|
||||||
"3 Hello Bertie, any good in your mind? \n",
|
|
||||||
"4 Number 10. Fresh Nelly is waiting on you. Good... \n",
|
|
||||||
"... ... \n",
|
|
||||||
"2615 Oh, to shoot my soul's full meaning into futur... \n",
|
|
||||||
"2616 Then I, long tried by natural ills, received t... \n",
|
|
||||||
"2617 I love thee freely as men strive for right. I ... \n",
|
|
||||||
"2618 I love thee with the passion put to use, in my... \n",
|
|
||||||
"2619 I love thee with the love I seemed to lose wit... \n",
|
|
||||||
"\n",
|
|
||||||
" reference \\\n",
|
|
||||||
"0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n",
|
|
||||||
"1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n",
|
|
||||||
"2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n",
|
|
||||||
"3 HELLO BERTIE ANY GOOD IN YOUR MIND \n",
|
|
||||||
"4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n",
|
|
||||||
"... ... \n",
|
|
||||||
"2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n",
|
|
||||||
"2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n",
|
|
||||||
"2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n",
|
|
||||||
"2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n",
|
|
||||||
"2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n",
|
|
||||||
"\n",
|
|
||||||
" hypothesis_clean \\\n",
|
|
||||||
"0 he hoped there would be stew for dinner turnip... \n",
|
|
||||||
"1 stuffered into you his belly counseled him \n",
|
|
||||||
"2 after early nightfall the yellow lamps would l... \n",
|
|
||||||
"3 hello bertie any good in your mind \n",
|
|
||||||
"4 number 10 fresh nelly is waiting on you good n... \n",
|
|
||||||
"... ... \n",
|
|
||||||
"2615 0 to shoot my soul is full meaning into future... \n",
|
|
||||||
"2616 then i long tried by natural ills received the... \n",
|
|
||||||
"2617 i love thee freely as men strive for right i l... \n",
|
|
||||||
"2618 i love thee with the passion put to use in my ... \n",
|
|
||||||
"2619 i love thee with the love i seemed to lose wit... \n",
|
|
||||||
"\n",
|
|
||||||
" reference_clean \n",
|
|
||||||
"0 he hoped there would be stew for dinner turnip... \n",
|
|
||||||
"1 stuff it into you his belly counseled him \n",
|
|
||||||
"2 after early nightfall the yellow lamps would l... \n",
|
|
||||||
"3 hello bertie any good in your mind \n",
|
|
||||||
"4 number 10 fresh nelly is waiting on you good n... \n",
|
|
||||||
"... ... \n",
|
|
||||||
"2615 0 to shoot my soul is full meaning into future... \n",
|
|
||||||
"2616 then i long tried by natural ills received the... \n",
|
|
||||||
"2617 i love thee freely as men strive for right i l... \n",
|
|
||||||
"2618 i love thee with the passion put to use in my ... \n",
|
|
||||||
"2619 i love thee with a love i seemed to lose with ... \n",
|
|
||||||
"\n",
|
|
||||||
"[2620 rows x 4 columns]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"data[\"hypothesis_clean\"] = [normalizer(text) for text in data[\"hypothesis\"]]\n",
|
"data[\"hypothesis_clean\"] = [normalizer(text) for text in data[\"hypothesis\"]]\n",
|
||||||
"data[\"reference_clean\"] = [normalizer(text) for text in data[\"reference\"]]\n",
|
"data[\"reference_clean\"] = [normalizer(text) for text in data[\"reference\"]]\n",
|
||||||
"data"
|
"data"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
@ -566,20 +251,12 @@
|
|||||||
"id": "EBGSITeBYPTT",
|
"id": "EBGSITeBYPTT",
|
||||||
"outputId": "7b3dbe7c-a37e-4a07-a50a-b27d5f88b68f"
|
"outputId": "7b3dbe7c-a37e-4a07-a50a-b27d5f88b68f"
|
||||||
},
|
},
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"WER: 4.26 %\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"wer = jiwer.wer(list(data[\"reference_clean\"]), list(data[\"hypothesis_clean\"]))\n",
|
"wer = jiwer.wer(list(data[\"reference_clean\"]), list(data[\"hypothesis_clean\"]))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(f\"WER: {wer * 100:.2f} %\")"
|
"print(f\"WER: {wer * 100:.2f} %\")"
|
||||||
]
|
],
|
||||||
|
"outputs": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from .model import ModelDimensions, Whisper
|
|||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
|
||||||
|
# what are these models? a: they are the pre-trained models that are available for use
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
@ -46,10 +47,12 @@ _ALIGNMENT_HEADS = {
|
|||||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# q: download the model from the given url and save it to the given root directory
|
||||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
|
|
||||||
|
# what is sha256?
|
||||||
|
# a: it is a cryptographic hash function that produces a fixed-size hash value
|
||||||
expected_sha256 = url.split("/")[-2]
|
expected_sha256 = url.split("/")[-2]
|
||||||
download_target = os.path.join(root, os.path.basename(url))
|
download_target = os.path.join(root, os.path.basename(url))
|
||||||
|
|
||||||
@ -59,6 +62,7 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
if os.path.isfile(download_target):
|
if os.path.isfile(download_target):
|
||||||
with open(download_target, "rb") as f:
|
with open(download_target, "rb") as f:
|
||||||
model_bytes = f.read()
|
model_bytes = f.read()
|
||||||
|
# what is the purpose of this if statement? a: to check if the SHA256 checksum matches
|
||||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||||
return model_bytes if in_memory else download_target
|
return model_bytes if in_memory else download_target
|
||||||
else:
|
else:
|
||||||
@ -66,6 +70,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# is the following line re-downloading the model? a: yes
|
||||||
|
# so this function checks whether the model is already downloaded and if not, it downloads the model?
|
||||||
|
# a: yes
|
||||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
with tqdm(
|
with tqdm(
|
||||||
total=int(source.info().get("Content-Length")),
|
total=int(source.info().get("Content-Length")),
|
||||||
@ -88,14 +95,17 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# is this the checkpoint file? a: yes
|
||||||
return model_bytes if in_memory else download_target
|
return model_bytes if in_memory else download_target
|
||||||
|
|
||||||
|
# q: what is the purpose of this function? a: to return the names of the available models
|
||||||
def available_models() -> List[str]:
|
def available_models() -> List[str]:
|
||||||
"""Returns the names of available models"""
|
"""Returns the names of available models"""
|
||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
# q: what is the purpose of this function? a: to load the model from the given name
|
||||||
|
# what does -> Whisper in Python mean? a: it means that the function returns an object of type Whisper
|
||||||
|
# load 一个模型,返回一个Whisper对象
|
||||||
def load_model(
|
def load_model(
|
||||||
name: str,
|
name: str,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
@ -140,14 +150,24 @@ def load_model(
|
|||||||
f"Model {name} not found; available models = {available_models()}"
|
f"Model {name} not found; available models = {available_models()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# what is "with" in Python?
|
||||||
|
# a: it is used to open a file and automatically close it after the block of code is executed
|
||||||
with (
|
with (
|
||||||
|
# what if checkpoint_file is in memory? a: it uses io.BytesIO to read the file
|
||||||
|
# what if checkpoint_file is not in memory? a: it uses open to read the file
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
) as fp:
|
) as fp:
|
||||||
checkpoint = torch.load(fp, map_location=device)
|
checkpoint = torch.load(fp, map_location=device)
|
||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
|
# what is the **checkpoint["dims"]? a: it unpacks the dictionary into keyword arguments
|
||||||
|
# are arguments in ModelDimensions nullable? a: no
|
||||||
|
# so what if the checkpoint["dims"] is missing? a: it will raise an error
|
||||||
|
# how to confirm that checkpoint contains the "dims" key? a: by checking the keys of the dictionary
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
model = Whisper(dims)
|
model = Whisper(dims)
|
||||||
|
|
||||||
|
# what is load_state_dict? a: it loads the model weights
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
|
|||||||
@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
q: what is the usage of this file? a: this file contains the audio processing functions
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from subprocess import CalledProcessError, run
|
from subprocess import CalledProcessError, run
|
||||||
@ -11,10 +15,13 @@ from .utils import exact_div
|
|||||||
|
|
||||||
# hard-coded audio hyperparameters
|
# hard-coded audio hyperparameters
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
N_FFT = 400
|
N_FFT = 400 # 25ms window
|
||||||
HOP_LENGTH = 160
|
HOP_LENGTH = 160
|
||||||
CHUNK_LENGTH = 30
|
CHUNK_LENGTH = 30
|
||||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
|
|
||||||
|
# what is frame? a: a frame is a short segment of audio, usually 10ms
|
||||||
|
# what is frame used for? a: it is used to compute the spectrogram
|
||||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||||
|
|
||||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||||
@ -135,23 +142,23 @@ def log_mel_spectrogram(
|
|||||||
torch.Tensor, shape = (80, n_frames)
|
torch.Tensor, shape = (80, n_frames)
|
||||||
A Tensor that contains the Mel spectrogram
|
A Tensor that contains the Mel spectrogram
|
||||||
"""
|
"""
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio): # load audio if not already a tensor
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio) # load audio from file
|
||||||
audio = torch.from_numpy(audio)
|
audio = torch.from_numpy(audio) # convert to tensor
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
audio = audio.to(device)
|
audio = audio.to(device)
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
audio = F.pad(audio, (0, padding))
|
audio = F.pad(audio, (0, padding)) # pad audio to the right
|
||||||
window = torch.hann_window(N_FFT).to(audio.device)
|
window = torch.hann_window(N_FFT).to(audio.device) # create a Hann window
|
||||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) # compute STFT
|
||||||
magnitudes = stft[..., :-1].abs() ** 2
|
magnitudes = stft[..., :-1].abs() ** 2 # compute magnitudes
|
||||||
|
|
||||||
filters = mel_filters(audio.device, n_mels)
|
filters = mel_filters(audio.device, n_mels) # load mel filters
|
||||||
mel_spec = filters @ magnitudes
|
mel_spec = filters @ magnitudes # apply mel filters
|
||||||
|
|
||||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10() # compute log spectrogram
|
||||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
log_spec = (log_spec + 4.0) / 4.0
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
return log_spec
|
return log_spec
|
||||||
|
|||||||
@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
q: what is the usage of this file? a: this file contains the audio processing functions
|
||||||
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
@ -152,6 +156,7 @@ class PyTorchInference(Inference):
|
|||||||
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
||||||
self.kv_modules = key_modules + value_modules
|
self.kv_modules = key_modules + value_modules
|
||||||
|
|
||||||
|
# forward pass through the decoder, with key-value caching
|
||||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
if not self.kv_cache:
|
if not self.kv_cache:
|
||||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||||
|
|||||||
@ -8,11 +8,14 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
# q: why the decoding has a dot before it? a: it is a relative import
|
||||||
|
# q: what is relative import? a: it is a way to import modules from the same package
|
||||||
from .decoding import decode as decode_function
|
from .decoding import decode as decode_function
|
||||||
from .decoding import detect_language as detect_language_function
|
from .decoding import detect_language as detect_language_function
|
||||||
from .transcribe import transcribe as transcribe_function
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
|
|
||||||
|
# Q: what is ModelDimensions? a: it is a data class that stores the dimensions of the model
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelDimensions:
|
class ModelDimensions:
|
||||||
n_mels: int
|
n_mels: int
|
||||||
@ -27,13 +30,18 @@ class ModelDimensions:
|
|||||||
n_text_layer: int
|
n_text_layer: int
|
||||||
|
|
||||||
|
|
||||||
|
# q: What is layer norm? a: https://arxiv.org/abs/1607.06450
|
||||||
|
# q: explain it in short words? a: it normalizes the input tensor across the last dimension
|
||||||
|
# you are so cool! thanks! I know! 😎
|
||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return super().forward(x.float()).type(x.dtype)
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
# q: what is the usage of this class? a: it is a linear layer that converts the input tensor to the output tensor
|
||||||
class Linear(nn.Linear):
|
class Linear(nn.Linear):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# q: what is F.linear? a: it is a function that applies a linear transformation to the input tensor
|
||||||
|
# q: what is F here? a: it is the torch.nn.functional module
|
||||||
return F.linear(
|
return F.linear(
|
||||||
x,
|
x,
|
||||||
self.weight.to(x.dtype),
|
self.weight.to(x.dtype),
|
||||||
@ -41,15 +49,19 @@ class Linear(nn.Linear):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# q: what is the usage of this class? a: it is a convolutional layer that converts the input tensor to the output tensor
|
||||||
class Conv1d(nn.Conv1d):
|
class Conv1d(nn.Conv1d):
|
||||||
def _conv_forward(
|
def _conv_forward(
|
||||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
# q: what is super()? a: it is a reference to the parent class
|
||||||
|
#q: what is the parent class here? a: it is the nn.Conv1d class
|
||||||
return super()._conv_forward(
|
return super()._conv_forward(
|
||||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# q: what is the usage of this function? a: it returns sinusoids for positional embedding
|
||||||
def sinusoids(length, channels, max_timescale=10000):
|
def sinusoids(length, channels, max_timescale=10000):
|
||||||
"""Returns sinusoids for positional embedding"""
|
"""Returns sinusoids for positional embedding"""
|
||||||
assert channels % 2 == 0
|
assert channels % 2 == 0
|
||||||
@ -58,8 +70,9 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
|
# q: what is the usage of this class? a: it is a multi-head attention layer
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
|
# what is n_state? a: it is the number of features in the input tensor
|
||||||
def __init__(self, n_state: int, n_head: int):
|
def __init__(self, n_state: int, n_head: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
@ -107,11 +120,15 @@ class MultiHeadAttention(nn.Module):
|
|||||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||||
|
|
||||||
|
# q: what is the usage of this class? a: it is a residual attention block
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
# q: what is cross attention? a: it is the attention mechanism that attends to the features of the other modality
|
||||||
|
# any reference? a: https://arxiv.org/abs/1706.03762
|
||||||
|
# why we need cross attention? a: it helps to align the audio and text features
|
||||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# what is n_state? a: it is the number of features in the input tensor
|
||||||
self.attn = MultiHeadAttention(n_state, n_head)
|
self.attn = MultiHeadAttention(n_state, n_head)
|
||||||
self.attn_ln = LayerNorm(n_state)
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
@ -121,6 +138,8 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
n_mlp = n_state * 4
|
n_mlp = n_state * 4
|
||||||
|
|
||||||
|
# q: what is mlp? a: it is a multi-layer perceptron
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||||
)
|
)
|
||||||
@ -139,7 +158,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
x = x + self.mlp(self.mlp_ln(x))
|
x = x + self.mlp(self.mlp_ln(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# q: what is the usage of this class? a: it is a model that transcribes the audio to text
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
@ -154,6 +173,10 @@ class AudioEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.ln_post = LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
|
|
||||||
|
# what is ctx? a: it is the context size
|
||||||
|
# what is context size? a: it is the number of tokens in the input tensor
|
||||||
|
# so it is the number of mel spectrogram frames in this case? a: yes
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
"""
|
"""
|
||||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
@ -173,6 +196,7 @@ class AudioEncoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# q: what is the usage of this class? a: it is a model that transcribes the audio to text
|
||||||
class TextDecoder(nn.Module):
|
class TextDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
@ -217,33 +241,46 @@ class TextDecoder(nn.Module):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
# so the whisper is made of an audio encoder and a text decoder? a: yes
|
||||||
|
# what is the usage of this class? a: it is a model that transcribes the audio to text
|
||||||
class Whisper(nn.Module):
|
class Whisper(nn.Module):
|
||||||
def __init__(self, dims: ModelDimensions):
|
def __init__(self, dims: ModelDimensions):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.encoder = AudioEncoder(
|
self.encoder = AudioEncoder(
|
||||||
self.dims.n_mels,
|
self.dims.n_mels, # the number of mel spectrogram frames
|
||||||
self.dims.n_audio_ctx,
|
self.dims.n_audio_ctx, # the number of tokens in the audio tensor
|
||||||
self.dims.n_audio_state,
|
self.dims.n_audio_state, # the number of features in the audio tensor
|
||||||
self.dims.n_audio_head,
|
self.dims.n_audio_head, # the number of heads in the audio tensor
|
||||||
self.dims.n_audio_layer,
|
self.dims.n_audio_layer, # the number of layers in the audio tensor
|
||||||
)
|
)
|
||||||
self.decoder = TextDecoder(
|
self.decoder = TextDecoder(
|
||||||
self.dims.n_vocab,
|
self.dims.n_vocab, # the number of tokens in the text tensor
|
||||||
self.dims.n_text_ctx,
|
self.dims.n_text_ctx, # the number of tokens in the text tensor
|
||||||
self.dims.n_text_state,
|
self.dims.n_text_state, # the number of features in the text tensor
|
||||||
self.dims.n_text_head,
|
self.dims.n_text_head, # the number of heads in the text tensor
|
||||||
self.dims.n_text_layer,
|
self.dims.n_text_layer, # the number of layers in the text tensor
|
||||||
|
# you are so clever! thanks! 😎
|
||||||
)
|
)
|
||||||
# use the last half among the decoder layers for time alignment by default;
|
# use the last half among the decoder layers for time alignment by default;
|
||||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||||
|
|
||||||
|
# what is all_heads? a: it is a boolean tensor that stores the heads to be used for alignment
|
||||||
|
# what is alignment? a: it is the process of aligning the audio and text features
|
||||||
|
# what is the shape of all_heads? a: it is (n_text_layer, n_text_head)
|
||||||
|
# why it is of this shape? a: it is because the alignment is done on the text tensor
|
||||||
all_heads = torch.zeros(
|
all_heads = torch.zeros(
|
||||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||||
)
|
)
|
||||||
|
# what does it mean? a: it means that the first half of the heads are not used for alignment
|
||||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||||
|
# what is register_buffer? a: it is a method that registers a tensor as a buffer
|
||||||
|
# what is a buffer? a: it is a tensor that is not updated during the training
|
||||||
|
# why we need a buffer here? a: it is because the alignment heads are not updated during the training
|
||||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||||
|
|
||||||
|
# what is the usage of this function? a: it sets the alignment heads
|
||||||
|
# what is alignment heads? a: it is the heads that are used for alignment
|
||||||
def set_alignment_heads(self, dump: bytes):
|
def set_alignment_heads(self, dump: bytes):
|
||||||
array = np.frombuffer(
|
array = np.frombuffer(
|
||||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||||
@ -264,6 +301,7 @@ class Whisper(nn.Module):
|
|||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
return self.decoder(tokens, self.encoder(mel))
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
|
# q: what is the usage of @property? a: it is a decorator that makes a method accessible as an attribute
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
@ -276,6 +314,7 @@ class Whisper(nn.Module):
|
|||||||
def num_languages(self):
|
def num_languages(self):
|
||||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
|
# q: what is the usage of this function? a: it installs hooks to save the intermediate tensors
|
||||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||||
@ -293,16 +332,33 @@ class Whisper(nn.Module):
|
|||||||
cache = {**cache} if cache is not None else {}
|
cache = {**cache} if cache is not None else {}
|
||||||
hooks = []
|
hooks = []
|
||||||
|
|
||||||
|
# what does output.shape[1] > self.dims.n_text_ctx mean? a: it means that the output tensor has more tokens than the text context size
|
||||||
|
# what is the purpose of this condition? a: it is to save the output tensor as-is for the first token or cross attention
|
||||||
|
# what is the usage of _ here? a: it is a placeholder for the input tensor
|
||||||
|
# but _ is not used in the function? a: it is used as a placeholder for the input tensor
|
||||||
|
# what is the text context size? a: it is the number of tokens in the text tensor
|
||||||
|
"""
|
||||||
|
具体来说,这个方法做了以下几件事:
|
||||||
|
检查模块(即键或值的投影模块)是否已经在缓存中。如果不在,或者输出张量的第二个维度(代表令牌的数量)大于文本上下文的大小,那么就将输出张量存储在缓存中。
|
||||||
|
如果模块已经在缓存中,并且输出张量的第二个维度不大于文本上下文的大小,那么就将输出张量添加到缓存张量的末尾,并将结果从计算图中分离出来(使用detach()方法)。
|
||||||
|
最后,这个方法返回更新后的缓存张量。
|
||||||
|
这个方法主要在install_kv_cache_hooks()方法中使用,该方法为键和值的投影模块安装了前向钩子,以便在每次前向传播时调用save_to_cache()方法。
|
||||||
|
"""
|
||||||
def save_to_cache(module, _, output):
|
def save_to_cache(module, _, output):
|
||||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||||
# save as-is, for the first token or cross attention
|
# save as-is, for the first token or cross attention
|
||||||
cache[module] = output
|
cache[module] = output
|
||||||
else:
|
else:
|
||||||
|
# what does this line mean? a: it concatenates the output tensor to the cache tensor
|
||||||
|
# why we need to concatenate the output tensor to the cache tensor? a: it is to save the intermediate tensors
|
||||||
|
# what does detach() mean? a: it is to detach the tensor from the computation graph
|
||||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||||
return cache[module]
|
return cache[module]
|
||||||
|
|
||||||
|
|
||||||
def install_hooks(layer: nn.Module):
|
def install_hooks(layer: nn.Module):
|
||||||
if isinstance(layer, MultiHeadAttention):
|
if isinstance(layer, MultiHeadAttention):
|
||||||
|
# what is register_forward_hook? a: it is a method that registers a hook to be called after the forward pass
|
||||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,12 @@
|
|||||||
|
"""
|
||||||
|
q: what is the usage of this file? a: this file contains the audio processing functions
|
||||||
|
q: do you think english.json is complicated? a: no, it's a simple mapping of british-american spellings
|
||||||
|
q: do you have a simpler way to realize this? a: yes, we can use a dictionary to map the words
|
||||||
|
q: so why doesn't english.json use a dictionary? a: it's easier to read and write the mappings in a json file
|
||||||
|
q: how to use the dictionary you mentioned, i mean what function to call?
|
||||||
|
a: we can use the dictionary in the EnglishSpellingNormalizer class
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|||||||
@ -1,3 +1,11 @@
|
|||||||
|
"""
|
||||||
|
q: what is the usage of this file?
|
||||||
|
a: This file contains the implementation of the `find_alignment` function,
|
||||||
|
which is used to align the text tokens with the audio frames.
|
||||||
|
The `add_word_timestamps` function is used to add timestamps to the words in the segments.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
"""
|
||||||
|
q: what is the usage of this file? a: this file is used to tokenize the text data
|
||||||
|
"""
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import string
|
import string
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
"""
|
||||||
|
q: what is the usage of this file? a: this file contains the audio processing functions
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
@ -34,7 +37,8 @@ from .utils import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
|
|
||||||
|
# hard-coded audio hyperparameters
|
||||||
|
#
|
||||||
def transcribe(
|
def transcribe(
|
||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
@ -118,7 +122,9 @@ def transcribe(
|
|||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||||
"""
|
"""
|
||||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
# what is fp16? a: half-precision floating-point format
|
||||||
|
# is fp16 better than fp32? a: fp16 is faster but less accurate
|
||||||
|
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 # type: ignore
|
||||||
if model.device == torch.device("cpu"):
|
if model.device == torch.device("cpu"):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||||
@ -130,8 +136,10 @@ def transcribe(
|
|||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
# why? a: to make sure the audio is long enough to be processed
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) # type: ignore
|
||||||
|
# why it needs to minus N_FRAMES? a: to get the number of frames in the content
|
||||||
|
content_frames = mel.shape[-1] - N_FRAMES # number of frames in the content
|
||||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
@ -498,6 +506,7 @@ def transcribe(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# what does cli stand for? command line interface
|
||||||
def cli():
|
def cli():
|
||||||
from . import available_models
|
from . import available_models
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
"""
|
||||||
|
q: what is the usage of this file? a: this file contains the audio processing functions
|
||||||
|
"""
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,3 +1,10 @@
|
|||||||
|
"""
|
||||||
|
q: what is the usage of this file? a: this file contains the utility functions
|
||||||
|
|
||||||
|
q: what is the usage of meanwhile.json? a: it is used to store the results of the meanwhile tests
|
||||||
|
q: what is the meanwhile tests? a: it is a test suite for the whisper project
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user