add comments

This commit is contained in:
zzy981019 2024-05-24 17:32:32 +08:00
parent ba3f3cd54b
commit ed4b0d14a2
11 changed files with 182 additions and 378 deletions

View File

@ -17,11 +17,11 @@
"metadata": { "metadata": {
"id": "ZsJUxc0aRsAf" "id": "ZsJUxc0aRsAf"
}, },
"outputs": [],
"source": [ "source": [
"! pip install git+https://github.com/openai/whisper.git\n", "! pip install git+https://github.com/openai/whisper.git\n",
"! pip install jiwer" "! pip install jiwer"
] ],
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -40,7 +40,6 @@
"metadata": { "metadata": {
"id": "3CqtR2Fi5-vP" "id": "3CqtR2Fi5-vP"
}, },
"outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"import numpy as np\n", "import numpy as np\n",
@ -59,7 +58,8 @@
"\n", "\n",
"\n", "\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"" "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
] ],
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
@ -67,7 +67,6 @@
"metadata": { "metadata": {
"id": "GuCCB2KYOJCE" "id": "GuCCB2KYOJCE"
}, },
"outputs": [],
"source": [ "source": [
"class LibriSpeech(torch.utils.data.Dataset):\n", "class LibriSpeech(torch.utils.data.Dataset):\n",
" \"\"\"\n", " \"\"\"\n",
@ -92,7 +91,8 @@
" mel = whisper.log_mel_spectrogram(audio)\n", " mel = whisper.log_mel_spectrogram(audio)\n",
" \n", " \n",
" return (mel, text)" " return (mel, text)"
] ],
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
@ -100,11 +100,11 @@
"metadata": { "metadata": {
"id": "-YcRU5jqNqo2" "id": "-YcRU5jqNqo2"
}, },
"outputs": [],
"source": [ "source": [
"dataset = LibriSpeech(\"test-clean\")\n", "dataset = LibriSpeech(\"test-clean\")\n",
"loader = torch.utils.data.DataLoader(dataset, batch_size=16)" "loader = torch.utils.data.DataLoader(dataset, batch_size=16)"
] ],
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -127,32 +127,24 @@
"id": "_PokfNJtOYNu", "id": "_PokfNJtOYNu",
"outputId": "2c53ec44-bc93-4107-b4fa-214e3f71fe8e" "outputId": "2c53ec44-bc93-4107-b4fa-214e3f71fe8e"
}, },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model is English-only and has 71,825,408 parameters.\n"
]
}
],
"source": [ "source": [
"model = whisper.load_model(\"base.en\")\n", "model = whisper.load_model(\"base.en\")\n",
"print(\n", "print(\n",
" f\"Model is {'multilingual' if model.is_multilingual else 'English-only'} \"\n", " f\"Model is {'multilingual' if model.is_multilingual else 'English-only'} \"\n",
" f\"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters.\"\n", " f\"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters.\"\n",
")" ")"
] ],
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# predict without timestamps for short-form transcription\n", "# predict without timestamps for short-form transcription\n",
"options = whisper.DecodingOptions(language=\"en\", without_timestamps=True)" "options = whisper.DecodingOptions(language=\"en\", without_timestamps=True)"
] ],
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
@ -178,22 +170,6 @@
"id": "7OWTn_KvNk59", "id": "7OWTn_KvNk59",
"outputId": "a813a792-3c91-4144-f11f-054fd6778023" "outputId": "a813a792-3c91-4144-f11f-054fd6778023"
}, },
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9df048b46f764cf68cbe0045b8ff73a8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/164 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [ "source": [
"hypotheses = []\n", "hypotheses = []\n",
"references = []\n", "references = []\n",
@ -202,7 +178,8 @@
" results = model.decode(mels, options)\n", " results = model.decode(mels, options)\n",
" hypotheses.extend([result.text for result in results])\n", " hypotheses.extend([result.text for result in results])\n",
" references.extend(texts)" " references.extend(texts)"
] ],
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
@ -215,132 +192,11 @@
"id": "4nTyynELQ42j", "id": "4nTyynELQ42j",
"outputId": "1c72d25a-3e87-4c60-a8d1-1da9d2f73bd7" "outputId": "1c72d25a-3e87-4c60-a8d1-1da9d2f73bd7"
}, },
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>hypothesis</th>\n",
" <th>reference</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>He hoped there would be stew for dinner, turni...</td>\n",
" <td>HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Stuffered into you, his belly counseled him.</td>\n",
" <td>STUFF IT INTO YOU HIS BELLY COUNSELLED HIM</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>After early nightfall the yellow lamps would l...</td>\n",
" <td>AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Hello Bertie, any good in your mind?</td>\n",
" <td>HELLO BERTIE ANY GOOD IN YOUR MIND</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Number 10. Fresh Nelly is waiting on you. Good...</td>\n",
" <td>NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2615</th>\n",
" <td>Oh, to shoot my soul's full meaning into futur...</td>\n",
" <td>OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2616</th>\n",
" <td>Then I, long tried by natural ills, received t...</td>\n",
" <td>THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2617</th>\n",
" <td>I love thee freely as men strive for right. I ...</td>\n",
" <td>I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2618</th>\n",
" <td>I love thee with the passion put to use, in my...</td>\n",
" <td>I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2619</th>\n",
" <td>I love thee with the love I seemed to lose wit...</td>\n",
" <td>I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2620 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" hypothesis \\\n",
"0 He hoped there would be stew for dinner, turni... \n",
"1 Stuffered into you, his belly counseled him. \n",
"2 After early nightfall the yellow lamps would l... \n",
"3 Hello Bertie, any good in your mind? \n",
"4 Number 10. Fresh Nelly is waiting on you. Good... \n",
"... ... \n",
"2615 Oh, to shoot my soul's full meaning into futur... \n",
"2616 Then I, long tried by natural ills, received t... \n",
"2617 I love thee freely as men strive for right. I ... \n",
"2618 I love thee with the passion put to use, in my... \n",
"2619 I love thee with the love I seemed to lose wit... \n",
"\n",
" reference \n",
"0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n",
"1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n",
"2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n",
"3 HELLO BERTIE ANY GOOD IN YOUR MIND \n",
"4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n",
"... ... \n",
"2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n",
"2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n",
"2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n",
"2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n",
"2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n",
"\n",
"[2620 rows x 2 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))\n", "data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))\n",
"data" "data"
] ],
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -359,13 +215,13 @@
"metadata": { "metadata": {
"id": "dl-KBDflMhrg" "id": "dl-KBDflMhrg"
}, },
"outputs": [],
"source": [ "source": [
"import jiwer\n", "import jiwer\n",
"from whisper.normalizers import EnglishTextNormalizer\n", "from whisper.normalizers import EnglishTextNormalizer\n",
"\n", "\n",
"normalizer = EnglishTextNormalizer()" "normalizer = EnglishTextNormalizer()"
] ],
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
@ -378,183 +234,12 @@
"id": "6-O048q4WI4o", "id": "6-O048q4WI4o",
"outputId": "f2089bc9-f535-441e-f192-26e52ae82b5e" "outputId": "f2089bc9-f535-441e-f192-26e52ae82b5e"
}, },
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>hypothesis</th>\n",
" <th>reference</th>\n",
" <th>hypothesis_clean</th>\n",
" <th>reference_clean</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>He hoped there would be stew for dinner, turni...</td>\n",
" <td>HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...</td>\n",
" <td>he hoped there would be stew for dinner turnip...</td>\n",
" <td>he hoped there would be stew for dinner turnip...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Stuffered into you, his belly counseled him.</td>\n",
" <td>STUFF IT INTO YOU HIS BELLY COUNSELLED HIM</td>\n",
" <td>stuffered into you his belly counseled him</td>\n",
" <td>stuff it into you his belly counseled him</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>After early nightfall the yellow lamps would l...</td>\n",
" <td>AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...</td>\n",
" <td>after early nightfall the yellow lamps would l...</td>\n",
" <td>after early nightfall the yellow lamps would l...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Hello Bertie, any good in your mind?</td>\n",
" <td>HELLO BERTIE ANY GOOD IN YOUR MIND</td>\n",
" <td>hello bertie any good in your mind</td>\n",
" <td>hello bertie any good in your mind</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Number 10. Fresh Nelly is waiting on you. Good...</td>\n",
" <td>NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...</td>\n",
" <td>number 10 fresh nelly is waiting on you good n...</td>\n",
" <td>number 10 fresh nelly is waiting on you good n...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2615</th>\n",
" <td>Oh, to shoot my soul's full meaning into futur...</td>\n",
" <td>OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...</td>\n",
" <td>0 to shoot my soul is full meaning into future...</td>\n",
" <td>0 to shoot my soul is full meaning into future...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2616</th>\n",
" <td>Then I, long tried by natural ills, received t...</td>\n",
" <td>THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...</td>\n",
" <td>then i long tried by natural ills received the...</td>\n",
" <td>then i long tried by natural ills received the...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2617</th>\n",
" <td>I love thee freely as men strive for right. I ...</td>\n",
" <td>I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...</td>\n",
" <td>i love thee freely as men strive for right i l...</td>\n",
" <td>i love thee freely as men strive for right i l...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2618</th>\n",
" <td>I love thee with the passion put to use, in my...</td>\n",
" <td>I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...</td>\n",
" <td>i love thee with the passion put to use in my ...</td>\n",
" <td>i love thee with the passion put to use in my ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2619</th>\n",
" <td>I love thee with the love I seemed to lose wit...</td>\n",
" <td>I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ...</td>\n",
" <td>i love thee with the love i seemed to lose wit...</td>\n",
" <td>i love thee with a love i seemed to lose with ...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2620 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" hypothesis \\\n",
"0 He hoped there would be stew for dinner, turni... \n",
"1 Stuffered into you, his belly counseled him. \n",
"2 After early nightfall the yellow lamps would l... \n",
"3 Hello Bertie, any good in your mind? \n",
"4 Number 10. Fresh Nelly is waiting on you. Good... \n",
"... ... \n",
"2615 Oh, to shoot my soul's full meaning into futur... \n",
"2616 Then I, long tried by natural ills, received t... \n",
"2617 I love thee freely as men strive for right. I ... \n",
"2618 I love thee with the passion put to use, in my... \n",
"2619 I love thee with the love I seemed to lose wit... \n",
"\n",
" reference \\\n",
"0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n",
"1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n",
"2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n",
"3 HELLO BERTIE ANY GOOD IN YOUR MIND \n",
"4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n",
"... ... \n",
"2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n",
"2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n",
"2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n",
"2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n",
"2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n",
"\n",
" hypothesis_clean \\\n",
"0 he hoped there would be stew for dinner turnip... \n",
"1 stuffered into you his belly counseled him \n",
"2 after early nightfall the yellow lamps would l... \n",
"3 hello bertie any good in your mind \n",
"4 number 10 fresh nelly is waiting on you good n... \n",
"... ... \n",
"2615 0 to shoot my soul is full meaning into future... \n",
"2616 then i long tried by natural ills received the... \n",
"2617 i love thee freely as men strive for right i l... \n",
"2618 i love thee with the passion put to use in my ... \n",
"2619 i love thee with the love i seemed to lose wit... \n",
"\n",
" reference_clean \n",
"0 he hoped there would be stew for dinner turnip... \n",
"1 stuff it into you his belly counseled him \n",
"2 after early nightfall the yellow lamps would l... \n",
"3 hello bertie any good in your mind \n",
"4 number 10 fresh nelly is waiting on you good n... \n",
"... ... \n",
"2615 0 to shoot my soul is full meaning into future... \n",
"2616 then i long tried by natural ills received the... \n",
"2617 i love thee freely as men strive for right i l... \n",
"2618 i love thee with the passion put to use in my ... \n",
"2619 i love thee with a love i seemed to lose with ... \n",
"\n",
"[2620 rows x 4 columns]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"data[\"hypothesis_clean\"] = [normalizer(text) for text in data[\"hypothesis\"]]\n", "data[\"hypothesis_clean\"] = [normalizer(text) for text in data[\"hypothesis\"]]\n",
"data[\"reference_clean\"] = [normalizer(text) for text in data[\"reference\"]]\n", "data[\"reference_clean\"] = [normalizer(text) for text in data[\"reference\"]]\n",
"data" "data"
] ],
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
@ -566,20 +251,12 @@
"id": "EBGSITeBYPTT", "id": "EBGSITeBYPTT",
"outputId": "7b3dbe7c-a37e-4a07-a50a-b27d5f88b68f" "outputId": "7b3dbe7c-a37e-4a07-a50a-b27d5f88b68f"
}, },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WER: 4.26 %\n"
]
}
],
"source": [ "source": [
"wer = jiwer.wer(list(data[\"reference_clean\"]), list(data[\"hypothesis_clean\"]))\n", "wer = jiwer.wer(list(data[\"reference_clean\"]), list(data[\"hypothesis_clean\"]))\n",
"\n", "\n",
"print(f\"WER: {wer * 100:.2f} %\")" "print(f\"WER: {wer * 100:.2f} %\")"
] ],
"outputs": []
} }
], ],
"metadata": { "metadata": {

View File

@ -14,6 +14,7 @@ from .model import ModelDimensions, Whisper
from .transcribe import transcribe from .transcribe import transcribe
from .version import __version__ from .version import __version__
# what are these models? a: they are the pre-trained models that are available for use
_MODELS = { _MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@ -46,10 +47,12 @@ _ALIGNMENT_HEADS = {
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
} }
# q: download the model from the given url and save it to the given root directory
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
# what is sha256?
# a: it is a cryptographic hash function that produces a fixed-size hash value
expected_sha256 = url.split("/")[-2] expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url)) download_target = os.path.join(root, os.path.basename(url))
@ -59,6 +62,7 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if os.path.isfile(download_target): if os.path.isfile(download_target):
with open(download_target, "rb") as f: with open(download_target, "rb") as f:
model_bytes = f.read() model_bytes = f.read()
# what is the purpose of this if statement? a: to check if the SHA256 checksum matches
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target return model_bytes if in_memory else download_target
else: else:
@ -66,6 +70,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
) )
# is the following line re-downloading the model? a: yes
# so this function checks whether the model is already downloaded and if not, it downloads the model?
# a: yes
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm( with tqdm(
total=int(source.info().get("Content-Length")), total=int(source.info().get("Content-Length")),
@ -88,14 +95,17 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
) )
# is this the checkpoint file? a: yes
return model_bytes if in_memory else download_target return model_bytes if in_memory else download_target
# q: what is the purpose of this function? a: to return the names of the available models
def available_models() -> List[str]: def available_models() -> List[str]:
"""Returns the names of available models""" """Returns the names of available models"""
return list(_MODELS.keys()) return list(_MODELS.keys())
# q: what is the purpose of this function? a: to load the model from the given name
# what does -> Whisper in Python mean? a: it means that the function returns an object of type Whisper
# load 一个模型返回一个Whisper对象
def load_model( def load_model(
name: str, name: str,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
@ -140,14 +150,24 @@ def load_model(
f"Model {name} not found; available models = {available_models()}" f"Model {name} not found; available models = {available_models()}"
) )
# what is "with" in Python?
# a: it is used to open a file and automatically close it after the block of code is executed
with ( with (
# what if checkpoint_file is in memory? a: it uses io.BytesIO to read the file
# what if checkpoint_file is not in memory? a: it uses open to read the file
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp: ) as fp:
checkpoint = torch.load(fp, map_location=device) checkpoint = torch.load(fp, map_location=device)
del checkpoint_file del checkpoint_file
# what is the **checkpoint["dims"]? a: it unpacks the dictionary into keyword arguments
# are arguments in ModelDimensions nullable? a: no
# so what if the checkpoint["dims"] is missing? a: it will raise an error
# how to confirm that checkpoint contains the "dims" key? a: by checking the keys of the dictionary
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims)
# what is load_state_dict? a: it loads the model weights
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None: if alignment_heads is not None:

View File

@ -1,3 +1,7 @@
"""
q: what is the usage of this file? a: this file contains the audio processing functions
"""
import os import os
from functools import lru_cache from functools import lru_cache
from subprocess import CalledProcessError, run from subprocess import CalledProcessError, run
@ -11,10 +15,13 @@ from .utils import exact_div
# hard-coded audio hyperparameters # hard-coded audio hyperparameters
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
N_FFT = 400 N_FFT = 400 # 25ms window
HOP_LENGTH = 160 HOP_LENGTH = 160
CHUNK_LENGTH = 30 CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
# what is frame? a: a frame is a short segment of audio, usually 10ms
# what is frame used for? a: it is used to compute the spectrogram
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
@ -135,23 +142,23 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames) torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram
""" """
if not torch.is_tensor(audio): if not torch.is_tensor(audio): # load audio if not already a tensor
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio) # load audio from file
audio = torch.from_numpy(audio) audio = torch.from_numpy(audio) # convert to tensor
if device is not None: if device is not None:
audio = audio.to(device) audio = audio.to(device)
if padding > 0: if padding > 0:
audio = F.pad(audio, (0, padding)) audio = F.pad(audio, (0, padding)) # pad audio to the right
window = torch.hann_window(N_FFT).to(audio.device) window = torch.hann_window(N_FFT).to(audio.device) # create a Hann window
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) # compute STFT
magnitudes = stft[..., :-1].abs() ** 2 magnitudes = stft[..., :-1].abs() ** 2 # compute magnitudes
filters = mel_filters(audio.device, n_mels) filters = mel_filters(audio.device, n_mels) # load mel filters
mel_spec = filters @ magnitudes mel_spec = filters @ magnitudes # apply mel filters
log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.clamp(mel_spec, min=1e-10).log10() # compute log spectrogram
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
return log_spec return log_spec

