From dcfe23f2927c1c102b5f70f78ffa5ee9feb4d7c4 Mon Sep 17 00:00:00 2001 From: Troy McConaghy Date: Sun, 25 Nov 2018 20:24:03 +0100 Subject: [PATCH] Account for values that are arrays/lists (#2607) when checking if keys are valid. --- bigchaindb/backend/schema.py | 6 ++++-- bigchaindb/common/utils.py | 32 +++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/bigchaindb/backend/schema.py b/bigchaindb/backend/schema.py index 08fdbc0e..defd38ac 100644 --- a/bigchaindb/backend/schema.py +++ b/bigchaindb/backend/schema.py @@ -10,7 +10,7 @@ import logging import bigchaindb from bigchaindb.backend.connection import connect from bigchaindb.common.exceptions import ValidationError -from bigchaindb.common.utils import validate_all_values_for_key +from bigchaindb.common.utils import validate_all_values_for_key_in_obj, validate_all_values_for_key_in_list logger = logging.getLogger(__name__) @@ -101,7 +101,9 @@ def validate_language_key(obj, key): if backend == 'localmongodb': data = obj.get(key, {}) if isinstance(data, dict): - validate_all_values_for_key(data, 'language', validate_language) + validate_all_values_for_key_in_obj(data, 'language', validate_language) + elif isinstance(data, list): + validate_all_values_for_key_in_list(data, 'language', validate_language) def validate_language(value): diff --git a/bigchaindb/common/utils.py b/bigchaindb/common/utils.py index 4e1f8ca2..62b5e5c9 100644 --- a/bigchaindb/common/utils.py +++ b/bigchaindb/common/utils.py @@ -76,10 +76,20 @@ def validate_txn_obj(obj_name, obj, key, validation_fun): if backend == 'localmongodb': data = obj.get(key, {}) if isinstance(data, dict): - validate_all_keys(obj_name, data, validation_fun) + validate_all_keys_in_obj(obj_name, data, validation_fun) + elif isinstance(data, list): + validate_all_items_in_list(obj_name, data, validation_fun) -def validate_all_keys(obj_name, obj, validation_fun): +def validate_all_items_in_list(obj_name, data, validation_fun): + for item in data: + if isinstance(item, dict): + validate_all_keys_in_obj(obj_name, item, validation_fun) + elif isinstance(item, list): + validate_all_items_in_list(obj_name, item, validation_fun) + + +def validate_all_keys_in_obj(obj_name, obj, validation_fun): """Validate all (nested) keys in `obj` by using `validation_fun`. Args: @@ -97,10 +107,12 @@ def validate_all_keys(obj_name, obj, validation_fun): for key, value in obj.items(): validation_fun(obj_name, key) if isinstance(value, dict): - validate_all_keys(obj_name, value, validation_fun) + validate_all_keys_in_obj(obj_name, value, validation_fun) + elif isinstance(value, list): + validate_all_items_in_list(obj_name, value, validation_fun) -def validate_all_values_for_key(obj, key, validation_fun): +def validate_all_values_for_key_in_obj(obj, key, validation_fun): """Validate value for all (nested) occurrence of `key` in `obj` using `validation_fun`. @@ -117,7 +129,17 @@ def validate_all_values_for_key(obj, key, validation_fun): if vkey == key: validation_fun(value) elif isinstance(value, dict): - validate_all_values_for_key(value, key, validation_fun) + validate_all_values_for_key_in_obj(value, key, validation_fun) + elif isinstance(value, list): + validate_all_values_for_key_in_list(value, key, validation_fun) + + +def validate_all_values_for_key_in_list(input_list, key, validation_fun): + for item in input_list: + if isinstance(item, dict): + validate_all_values_for_key_in_obj(item, key, validation_fun) + elif isinstance(item, list): + validate_all_values_for_key_in_list(item, key, validation_fun) def validate_key(obj_name, key):