Source code for tldap.backend.base

# Copyright 2012-2014 Brian May
#
# This file is part of python-tldap.
#
# python-tldap is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# python-tldap is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with python-tldap  If not, see <http://www.gnu.org/licenses/>.

""" This module provides the LDAP base functions
with a subset of the functions from the real ldap module. """

import logging
import ssl
from typing import Callable, Generator, Optional, Tuple, TypeVar
from urllib.parse import urlparse

import ldap3
import ldap3.core.exceptions as exceptions


logger = logging.getLogger(__name__)


def _debug(*argv):
    argv = [str(arg) for arg in argv]
    logger.debug(" ".join(argv))


Entity = TypeVar('Entity')


[docs] class LdapBase(object): """ The vase LDAP connection class. """ def __init__(self, settings_dict: dict) -> None: self.settings_dict = settings_dict self._obj = None self._connection_class = ldap3.Connection
[docs] def close(self) -> None: if self._obj is not None: self._obj.unbind() self._obj = None
######################### # Connection Management # #########################
[docs] def set_connection_class(self, connection_class): self._connection_class = connection_class
[docs] def check_password(self, dn: str, password: str) -> bool: try: conn = self._connect(user=dn, password=password) conn.unbind() return True except exceptions.LDAPInvalidCredentialsResult: return False except exceptions.LDAPUnwillingToPerformResult: return False
def _connect(self, user: str, password: str) -> ldap3.Connection: settings = self.settings_dict _debug("connecting") url = urlparse(settings['URI']) if url.scheme == "ldaps": use_ssl = True elif url.scheme == "ldap": use_ssl = False else: raise RuntimeError("Unknown scheme '%s'" % url.scheme) if ":" in url.netloc: host, port = url.netloc.split(":") port = int(port) else: host = url.netloc if use_ssl: port = 636 else: port = 389 start_tls = False if 'START_TLS' in settings and settings['START_TLS']: start_tls = True tls = None if use_ssl or start_tls: tls = ldap3.Tls() if 'CIPHERS' in settings: tls.ciphers = settings['CIPHERS'] if 'TLS_CA' in settings and settings['TLS_CA']: tls.ca_certs_file = settings['TLS_CA'] if 'REQUIRE_TLS' in settings and settings['REQUIRE_TLS']: tls.validate = ssl.CERT_REQUIRED s = ldap3.Server(host, port=port, use_ssl=use_ssl, tls=tls) c = self._connection_class( s, # client_strategy=ldap3.STRATEGY_SYNC_RESTARTABLE, user=user, password=password, authentication=ldap3.SIMPLE) c.strategy.restartable_sleep_time = 0 c.strategy.restartable_tries = 1 c.raise_exceptions = True c.open() if start_tls: c.start_tls() try: c.bind() except: # noqa: E722 c.unbind() raise return c def _reconnect(self) -> None: settings = self.settings_dict try: self._obj = self._connect( user=settings['USER'], password=settings['PASSWORD']) except Exception: self._obj = None raise assert self._obj is not None def _do_with_retry(self, fn: Callable[[ldap3.Connection], Entity]) -> Entity: if self._obj is None: self._reconnect() assert self._obj is not None try: return fn(self._obj) except ldap3.core.exceptions.LDAPSessionTerminatedByServerError: # if it fails, reconnect then retry _debug("SERVER_DOWN, reconnecting") self._reconnect() return fn(self._obj) ################### # read only stuff # ###################
[docs] def search(self, base, scope, filterstr='(objectClass=*)', attrlist=None, limit=None) -> Generator[Tuple[str, dict], None, None]: """ Search for entries in LDAP database. """ _debug("search", base, scope, filterstr, attrlist, limit) # first results if attrlist is None: attrlist = ldap3.ALL_ATTRIBUTES elif isinstance(attrlist, set): attrlist = list(attrlist) def first_results(obj): _debug("---> searching ldap", limit) obj.search( base, filterstr, scope, attributes=attrlist, paged_size=limit) return obj.response # get the 1st result result_list = self._do_with_retry(first_results) # Loop over list of search results for result_item in result_list: # skip searchResRef for now if result_item['type'] != "searchResEntry": continue dn = result_item['dn'] attributes = result_item['raw_attributes'] # did we already retrieve this from cache? _debug("---> got ldap result", dn) _debug("---> yielding", result_item) yield (dn, attributes) # we are finished - return results, eat cake _debug("---> done") return
#################### # Cache Management # ####################
[docs] def reset(self, force_flush_cache: bool = False) -> None: """ Reset transaction back to original state, discarding all uncompleted transactions. """ pass
########################## # Transaction Management # ########################## # Fake it
[docs] def is_dirty(self) -> bool: """ Are there uncommitted changes? """ raise NotImplementedError()
[docs] def is_managed(self) -> bool: """ Are we inside transaction management? """ raise NotImplementedError()
[docs] def enter_transaction_management(self) -> None: """ Start a transaction. """ raise NotImplementedError()
[docs] def leave_transaction_management(self) -> None: """ End a transaction. Must not be dirty when doing so. ie. commit() or rollback() must be called if changes made. If dirty, changes will be discarded. """ raise NotImplementedError()
[docs] def commit(self) -> None: """ Attempt to commit all changes to LDAP database. i.e. forget all rollbacks. However stay inside transaction management. """ raise NotImplementedError()
[docs] def rollback(self) -> None: """ Roll back to previous database state. However stay inside transaction management. """ raise NotImplementedError()
################################## # Functions needing Transactions # ##################################
[docs] def add(self, dn: str, mod_list: dict) -> None: """ Add a DN to the LDAP database; See ldap module. Doesn't return a result if transactions enabled. """ raise NotImplementedError()
[docs] def modify(self, dn: str, mod_list: dict) -> None: """ Modify a DN in the LDAP database; See ldap module. Doesn't return a result if transactions enabled. """ raise NotImplementedError()
[docs] def modify_no_rollback(self, dn: str, mod_list: dict) -> None: """ Modify a DN in the LDAP database; See ldap module. Doesn't return a result if transactions enabled. """ raise NotImplementedError()
[docs] def delete(self, dn: str) -> None: """ delete a dn in the ldap database; see ldap module. doesn't return a result if transactions enabled. """ raise NotImplementedError()
[docs] def rename(self, dn: str, new_rdn: str, new_base_dn: Optional[str] = None) -> None: """ rename a dn in the ldap database; see ldap module. doesn't return a result if transactions enabled. """ raise NotImplementedError()