View File

@ -1,3 +1,7 @@
"""
q: what is the usage of this file? a: this file contains the audio processing functions
"""
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
@ -152,6 +156,7 @@ class PyTorchInference(Inference):
value_modules = [block.attn.value for block in self.model.decoder.blocks] value_modules = [block.attn.value for block in self.model.decoder.blocks]
self.kv_modules = key_modules + value_modules self.kv_modules = key_modules + value_modules
# forward pass through the decoder, with key-value caching
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache: if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()

View File

@ -8,11 +8,14 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
# q: why the decoding has a dot before it? a: it is a relative import
# q: what is relative import? a: it is a way to import modules from the same package
from .decoding import decode as decode_function from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function from .transcribe import transcribe as transcribe_function
# Q: what is ModelDimensions? a: it is a data class that stores the dimensions of the model
@dataclass @dataclass
class ModelDimensions: class ModelDimensions:
n_mels: int n_mels: int
@ -27,13 +30,18 @@ class ModelDimensions:
n_text_layer: int n_text_layer: int
# q: What is layer norm? a: https://arxiv.org/abs/1607.06450
# q: explain it in short words? a: it normalizes the input tensor across the last dimension
# you are so cool! thanks! I know! 😎
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
# q: what is the usage of this class? a: it is a linear layer that converts the input tensor to the output tensor
class Linear(nn.Linear): class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# q: what is F.linear? a: it is a function that applies a linear transformation to the input tensor
# q: what is F here? a: it is the torch.nn.functional module
return F.linear( return F.linear(
x, x,
self.weight.to(x.dtype), self.weight.to(x.dtype),
@ -41,15 +49,19 @@ class Linear(nn.Linear):
) )
# q: what is the usage of this class? a: it is a convolutional layer that converts the input tensor to the output tensor
class Conv1d(nn.Conv1d): class Conv1d(nn.Conv1d):
def _conv_forward( def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor] self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor: ) -> Tensor:
# q: what is super()? a: it is a reference to the parent class
#q: what is the parent class here? a: it is the nn.Conv1d class
return super()._conv_forward( return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
) )
# q: what is the usage of this function? a: it returns sinusoids for positional embedding
def sinusoids(length, channels, max_timescale=10000): def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding""" """Returns sinusoids for positional embedding"""
assert channels % 2 == 0 assert channels % 2 == 0
@ -58,8 +70,9 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
# q: what is the usage of this class? a: it is a multi-head attention layer
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
# what is n_state? a: it is the number of features in the input tensor
def __init__(self, n_state: int, n_head: int): def __init__(self, n_state: int, n_head: int):
super().__init__() super().__init__()
self.n_head = n_head self.n_head = n_head
@ -107,11 +120,15 @@ class MultiHeadAttention(nn.Module):
w = F.softmax(qk, dim=-1).to(q.dtype) w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
# q: what is the usage of this class? a: it is a residual attention block
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
# q: what is cross attention? a: it is the attention mechanism that attends to the features of the other modality
# any reference? a: https://arxiv.org/abs/1706.03762
# why we need cross attention? a: it helps to align the audio and text features
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__() super().__init__()
# what is n_state? a: it is the number of features in the input tensor
self.attn = MultiHeadAttention(n_state, n_head) self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state) self.attn_ln = LayerNorm(n_state)
@ -121,6 +138,8 @@ class ResidualAttentionBlock(nn.Module):
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4 n_mlp = n_state * 4
# q: what is mlp? a: it is a multi-layer perceptron
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
) )
@ -139,7 +158,7 @@ class ResidualAttentionBlock(nn.Module):
x = x + self.mlp(self.mlp_ln(x)) x = x + self.mlp(self.mlp_ln(x))
return x return x
# q: what is the usage of this class? a: it is a model that transcribes the audio to text
class AudioEncoder(nn.Module): class AudioEncoder(nn.Module):
def __init__( def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
@ -154,6 +173,10 @@ class AudioEncoder(nn.Module):
) )
self.ln_post = LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
# what is ctx? a: it is the context size
# what is context size? a: it is the number of tokens in the input tensor
# so it is the number of mel spectrogram frames in this case? a: yes
def forward(self, x: Tensor): def forward(self, x: Tensor):
""" """
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
@ -173,6 +196,7 @@ class AudioEncoder(nn.Module):
return x return x
# q: what is the usage of this class? a: it is a model that transcribes the audio to text
class TextDecoder(nn.Module): class TextDecoder(nn.Module):
def __init__( def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
@ -217,33 +241,46 @@ class TextDecoder(nn.Module):
return logits return logits
# so the whisper is made of an audio encoder and a text decoder? a: yes
# what is the usage of this class? a: it is a model that transcribes the audio to text
class Whisper(nn.Module): class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions): def __init__(self, dims: ModelDimensions):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.encoder = AudioEncoder( self.encoder = AudioEncoder(
self.dims.n_mels, self.dims.n_mels, # the number of mel spectrogram frames
self.dims.n_audio_ctx, self.dims.n_audio_ctx, # the number of tokens in the audio tensor
self.dims.n_audio_state, self.dims.n_audio_state, # the number of features in the audio tensor
self.dims.n_audio_head, self.dims.n_audio_head, # the number of heads in the audio tensor
self.dims.n_audio_layer, self.dims.n_audio_layer, # the number of layers in the audio tensor
) )
self.decoder = TextDecoder( self.decoder = TextDecoder(
self.dims.n_vocab, self.dims.n_vocab, # the number of tokens in the text tensor
self.dims.n_text_ctx, self.dims.n_text_ctx, # the number of tokens in the text tensor
self.dims.n_text_state, self.dims.n_text_state, # the number of features in the text tensor
self.dims.n_text_head, self.dims.n_text_head, # the number of heads in the text tensor
self.dims.n_text_layer, self.dims.n_text_layer, # the number of layers in the text tensor
# you are so clever! thanks! 😎
) )
# use the last half among the decoder layers for time alignment by default; # use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below. # to use a specific set of heads, see `set_alignment_heads()` below.
# what is all_heads? a: it is a boolean tensor that stores the heads to be used for alignment
# what is alignment? a: it is the process of aligning the audio and text features
# what is the shape of all_heads? a: it is (n_text_layer, n_text_head)
# why it is of this shape? a: it is because the alignment is done on the text tensor
all_heads = torch.zeros( all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
) )
# what does it mean? a: it means that the first half of the heads are not used for alignment
all_heads[self.dims.n_text_layer // 2 :] = True all_heads[self.dims.n_text_layer // 2 :] = True
# what is register_buffer? a: it is a method that registers a tensor as a buffer
# what is a buffer? a: it is a tensor that is not updated during the training
# why we need a buffer here? a: it is because the alignment heads are not updated during the training
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
# what is the usage of this function? a: it sets the alignment heads
# what is alignment heads? a: it is the heads that are used for alignment
def set_alignment_heads(self, dump: bytes): def set_alignment_heads(self, dump: bytes):
array = np.frombuffer( array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool gzip.decompress(base64.b85decode(dump)), dtype=bool
@ -264,6 +301,7 @@ class Whisper(nn.Module):
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel)) return self.decoder(tokens, self.encoder(mel))
# q: what is the usage of @property? a: it is a decorator that makes a method accessible as an attribute
@property @property
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
@ -276,6 +314,7 @@ class Whisper(nn.Module):
def num_languages(self): def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual) return self.dims.n_vocab - 51765 - int(self.is_multilingual)
# q: what is the usage of this function? a: it installs hooks to save the intermediate tensors
def install_kv_cache_hooks(self, cache: Optional[dict] = None): def install_kv_cache_hooks(self, cache: Optional[dict] = None):
""" """
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
@ -293,16 +332,33 @@ class Whisper(nn.Module):
cache = {**cache} if cache is not None else {} cache = {**cache} if cache is not None else {}
hooks = [] hooks = []
# what does output.shape[1] > self.dims.n_text_ctx mean? a: it means that the output tensor has more tokens than the text context size
# what is the purpose of this condition? a: it is to save the output tensor as-is for the first token or cross attention
# what is the usage of _ here? a: it is a placeholder for the input tensor
# but _ is not used in the function? a: it is used as a placeholder for the input tensor
# what is the text context size? a: it is the number of tokens in the text tensor
"""
具体来说这个方法做了以下几件事
检查模块即键或值的投影模块是否已经在缓存中如果不在或者输出张量的第二个维度代表令牌的数量大于文本上下文的大小那么就将输出张量存储在缓存中
如果模块已经在缓存中并且输出张量的第二个维度不大于文本上下文的大小那么就将输出张量添加到缓存张量的末尾并将结果从计算图中分离出来使用detach()方法
最后这个方法返回更新后的缓存张量
这个方法主要在install_kv_cache_hooks()方法中使用该方法为键和值的投影模块安装了前向钩子以便在每次前向传播时调用save_to_cache()方法
"""
def save_to_cache(module, _, output): def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx: if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention # save as-is, for the first token or cross attention
cache[module] = output cache[module] = output
else: else:
# what does this line mean? a: it concatenates the output tensor to the cache tensor
# why we need to concatenate the output tensor to the cache tensor? a: it is to save the intermediate tensors
# what does detach() mean? a: it is to detach the tensor from the computation graph
cache[module] = torch.cat([cache[module], output], dim=1).detach() cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module] return cache[module]
def install_hooks(layer: nn.Module): def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention): if isinstance(layer, MultiHeadAttention):
# what is register_forward_hook? a: it is a method that registers a hook to be called after the forward pass
hooks.append(layer.key.register_forward_hook(save_to_cache)) hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache)) hooks.append(layer.value.register_forward_hook(save_to_cache))

View File

@ -1,3 +1,12 @@
"""
q: what is the usage of this file? a: this file contains the audio processing functions
q: do you think english.json is complicated? a: no, it's a simple mapping of british-american spellings
q: do you have a simpler way to realize this? a: yes, we can use a dictionary to map the words
q: so why doesn't english.json use a dictionary? a: it's easier to read and write the mappings in a json file
q: how to use the dictionary you mentioned, i mean what function to call?
a: we can use the dictionary in the EnglishSpellingNormalizer class
"""
import json import json
import os import os
import re import re

View File

@ -1,3 +1,11 @@
"""
q: what is the usage of this file?
a: This file contains the implementation of the `find_alignment` function,
which is used to align the text tokens with the audio frames.
The `add_word_timestamps` function is used to add timestamps to the words in the segments.
"""
import itertools import itertools
import subprocess import subprocess
import warnings import warnings

View File

@ -1,3 +1,6 @@
"""
q: what is the usage of this file? a: this file is used to tokenize the text data
"""
import base64 import base64
import os import os
import string import string

View File

@ -1,3 +1,6 @@
"""
q: what is the usage of this file? a: this file contains the audio processing functions
"""
import argparse import argparse
import os import os
import traceback import traceback
@ -34,7 +37,8 @@ from .utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
# hard-coded audio hyperparameters
#
def transcribe( def transcribe(
model: "Whisper", model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor],
@ -118,7 +122,9 @@ def transcribe(
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
""" """
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 # what is fp16? a: half-precision floating-point format
# is fp16 better than fp32? a: fp16 is faster but less accurate
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 # type: ignore
if model.device == torch.device("cpu"): if model.device == torch.device("cpu"):
if torch.cuda.is_available(): if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available") warnings.warn("Performing inference on CPU when CUDA is available")
@ -130,8 +136,10 @@ def transcribe(
decode_options["fp16"] = False decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) # why? a: to make sure the audio is long enough to be processed
content_frames = mel.shape[-1] - N_FRAMES mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) # type: ignore
# why it needs to minus N_FRAMES? a: to get the number of frames in the content
content_frames = mel.shape[-1] - N_FRAMES # number of frames in the content
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None: if decode_options.get("language", None) is None:
@ -498,6 +506,7 @@ def transcribe(
) )
# what does cli stand for? command line interface
def cli(): def cli():
from . import available_models from . import available_models

View File

@ -1,3 +1,6 @@
"""
q: what is the usage of this file? a: this file contains the audio processing functions
"""
from functools import lru_cache from functools import lru_cache
import numpy as np import numpy as np

View File

@ -1,3 +1,10 @@
"""
q: what is the usage of this file? a: this file contains the utility functions
q: what is the usage of meanwhile.json? a: it is used to store the results of the meanwhile tests
q: what is the meanwhile tests? a: it is a test suite for the whisper project
"""
import json import json
import os import os
import re import re