From ed4b0d14a2ba71b8dbb934d786c6afb1c94719c8 Mon Sep 17 00:00:00 2001 From: zzy981019 Date: Fri, 24 May 2024 17:32:32 +0800 Subject: [PATCH] add comments --- notebooks/LibriSpeech.ipynb | 367 ++------------------------------- whisper/__init__.py | 26 ++- whisper/audio.py | 29 ++- whisper/decoding.py | 5 + whisper/model.py | 86 ++++++-- whisper/normalizers/english.py | 9 + whisper/timing.py | 8 + whisper/tokenizer.py | 3 + whisper/transcribe.py | 17 +- whisper/triton_ops.py | 3 + whisper/utils.py | 7 + 11 files changed, 182 insertions(+), 378 deletions(-) diff --git a/notebooks/LibriSpeech.ipynb b/notebooks/LibriSpeech.ipynb index 3d90e65..95d020b 100644 --- a/notebooks/LibriSpeech.ipynb +++ b/notebooks/LibriSpeech.ipynb @@ -17,11 +17,11 @@ "metadata": { "id": "ZsJUxc0aRsAf" }, - "outputs": [], "source": [ "! pip install git+https://github.com/openai/whisper.git\n", "! pip install jiwer" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -40,7 +40,6 @@ "metadata": { "id": "3CqtR2Fi5-vP" }, - "outputs": [], "source": [ "import os\n", "import numpy as np\n", @@ -59,7 +58,8 @@ "\n", "\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -67,7 +67,6 @@ "metadata": { "id": "GuCCB2KYOJCE" }, - "outputs": [], "source": [ "class LibriSpeech(torch.utils.data.Dataset):\n", " \"\"\"\n", @@ -92,7 +91,8 @@ " mel = whisper.log_mel_spectrogram(audio)\n", " \n", " return (mel, text)" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -100,11 +100,11 @@ "metadata": { "id": "-YcRU5jqNqo2" }, - "outputs": [], "source": [ "dataset = LibriSpeech(\"test-clean\")\n", "loader = torch.utils.data.DataLoader(dataset, batch_size=16)" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -127,32 +127,24 @@ "id": "_PokfNJtOYNu", "outputId": "2c53ec44-bc93-4107-b4fa-214e3f71fe8e" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model is English-only and has 71,825,408 parameters.\n" - ] - } - ], "source": [ "model = whisper.load_model(\"base.en\")\n", "print(\n", " f\"Model is {'multilingual' if model.is_multilingual else 'English-only'} \"\n", " f\"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters.\"\n", ")" - ] + ], + "outputs": [] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], "source": [ "# predict without timestamps for short-form transcription\n", "options = whisper.DecodingOptions(language=\"en\", without_timestamps=True)" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -178,22 +170,6 @@ "id": "7OWTn_KvNk59", "outputId": "a813a792-3c91-4144-f11f-054fd6778023" }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9df048b46f764cf68cbe0045b8ff73a8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/164 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
hypothesisreference
0He hoped there would be stew for dinner, turni...HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...
1Stuffered into you, his belly counseled him.STUFF IT INTO YOU HIS BELLY COUNSELLED HIM
2After early nightfall the yellow lamps would l...AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...
3Hello Bertie, any good in your mind?HELLO BERTIE ANY GOOD IN YOUR MIND
4Number 10. Fresh Nelly is waiting on you. Good...NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...
.........
2615Oh, to shoot my soul's full meaning into futur...OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...
2616Then I, long tried by natural ills, received t...THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...
2617I love thee freely as men strive for right. I ...I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...
2618I love thee with the passion put to use, in my...I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...
2619I love thee with the love I seemed to lose wit...I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ...
\n", - "

2620 rows × 2 columns

\n", - "" - ], - "text/plain": [ - " hypothesis \\\n", - "0 He hoped there would be stew for dinner, turni... \n", - "1 Stuffered into you, his belly counseled him. \n", - "2 After early nightfall the yellow lamps would l... \n", - "3 Hello Bertie, any good in your mind? \n", - "4 Number 10. Fresh Nelly is waiting on you. Good... \n", - "... ... \n", - "2615 Oh, to shoot my soul's full meaning into futur... \n", - "2616 Then I, long tried by natural ills, received t... \n", - "2617 I love thee freely as men strive for right. I ... \n", - "2618 I love thee with the passion put to use, in my... \n", - "2619 I love thee with the love I seemed to lose wit... \n", - "\n", - " reference \n", - "0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n", - "1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n", - "2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n", - "3 HELLO BERTIE ANY GOOD IN YOUR MIND \n", - "4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n", - "... ... \n", - "2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n", - "2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n", - "2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n", - "2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n", - "2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n", - "\n", - "[2620 rows x 2 columns]" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ "data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))\n", "data" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -359,13 +215,13 @@ "metadata": { "id": "dl-KBDflMhrg" }, - "outputs": [], "source": [ "import jiwer\n", "from whisper.normalizers import EnglishTextNormalizer\n", "\n", "normalizer = EnglishTextNormalizer()" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -378,183 +234,12 @@ "id": "6-O048q4WI4o", "outputId": "f2089bc9-f535-441e-f192-26e52ae82b5e" }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
hypothesisreferencehypothesis_cleanreference_clean
0He hoped there would be stew for dinner, turni...HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...he hoped there would be stew for dinner turnip...he hoped there would be stew for dinner turnip...
1Stuffered into you, his belly counseled him.STUFF IT INTO YOU HIS BELLY COUNSELLED HIMstuffered into you his belly counseled himstuff it into you his belly counseled him
2After early nightfall the yellow lamps would l...AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...after early nightfall the yellow lamps would l...after early nightfall the yellow lamps would l...
3Hello Bertie, any good in your mind?HELLO BERTIE ANY GOOD IN YOUR MINDhello bertie any good in your mindhello bertie any good in your mind
4Number 10. Fresh Nelly is waiting on you. Good...NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...number 10 fresh nelly is waiting on you good n...number 10 fresh nelly is waiting on you good n...
...............
2615Oh, to shoot my soul's full meaning into futur...OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...0 to shoot my soul is full meaning into future...0 to shoot my soul is full meaning into future...
2616Then I, long tried by natural ills, received t...THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...then i long tried by natural ills received the...then i long tried by natural ills received the...
2617I love thee freely as men strive for right. I ...I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...i love thee freely as men strive for right i l...i love thee freely as men strive for right i l...
2618I love thee with the passion put to use, in my...I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...i love thee with the passion put to use in my ...i love thee with the passion put to use in my ...
2619I love thee with the love I seemed to lose wit...I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ...i love thee with the love i seemed to lose wit...i love thee with a love i seemed to lose with ...
\n", - "

2620 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " hypothesis \\\n", - "0 He hoped there would be stew for dinner, turni... \n", - "1 Stuffered into you, his belly counseled him. \n", - "2 After early nightfall the yellow lamps would l... \n", - "3 Hello Bertie, any good in your mind? \n", - "4 Number 10. Fresh Nelly is waiting on you. Good... \n", - "... ... \n", - "2615 Oh, to shoot my soul's full meaning into futur... \n", - "2616 Then I, long tried by natural ills, received t... \n", - "2617 I love thee freely as men strive for right. I ... \n", - "2618 I love thee with the passion put to use, in my... \n", - "2619 I love thee with the love I seemed to lose wit... \n", - "\n", - " reference \\\n", - "0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n", - "1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n", - "2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n", - "3 HELLO BERTIE ANY GOOD IN YOUR MIND \n", - "4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n", - "... ... \n", - "2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n", - "2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n", - "2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n", - "2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n", - "2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n", - "\n", - " hypothesis_clean \\\n", - "0 he hoped there would be stew for dinner turnip... \n", - "1 stuffered into you his belly counseled him \n", - "2 after early nightfall the yellow lamps would l... \n", - "3 hello bertie any good in your mind \n", - "4 number 10 fresh nelly is waiting on you good n... \n", - "... ... \n", - "2615 0 to shoot my soul is full meaning into future... \n", - "2616 then i long tried by natural ills received the... \n", - "2617 i love thee freely as men strive for right i l... \n", - "2618 i love thee with the passion put to use in my ... \n", - "2619 i love thee with the love i seemed to lose wit... \n", - "\n", - " reference_clean \n", - "0 he hoped there would be stew for dinner turnip... \n", - "1 stuff it into you his belly counseled him \n", - "2 after early nightfall the yellow lamps would l... \n", - "3 hello bertie any good in your mind \n", - "4 number 10 fresh nelly is waiting on you good n... \n", - "... ... \n", - "2615 0 to shoot my soul is full meaning into future... \n", - "2616 then i long tried by natural ills received the... \n", - "2617 i love thee freely as men strive for right i l... \n", - "2618 i love thee with the passion put to use in my ... \n", - "2619 i love thee with a love i seemed to lose with ... \n", - "\n", - "[2620 rows x 4 columns]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ "data[\"hypothesis_clean\"] = [normalizer(text) for text in data[\"hypothesis\"]]\n", "data[\"reference_clean\"] = [normalizer(text) for text in data[\"reference\"]]\n", "data" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -566,20 +251,12 @@ "id": "EBGSITeBYPTT", "outputId": "7b3dbe7c-a37e-4a07-a50a-b27d5f88b68f" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WER: 4.26 %\n" - ] - } - ], "source": [ "wer = jiwer.wer(list(data[\"reference_clean\"]), list(data[\"hypothesis_clean\"]))\n", "\n", "print(f\"WER: {wer * 100:.2f} %\")" - ] + ], + "outputs": [] } ], "metadata": { diff --git a/whisper/__init__.py b/whisper/__init__.py index d7fbba3..a69805a 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -14,6 +14,7 @@ from .model import ModelDimensions, Whisper from .transcribe import transcribe from .version import __version__ +# what are these models? a: they are the pre-trained models that are available for use _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", @@ -46,10 +47,12 @@ _ALIGNMENT_HEADS = { "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", } - +# q: download the model from the given url and save it to the given root directory def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: os.makedirs(root, exist_ok=True) + # what is sha256? + # a: it is a cryptographic hash function that produces a fixed-size hash value expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, os.path.basename(url)) @@ -59,6 +62,7 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: if os.path.isfile(download_target): with open(download_target, "rb") as f: model_bytes = f.read() + # what is the purpose of this if statement? a: to check if the SHA256 checksum matches if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: return model_bytes if in_memory else download_target else: @@ -66,6 +70,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" ) + # is the following line re-downloading the model? a: yes + # so this function checks whether the model is already downloaded and if not, it downloads the model? + # a: yes with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm( total=int(source.info().get("Content-Length")), @@ -88,14 +95,17 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." ) + # is this the checkpoint file? a: yes return model_bytes if in_memory else download_target - +# q: what is the purpose of this function? a: to return the names of the available models def available_models() -> List[str]: """Returns the names of available models""" return list(_MODELS.keys()) - +# q: what is the purpose of this function? a: to load the model from the given name +# what does -> Whisper in Python mean? a: it means that the function returns an object of type Whisper +# load 一个模型,返回一个Whisper对象 def load_model( name: str, device: Optional[Union[str, torch.device]] = None, @@ -140,14 +150,24 @@ def load_model( f"Model {name} not found; available models = {available_models()}" ) + # what is "with" in Python? + # a: it is used to open a file and automatically close it after the block of code is executed with ( + # what if checkpoint_file is in memory? a: it uses io.BytesIO to read the file + # what if checkpoint_file is not in memory? a: it uses open to read the file io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") ) as fp: checkpoint = torch.load(fp, map_location=device) del checkpoint_file + # what is the **checkpoint["dims"]? a: it unpacks the dictionary into keyword arguments + # are arguments in ModelDimensions nullable? a: no + # so what if the checkpoint["dims"] is missing? a: it will raise an error + # how to confirm that checkpoint contains the "dims" key? a: by checking the keys of the dictionary dims = ModelDimensions(**checkpoint["dims"]) model = Whisper(dims) + + # what is load_state_dict? a: it loads the model weights model.load_state_dict(checkpoint["model_state_dict"]) if alignment_heads is not None: diff --git a/whisper/audio.py b/whisper/audio.py index cf6c66a..ebf5e1a 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -1,3 +1,7 @@ +""" +q: what is the usage of this file? a: this file contains the audio processing functions +""" + import os from functools import lru_cache from subprocess import CalledProcessError, run @@ -11,10 +15,13 @@ from .utils import exact_div # hard-coded audio hyperparameters SAMPLE_RATE = 16000 -N_FFT = 400 +N_FFT = 400 # 25ms window HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk + +# what is frame? a: a frame is a short segment of audio, usually 10ms +# what is frame used for? a: it is used to compute the spectrogram N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 @@ -135,23 +142,23 @@ def log_mel_spectrogram( torch.Tensor, shape = (80, n_frames) A Tensor that contains the Mel spectrogram """ - if not torch.is_tensor(audio): + if not torch.is_tensor(audio): # load audio if not already a tensor if isinstance(audio, str): - audio = load_audio(audio) - audio = torch.from_numpy(audio) + audio = load_audio(audio) # load audio from file + audio = torch.from_numpy(audio) # convert to tensor if device is not None: audio = audio.to(device) if padding > 0: - audio = F.pad(audio, (0, padding)) - window = torch.hann_window(N_FFT).to(audio.device) - stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) - magnitudes = stft[..., :-1].abs() ** 2 + audio = F.pad(audio, (0, padding)) # pad audio to the right + window = torch.hann_window(N_FFT).to(audio.device) # create a Hann window + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) # compute STFT + magnitudes = stft[..., :-1].abs() ** 2 # compute magnitudes - filters = mel_filters(audio.device, n_mels) - mel_spec = filters @ magnitudes + filters = mel_filters(audio.device, n_mels) # load mel filters + mel_spec = filters @ magnitudes # apply mel filters - log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.clamp(mel_spec, min=1e-10).log10() # compute log spectrogram log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d0..4e6598b 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -1,3 +1,7 @@ +""" +q: what is the usage of this file? a: this file contains the audio processing functions +""" + from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -152,6 +156,7 @@ class PyTorchInference(Inference): value_modules = [block.attn.value for block in self.model.decoder.blocks] self.kv_modules = key_modules + value_modules + # forward pass through the decoder, with key-value caching def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: if not self.kv_cache: self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() diff --git a/whisper/model.py b/whisper/model.py index a678283..91b7cd9 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -8,11 +8,14 @@ import torch import torch.nn.functional as F from torch import Tensor, nn +# q: why the decoding has a dot before it? a: it is a relative import +# q: what is relative import? a: it is a way to import modules from the same package from .decoding import decode as decode_function from .decoding import detect_language as detect_language_function from .transcribe import transcribe as transcribe_function +# Q: what is ModelDimensions? a: it is a data class that stores the dimensions of the model @dataclass class ModelDimensions: n_mels: int @@ -27,13 +30,18 @@ class ModelDimensions: n_text_layer: int +# q: What is layer norm? a: https://arxiv.org/abs/1607.06450 +# q: explain it in short words? a: it normalizes the input tensor across the last dimension +# you are so cool! thanks! I know! 😎 class LayerNorm(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: return super().forward(x.float()).type(x.dtype) - +# q: what is the usage of this class? a: it is a linear layer that converts the input tensor to the output tensor class Linear(nn.Linear): def forward(self, x: Tensor) -> Tensor: + # q: what is F.linear? a: it is a function that applies a linear transformation to the input tensor + # q: what is F here? a: it is the torch.nn.functional module return F.linear( x, self.weight.to(x.dtype), @@ -41,15 +49,19 @@ class Linear(nn.Linear): ) +# q: what is the usage of this class? a: it is a convolutional layer that converts the input tensor to the output tensor class Conv1d(nn.Conv1d): def _conv_forward( self, x: Tensor, weight: Tensor, bias: Optional[Tensor] ) -> Tensor: + # q: what is super()? a: it is a reference to the parent class + #q: what is the parent class here? a: it is the nn.Conv1d class return super()._conv_forward( x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) ) +# q: what is the usage of this function? a: it returns sinusoids for positional embedding def sinusoids(length, channels, max_timescale=10000): """Returns sinusoids for positional embedding""" assert channels % 2 == 0 @@ -58,8 +70,9 @@ def sinusoids(length, channels, max_timescale=10000): scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) - +# q: what is the usage of this class? a: it is a multi-head attention layer class MultiHeadAttention(nn.Module): + # what is n_state? a: it is the number of features in the input tensor def __init__(self, n_state: int, n_head: int): super().__init__() self.n_head = n_head @@ -107,11 +120,15 @@ class MultiHeadAttention(nn.Module): w = F.softmax(qk, dim=-1).to(q.dtype) return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() - +# q: what is the usage of this class? a: it is a residual attention block class ResidualAttentionBlock(nn.Module): + # q: what is cross attention? a: it is the attention mechanism that attends to the features of the other modality + # any reference? a: https://arxiv.org/abs/1706.03762 + # why we need cross attention? a: it helps to align the audio and text features def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): super().__init__() + # what is n_state? a: it is the number of features in the input tensor self.attn = MultiHeadAttention(n_state, n_head) self.attn_ln = LayerNorm(n_state) @@ -121,6 +138,8 @@ class ResidualAttentionBlock(nn.Module): self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None n_mlp = n_state * 4 + + # q: what is mlp? a: it is a multi-layer perceptron self.mlp = nn.Sequential( Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) ) @@ -139,7 +158,7 @@ class ResidualAttentionBlock(nn.Module): x = x + self.mlp(self.mlp_ln(x)) return x - +# q: what is the usage of this class? a: it is a model that transcribes the audio to text class AudioEncoder(nn.Module): def __init__( self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int @@ -154,6 +173,10 @@ class AudioEncoder(nn.Module): ) self.ln_post = LayerNorm(n_state) + + # what is ctx? a: it is the context size + # what is context size? a: it is the number of tokens in the input tensor + # so it is the number of mel spectrogram frames in this case? a: yes def forward(self, x: Tensor): """ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) @@ -173,6 +196,7 @@ class AudioEncoder(nn.Module): return x +# q: what is the usage of this class? a: it is a model that transcribes the audio to text class TextDecoder(nn.Module): def __init__( self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int @@ -217,33 +241,46 @@ class TextDecoder(nn.Module): return logits - +# so the whisper is made of an audio encoder and a text decoder? a: yes +# what is the usage of this class? a: it is a model that transcribes the audio to text class Whisper(nn.Module): def __init__(self, dims: ModelDimensions): super().__init__() self.dims = dims self.encoder = AudioEncoder( - self.dims.n_mels, - self.dims.n_audio_ctx, - self.dims.n_audio_state, - self.dims.n_audio_head, - self.dims.n_audio_layer, + self.dims.n_mels, # the number of mel spectrogram frames + self.dims.n_audio_ctx, # the number of tokens in the audio tensor + self.dims.n_audio_state, # the number of features in the audio tensor + self.dims.n_audio_head, # the number of heads in the audio tensor + self.dims.n_audio_layer, # the number of layers in the audio tensor ) self.decoder = TextDecoder( - self.dims.n_vocab, - self.dims.n_text_ctx, - self.dims.n_text_state, - self.dims.n_text_head, - self.dims.n_text_layer, + self.dims.n_vocab, # the number of tokens in the text tensor + self.dims.n_text_ctx, # the number of tokens in the text tensor + self.dims.n_text_state, # the number of features in the text tensor + self.dims.n_text_head, # the number of heads in the text tensor + self.dims.n_text_layer, # the number of layers in the text tensor + # you are so clever! thanks! 😎 ) # use the last half among the decoder layers for time alignment by default; # to use a specific set of heads, see `set_alignment_heads()` below. + + # what is all_heads? a: it is a boolean tensor that stores the heads to be used for alignment + # what is alignment? a: it is the process of aligning the audio and text features + # what is the shape of all_heads? a: it is (n_text_layer, n_text_head) + # why it is of this shape? a: it is because the alignment is done on the text tensor all_heads = torch.zeros( self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool ) + # what does it mean? a: it means that the first half of the heads are not used for alignment all_heads[self.dims.n_text_layer // 2 :] = True + # what is register_buffer? a: it is a method that registers a tensor as a buffer + # what is a buffer? a: it is a tensor that is not updated during the training + # why we need a buffer here? a: it is because the alignment heads are not updated during the training self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + # what is the usage of this function? a: it sets the alignment heads + # what is alignment heads? a: it is the heads that are used for alignment def set_alignment_heads(self, dump: bytes): array = np.frombuffer( gzip.decompress(base64.b85decode(dump)), dtype=bool @@ -264,6 +301,7 @@ class Whisper(nn.Module): ) -> Dict[str, torch.Tensor]: return self.decoder(tokens, self.encoder(mel)) + # q: what is the usage of @property? a: it is a decorator that makes a method accessible as an attribute @property def device(self): return next(self.parameters()).device @@ -276,6 +314,7 @@ class Whisper(nn.Module): def num_languages(self): return self.dims.n_vocab - 51765 - int(self.is_multilingual) + # q: what is the usage of this function? a: it installs hooks to save the intermediate tensors def install_kv_cache_hooks(self, cache: Optional[dict] = None): """ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value @@ -293,16 +332,33 @@ class Whisper(nn.Module): cache = {**cache} if cache is not None else {} hooks = [] + # what does output.shape[1] > self.dims.n_text_ctx mean? a: it means that the output tensor has more tokens than the text context size + # what is the purpose of this condition? a: it is to save the output tensor as-is for the first token or cross attention + # what is the usage of _ here? a: it is a placeholder for the input tensor + # but _ is not used in the function? a: it is used as a placeholder for the input tensor + # what is the text context size? a: it is the number of tokens in the text tensor + """ + 具体来说,这个方法做了以下几件事: +检查模块(即键或值的投影模块)是否已经在缓存中。如果不在,或者输出张量的第二个维度(代表令牌的数量)大于文本上下文的大小,那么就将输出张量存储在缓存中。 +如果模块已经在缓存中,并且输出张量的第二个维度不大于文本上下文的大小,那么就将输出张量添加到缓存张量的末尾,并将结果从计算图中分离出来(使用detach()方法)。 +最后,这个方法返回更新后的缓存张量。 +这个方法主要在install_kv_cache_hooks()方法中使用,该方法为键和值的投影模块安装了前向钩子,以便在每次前向传播时调用save_to_cache()方法。 + """ def save_to_cache(module, _, output): if module not in cache or output.shape[1] > self.dims.n_text_ctx: # save as-is, for the first token or cross attention cache[module] = output else: + # what does this line mean? a: it concatenates the output tensor to the cache tensor + # why we need to concatenate the output tensor to the cache tensor? a: it is to save the intermediate tensors + # what does detach() mean? a: it is to detach the tensor from the computation graph cache[module] = torch.cat([cache[module], output], dim=1).detach() return cache[module] + def install_hooks(layer: nn.Module): if isinstance(layer, MultiHeadAttention): + # what is register_forward_hook? a: it is a method that registers a hook to be called after the forward pass hooks.append(layer.key.register_forward_hook(save_to_cache)) hooks.append(layer.value.register_forward_hook(save_to_cache)) diff --git a/whisper/normalizers/english.py b/whisper/normalizers/english.py index 4932042..969769b 100644 --- a/whisper/normalizers/english.py +++ b/whisper/normalizers/english.py @@ -1,3 +1,12 @@ +""" +q: what is the usage of this file? a: this file contains the audio processing functions +q: do you think english.json is complicated? a: no, it's a simple mapping of british-american spellings +q: do you have a simpler way to realize this? a: yes, we can use a dictionary to map the words +q: so why doesn't english.json use a dictionary? a: it's easier to read and write the mappings in a json file +q: how to use the dictionary you mentioned, i mean what function to call? +a: we can use the dictionary in the EnglishSpellingNormalizer class +""" + import json import os import re diff --git a/whisper/timing.py b/whisper/timing.py index b695ead..bc07517 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -1,3 +1,11 @@ +""" +q: what is the usage of this file? +a: This file contains the implementation of the `find_alignment` function, +which is used to align the text tokens with the audio frames. +The `add_word_timestamps` function is used to add timestamps to the words in the segments. + +""" + import itertools import subprocess import warnings diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index 2af8375..5ceb14c 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -1,3 +1,6 @@ +""" +q: what is the usage of this file? a: this file is used to tokenize the text data +""" import base64 import os import string diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a2..2298a2c 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -1,3 +1,6 @@ +""" +q: what is the usage of this file? a: this file contains the audio processing functions +""" import argparse import os import traceback @@ -34,7 +37,8 @@ from .utils import ( if TYPE_CHECKING: from .model import Whisper - +# hard-coded audio hyperparameters +# def transcribe( model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], @@ -118,7 +122,9 @@ def transcribe( A dictionary containing the resulting text ("text") and segment-level details ("segments"), and the spoken language ("language"), which is detected when `decode_options["language"]` is None. """ - dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 + # what is fp16? a: half-precision floating-point format + # is fp16 better than fp32? a: fp16 is faster but less accurate + dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 # type: ignore if model.device == torch.device("cpu"): if torch.cuda.is_available(): warnings.warn("Performing inference on CPU when CUDA is available") @@ -130,8 +136,10 @@ def transcribe( decode_options["fp16"] = False # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) - content_frames = mel.shape[-1] - N_FRAMES + # why? a: to make sure the audio is long enough to be processed + mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) # type: ignore + # why it needs to minus N_FRAMES? a: to get the number of frames in the content + content_frames = mel.shape[-1] - N_FRAMES # number of frames in the content content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) if decode_options.get("language", None) is None: @@ -498,6 +506,7 @@ def transcribe( ) +# what does cli stand for? command line interface def cli(): from . import available_models diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py index edd4564..67cae48 100644 --- a/whisper/triton_ops.py +++ b/whisper/triton_ops.py @@ -1,3 +1,6 @@ +""" +q: what is the usage of this file? a: this file contains the audio processing functions +""" from functools import lru_cache import numpy as np diff --git a/whisper/utils.py b/whisper/utils.py index 9b9b138..7a9a3a9 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -1,3 +1,10 @@ +""" +q: what is the usage of this file? a: this file contains the utility functions + +q: what is the usage of meanwhile.json? a: it is used to store the results of the meanwhile tests +q: what is the meanwhile tests? a: it is a test suite for the whisper project +""" + import json import os import re