mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Merge 022b7aae8f0b91660ab5af4bd01ca39d12ccaaba into 173ff7dd1d9fb1c4fddea0d41d704cfefeb8908c
This commit is contained in:
commit
c7d891e9bd
@ -24,24 +24,20 @@ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audi
|
|||||||
|
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||||
"""
|
"""
|
||||||
Open an audio file and read as mono waveform, resampling as necessary
|
Open an audio file and read as mono waveform, resampling as necessary.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
file: str
|
file: str
|
||||||
The audio file to open
|
The audio file to open.
|
||||||
|
|
||||||
sr: int
|
sr: int
|
||||||
The sample rate to resample the audio if necessary
|
The sample rate to resample the audio if necessary.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A NumPy array containing the audio waveform, in float32 dtype.
|
A NumPy array containing the audio waveform, in float32 dtype.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This launches a subprocess to decode audio while down-mixing
|
|
||||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
|
||||||
# fmt: off
|
|
||||||
cmd = [
|
cmd = [
|
||||||
"ffmpeg",
|
"ffmpeg",
|
||||||
"-nostdin",
|
"-nostdin",
|
||||||
@ -53,7 +49,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
"-ar", str(sr),
|
"-ar", str(sr),
|
||||||
"-"
|
"-"
|
||||||
]
|
]
|
||||||
# fmt: on
|
|
||||||
try:
|
try:
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
except CalledProcessError as e:
|
except CalledProcessError as e:
|
||||||
@ -65,6 +60,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||||
"""
|
"""
|
||||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
array: Union[np.ndarray, torch.Tensor]
|
||||||
|
The audio array to pad or trim.
|
||||||
|
|
||||||
|
length: int
|
||||||
|
The desired length of the audio array.
|
||||||
|
|
||||||
|
axis: int
|
||||||
|
The axis along which to pad or trim.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A padded or trimmed array.
|
||||||
"""
|
"""
|
||||||
if torch.is_tensor(array):
|
if torch.is_tensor(array):
|
||||||
if array.shape[axis] > length:
|
if array.shape[axis] > length:
|
||||||
@ -91,14 +101,20 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
Load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||||
Allows decoupling librosa dependency; saved using:
|
|
||||||
|
|
||||||
np.savez_compressed(
|
Parameters
|
||||||
"mel_filters.npz",
|
----------
|
||||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
device: torch.device
|
||||||
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
The device to load the filters on.
|
||||||
)
|
|
||||||
|
n_mels: int
|
||||||
|
The number of Mel-frequency filters.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The Mel filterbank matrix.
|
||||||
"""
|
"""
|
||||||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||||
|
|
||||||
@ -114,27 +130,28 @@ def log_mel_spectrogram(
|
|||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute the log-Mel spectrogram of
|
Compute the log-Mel spectrogram of the audio.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
audio: Union[str, np.ndarray, torch.Tensor]
|
||||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz.
|
||||||
|
|
||||||
n_mels: int
|
n_mels: int
|
||||||
The number of Mel-frequency filters, only 80 is supported
|
The number of Mel-frequency filters.
|
||||||
|
|
||||||
padding: int
|
padding: int
|
||||||
Number of zero samples to pad to the right
|
Number of zero samples to pad to the right.
|
||||||
|
|
||||||
device: Optional[Union[str, torch.device]]
|
device: Optional[Union[str, torch.device]]
|
||||||
If given, the audio tensor is moved to this device before STFT
|
If given, the audio tensor is moved to this device before STFT.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor, shape = (80, n_frames)
|
torch.Tensor
|
||||||
A Tensor that contains the Mel spectrogram
|
A Tensor that contains the Mel spectrogram.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
@ -155,3 +172,6 @@ def log_mel_spectrogram(
|
|||||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
log_spec = (log_spec + 4.0) / 4.0
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
return log_spec
|
return log_spec
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error computing log-mel spectrogram: {e}")
|
||||||
|
return None
|
||||||
|
|||||||
@ -125,6 +125,7 @@ class DecodingResult:
|
|||||||
no_speech_prob: float = np.nan
|
no_speech_prob: float = np.nan
|
||||||
temperature: float = np.nan
|
temperature: float = np.nan
|
||||||
compression_ratio: float = np.nan
|
compression_ratio: float = np.nan
|
||||||
|
tokens_probs: list[float] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
@ -218,8 +219,8 @@ class TokenDecoder:
|
|||||||
"""Initialize any stateful variables for decoding a new sequence"""
|
"""Initialize any stateful variables for decoding a new sequence"""
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list[list[float]]
|
||||||
) -> Tuple[Tensor, bool]:
|
) -> Tuple[Tensor, list, bool]:
|
||||||
"""Specify how to select the next token, based on the current trace and logits
|
"""Specify how to select the next token, based on the current trace and logits
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -275,8 +276,8 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
self.eot = eot
|
self.eot = eot
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list
|
||||||
) -> Tuple[Tensor, bool]:
|
) -> Tuple[Tensor, list, bool]:
|
||||||
if self.temperature == 0:
|
if self.temperature == 0:
|
||||||
next_tokens = logits.argmax(dim=-1)
|
next_tokens = logits.argmax(dim=-1)
|
||||||
else:
|
else:
|
||||||
@ -284,18 +285,25 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
|
|
||||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||||
|
current_probs = torch.exp(current_logprobs)
|
||||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||||
|
|
||||||
|
tokens_probs = [t_p + [c_p.item()] for t_p, c_p in zip(tokens_probs, current_probs)]
|
||||||
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||||
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||||
|
|
||||||
completed = (tokens[:, -1] == self.eot).all()
|
completed = (tokens[:, -1] == self.eot).all()
|
||||||
return tokens, completed
|
|
||||||
|
|
||||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
return tokens, tokens_probs, completed
|
||||||
# make sure each sequence has at least one EOT token at the end
|
|
||||||
|
def finalize(
|
||||||
|
self, tokens: Tensor, tokens_probs: list, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Tensor, list, list]:
|
||||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||||
return tokens, sum_logprobs.tolist()
|
tokens_probs = [[ t + [1.0] for t in s] for s in tokens_probs]
|
||||||
|
return tokens, tokens_probs, sum_logprobs.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchDecoder(TokenDecoder):
|
class BeamSearchDecoder(TokenDecoder):
|
||||||
@ -321,37 +329,39 @@ class BeamSearchDecoder(TokenDecoder):
|
|||||||
self.finished_sequences = None
|
self.finished_sequences = None
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list
|
||||||
) -> Tuple[Tensor, bool]:
|
) -> Tuple[Tensor, list, bool]:
|
||||||
if tokens.shape[0] % self.beam_size != 0:
|
if tokens.shape[0] % self.beam_size != 0:
|
||||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||||
|
|
||||||
n_audio = tokens.shape[0] // self.beam_size
|
n_audio = tokens.shape[0] // self.beam_size
|
||||||
if self.finished_sequences is None: # for the first update
|
if self.finished_sequences is None:
|
||||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||||
|
|
||||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
next_tokens, source_indices, finished_sequences = [], [], []
|
next_tokens, source_indices, finished_sequences = [], [], []
|
||||||
for i in range(n_audio):
|
for i in range(n_audio):
|
||||||
scores, sources, finished = {}, {}, {}
|
scores, sources, finished, probs = {}, {}, {}, {}
|
||||||
|
|
||||||
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
|
||||||
for j in range(self.beam_size):
|
for j in range(self.beam_size):
|
||||||
idx = i * self.beam_size + j
|
idx = i * self.beam_size + j
|
||||||
prefix = tokens[idx].tolist()
|
prefix = tokens[idx].tolist()
|
||||||
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||||
|
prob = torch.exp(logprob).item()
|
||||||
new_logprob = (sum_logprobs[idx] + logprob).item()
|
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||||
sequence = tuple(prefix + [token.item()])
|
sequence = tuple(prefix + [token.item()])
|
||||||
scores[sequence] = new_logprob
|
scores[sequence] = new_logprob
|
||||||
sources[sequence] = idx
|
sources[sequence] = idx
|
||||||
|
|
||||||
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
probs[sequence] = tokens_probs[idx] + [prob]
|
||||||
|
|
||||||
saved = 0
|
saved = 0
|
||||||
for sequence in sorted(scores, key=scores.get, reverse=True):
|
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||||
if sequence[-1] == self.eot:
|
if sequence[-1] == self.eot:
|
||||||
finished[sequence] = scores[sequence]
|
finished[sequence] = (scores[sequence], probs[sequence])
|
||||||
else:
|
else:
|
||||||
sum_logprobs[len(next_tokens)] = scores[sequence]
|
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||||
|
tokens_probs[len(next_tokens)] = probs[sequence]
|
||||||
next_tokens.append(sequence)
|
next_tokens.append(sequence)
|
||||||
source_indices.append(sources[sequence])
|
source_indices.append(sources[sequence])
|
||||||
|
|
||||||
@ -364,44 +374,42 @@ class BeamSearchDecoder(TokenDecoder):
|
|||||||
tokens = torch.tensor(next_tokens, device=tokens.device)
|
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||||
self.inference.rearrange_kv_cache(source_indices)
|
self.inference.rearrange_kv_cache(source_indices)
|
||||||
|
|
||||||
# add newly finished sequences to self.finished_sequences
|
|
||||||
assert len(self.finished_sequences) == len(finished_sequences)
|
assert len(self.finished_sequences) == len(finished_sequences)
|
||||||
for previously_finished, newly_finished in zip(
|
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
||||||
self.finished_sequences, finished_sequences
|
|
||||||
):
|
|
||||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||||
if len(previously_finished) >= self.max_candidates:
|
if len(previously_finished) >= self.max_candidates:
|
||||||
break # the candidate list is full
|
break
|
||||||
previously_finished[seq] = newly_finished[seq]
|
previously_finished[seq] = newly_finished[seq]
|
||||||
|
|
||||||
# mark as completed if all audio has enough number of samples
|
|
||||||
completed = all(
|
completed = all(
|
||||||
len(sequences) >= self.max_candidates
|
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
||||||
for sequences in self.finished_sequences
|
|
||||||
)
|
)
|
||||||
return tokens, completed
|
|
||||||
|
|
||||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
return tokens, tokens_probs, completed
|
||||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
|
||||||
|
def finalize(
|
||||||
|
self, preceding_tokens: Tensor, preceding_tokens_prob: list, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[list, list, list]:
|
||||||
sum_logprobs = sum_logprobs.cpu()
|
sum_logprobs = sum_logprobs.cpu()
|
||||||
for i, sequences in enumerate(self.finished_sequences):
|
for i, sequences in enumerate(self.finished_sequences):
|
||||||
if (
|
if len(sequences) < self.beam_size:
|
||||||
len(sequences) < self.beam_size
|
|
||||||
): # when not enough sequences are finished
|
|
||||||
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||||
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||||
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
sequences[tuple(sequence)] = (sum_logprobs[i][j].item(), preceding_tokens_prob[i][j] + [1.0])
|
||||||
if len(sequences) >= self.beam_size:
|
if len(sequences) >= self.beam_size:
|
||||||
break
|
break
|
||||||
|
|
||||||
tokens: List[List[Tensor]] = [
|
tokens: List[List[Tensor]] = [
|
||||||
[torch.tensor(seq) for seq in sequences.keys()]
|
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
||||||
for sequences in self.finished_sequences
|
|
||||||
]
|
]
|
||||||
sum_logprobs: List[List[float]] = [
|
sum_logprobs: List[List[float]] = [
|
||||||
list(sequences.values()) for sequences in self.finished_sequences
|
[v[0] for v in sequences.values()] for sequences in self.finished_sequences
|
||||||
]
|
]
|
||||||
return tokens, sum_logprobs
|
tokens_probs: list[list[list[float]]] = [
|
||||||
|
[v[1] for v in sequences.values()] for sequences in self.finished_sequences
|
||||||
|
]
|
||||||
|
|
||||||
|
return tokens, tokens_probs, sum_logprobs
|
||||||
|
|
||||||
|
|
||||||
class LogitFilter:
|
class LogitFilter:
|
||||||
@ -700,7 +708,8 @@ class DecodingTask:
|
|||||||
logit_filter.apply(logits, tokens)
|
logit_filter.apply(logits, tokens)
|
||||||
|
|
||||||
# expand the tokens tensor with the selected next tokens
|
# expand the tokens tensor with the selected next tokens
|
||||||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
tokens, tokens_probs, completed = self.decoder.update(tokens, logits, sum_logprobs, tokens_probs)
|
||||||
|
|
||||||
|
|
||||||
if completed or tokens.shape[-1] > self.n_ctx:
|
if completed or tokens.shape[-1] > self.n_ctx:
|
||||||
break
|
break
|
||||||
@ -734,7 +743,7 @@ class DecodingTask:
|
|||||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
# call the main sampling loop
|
# call the main sampling loop
|
||||||
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
tokens, sum_logprobs, no_speech_probs, tokens_probs = self._main_loop(audio_features, tokens)
|
||||||
|
|
||||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||||
audio_features = audio_features[:: self.n_group]
|
audio_features = audio_features[:: self.n_group]
|
||||||
@ -747,8 +756,10 @@ class DecodingTask:
|
|||||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||||
tokens: List[List[Tensor]] = [
|
tokens: List[List[Tensor]] = [
|
||||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
||||||
for s in tokens
|
]
|
||||||
|
tokens_probs: list[list[list[float]]] = [
|
||||||
|
[probs[:tokens.shape[0]]for probs, tokens in zip(s, t)] for s, t in zip(tokens_probs, tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
# select the top-ranked sample in each group
|
# select the top-ranked sample in each group
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user