diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3796a39..dffc17c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,8 +6,38 @@ on: pull_request: branches: - main + jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Fetch base branch + run: git fetch origin ${{ github.base_ref }} + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + architecture: x64 + - name: Get pip cache dir + id: pip-cache + run: | + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip/pre-commit cache + uses: actions/cache@v3 + with: + path: | + ${{ steps.pip-cache.outputs.dir }} + ~/.cache/pre-commit + key: ${{ runner.os }}-pip-pre-commit-${{ hashFiles('**/.pre-commit-config.yaml') }} + restore-keys: | + ${{ runner.os }}-pip-pre-commit + - name: pre-commit + run: | + pip install -U pre-commit + pre-commit install --install-hooks + pre-commit run --all-files whisper-test: + needs: pre-commit runs-on: ubuntu-latest strategy: matrix: @@ -23,7 +53,4 @@ jobs: - uses: actions/checkout@v3 - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: pip install .["dev"] - - run: black --check --diff -t py38 --include '(\.pyi?)$' . - - run: isort --check --diff . - - run: flake8 --ignore E203,W503,W504,E501,E731,E741 . - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3f5a74b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,28 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-json + - id: end-of-file-fixer + types: [file, python] + - id: trailing-whitespace + types: [file, python] + - id: mixed-line-ending + - id: check-added-large-files + args: [--maxkb=4096] + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"] + - repo: https://github.com/pycqa/flake8.git + rev: 6.0.0 + hooks: + - id: flake8 + types: [python] + args: ["--max-line-length", "88", "--ignore", "E203,E501,W503,W504"] diff --git a/CHANGELOG.md b/CHANGELOG.md index a77d966..50c0536 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # CHANGELOG +## [v20230918](https://github.com/openai/whisper/releases/tag/v20230918) + +* Add .pre-commit-config.yaml ([#1528](https://github.com/openai/whisper/pull/1528)) +* fix doc of TextDecoder ([#1526](https://github.com/openai/whisper/pull/1526)) +* Update model-card.md ([#1643](https://github.com/openai/whisper/pull/1643)) +* word timing tweaks ([#1559](https://github.com/openai/whisper/pull/1559)) +* Avoid rearranging all caches ([#1483](https://github.com/openai/whisper/pull/1483)) +* Improve timestamp heuristics. ([#1461](https://github.com/openai/whisper/pull/1461)) +* fix condition_on_previous_text ([#1224](https://github.com/openai/whisper/pull/1224)) +* Fix numba depreceation notice ([#1233](https://github.com/openai/whisper/pull/1233)) +* Updated README.md to provide more insight on BLEU and specific appendices ([#1236](https://github.com/openai/whisper/pull/1236)) +* Avoid computing higher temperatures on no_speech segments ([#1279](https://github.com/openai/whisper/pull/1279)) +* Dropped unused execute bit from mel_filters.npz. ([#1254](https://github.com/openai/whisper/pull/1254)) +* Drop ffmpeg-python dependency and call ffmpeg directly. ([#1242](https://github.com/openai/whisper/pull/1242)) +* Python 3.11 ([#1171](https://github.com/openai/whisper/pull/1171)) +* Update decoding.py ([#1219](https://github.com/openai/whisper/pull/1219)) +* Update decoding.py ([#1155](https://github.com/openai/whisper/pull/1155)) +* Update README.md to reference tiktoken ([#1105](https://github.com/openai/whisper/pull/1105)) +* Implement max line width and max line count, and make word highlighting optional ([#1184](https://github.com/openai/whisper/pull/1184)) +* Squash long words at window and sentence boundaries. ([#1114](https://github.com/openai/whisper/pull/1114)) +* python-publish.yml: bump actions version to fix node warning ([#1211](https://github.com/openai/whisper/pull/1211)) +* Update tokenizer.py ([#1163](https://github.com/openai/whisper/pull/1163)) + ## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314) * abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090)) diff --git a/model-card.md b/model-card.md index 2ed85cf..b5a571a 100644 --- a/model-card.md +++ b/model-card.md @@ -37,7 +37,7 @@ Sequence-to-sequence ASR (automatic speech recognition) and speech translation m ### Evaluated Use -The primary intended users of these models are AI researchers studying robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research. +The primary intended users of these models are AI researchers studying the robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research. The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them. @@ -53,17 +53,17 @@ As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we s ## Performance and Limitations -Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, technical language, as well as zero shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level. +Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, and technical language, as well as zero-shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level. However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself. -Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356). +Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include a higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356). -In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages. +In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis of these limitations is provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse in lower-resource and/or lower-discoverability languages. ## Broader Implications We anticipate that Whisper models’ transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box – their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications. -There are also potential dual use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects. +There are also potential dual-use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects. diff --git a/whisper/decoding.py b/whisper/decoding.py index 457ee7c..ecd98a4 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -146,6 +146,10 @@ class PyTorchInference(Inference): self.kv_cache = {} self.hooks = [] + key_modules = [block.attn.key for block in self.model.decoder.blocks] + value_modules = [block.attn.value for block in self.model.decoder.blocks] + self.kv_modules = key_modules + value_modules + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: if not self.kv_cache: self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() @@ -164,9 +168,10 @@ class PyTorchInference(Inference): self.hooks = [] def rearrange_kv_cache(self, source_indices): - for module, tensor in self.kv_cache.items(): - # update the key/value cache to contain the selected sequences - self.kv_cache[module] = tensor[source_indices].detach() + if source_indices != list(range(len(source_indices))): + for module in self.kv_modules: + # update the key/value cache to contain the selected sequences + self.kv_cache[module] = self.kv_cache[module][source_indices].detach() class SequenceRanker: @@ -668,7 +673,6 @@ class DecodingTask: return languages, lang_probs def _main_loop(self, audio_features: Tensor, tokens: Tensor): - assert audio_features.shape[0] == tokens.shape[0] n_batch = tokens.shape[0] sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) no_speech_probs = [np.nan] * n_batch @@ -721,8 +725,7 @@ class DecodingTask: ) ] - # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling - audio_features = audio_features.repeat_interleave(self.n_group, dim=0) + # repeat text tensors by the group size, for beam search or best-of-n sampling tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop diff --git a/whisper/model.py b/whisper/model.py index 3457fcf..6913002 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -197,7 +197,7 @@ class TextDecoder(nn.Module): """ x : torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens - xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) + xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) the encoded audio features to be attended on """ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 diff --git a/whisper/timing.py b/whisper/timing.py index 1a73eaa..befcf46 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -202,7 +202,7 @@ def find_alignment( hook.remove() # heads * tokens * frames - weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T]) + weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) weights = weights[:, :, : num_frames // 2] weights = (weights * qk_scale).softmax(dim=-1) std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) @@ -214,6 +214,13 @@ def find_alignment( text_indices, time_indices = dtw(-matrix) words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return [] word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) @@ -225,28 +232,6 @@ def find_alignment( for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) ] - # hack: truncate long words at the start of a window and the start of a sentence. - # a better segmentation algorithm based on VAD should be able to replace this. - word_durations = end_times - start_times - word_durations = word_durations[word_durations.nonzero()] - if len(word_durations) > 0: - median_duration = np.median(word_durations) - max_duration = median_duration * 2 - sentence_end_marks = ".。!!??" - # ensure words at sentence boundaries are not longer than twice the median word duration. - for i in range(1, len(start_times)): - if end_times[i] - start_times[i] > max_duration: - if words[i] in sentence_end_marks: - end_times[i] = start_times[i] + max_duration - elif words[i - 1] in sentence_end_marks: - start_times[i] = end_times[i] - max_duration - # ensure the first and second word is not longer than twice the median word duration. - if len(start_times) > 0 and end_times[0] - start_times[0] > max_duration: - if len(start_times) > 1 and end_times[1] - start_times[1] > max_duration: - boundary = max(end_times[1] / 2, end_times[1] - max_duration) - end_times[0] = start_times[1] = boundary - start_times[0] = max(0, end_times[0] - max_duration) - return [ WordTiming(word, tokens, start, end, probability) for word, tokens, start, end, probability in zip( @@ -298,6 +283,7 @@ def add_word_timestamps( num_frames: int, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + last_speech_timestamp: float, **kwargs, ): if len(segments) == 0: @@ -310,6 +296,23 @@ def add_word_timestamps( text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) + word_durations = np.array([t.end - t.start for t in alignment]) + word_durations = word_durations[word_durations.nonzero()] + median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i].end - alignment[i].start > max_duration: + if alignment[i].word in sentence_end_marks: + alignment[i].end = alignment[i].start + max_duration + elif alignment[i - 1].word in sentence_end_marks: + alignment[i].start = alignment[i].end - max_duration + merge_punctuations(alignment, prepend_punctuations, append_punctuations) time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE @@ -335,18 +338,48 @@ def add_word_timestamps( saved_tokens += len(timing.tokens) word_index += 1 + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. if len(words) > 0: - segment["start"] = words[0]["start"] - # hack: prefer the segment-level end timestamp if the last word is too long. - # a better segmentation algorithm based on VAD should be able to replace this. + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if ( + segment["start"] < words[0]["end"] + and segment["start"] - 0.5 > words[0]["start"] + ): + words[0]["start"] = max( + 0, min(words[0]["end"] - median_duration, segment["start"]) + ) + else: + segment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. if ( segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"] ): - # adjust the word-level timestamps based on the segment-level timestamps - words[-1]["end"] = segment["end"] + words[-1]["end"] = max( + words[-1]["start"] + median_duration, segment["end"] + ) else: - # adjust the segment-level timestamps based on the word-level timestamps segment["end"] = words[-1]["end"] + last_speech_timestamp = segment["end"] + segment["words"] = words diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index 4030e15..3b23991 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -226,7 +226,7 @@ class Tokenizer: @cached_property def all_language_codes(self) -> Tuple[str]: - return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) + return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 3a096ae..74e8c51 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -225,6 +225,7 @@ def transcribe( with tqdm.tqdm( total=content_frames, unit="frames", disable=verbose is not False ) as pbar: + last_speech_timestamp = 0.0 while seek < content_frames: time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) mel_segment = mel[:, seek : seek + N_FRAMES] @@ -324,10 +325,13 @@ def transcribe( num_frames=segment_size, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, + last_speech_timestamp=last_speech_timestamp, ) word_end_timestamps = [ w["end"] for s in current_segments for w in s["words"] ] + if len(word_end_timestamps) > 0: + last_speech_timestamp = word_end_timestamps[-1] if not single_timestamp_ending and len(word_end_timestamps) > 0: seek_shift = round( (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND diff --git a/whisper/version.py b/whisper/version.py index 572259a..c43bf6f 100644 --- a/whisper/version.py +++ b/whisper/version.py @@ -1 +1 @@ -__version__ = "20230314" +__version__ = "20230918"