kwargs in decode() for convenience (#1061)

* kwargs in decode() for convenience

* formatting fix
This commit is contained in:
Jong Wook Kim 2023-03-08 18:46:38 -05:00 committed by GitHub
parent 38f2f4d99d
commit c4b50c0824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
@ -778,7 +778,10 @@ class DecodingTask:
@torch.no_grad() @torch.no_grad()
def decode( def decode(
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions() model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]: ) -> Union[DecodingResult, List[DecodingResult]]:
""" """
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@ -802,6 +805,9 @@ def decode(
if single := mel.ndim == 2: if single := mel.ndim == 2:
mel = mel.unsqueeze(0) mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel) result = DecodingTask(model, options).run(mel)
return result[0] if single else result return result[0] if single else result