diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b8522d5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,53 @@ +import pytest +from typing import Optional + +from whisper.utils import optional_float, optional_int, str2bool + + +@pytest.mark.parametrize(("provided", "expected"), [ + ("TRUE", True), + ("True", True), + ("true", True), + ("YES", True), + ("Yes", True), + ("yes", True), + ("Y", True), + ("y", True), + ("1", True), + + ("FALSE", False), + ("False", False), + ("false", False), + ("NO", False), + ("No", False), + ("no", False), + ("N", False), + ("n", False), + ("0", False), +]) +def test_str2bool(provided: str, expected: bool) -> None: + assert str2bool(provided) is expected + + +def test_str2bool_faulty_argument() -> None: + with pytest.raises(ValueError, match="Expected one of"): + str2bool("boom") + + +@pytest.mark.parametrize(("provided", "expected"), [ + ("1", 1), + ("None", None), + ("none", None), +]) +def test_optional_int(provided: str, expected: Optional[int]) -> None: + assert optional_int(provided) == expected + + +@pytest.mark.parametrize(("provided", "expected"), [ + ("1.23", 1.23), + ("1", 1), + ("None", None), + ("none", None), +]) +def test_optional_float(provided: str, expected: Optional[float]) -> None: + assert optional_float(provided) == expected diff --git a/whisper/utils.py b/whisper/utils.py index 9b9b138..264f18b 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -26,20 +26,21 @@ def exact_div(x, y): return x // y -def str2bool(string): - str2val = {"True": True, "False": False} - if string in str2val: - return str2val[string] - else: - raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") +def str2bool(string: str) -> bool: + if string.lower() in {'true', 'yes', 'y', '1'}: + return True + if string.lower() in {'false', 'no', 'n', '0'}: + return False + + raise ValueError(f"Expected one of true/yes/1 or false/no/0, but got {string}") def optional_int(string): - return None if string == "None" else int(string) + return None if string.lower() == "none" else int(string) def optional_float(string): - return None if string == "None" else float(string) + return None if string.lower() == "none" else float(string) def compression_ratio(text) -> float: