Merge 022b7aae8f0b91660ab5af4bd01ca39d12ccaaba into 173ff7dd1d9fb1c4fddea0d41d704cfefeb8908c

This commit is contained in:
Ashish Patel 2024-11-15 09:43:40 +09:00 committed by GitHub
commit c7d891e9bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 109 additions and 78 deletions

View File

@ -24,24 +24,20 @@ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audi
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary.
Parameters Parameters
---------- ----------
file: str file: str
The audio file to open The audio file to open.
sr: int sr: int
The sample rate to resample the audio if necessary The sample rate to resample the audio if necessary.
Returns Returns
------- -------
A NumPy array containing the audio waveform, in float32 dtype. A NumPy array containing the audio waveform, in float32 dtype.
""" """
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [ cmd = [
"ffmpeg", "ffmpeg",
"-nostdin", "-nostdin",
@ -53,7 +49,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"-ar", str(sr), "-ar", str(sr),
"-" "-"
] ]
# fmt: on
try: try:
out = run(cmd, capture_output=True, check=True).stdout out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e: except CalledProcessError as e:
@ -65,6 +60,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
""" """
Pad or trim the audio array to N_SAMPLES, as expected by the encoder. Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
Parameters
----------
array: Union[np.ndarray, torch.Tensor]
The audio array to pad or trim.
length: int
The desired length of the audio array.
axis: int
The axis along which to pad or trim.
Returns
-------
A padded or trimmed array.
""" """
if torch.is_tensor(array): if torch.is_tensor(array):
if array.shape[axis] > length: if array.shape[axis] > length:
@ -91,14 +101,20 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor: def mel_filters(device, n_mels: int) -> torch.Tensor:
""" """
load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed( Parameters
"mel_filters.npz", ----------
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), device: torch.device
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), The device to load the filters on.
)
n_mels: int
The number of Mel-frequency filters.
Returns
-------
torch.Tensor
The Mel filterbank matrix.
""" """
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
@ -114,44 +130,48 @@ def log_mel_spectrogram(
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
): ):
""" """
Compute the log-Mel spectrogram of Compute the log-Mel spectrogram of the audio.
Parameters Parameters
---------- ----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*) audio: Union[str, np.ndarray, torch.Tensor]
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz.
n_mels: int n_mels: int
The number of Mel-frequency filters, only 80 is supported The number of Mel-frequency filters.
padding: int padding: int
Number of zero samples to pad to the right Number of zero samples to pad to the right.
device: Optional[Union[str, torch.device]] device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT If given, the audio tensor is moved to this device before STFT.
Returns Returns
------- -------
torch.Tensor, shape = (80, n_frames) torch.Tensor
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram.
""" """
if not torch.is_tensor(audio): try:
if isinstance(audio, str): if not torch.is_tensor(audio):
audio = load_audio(audio) if isinstance(audio, str):
audio = torch.from_numpy(audio) audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None: if device is not None:
audio = audio.to(device) audio = audio.to(device)
if padding > 0: if padding > 0:
audio = F.pad(audio, (0, padding)) audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device) window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2 magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels) filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
return log_spec return log_spec
except Exception as e:
print(f"Error computing log-mel spectrogram: {e}")
return None

View File

