From 3e4bac324d751145ac6c9ec76964e8ef6517c320 Mon Sep 17 00:00:00 2001 From: Steven Van Ingelgem Date: Wed, 22 Nov 2023 06:42:40 +0100 Subject: [PATCH 1/2] (slightly) More robust conversion --- whisper/utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/whisper/utils.py b/whisper/utils.py index 7a172c4..41d1bda 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -26,20 +26,21 @@ def exact_div(x, y): return x // y -def str2bool(string): - str2val = {"True": True, "False": False} - if string in str2val: - return str2val[string] - else: - raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") +def str2bool(string: str) -> bool: + if string.lower() in {'true', 'yes', 'y', '1'}: + return True + if string.lower() in {'false', 'no', 'n', '0'}: + return False + + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") def optional_int(string): - return None if string == "None" else int(string) + return None if string.lower() == "none" else int(string) def optional_float(string): - return None if string == "None" else float(string) + return None if string.lower() == "none" else float(string) def compression_ratio(text) -> float: From 8c27cca65bf11e814488d3ac98849947fd5fd034 Mon Sep 17 00:00:00 2001 From: Steven Van Ingelgem Date: Wed, 22 Nov 2023 06:51:35 +0100 Subject: [PATCH 2/2] Added tests for the changes. --- tests/test_utils.py | 53 +++++++++++++++++++++++++++++++++++++++++++++ whisper/utils.py | 2 +- 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b8522d5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,53 @@ +import pytest +from typing import Optional + +from whisper.utils import optional_float, optional_int, str2bool + + +@pytest.mark.parametrize(("provided", "expected"), [ + ("TRUE", True), + ("True", True), + ("true", True), + ("YES", True), + ("Yes", True), + ("yes", True), + ("Y", True), + ("y", True), + ("1", True), + + ("FALSE", False), + ("False", False), + ("false", False), + ("NO", False), + ("No", False), + ("no", False), + ("N", False), + ("n", False), + ("0", False), +]) +def test_str2bool(provided: str, expected: bool) -> None: + assert str2bool(provided) is expected + + +def test_str2bool_faulty_argument() -> None: + with pytest.raises(ValueError, match="Expected one of"): + str2bool("boom") + + +@pytest.mark.parametrize(("provided", "expected"), [ + ("1", 1), + ("None", None), + ("none", None), +]) +def test_optional_int(provided: str, expected: Optional[int]) -> None: + assert optional_int(provided) == expected + + +@pytest.mark.parametrize(("provided", "expected"), [ + ("1.23", 1.23), + ("1", 1), + ("None", None), + ("none", None), +]) +def test_optional_float(provided: str, expected: Optional[float]) -> None: + assert optional_float(provided) == expected diff --git a/whisper/utils.py b/whisper/utils.py index 41d1bda..b3f4783 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -32,7 +32,7 @@ def str2bool(string: str) -> bool: if string.lower() in {'false', 'no', 'n', '0'}: return False - raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + raise ValueError(f"Expected one of true/yes/1 or false/no/0, but got {string}") def optional_int(string):