Source code for tldap.database

# Copyright 2018 Brian May
#
# This file is part of python-tldap.
#
# python-tldap is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# python-tldap is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with python-tldap  If not, see <http://www.gnu.org/licenses/>.

""" High level database interaction. """
from typing import (
    Any,
    Dict,
    Iterator,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    TypeVar,
)

import ldap3.core
import ldap3.core.exceptions

import tldap.fields
import tldap.query
from tldap import Q
from tldap.backend.base import LdapBase
from tldap.dict import ImmutableDict
from tldap.dn import dn2str, str2dn
from tldap.exceptions import (
    MultipleObjectsReturned,
    ObjectAlreadyExists,
    ObjectDoesNotExist,
    ValidationError,
)


NotLoadedListType = List[Any] or 'NotLoadedList'


[docs] class SearchOptions: """ Application specific search options. """ def __init__(self, base_dn: str, object_class: Set[str], pk_field: str) -> None: self.base_dn = base_dn self.object_class = object_class self.pk_field = pk_field
[docs] class Database: def __init__(self, connection: LdapBase, settings: Optional[dict] = None): self._connection = connection if settings is None: settings = connection.settings_dict self._settings = settings @property def connection(self) -> LdapBase: return self._connection @property def settings(self) -> dict: return self._settings
[docs] def get_default_database(): return Database(tldap.backend.connections['default'])
[docs] def get_database(database: Optional[Database]) -> Database: if database is None: return get_default_database() else: return database
def _python_to_list(value: Any) -> NotLoadedListType: if isinstance(value, NotLoadedList): return value elif isinstance(value, list): return value elif isinstance(value, set): return list(value) elif value is None: return [] else: return [value] def _list_to_python(field: tldap.fields.Field, value: NotLoadedListType) -> Any: assert not field.is_list assert isinstance(value, (list, NotLoadedList)) if not field.is_list: if isinstance(value, NotLoadedList): new_value = value elif len(value) == 0: new_value = None elif len(value) == 1: new_value = value[0] else: raise RuntimeError("Non-list value has more then 1 value.") else: new_value = value return new_value LdapObjectEntity = TypeVar('LdapObjectEntity', bound='LdapObject') LdapObjectClass = Type['LdapObject']
[docs] class LdapObject(ImmutableDict): """ A high level python representation of a LDAP object. """ def __init__(self, d: Optional[dict] = None) -> None: self._fields = self.get_fields() field_names = set(self._fields.keys()) python_data: Dict[str, NotLoadedListType] = { field_name: [] for field_name in field_names } if d is not None: python_data.update(d) super().__init__(field_names, python_data)
[docs] @classmethod def get_fields(cls) -> Dict[str, tldap.fields.Field]: raise NotImplementedError()
[docs] @classmethod def get_search_options(cls, database: Database) -> SearchOptions: raise NotImplementedError()
[docs] @classmethod def on_load(cls, python_data: 'LdapObject', database: Database) -> 'LdapObject': raise NotImplementedError()
[docs] @classmethod def on_save(cls, changes: 'Changeset', database: Database) -> 'Changeset': raise NotImplementedError()
def _set(self, key: str, value: NotLoadedListType) -> None: value = _python_to_list(value) return super()._set(key, value) def __copy__(self: LdapObjectEntity) -> LdapObjectEntity: return self.__class__(self._dict) # def __getitem__(self, key: str) -> Any: # raise Moew() # value = self._dict[key] # key = self.fix_key(key) # field = self._fields[key] # return _list_to_python(field, value)
[docs] def get_as_single(self, key: str) -> Any: value = self._dict[key] key = self.fix_key(key) field = self._fields[key] return _list_to_python(field, value)
[docs] def get_as_list(self, key: str) -> NotLoadedListType: return self._dict[key]
ChangesetEntity = TypeVar('ChangesetEntity', bound='Changeset') Operation = ldap3.MODIFY_ADD or ldap3.MODIFY_REPLACE or ldap3.MODIFY_DELETE
[docs] class Changeset(ImmutableDict): """ Represents a set of changes to an LdapObject. """ def __init__(self, fields: Dict[str, tldap.fields.Field], src: LdapObject, d: Optional[dict] = None) -> None: self._fields = fields self._src = src self._changes: Dict[str, List[Tuple[Operation, List[Any]]]] = {} self._errors: List[str] = [] field_names = set(fields.keys()) super().__init__(field_names, d) def __copy__(self: ChangesetEntity) -> ChangesetEntity: copy = self.__class__(self._fields, self._src, self._dict) copy._changes = self._changes return copy
[docs] def get_value_as_single(self, key: str) -> any: key = self.fix_key(key) if key in self._dict: field = self._fields[key] return _list_to_python(field, self._dict[key]) else: return self._src.get_as_single(key)
[docs] def get_value_as_list(self, key: str) -> List[Any]: if key in self._dict: return self._dict[key] else: return self._src.get_as_list(key)
@property def changes(self) -> Dict[str, List[Tuple[Operation, List[Any]]]]: return self._changes @staticmethod def _python_to_list(value: Any) -> List[Any]: value_list = _python_to_list(value) if isinstance(value_list, NotLoaded): raise RuntimeError("Unexpected NotLoaded value in Changeset.") for value in value_list: if isinstance(value, NotLoaded): raise RuntimeError("Unexpected NotLoaded value in Changeset.") return value_list def _set(self, key: str, value: Any) -> None: old_value = self.get_value_as_list(key) value_list = self._python_to_list(value) if value_list != old_value: operation: Operation = ldap3.MODIFY_REPLACE if value is None or value == []: operation = ldap3.MODIFY_DELETE self._add_mod(key, operation, value_list, overwrite=True) self._replay_mod(key, operation, value_list) return
[docs] def force_add(self, key: str, value: Any) -> 'Changeset': value_list = self._python_to_list(value) clone = self.__copy__() clone._add_mod(key, ldap3.MODIFY_ADD, value_list) clone._replay_mod(key, ldap3.MODIFY_ADD, value_list) return clone
[docs] def force_replace(self, key: str, value: Any) -> 'Changeset': value_list = self._python_to_list(value) clone = self.__copy__() clone._add_mod(key, ldap3.MODIFY_REPLACE, value_list) clone._replay_mod(key, ldap3.MODIFY_REPLACE, value_list) return clone
[docs] def force_delete(self, key: str, value: Any) -> 'Changeset': value_list = self._python_to_list(value) clone = self.__copy__() clone._add_mod(key, ldap3.MODIFY_DELETE, value_list) clone._replay_mod(key, ldap3.MODIFY_DELETE, value_list) return clone
def _add_mod(self, key: str, operation: Operation, new_value_list: List[Any], overwrite=False) -> None: if any(isinstance(value, list) for value in new_value_list): raise RuntimeError("Got list inside a list.") key = self.fix_key(key) if key in self._changes: if overwrite: new_list = [] else: new_list = self._changes[key] else: new_list = [] new_list = new_list + [(operation, new_value_list)] self._changes = { **self._changes, key: new_list } def _replay_mod(self, key: str, operation: Operation, new_value_list: List[Any]): if any(isinstance(value, list) for value in new_value_list): raise RuntimeError("Got list inside a list.") key = self.fix_key(key) old_value_list = self.get_value_as_list(key) if operation == ldap3.MODIFY_ADD: assert isinstance(new_value_list, list) for value in new_value_list: if value not in old_value_list: old_value_list.append(value) if len(old_value_list) == 0: raise RuntimeError("Can't add 0 items.") elif operation == ldap3.MODIFY_REPLACE: old_value_list = new_value_list elif operation == ldap3.MODIFY_DELETE: if len(new_value_list) == 0: old_value_list = [] else: for value in new_value_list: old_value_list.remove(value) else: raise RuntimeError(f"Unknown LDAP operation {operation}.") self._dict[key] = old_value_list field = self._fields[key] try: field.validate(old_value_list) except tldap.exceptions.ValidationError as e: self._errors.append(f"{key}: {e}.") @property def is_valid(self) -> bool: return len(self._errors) == 0 @property def errors(self) -> List[str]: return self._errors @property def src(self) -> LdapObject: return self._src
[docs] class NotLoaded: """ Base class to represent a related field that has not been loaded. """ def __repr__(self): raise NotImplementedError()
[docs] def load(self, database: Optional[Database] = None) -> LdapObject or List[LdapObject]: raise NotImplementedError()
@staticmethod def _load_one(table: LdapObjectClass, key: str, value: str, database: Optional[Database] = None) -> LdapObject: q = Q(**{key: value}) result = get_one(table, q, database) return result @staticmethod def _load_list(table: LdapObjectClass, key: str, value: str, database: Optional[Database] = None) -> List[LdapObject]: q = Q(**{key: value}) return list(search(table, q, database))
[docs] class NotLoadedObject(NotLoaded): """ Represents a single object that needs to be loaded. """ def __init__(self, *, table: LdapObjectClass, key: str, value: str): self._table = table self._key = key self._value = value def __repr__(self): return f"<NotLoadedObject {self._table} {self._key}={self._value}>"
[docs] def load(self, database: Optional[Database] = None) -> LdapObject: return self._load_one(self._table, self._key, self._value, database)
[docs] class NotLoadedList(NotLoaded): """ Represents a list of objects that needs to be loaded via a single key. """ def __init__(self, *, table: LdapObjectClass, key: str, value: str): self._table = table self._key = key self._value = value def __repr__(self): return f"<NotLoadedList {self._table} {self._key}={self._value}>"
[docs] def load(self, database: Optional[Database] = None) -> List[LdapObject]: return self._load_list(self._table, self._key, self._value, database)
[docs] def changeset(python_data: LdapObject, d: dict) -> Changeset: """ Generate changes object for ldap object. """ table: LdapObjectClass = type(python_data) fields = table.get_fields() changes = Changeset(fields, src=python_data, d=d) return changes
def _db_to_python(db_data: dict, table: LdapObjectClass, dn: str) -> LdapObject: """ Convert a DbDate object to a LdapObject. """ fields = table.get_fields() python_data = table({ name: field.to_python(db_data[name]) for name, field in fields.items() if field.db_field }) python_data = python_data.merge({ 'dn': dn, }) return python_data def _python_to_mod_new(changes: Changeset) -> Dict[str, List[List[bytes]]]: """ Convert a LdapChanges object to a modlist for add operation. """ table: LdapObjectClass = type(changes.src) fields = table.get_fields() result: Dict[str, List[List[bytes]]] = {} for name, field in fields.items(): if field.db_field: try: value = field.to_db(changes.get_value_as_list(name)) if len(value) > 0: result[name] = value except ValidationError as e: raise ValidationError(f"{name}: {e}.") return result def _python_to_mod_modify(changes: Changeset) -> Dict[str, List[Tuple[Operation, List[bytes]]]]: """ Convert a LdapChanges object to a modlist for a modify operation. """ table: LdapObjectClass = type(changes.src) changes = changes.changes result: Dict[str, List[Tuple[Operation, List[bytes]]]] = {} for key, l in changes.items(): field = _get_field_by_name(table, key) if field.db_field: try: new_list = [ (operation, field.to_db(value)) for operation, value in l ] result[key] = new_list except ValidationError as e: raise ValidationError(f"{key}: {e}.") return result
[docs] def get_one(table: LdapObjectClass, query: Optional[Q] = None, database: Optional[Database] = None, base_dn: Optional[str] = None) -> LdapObject: """ Get exactly one result from the database or fail. """ results = search(table, query, database, base_dn) try: result = next(results) except StopIteration: raise ObjectDoesNotExist(f"Cannot find result for {query}.") try: next(results) raise MultipleObjectsReturned(f"Found multiple results for {query}.") except StopIteration: pass return result
[docs] def preload(python_data: LdapObject, database: Optional[Database] = None) -> LdapObject: """ Preload all NotLoaded fields in LdapObject. """ changes = {} # Load objects within lists. def preload_item(value: Any) -> Any: if isinstance(value, NotLoaded): return value.load(database) else: return value for name in python_data.keys(): value_list = python_data.get_as_list(name) # Check for errors. if isinstance(value_list, NotLoadedObject): raise RuntimeError(f"{name}: Unexpected NotLoadedObject outside list.") elif isinstance(value_list, NotLoadedList): value_list = value_list.load(database) else: if any(isinstance(v, NotLoadedList) for v in value_list): raise RuntimeError(f"{name}: Unexpected NotLoadedList in list.") elif any(isinstance(v, NotLoadedObject) for v in value_list): value_list = [preload_item(value) for value in value_list] else: value_list = None if value_list is not None: changes[name] = value_list return python_data.merge(changes)
[docs] def insert(python_data: LdapObject, database: Optional[Database] = None) -> LdapObject: """ Insert a new python_data object in the database. """ assert isinstance(python_data, LdapObject) table: LdapObjectClass = type(python_data) # ADD NEW ENTRY empty_data = table() changes = changeset(empty_data, python_data.to_dict()) return save(changes, database)
[docs] def save(changes: Changeset, database: Optional[Database] = None) -> LdapObject: """ Save all changes in a LdapChanges. """ assert isinstance(changes, Changeset) if not changes.is_valid: raise RuntimeError(f"Changeset has errors {changes.errors}.") database = get_database(database) connection = database.connection table = type(changes._src) # Run hooks on changes changes = table.on_save(changes, database) # src dn | changes dn | result | action # ---------------------------------------|-------- # None | None | error | error # None | provided | use changes dn | create # provided | None | use src dn | modify # provided | provided | error | error src_dn = changes.src.get_as_single('dn') if src_dn is None and 'dn' not in changes: raise RuntimeError("No DN was given") elif src_dn is None and 'dn' in changes: dn = changes.get_value_as_single('dn') assert dn is not None create = True elif src_dn is not None and 'dn' not in changes: dn = src_dn assert dn is not None create = False else: raise RuntimeError("Changes to DN are not supported.") assert dn is not None if create: # Add new entry mod_list = _python_to_mod_new(changes) try: connection.add(dn, mod_list) except ldap3.core.exceptions.LDAPEntryAlreadyExistsResult: raise ObjectAlreadyExists( "Object with dn %r already exists doing add" % dn) else: mod_list = _python_to_mod_modify(changes) if len(mod_list) > 0: try: connection.modify(dn, mod_list) except ldap3.core.exceptions.LDAPNoSuchObjectResult: raise ObjectDoesNotExist( "Object with dn %r doesn't already exist doing modify" % dn) # get new values python_data = table(changes.src.to_dict()) python_data = python_data.merge(changes.to_dict()) python_data = python_data.on_load(python_data, database) return python_data
[docs] def delete(python_data: LdapObject, database: Optional[Database] = None) -> None: """ Delete a LdapObject from the database. """ dn = python_data.get_as_single('dn') assert dn is not None database = get_database(database) connection = database.connection connection.delete(dn)
def _get_field_by_name(table: LdapObjectClass, name: str) -> tldap.fields.Field: """ Lookup a field by its name. """ fields = table.get_fields() return fields[name]
[docs] def rename(python_data: LdapObject, new_base_dn: str = None, database: Optional[Database] = None, **kwargs) -> LdapObject: """ Move/rename a LdapObject in the database. """ table = type(python_data) dn = python_data.get_as_single('dn') assert dn is not None database = get_database(database) connection = database.connection # extract key and value from kwargs if len(kwargs) == 1: name, value = list(kwargs.items())[0] # work out the new rdn of the object split_new_rdn = [[(name, value, 1)]] field = _get_field_by_name(table, name) assert field.db_field python_data = python_data.merge({ name: value, }) elif len(kwargs) == 0: split_new_rdn = [str2dn(dn)[0]] else: assert False new_rdn = dn2str(split_new_rdn) connection.rename( dn, new_rdn, new_base_dn, ) if new_base_dn is not None: split_base_dn = str2dn(new_base_dn) else: split_base_dn = str2dn(dn)[1:] tmp_list = [split_new_rdn[0]] tmp_list.extend(split_base_dn) new_dn = dn2str(tmp_list) python_data = python_data.merge({ 'dn': new_dn, }) return python_data