diff --git a/bigchaindb/backend/connection.py b/bigchaindb/backend/connection.py index c1f0a629..b717703b 100644 --- a/bigchaindb/backend/connection.py +++ b/bigchaindb/backend/connection.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) def connect(backend=None, host=None, port=None, name=None, max_tries=None, - connection_timeout=None, replicaset=None): + connection_timeout=None, replicaset=None, ssl=None, login=None, password=None): """Create a new connection to the database backend. All arguments default to the current configuration's values if not @@ -50,6 +50,9 @@ def connect(backend=None, host=None, port=None, name=None, max_tries=None, # to handle these these additional args. In case of RethinkDBConnection # it just does not do anything with it. replicaset = replicaset or bigchaindb.config['database'].get('replicaset') + ssl = ssl if ssl is not None else bigchaindb.config['database'].get('ssl', False) + login = login or bigchaindb.config['database'].get('login') + password = password or bigchaindb.config['database'].get('password') try: module_name, _, class_name = BACKENDS[backend].rpartition('.') @@ -63,7 +66,7 @@ def connect(backend=None, host=None, port=None, name=None, max_tries=None, logger.debug('Connection: {}'.format(Class)) return Class(host=host, port=port, dbname=dbname, max_tries=max_tries, connection_timeout=connection_timeout, - replicaset=replicaset) + replicaset=replicaset, ssl=ssl, login=login, password=password) class Connection: diff --git a/bigchaindb/backend/mongodb/connection.py b/bigchaindb/backend/mongodb/connection.py index 8688e243..5c54470a 100644 --- a/bigchaindb/backend/mongodb/connection.py +++ b/bigchaindb/backend/mongodb/connection.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class MongoDBConnection(Connection): - def __init__(self, replicaset=None, **kwargs): + def __init__(self, replicaset=None, ssl=None, login=None, password=None, **kwargs): """Create a new Connection instance. Args: @@ -28,6 +28,9 @@ class MongoDBConnection(Connection): super().__init__(**kwargs) self.replicaset = replicaset or bigchaindb.config['database']['replicaset'] + self.ssl = ssl if ssl is not None else bigchaindb.config['database'].get('ssl', False) + self.login = login or bigchaindb.config['database'].get('login') + self.password = password or bigchaindb.config['database'].get('password') @property def db(self): @@ -71,14 +74,21 @@ class MongoDBConnection(Connection): # we should only return a connection if the replica set is # initialized. initialize_replica_set will check if the # replica set is initialized else it will initialize it. - initialize_replica_set(self.host, self.port, self.connection_timeout) + initialize_replica_set(self.host, self.port, self.connection_timeout, + self.dbname, self.ssl, self.login, self.password) # FYI: this might raise a `ServerSelectionTimeoutError`, # that is a subclass of `ConnectionFailure`. - return pymongo.MongoClient(self.host, - self.port, - replicaset=self.replicaset, - serverselectiontimeoutms=self.connection_timeout) + client = pymongo.MongoClient(self.host, + self.port, + replicaset=self.replicaset, + serverselectiontimeoutms=self.connection_timeout, + ssl=self.ssl) + + if self.login is not None and self.password is not None: + client[self.dbname].authenticate(self.login, self.password) + + return client # `initialize_replica_set` might raise `ConnectionFailure` or `OperationFailure`. except (pymongo.errors.ConnectionFailure, @@ -86,7 +96,7 @@ class MongoDBConnection(Connection): raise ConnectionError() from exc -def initialize_replica_set(host, port, connection_timeout): +def initialize_replica_set(host, port, connection_timeout, dbname, ssl, login, password): """Initialize a replica set. If already initialized skip.""" # Setup a MongoDB connection @@ -95,7 +105,12 @@ def initialize_replica_set(host, port, connection_timeout): # you try to connect to a replica set that is not yet initialized conn = pymongo.MongoClient(host=host, port=port, - serverselectiontimeoutms=connection_timeout) + serverselectiontimeoutms=connection_timeout, + ssl=ssl) + + if login is not None and password is not None: + conn[dbname].authenticate(login, password) + _check_replica_set(conn) host = '{}:{}'.format(bigchaindb.config['database']['host'], bigchaindb.config['database']['port']) diff --git a/tests/backend/mongodb/test_connection.py b/tests/backend/mongodb/test_connection.py index 6350a7c5..3edc31b1 100644 --- a/tests/backend/mongodb/test_connection.py +++ b/tests/backend/mongodb/test_connection.py @@ -99,6 +99,18 @@ def test_connection_run_errors(mock_client, mock_init_repl_set): assert query.run.call_count == 1 +@mock.patch('pymongo.database.Database.authenticate') +def test_connection_with_credentials(mock_authenticate): + import bigchaindb + from bigchaindb.backend.mongodb.connection import MongoDBConnection + conn = MongoDBConnection(host=bigchaindb.config['database']['host'], + port=bigchaindb.config['database']['port'], + login='theplague', + password='secret') + conn.connect() + assert mock_authenticate.call_count == 2 + + def test_check_replica_set_not_enabled(mongodb_connection): from bigchaindb.backend.mongodb.connection import _check_replica_set from bigchaindb.common.exceptions import ConfigurationError @@ -168,7 +180,7 @@ def test_initialize_replica_set(mock_cmd_line_opts): ] # check that it returns - assert initialize_replica_set('host', 1337, 1000) is None + assert initialize_replica_set('host', 1337, 1000, 'dbname', False, None, None) is None # test it raises OperationError if anything wrong with mock.patch.object(Database, 'command') as mock_command: @@ -178,4 +190,4 @@ def test_initialize_replica_set(mock_cmd_line_opts): ] with pytest.raises(pymongo.errors.OperationFailure): - initialize_replica_set('host', 1337, 1000) + initialize_replica_set('host', 1337, 1000, 'dbname', False, None, None)