@ -125,6 +125,7 @@ class DecodingResult:
no_speech_prob: float = np.nan no_speech_prob: float = np.nan
temperature: float = np.nan temperature: float = np.nan
compression_ratio: float = np.nan compression_ratio: float = np.nan
tokens_probs: list[float] = field(default_factory=list)
class Inference: class Inference:
@ -218,8 +219,8 @@ class TokenDecoder:
"""Initialize any stateful variables for decoding a new sequence""" """Initialize any stateful variables for decoding a new sequence"""
def update( def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list[list[float]]
) -> Tuple[Tensor, bool]: ) -> Tuple[Tensor, list, bool]:
"""Specify how to select the next token, based on the current trace and logits """Specify how to select the next token, based on the current trace and logits
Parameters Parameters
@ -275,8 +276,8 @@ class GreedyDecoder(TokenDecoder):
self.eot = eot self.eot = eot
def update( def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list
) -> Tuple[Tensor, bool]: ) -> Tuple[Tensor, list, bool]:
if self.temperature == 0: if self.temperature == 0:
next_tokens = logits.argmax(dim=-1) next_tokens = logits.argmax(dim=-1)
else: else:
@ -284,18 +285,25 @@ class GreedyDecoder(TokenDecoder):
logprobs = F.log_softmax(logits.float(), dim=-1) logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
current_probs = torch.exp(current_logprobs)
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
tokens_probs = [t_p + [c_p.item()] for t_p, c_p in zip(tokens_probs, current_probs)]
next_tokens[tokens[:, -1] == self.eot] = self.eot next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all() completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
return tokens, tokens_probs, completed
def finalize(self, tokens: Tensor, sum_logprobs: Tensor): def finalize(
# make sure each sequence has at least one EOT token at the end self, tokens: Tensor, tokens_probs: list, sum_logprobs: Tensor
) -> Tuple[Tensor, list, list]:
tokens = F.pad(tokens, (0, 1), value=self.eot) tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist() tokens_probs = [[ t + [1.0] for t in s] for s in tokens_probs]
return tokens, tokens_probs, sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder): class BeamSearchDecoder(TokenDecoder):
@ -321,37 +329,39 @@ class BeamSearchDecoder(TokenDecoder):
self.finished_sequences = None self.finished_sequences = None
def update( def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list
) -> Tuple[Tensor, bool]: ) -> Tuple[Tensor, list, bool]:
if tokens.shape[0] % self.beam_size != 0: if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)] self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1) logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], [] next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio): for i in range(n_audio):
scores, sources, finished = {}, {}, {} scores, sources, finished, probs = {}, {}, {}, {}
# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size): for j in range(self.beam_size):
idx = i * self.beam_size + j idx = i * self.beam_size + j
prefix = tokens[idx].tolist() prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
prob = torch.exp(logprob).item()
new_logprob = (sum_logprobs[idx] + logprob).item() new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()]) sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob scores[sequence] = new_logprob
sources[sequence] = idx sources[sequence] = idx
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio probs[sequence] = tokens_probs[idx] + [prob]
saved = 0 saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True): for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot: if sequence[-1] == self.eot:
finished[sequence] = scores[sequence] finished[sequence] = (scores[sequence], probs[sequence])
else: else:
sum_logprobs[len(next_tokens)] = scores[sequence] sum_logprobs[len(next_tokens)] = scores[sequence]
tokens_probs[len(next_tokens)] = probs[sequence]
next_tokens.append(sequence) next_tokens.append(sequence)
source_indices.append(sources[sequence]) source_indices.append(sources[sequence])
@ -364,44 +374,42 @@ class BeamSearchDecoder(TokenDecoder):
tokens = torch.tensor(next_tokens, device=tokens.device) tokens = torch.tensor(next_tokens, device=tokens.device)
self.inference.rearrange_kv_cache(source_indices) self.inference.rearrange_kv_cache(source_indices)
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences) assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip( for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates: if len(previously_finished) >= self.max_candidates:
break # the candidate list is full break
previously_finished[seq] = newly_finished[seq] previously_finished[seq] = newly_finished[seq]
# mark as completed if all audio has enough number of samples
completed = all( completed = all(
len(sequences) >= self.max_candidates len(sequences) >= self.max_candidates for sequences in self.finished_sequences
for sequences in self.finished_sequences
) )
return tokens, completed
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): return tokens, tokens_probs, completed
# collect all finished sequences, including patience, and add unfinished ones if not enough
def finalize(
self, preceding_tokens: Tensor, preceding_tokens_prob: list, sum_logprobs: Tensor
) -> Tuple[list, list, list]:
sum_logprobs = sum_logprobs.cpu() sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences): for i, sequences in enumerate(self.finished_sequences):
if ( if len(sequences) < self.beam_size:
len(sequences) < self.beam_size
): # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]: for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot] sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item() sequences[tuple(sequence)] = (sum_logprobs[i][j].item(), preceding_tokens_prob[i][j] + [1.0])
if len(sequences) >= self.beam_size: if len(sequences) >= self.beam_size:
break break
tokens: List[List[Tensor]] = [ tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
for sequences in self.finished_sequences
] ]
sum_logprobs: List[List[float]] = [ sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences [v[0] for v in sequences.values()] for sequences in self.finished_sequences
] ]
return tokens, sum_logprobs tokens_probs: list[list[list[float]]] = [
[v[1] for v in sequences.values()] for sequences in self.finished_sequences
]
return tokens, tokens_probs, sum_logprobs
class LogitFilter: class LogitFilter:
@ -700,7 +708,8 @@ class DecodingTask:
logit_filter.apply(logits, tokens) logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens # expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) tokens, tokens_probs, completed = self.decoder.update(tokens, logits, sum_logprobs, tokens_probs)
if completed or tokens.shape[-1] > self.n_ctx: if completed or tokens.shape[-1] > self.n_ctx:
break break
@ -734,7 +743,7 @@ class DecodingTask:
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop # call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) tokens, sum_logprobs, no_speech_probs, tokens_probs = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions # reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group] audio_features = audio_features[:: self.n_group]
@ -747,8 +756,10 @@ class DecodingTask:
# get the final candidates for each group, and slice between the first sampled token and EOT # get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [ tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
for s in tokens ]
tokens_probs: list[list[list[float]]] = [
[probs[:tokens.shape[0]]for probs, tokens in zip(s, t)] for s, t in zip(tokens_probs, tokens)
] ]
# select the top-ranked sample in each group # select the top-ranked sample in each group