2025-12-25 upload

This commit is contained in:
“shengyudong”
2025-12-25 11:16:59 +08:00
commit 322ac74336
2241 changed files with 639966 additions and 0 deletions

View File

@@ -0,0 +1,119 @@
"""
"""
# Created on 2016.07.10
#
# Author: Giovanni Cannata
#
# Copyright 2016 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
try:
from queue import Queue
except ImportError: # Python 2
# noinspection PyUnresolvedReferences
from Queue import Queue
from io import StringIO
from os import linesep
from ..protocol.rfc2849 import decode_persistent_search_control
from ..strategy.asynchronous import AsyncStrategy
from ..core.exceptions import LDAPLDIFError
from ..utils.conv import prepare_for_stream
from ..protocol.rfc2849 import persistent_search_response_to_ldif, add_ldif_header
# noinspection PyProtectedMember
class AsyncStreamStrategy(AsyncStrategy):
"""
This strategy is asynchronous. It streams responses in a generator as they appear in the self._responses container
"""
def __init__(self, ldap_connection):
AsyncStrategy.__init__(self, ldap_connection)
self.can_stream = True
self.line_separator = linesep
self.all_base64 = False
self.stream = None
self.order = dict()
self._header_added = False
self.persistent_search_message_id = None
self.streaming = False
self.callback = None
if ldap_connection.pool_size:
self.events = Queue(ldap_connection.pool_size)
else:
self.events = Queue()
del self._requests # remove _requests dict from Async Strategy
def _start_listen(self):
AsyncStrategy._start_listen(self)
if self.streaming:
if not self.stream or (isinstance(self.stream, StringIO) and self.stream.closed):
self.set_stream(StringIO())
def _stop_listen(self):
AsyncStrategy._stop_listen(self)
if self.streaming:
self.stream.close()
def accumulate_stream(self, message_id, change):
if message_id == self.persistent_search_message_id:
with self.async_lock:
self._responses[message_id] = []
if self.streaming:
if not self._header_added and self.stream.tell() == 0:
header = add_ldif_header(['-'])[0]
self.stream.write(prepare_for_stream(header + self.line_separator + self.line_separator))
ldif_lines = persistent_search_response_to_ldif(change)
if self.stream and ldif_lines and not self.connection.closed:
fragment = self.line_separator.join(ldif_lines)
if not self._header_added and self.stream.tell() == 0:
self._header_added = True
header = add_ldif_header(['-'])[0]
self.stream.write(prepare_for_stream(header + self.line_separator + self.line_separator))
self.stream.write(prepare_for_stream(fragment + self.line_separator + self.line_separator))
else: # strategy is not streaming, events are added to a queue
notification = decode_persistent_search_control(change)
if notification:
change.update(notification)
del change['controls']['2.16.840.1.113730.3.4.7']
if not self.callback:
self.events.put(change)
else:
self.callback(change)
def get_stream(self):
if self.streaming:
return self.stream
return None
def set_stream(self, value):
error = False
try:
if not value.writable():
error = True
except (ValueError, AttributeError):
error = True
if error:
raise LDAPLDIFError('stream must be writable')
self.stream = value
self.streaming = True

View File

@@ -0,0 +1,292 @@
"""
"""
# Created on 2013.07.15
#
# Author: Giovanni Cannata
#
# Copyright 2013 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
from threading import Thread, Lock, Event
import socket
from .. import get_config_parameter, DIGEST_MD5
from ..core.exceptions import LDAPSSLConfigurationError, LDAPStartTLSError, LDAPOperationResult, LDAPSignatureVerificationFailedError
from ..strategy.base import BaseStrategy, RESPONSE_COMPLETE
from ..protocol.rfc4511 import LDAPMessage
from ..utils.log import log, log_enabled, format_ldap_message, ERROR, NETWORK, EXTENDED
from ..utils.asn1 import decoder, decode_message_fast
from ..protocol.sasl.digestMd5 import md5_hmac
# noinspection PyProtectedMember
class AsyncStrategy(BaseStrategy):
"""
This strategy is asynchronous. You send the request and get the messageId of the request sent
Receiving data from socket is managed in a separated thread in a blocking mode
Requests return an int value to indicate the messageId of the requested Operation
You get the response with get_response, it has a timeout to wait for response to appear
Connection.response will contain the whole LDAP response for the messageId requested in a dict form
Connection.request will contain the result LDAP message in a dict form
Response appear in strategy._responses dictionary
"""
# noinspection PyProtectedMember
class ReceiverSocketThread(Thread):
"""
The thread that actually manage the receiver socket
"""
def __init__(self, ldap_connection):
Thread.__init__(self)
self.connection = ldap_connection
self.socket_size = get_config_parameter('SOCKET_SIZE')
def run(self):
"""
Waits for data on socket, computes the length of the message and waits for enough bytes to decode the message
Message are appended to strategy._responses
"""
unprocessed = b''
get_more_data = True
listen = True
data = b''
sasl_total_bytes_recieved = 0
sasl_received_data = b'' # used to verify the signature, typo GC
sasl_next_packet = b''
sasl_buffer_length = -1
# sasl_signature = b'' # not needed here GC
# sasl_sec_num = b'' # used to verify the signature, not needed here GC
while listen:
if get_more_data:
try:
data = self.connection.socket.recv(self.socket_size)
except (OSError, socket.error, AttributeError):
if self.connection.receive_timeout: # a receive timeout has been detected - keep kistening on the socket
continue
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', str(e), self.connection)
raise # unexpected exception - re-raise
if len(data) > 0:
# If we are using DIGEST-MD5 and LDAP signing is set : verify & remove the signature from the message
if self.connection.sasl_mechanism == DIGEST_MD5 and self.connection._digest_md5_kis and not self.connection.sasl_in_progress:
data = sasl_next_packet + data
if sasl_received_data == b'' or sasl_next_packet:
# Remove the sizeOf(encoded_message + signature + 0x0001 + secNum) from data.
sasl_buffer_length = int.from_bytes(data[0:4], "big")
data = data[4:]
sasl_next_packet = b''
sasl_total_bytes_recieved += len(data)
sasl_received_data += data
if sasl_total_bytes_recieved >= sasl_buffer_length:
# When the LDAP response is splitted accross multiple TCP packets, the SASL buffer length is equal to the MTU of each packet..Which is usually not equal to self.socket_size
# This means that the end of one SASL packet/beginning of one other....could be located in the middle of data
# We are using "sasl_received_data" instead of "data" & "unprocessed" for this reason
# structure of messages when LDAP signing is enabled : sizeOf(encoded_message + signature + 0x0001 + secNum) + encoded_message + signature + 0x0001 + secNum
sasl_signature = sasl_received_data[sasl_buffer_length - 16:sasl_buffer_length - 6]
sasl_sec_num = sasl_received_data[sasl_buffer_length - 4:sasl_buffer_length]
sasl_next_packet = sasl_received_data[sasl_buffer_length:] # the last "data" variable may contain another sasl packet. We'll process it at the next iteration.
sasl_received_data = sasl_received_data[:sasl_buffer_length - 16] # remove signature + 0x0001 + secNum + the next packet if any, from sasl_received_data
kis = self.connection._digest_md5_kis # renamed to lowercase GC
calculated_signature = bytes.fromhex(md5_hmac(kis, sasl_sec_num + sasl_received_data)[0:20])
if sasl_signature != calculated_signature:
raise LDAPSignatureVerificationFailedError("Signature verification failed for the recieved LDAP message number " + str(int.from_bytes(sasl_sec_num, 'big')) + ". Expected signature " + calculated_signature.hex() + " but got " + sasl_signature.hex() + ".")
sasl_total_bytes_recieved = 0
unprocessed += sasl_received_data
sasl_received_data = b''
else:
unprocessed += data
data = b''
else:
listen = False
length = BaseStrategy.compute_ldap_message_size(unprocessed)
if length == -1 or len(unprocessed) < length:
get_more_data = True
elif len(unprocessed) >= length: # add message to message list
if self.connection.usage:
self.connection._usage.update_received_message(length)
if log_enabled(NETWORK):
log(NETWORK, 'received %d bytes via <%s>', length, self.connection)
if self.connection.fast_decoder:
ldap_resp = decode_message_fast(unprocessed[:length])
dict_response = self.connection.strategy.decode_response_fast(ldap_resp)
else:
ldap_resp = decoder.decode(unprocessed[:length], asn1Spec=LDAPMessage())[0]
dict_response = self.connection.strategy.decode_response(ldap_resp)
message_id = int(ldap_resp['messageID'])
if log_enabled(NETWORK):
log(NETWORK, 'received 1 ldap message via <%s>', self.connection)
if log_enabled(EXTENDED):
log(EXTENDED, 'ldap message received via <%s>:%s', self.connection, format_ldap_message(ldap_resp, '<<'))
if dict_response['type'] == 'extendedResp' and (dict_response['responseName'] == '1.3.6.1.4.1.1466.20037' or hasattr(self.connection, '_awaiting_for_async_start_tls')):
if dict_response['result'] == 0: # StartTls in progress
if self.connection.server.tls:
self.connection.server.tls._start_tls(self.connection)
else:
self.connection.last_error = 'no Tls object defined in Server'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSSLConfigurationError(self.connection.last_error)
else:
self.connection.last_error = 'asynchronous StartTls failed'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPStartTLSError(self.connection.last_error)
del self.connection._awaiting_for_async_start_tls
if message_id != 0: # 0 is reserved for 'Unsolicited Notification' from server as per RFC4511 (paragraph 4.4)
with self.connection.strategy.async_lock:
if message_id in self.connection.strategy._responses:
self.connection.strategy._responses[message_id].append(dict_response)
else:
self.connection.strategy._responses[message_id] = [dict_response]
if dict_response['type'] not in ['searchResEntry', 'searchResRef', 'intermediateResponse']:
self.connection.strategy._responses[message_id].append(RESPONSE_COMPLETE)
self.connection.strategy.set_event_for_message(message_id)
if self.connection.strategy.can_stream: # for AsyncStreamStrategy, used for PersistentSearch
self.connection.strategy.accumulate_stream(message_id, dict_response)
unprocessed = unprocessed[length:]
get_more_data = False if unprocessed else True
listen = True if self.connection.listening or unprocessed else False
else: # Unsolicited Notification
if dict_response['responseName'] == '1.3.6.1.4.1.1466.20036': # Notice of Disconnection as per RFC4511 (paragraph 4.4.1)
listen = False
else:
self.connection.last_error = 'unknown unsolicited notification from server'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPStartTLSError(self.connection.last_error)
self.connection.strategy.close()
def __init__(self, ldap_connection):
BaseStrategy.__init__(self, ldap_connection)
self.sync = False
self.no_real_dsa = False
self.pooled = False
self._responses = None
self._requests = None
self.can_stream = False
self.receiver = None
self.async_lock = Lock()
self.event_lock = Lock()
self._events = {}
def open(self, reset_usage=True, read_server_info=True):
"""
Open connection and start listen on the socket in a different thread
"""
with self.connection.connection_lock:
self._responses = dict()
self._requests = dict()
BaseStrategy.open(self, reset_usage, read_server_info)
if read_server_info:
try:
self.connection.refresh_server_info()
except LDAPOperationResult: # catch errors from server if raise_exception = True
self.connection.server._dsa_info = None
self.connection.server._schema_info = None
def close(self):
"""
Close connection and stop socket thread
"""
with self.connection.connection_lock:
BaseStrategy.close(self)
def _add_event_for_message(self, message_id):
with self.event_lock:
# Should have the check here because the receiver thread may has created it
if message_id not in self._events:
self._events[message_id] = Event()
def set_event_for_message(self, message_id):
with self.event_lock:
# The receiver thread may receive the response before the sender set the event for the message_id,
# so we have to check if the event exists
if message_id not in self._events:
self._events[message_id] = Event()
self._events[message_id].set()
def _get_event_for_message(self, message_id):
with self.event_lock:
if message_id not in self._events:
raise RuntimeError('Event for message[{}] should have been created before accessing'.format(message_id))
return self._events[message_id]
def post_send_search(self, message_id):
"""
Clears connection.response and returns messageId
"""
self.connection.response = None
self.connection.request = None
self.connection.result = None
self._add_event_for_message(message_id)
return message_id
def post_send_single_response(self, message_id):
"""
Clears connection.response and returns messageId.
"""
self.connection.response = None
self.connection.request = None
self.connection.result = None
self._add_event_for_message(message_id)
return message_id
def _start_listen(self):
"""
Start thread in daemon mode
"""
if not self.connection.listening:
self.receiver = AsyncStrategy.ReceiverSocketThread(self.connection)
self.connection.listening = True
self.receiver.daemon = True
self.receiver.start()
def _get_response(self, message_id, timeout):
"""
Performs the capture of LDAP response for this strategy
The response is only complete after the event been set
"""
event = self._get_event_for_message(message_id)
flag = event.wait(timeout)
if not flag:
# timeout
return None
# In this stage we could ensure the response is already there
self._events.pop(message_id)
with self.async_lock:
return self._responses.pop(message_id)
def receiving(self):
raise NotImplementedError
def get_stream(self):
raise NotImplementedError
def set_stream(self, value):
raise NotImplementedError

View File

@@ -0,0 +1,925 @@
"""
"""
# Created on 2013.07.15
#
# Author: Giovanni Cannata
#
# Copyright 2013 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more dectails.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
import socket
try: # try to discover if unix sockets are available for LDAP over IPC (ldapi:// scheme)
# noinspection PyUnresolvedReferences
from socket import AF_UNIX
unix_socket_available = True
except ImportError:
unix_socket_available = False
from struct import pack
from platform import system
from random import choice
from .. import SYNC, ANONYMOUS, get_config_parameter, BASE, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, NO_ATTRIBUTES, DIGEST_MD5
from ..core.results import DO_NOT_RAISE_EXCEPTIONS, RESULT_REFERRAL
from ..core.exceptions import LDAPOperationResult, LDAPSASLBindInProgressError, LDAPSocketOpenError, LDAPSessionTerminatedByServerError,\
LDAPUnknownResponseError, LDAPUnknownRequestError, LDAPReferralError, communication_exception_factory, LDAPStartTLSError, \
LDAPSocketSendError, LDAPExceptionError, LDAPControlError, LDAPResponseTimeoutError, LDAPTransactionError
from ..utils.uri import parse_uri
from ..protocol.rfc4511 import LDAPMessage, ProtocolOp, MessageID, SearchResultEntry
from ..operation.add import add_response_to_dict, add_request_to_dict
from ..operation.modify import modify_request_to_dict, modify_response_to_dict
from ..operation.search import search_result_reference_response_to_dict, search_result_done_response_to_dict,\
search_result_entry_response_to_dict, search_request_to_dict, search_result_entry_response_to_dict_fast,\
search_result_reference_response_to_dict_fast, attributes_to_dict, attributes_to_dict_fast
from ..operation.bind import bind_response_to_dict, bind_request_to_dict, sicily_bind_response_to_dict, bind_response_to_dict_fast, \
sicily_bind_response_to_dict_fast
from ..operation.compare import compare_response_to_dict, compare_request_to_dict
from ..operation.extended import extended_request_to_dict, extended_response_to_dict, intermediate_response_to_dict, extended_response_to_dict_fast, intermediate_response_to_dict_fast
from ..core.server import Server
from ..operation.modifyDn import modify_dn_request_to_dict, modify_dn_response_to_dict
from ..operation.delete import delete_response_to_dict, delete_request_to_dict
from ..protocol.convert import prepare_changes_for_request, build_controls_list
from ..operation.abandon import abandon_request_to_dict
from ..core.tls import Tls
from ..protocol.oid import Oids
from ..protocol.rfc2696 import RealSearchControlValue
from ..protocol.microsoft import DirSyncControlResponseValue
from ..utils.log import log, log_enabled, ERROR, BASIC, PROTOCOL, NETWORK, EXTENDED, format_ldap_message
from ..utils.asn1 import encode, decoder, ldap_result_to_dict_fast, decode_sequence
from ..utils.conv import to_unicode
from ..protocol.sasl.digestMd5 import md5_h, md5_hmac
SESSION_TERMINATED_BY_SERVER = 'TERMINATED_BY_SERVER'
TRANSACTION_ERROR = 'TRANSACTION_ERROR'
RESPONSE_COMPLETE = 'RESPONSE_FROM_SERVER_COMPLETE'
# noinspection PyProtectedMember
class BaseStrategy(object):
"""
Base class for connection strategy
"""
def __init__(self, ldap_connection):
self.connection = ldap_connection
self._outstanding = None
self._referrals = []
self.sync = None # indicates a synchronous connection
self.no_real_dsa = None # indicates a connection to a fake LDAP server
self.pooled = None # Indicates a connection with a connection pool
self.can_stream = None # indicates if a strategy keeps a stream of responses (i.e. LdifProducer can accumulate responses with a single header). Stream must be initialized and closed in _start_listen() and _stop_listen()
self.referral_cache = {}
self.thread_safe = False # Indicates that connection can be used in a multithread application
if log_enabled(BASIC):
log(BASIC, 'instantiated <%s>: <%s>', self.__class__.__name__, self)
def __str__(self):
s = [
str(self.connection) if self.connection else 'None',
'sync' if self.sync else 'async',
'no real DSA' if self.no_real_dsa else 'real DSA',
'pooled' if self.pooled else 'not pooled',
'can stream output' if self.can_stream else 'cannot stream output',
]
return ' - '.join(s)
def open(self, reset_usage=True, read_server_info=True):
"""
Open a socket to a server. Choose a server from the server pool if available
"""
if log_enabled(NETWORK):
log(NETWORK, 'opening connection for <%s>', self.connection)
if self.connection.lazy and not self.connection._executing_deferred:
self.connection._deferred_open = True
self.connection.closed = False
if log_enabled(NETWORK):
log(NETWORK, 'deferring open connection for <%s>', self.connection)
else:
if not self.connection.closed and not self.connection._executing_deferred: # try to close connection if still open
self.close()
self._outstanding = dict()
if self.connection.usage:
if reset_usage or not self.connection._usage.initial_connection_start_time:
self.connection._usage.start()
if self.connection.server_pool:
new_server = self.connection.server_pool.get_server(self.connection) # get a server from the server_pool if available
if self.connection.server != new_server:
self.connection.server = new_server
if self.connection.usage:
self.connection._usage.servers_from_pool += 1
exception_history = []
if not self.no_real_dsa: # tries to connect to a real server
for candidate_address in self.connection.server.candidate_addresses():
try:
if log_enabled(BASIC):
log(BASIC, 'try to open candidate address %s', candidate_address[:-2])
self._open_socket(candidate_address, self.connection.server.ssl, unix_socket=self.connection.server.ipc)
self.connection.server.current_address = candidate_address
self.connection.server.update_availability(candidate_address, True)
break
except Exception as e:
self.connection.server.update_availability(candidate_address, False)
# exception_history.append((datetime.now(), exc_type, exc_value, candidate_address[4]))
exception_history.append((type(e)(str(e)), candidate_address[4]))
if not self.connection.server.current_address and exception_history:
if len(exception_history) == 1: # only one exception, reraise
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', str(exception_history[0][0]) + ' ' + str((exception_history[0][1])), self.connection)
raise exception_history[0][0]
else:
if log_enabled(ERROR):
log(ERROR, 'unable to open socket for <%s>', self.connection)
raise LDAPSocketOpenError('unable to open socket', exception_history)
elif not self.connection.server.current_address:
if log_enabled(ERROR):
log(ERROR, 'invalid server address for <%s>', self.connection)
raise LDAPSocketOpenError('invalid server address')
self.connection._deferred_open = False
self._start_listen()
if log_enabled(NETWORK):
log(NETWORK, 'connection open for <%s>', self.connection)
def close(self):
"""
Close connection
"""
if log_enabled(NETWORK):
log(NETWORK, 'closing connection for <%s>', self.connection)
if self.connection.lazy and not self.connection._executing_deferred and (self.connection._deferred_bind or self.connection._deferred_open):
self.connection.listening = False
self.connection.closed = True
if log_enabled(NETWORK):
log(NETWORK, 'deferred connection closed for <%s>', self.connection)
else:
if not self.connection.closed:
self._stop_listen()
if not self. no_real_dsa:
self._close_socket()
if log_enabled(NETWORK):
log(NETWORK, 'connection closed for <%s>', self.connection)
self.connection.bound = False
self.connection.request = None
self.connection.response = None
self.connection.tls_started = False
self._outstanding = None
self._referrals = []
if not self.connection.strategy.no_real_dsa:
self.connection.server.current_address = None
if self.connection.usage:
self.connection._usage.stop()
def _open_socket(self, address, use_ssl=False, unix_socket=False):
"""
Tries to open and connect a socket to a Server
raise LDAPExceptionError if unable to open or connect socket
"""
try:
self.connection.socket = socket.socket(*address[:3])
except Exception as e:
self.connection.last_error = 'socket creation error: ' + str(e)
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
# raise communication_exception_factory(LDAPSocketOpenError, exc)(self.connection.last_error)
raise communication_exception_factory(LDAPSocketOpenError, type(e)(str(e)))(self.connection.last_error)
# Try to bind the socket locally before connecting to the remote address
# We go through our connection's source ports and try to bind our socket to our connection's source address
# with them.
# If no source address or ports were specified, this will have the same success/fail result as if we
# tried to connect to the remote server without binding locally first.
# This is actually a little bit better, as it lets us distinguish the case of "issue binding the socket
# locally" from "remote server is unavailable" with more clarity, though this will only really be an
# issue when no source address/port is specified if the system checking server availability is running
# as a very unprivileged user.
last_bind_exc = None
if unix_socket_available and self.connection.socket.family != socket.AF_UNIX:
socket_bind_succeeded = False
for source_port in self.connection.source_port_list:
try:
self.connection.socket.bind((self.connection.source_address, source_port))
socket_bind_succeeded = True
break
except Exception as bind_ex:
last_bind_exc = bind_ex
# we'll always end up logging at error level if we cannot bind any ports to the address locally.
# but if some work and some don't you probably don't want the ones that don't at ERROR level
if log_enabled(NETWORK):
log(NETWORK, 'Unable to bind to local address <%s> with source port <%s> due to <%s>',
self.connection.source_address, source_port, bind_ex)
if not socket_bind_succeeded:
self.connection.last_error = 'socket connection error while locally binding: ' + str(last_bind_exc)
if log_enabled(ERROR):
log(ERROR, 'Unable to locally bind to local address <%s> with any of the source ports <%s> for connection <%s due to <%s>',
self.connection.source_address, self.connection.source_port_list, self.connection, last_bind_exc)
raise communication_exception_factory(LDAPSocketOpenError, type(last_bind_exc)(str(last_bind_exc)))(last_bind_exc)
try: # set socket timeout for opening connection
if self.connection.server.connect_timeout:
self.connection.socket.settimeout(self.connection.server.connect_timeout)
self.connection.socket.connect(address[4])
except socket.error as e:
self.connection.last_error = 'socket connection error while opening: ' + str(e)
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
# raise communication_exception_factory(LDAPSocketOpenError, exc)(self.connection.last_error)
raise communication_exception_factory(LDAPSocketOpenError, type(e)(str(e)))(self.connection.last_error)
# Set connection recv timeout (must be set after connect,
# because socket.settimeout() affects both, connect() as
# well as recv(). Set it before tls.wrap_socket() because
# the recv timeout should take effect during the TLS
# handshake.
if self.connection.receive_timeout is not None:
try: # set receive timeout for the connection socket
self.connection.socket.settimeout(self.connection.receive_timeout)
if system().lower() == 'windows':
self.connection.socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVTIMEO, int(1000 * self.connection.receive_timeout))
else:
self.connection.socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVTIMEO, pack('LL', self.connection.receive_timeout, 0))
except socket.error as e:
self.connection.last_error = 'unable to set receive timeout for socket connection: ' + str(e)
# if exc:
# if log_enabled(ERROR):
# log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
# raise communication_exception_factory(LDAPSocketOpenError, exc)(self.connection.last_error)
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise communication_exception_factory(LDAPSocketOpenError, type(e)(str(e)))(self.connection.last_error)
if use_ssl:
try:
self.connection.server.tls.wrap_socket(self.connection, do_handshake=True)
if self.connection.usage:
self.connection._usage.wrapped_sockets += 1
except Exception as e:
self.connection.last_error = 'socket ssl wrapping error: ' + str(e)
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise communication_exception_factory(LDAPSocketOpenError, type(e)(str(e)))(self.connection.last_error)
if self.connection.usage:
self.connection._usage.open_sockets += 1
self.connection.closed = False
def _close_socket(self):
"""
Try to close a socket
don't raise exception if unable to close socket, assume socket is already closed
"""
try:
self.connection.socket.shutdown(socket.SHUT_RDWR)
except Exception:
pass
try:
self.connection.socket.close()
except Exception:
pass
self.connection.socket = None
self.connection.closed = True
if self.connection.usage:
self.connection._usage.closed_sockets += 1
def _stop_listen(self):
self.connection.listening = False
def send(self, message_type, request, controls=None):
"""
Send an LDAP message
Returns the message_id
"""
self.connection.request = None
if self.connection.listening:
if self.connection.sasl_in_progress and message_type not in ['bindRequest']: # as per RFC4511 (4.2.1)
self.connection.last_error = 'cannot send operation requests while SASL bind is in progress'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSASLBindInProgressError(self.connection.last_error)
message_id = self.connection.server.next_message_id()
ldap_message = LDAPMessage()
ldap_message['messageID'] = MessageID(message_id)
ldap_message['protocolOp'] = ProtocolOp().setComponentByName(message_type, request)
message_controls = build_controls_list(controls)
if message_controls is not None:
ldap_message['controls'] = message_controls
self.connection.request = BaseStrategy.decode_request(message_type, request, controls)
self._outstanding[message_id] = self.connection.request
self.sending(ldap_message)
else:
self.connection.last_error = 'unable to send message, socket is not open'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSocketOpenError(self.connection.last_error)
return message_id
def get_response(self, message_id, timeout=None, get_request=False):
"""
Get response LDAP messages
Responses are returned by the underlying connection strategy
Check if message_id LDAP message is still outstanding and wait for timeout to see if it appears in _get_response
Result is stored in connection.result
Responses without result is stored in connection.response
A tuple (responses, result) is returned
"""
if timeout is None:
timeout = get_config_parameter('RESPONSE_WAITING_TIMEOUT')
response = None
result = None
# request = None
if self._outstanding and message_id in self._outstanding:
responses = self._get_response(message_id, timeout)
if not responses:
if log_enabled(ERROR):
log(ERROR, 'socket timeout, no response from server for <%s>', self.connection)
raise LDAPResponseTimeoutError('no response from server')
if responses == SESSION_TERMINATED_BY_SERVER:
try: # try to close the session but don't raise any error if server has already closed the session
self.close()
except (socket.error, LDAPExceptionError):
pass
self.connection.last_error = 'session terminated by server'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSessionTerminatedByServerError(self.connection.last_error)
elif responses == TRANSACTION_ERROR: # Novell LDAP Transaction unsolicited notification
self.connection.last_error = 'transaction error'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPTransactionError(self.connection.last_error)
# if referral in response opens a new connection to resolve referrals if requested
if responses[-2]['result'] == RESULT_REFERRAL:
if self.connection.usage:
self.connection._usage.referrals_received += 1
if self.connection.auto_referrals:
ref_response, ref_result = self.do_operation_on_referral(self._outstanding[message_id], responses[-2]['referrals'])
if ref_response is not None:
responses = ref_response + [ref_result]
responses.append(RESPONSE_COMPLETE)
elif ref_result is not None:
responses = [ref_result, RESPONSE_COMPLETE]
self._referrals = []
if responses:
result = responses[-2]
response = responses[:-2]
self.connection.result = None
self.connection.response = None
if self.connection.raise_exceptions and result and result['result'] not in DO_NOT_RAISE_EXCEPTIONS:
if log_enabled(PROTOCOL):
log(PROTOCOL, 'operation result <%s> for <%s>', result, self.connection)
self._outstanding.pop(message_id)
self.connection.result = result.copy()
raise LDAPOperationResult(result=result['result'], description=result['description'], dn=result['dn'], message=result['message'], response_type=result['type'])
# checks if any response has a range tag
# self._auto_range_searching is set as a flag to avoid recursive searches
if self.connection.auto_range and not hasattr(self, '_auto_range_searching') and any((True for resp in response if 'raw_attributes' in resp for name in resp['raw_attributes'] if ';range=' in name)):
self._auto_range_searching = result.copy()
temp_response = response[:] # copy
if self.do_search_on_auto_range(self._outstanding[message_id], response):
for resp in temp_response:
if resp['type'] == 'searchResEntry':
keys = [key for key in resp['raw_attributes'] if ';range=' in key]
for key in keys:
del resp['raw_attributes'][key]
del resp['attributes'][key]
response = temp_response
result = self._auto_range_searching
del self._auto_range_searching
if self.connection.empty_attributes:
for entry in response:
if entry['type'] == 'searchResEntry':
for attribute_type in self._outstanding[message_id]['attributes']:
if attribute_type not in entry['raw_attributes'] and attribute_type not in (ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, NO_ATTRIBUTES):
entry['raw_attributes'][attribute_type] = list()
entry['attributes'][attribute_type] = list()
if log_enabled(PROTOCOL):
log(PROTOCOL, 'attribute set to empty list for missing attribute <%s> in <%s>', attribute_type, self)
if not self.connection.auto_range:
attrs_to_remove = []
# removes original empty attribute in case a range tag is returned
for attribute_type in entry['attributes']:
if ';range' in attribute_type.lower():
orig_attr, _, _ = attribute_type.partition(';')
attrs_to_remove.append(orig_attr)
for attribute_type in attrs_to_remove:
if log_enabled(PROTOCOL):
log(PROTOCOL, 'attribute type <%s> removed in response because of same attribute returned as range by the server in <%s>', attribute_type, self)
del entry['raw_attributes'][attribute_type]
del entry['attributes'][attribute_type]
request = self._outstanding.pop(message_id)
else:
if log_enabled(ERROR):
log(ERROR, 'message id not in outstanding queue for <%s>', self.connection)
raise(LDAPResponseTimeoutError('message id not in outstanding queue'))
if get_request:
return response, result, request
else:
return response, result
@staticmethod
def compute_ldap_message_size(data):
"""
Compute LDAP Message size according to BER definite length rules
Returns -1 if too few data to compute message length
"""
if isinstance(data, str): # fix for Python 2, data is string not bytes
data = bytearray(data) # Python 2 bytearray is equivalent to Python 3 bytes
ret_value = -1
if len(data) > 2:
if data[1] <= 127: # BER definite length - short form. Highest bit of byte 1 is 0, message length is in the last 7 bits - Value can be up to 127 bytes long
ret_value = data[1] + 2
else: # BER definite length - long form. Highest bit of byte 1 is 1, last 7 bits counts the number of following octets containing the value length
bytes_length = data[1] - 128
if len(data) >= bytes_length + 2:
value_length = 0
cont = bytes_length
for byte in data[2:2 + bytes_length]:
cont -= 1
value_length += byte * (256 ** cont)
ret_value = value_length + 2 + bytes_length
return ret_value
def decode_response(self, ldap_message):
"""
Convert received LDAPMessage to a dict
"""
message_type = ldap_message.getComponentByName('protocolOp').getName()
component = ldap_message['protocolOp'].getComponent()
controls = ldap_message['controls'] if ldap_message['controls'].hasValue() else None
if message_type == 'bindResponse':
if not bytes(component['matchedDN']).startswith(b'NTLM'): # patch for microsoft ntlm authentication
result = bind_response_to_dict(component)
else:
result = sicily_bind_response_to_dict(component)
elif message_type == 'searchResEntry':
result = search_result_entry_response_to_dict(component, self.connection.server.schema, self.connection.server.custom_formatter, self.connection.check_names)
elif message_type == 'searchResDone':
result = search_result_done_response_to_dict(component)
elif message_type == 'searchResRef':
result = search_result_reference_response_to_dict(component)
elif message_type == 'modifyResponse':
result = modify_response_to_dict(component)
elif message_type == 'addResponse':
result = add_response_to_dict(component)
elif message_type == 'delResponse':
result = delete_response_to_dict(component)
elif message_type == 'modDNResponse':
result = modify_dn_response_to_dict(component)
elif message_type == 'compareResponse':
result = compare_response_to_dict(component)
elif message_type == 'extendedResp':
result = extended_response_to_dict(component)
elif message_type == 'intermediateResponse':
result = intermediate_response_to_dict(component)
else:
if log_enabled(ERROR):
log(ERROR, 'unknown response <%s> for <%s>', message_type, self.connection)
raise LDAPUnknownResponseError('unknown response')
result['type'] = message_type
if controls:
result['controls'] = dict()
for control in controls:
decoded_control = self.decode_control(control)
result['controls'][decoded_control[0]] = decoded_control[1]
return result
def decode_response_fast(self, ldap_message):
"""
Convert received LDAPMessage from fast ber decoder to a dict
"""
if ldap_message['protocolOp'] == 1: # bindResponse
if not ldap_message['payload'][1][3].startswith(b'NTLM'): # patch for microsoft ntlm authentication
result = bind_response_to_dict_fast(ldap_message['payload'])
else:
result = sicily_bind_response_to_dict_fast(ldap_message['payload'])
result['type'] = 'bindResponse'
elif ldap_message['protocolOp'] == 4: # searchResEntry'
result = search_result_entry_response_to_dict_fast(ldap_message['payload'], self.connection.server.schema, self.connection.server.custom_formatter, self.connection.check_names)
result['type'] = 'searchResEntry'
elif ldap_message['protocolOp'] == 5: # searchResDone
result = ldap_result_to_dict_fast(ldap_message['payload'])
result['type'] = 'searchResDone'
elif ldap_message['protocolOp'] == 19: # searchResRef
result = search_result_reference_response_to_dict_fast(ldap_message['payload'])
result['type'] = 'searchResRef'
elif ldap_message['protocolOp'] == 7: # modifyResponse
result = ldap_result_to_dict_fast(ldap_message['payload'])
result['type'] = 'modifyResponse'
elif ldap_message['protocolOp'] == 9: # addResponse
result = ldap_result_to_dict_fast(ldap_message['payload'])
result['type'] = 'addResponse'
elif ldap_message['protocolOp'] == 11: # delResponse
result = ldap_result_to_dict_fast(ldap_message['payload'])
result['type'] = 'delResponse'
elif ldap_message['protocolOp'] == 13: # modDNResponse
result = ldap_result_to_dict_fast(ldap_message['payload'])
result['type'] = 'modDNResponse'
elif ldap_message['protocolOp'] == 15: # compareResponse
result = ldap_result_to_dict_fast(ldap_message['payload'])
result['type'] = 'compareResponse'
elif ldap_message['protocolOp'] == 24: # extendedResp
result = extended_response_to_dict_fast(ldap_message['payload'])
result['type'] = 'extendedResp'
elif ldap_message['protocolOp'] == 25: # intermediateResponse
result = intermediate_response_to_dict_fast(ldap_message['payload'])
result['type'] = 'intermediateResponse'
else:
if log_enabled(ERROR):
log(ERROR, 'unknown response <%s> for <%s>', ldap_message['protocolOp'], self.connection)
raise LDAPUnknownResponseError('unknown response')
if ldap_message['controls']:
result['controls'] = dict()
for control in ldap_message['controls']:
decoded_control = self.decode_control_fast(control[3])
result['controls'][decoded_control[0]] = decoded_control[1]
return result
@staticmethod
def decode_control(control):
"""
decode control, return a 2-element tuple where the first element is the control oid
and the second element is a dictionary with description (from Oids), criticality and decoded control value
"""
control_type = str(control['controlType'])
criticality = bool(control['criticality'])
control_value = bytes(control['controlValue'])
unprocessed = None
if control_type == '1.2.840.113556.1.4.319': # simple paged search as per RFC2696
control_resp, unprocessed = decoder.decode(control_value, asn1Spec=RealSearchControlValue())
control_value = dict()
control_value['size'] = int(control_resp['size'])
control_value['cookie'] = bytes(control_resp['cookie'])
elif control_type == '1.2.840.113556.1.4.841': # DirSync AD
control_resp, unprocessed = decoder.decode(control_value, asn1Spec=DirSyncControlResponseValue())
control_value = dict()
control_value['more_results'] = bool(control_resp['MoreResults']) # more_result if nonzero
control_value['cookie'] = bytes(control_resp['CookieServer'])
elif control_type == '1.3.6.1.1.13.1' or control_type == '1.3.6.1.1.13.2': # Pre-Read control, Post-Read Control as per RFC 4527
control_resp, unprocessed = decoder.decode(control_value, asn1Spec=SearchResultEntry())
control_value = dict()
control_value['result'] = attributes_to_dict(control_resp['attributes'])
if unprocessed:
if log_enabled(ERROR):
log(ERROR, 'unprocessed control response in substrate')
raise LDAPControlError('unprocessed control response in substrate')
return control_type, {'description': Oids.get(control_type, ''), 'criticality': criticality, 'value': control_value}
@staticmethod
def decode_control_fast(control, from_server=True):
"""
decode control, return a 2-element tuple where the first element is the control oid
and the second element is a dictionary with description (from Oids), criticality and decoded control value
"""
control_type = str(to_unicode(control[0][3], from_server=from_server))
criticality = False
control_value = None
for r in control[1:]:
if r[2] == 4: # controlValue
control_value = r[3]
else:
criticality = False if r[3] == 0 else True # criticality (booleand default to False)
if control_type == '1.2.840.113556.1.4.319': # simple paged search as per RFC2696
control_resp = decode_sequence(control_value, 0, len(control_value))
control_value = dict()
control_value['size'] = int(control_resp[0][3][0][3])
control_value['cookie'] = bytes(control_resp[0][3][1][3])
elif control_type == '1.2.840.113556.1.4.841': # DirSync AD
control_resp = decode_sequence(control_value, 0, len(control_value))
control_value = dict()
control_value['more_results'] = True if control_resp[0][3][0][3] else False # more_result if nonzero
control_value['cookie'] = control_resp[0][3][2][3]
elif control_type == '1.3.6.1.1.13.1' or control_type == '1.3.6.1.1.13.2': # Pre-Read control, Post-Read Control as per RFC 4527
control_resp = decode_sequence(control_value, 0, len(control_value))
control_value = dict()
control_value['result'] = attributes_to_dict_fast(control_resp[0][3][1][3])
return control_type, {'description': Oids.get(control_type, ''), 'criticality': criticality, 'value': control_value}
@staticmethod
def decode_request(message_type, component, controls=None):
# message_type = ldap_message.getComponentByName('protocolOp').getName()
# component = ldap_message['protocolOp'].getComponent()
if message_type == 'bindRequest':
result = bind_request_to_dict(component)
elif message_type == 'unbindRequest':
result = dict()
elif message_type == 'addRequest':
result = add_request_to_dict(component)
elif message_type == 'compareRequest':
result = compare_request_to_dict(component)
elif message_type == 'delRequest':
result = delete_request_to_dict(component)
elif message_type == 'extendedReq':
result = extended_request_to_dict(component)
elif message_type == 'modifyRequest':
result = modify_request_to_dict(component)
elif message_type == 'modDNRequest':
result = modify_dn_request_to_dict(component)
elif message_type == 'searchRequest':
result = search_request_to_dict(component)
elif message_type == 'abandonRequest':
result = abandon_request_to_dict(component)
else:
if log_enabled(ERROR):
log(ERROR, 'unknown request <%s>', message_type)
raise LDAPUnknownRequestError('unknown request')
result['type'] = message_type
result['controls'] = controls
return result
def valid_referral_list(self, referrals):
referral_list = []
for referral in referrals:
candidate_referral = parse_uri(referral)
if candidate_referral:
for ref_host in self.connection.server.allowed_referral_hosts:
if ref_host[0] == candidate_referral['host'] or ref_host[0] == '*':
if candidate_referral['host'] not in self._referrals:
candidate_referral['anonymousBindOnly'] = not ref_host[1]
referral_list.append(candidate_referral)
break
return referral_list
def do_next_range_search(self, request, response, attr_name):
done = False
current_response = response
while not done:
attr_type, _, returned_range = attr_name.partition(';range=')
_, _, high_range = returned_range.partition('-')
response['raw_attributes'][attr_type] += current_response['raw_attributes'][attr_name]
response['attributes'][attr_type] += current_response['attributes'][attr_name]
if high_range != '*':
if log_enabled(PROTOCOL):
log(PROTOCOL, 'performing next search on auto-range <%s> via <%s>', str(int(high_range) + 1), self.connection)
requested_range = attr_type + ';range=' + str(int(high_range) + 1) + '-*'
result = self.connection.search(search_base=response['dn'],
search_filter='(objectclass=*)',
search_scope=BASE,
dereference_aliases=request['dereferenceAlias'],
attributes=[attr_type + ';range=' + str(int(high_range) + 1) + '-*'])
if self.connection.strategy.thread_safe:
status, result, _response, _ = result
else:
status = result
result = self.connection.result
_response = self.connection.response
if self.connection.strategy.sync:
if status:
current_response = _response[0]
else:
done = True
else:
current_response, _ = self.get_response(status)
current_response = current_response[0]
if not done:
if requested_range in current_response['raw_attributes'] and len(current_response['raw_attributes'][requested_range]) == 0:
del current_response['raw_attributes'][requested_range]
del current_response['attributes'][requested_range]
attr_name = list(filter(lambda a: ';range=' in a, current_response['raw_attributes'].keys()))[0]
continue
done = True
def do_search_on_auto_range(self, request, response):
for resp in [r for r in response if r['type'] == 'searchResEntry']:
for attr_name in list(resp['raw_attributes'].keys()): # generate list to avoid changing of dict size error
if ';range=' in attr_name:
attr_type, _, range_values = attr_name.partition(';range=')
if range_values in ('1-1', '0-0'): # DirSync returns these values for adding and removing members
return False
if attr_type not in resp['raw_attributes'] or resp['raw_attributes'][attr_type] is None:
resp['raw_attributes'][attr_type] = list()
if attr_type not in resp['attributes'] or resp['attributes'][attr_type] is None:
resp['attributes'][attr_type] = list()
self.do_next_range_search(request, resp, attr_name)
return True
def create_referral_connection(self, referrals):
referral_connection = None
selected_referral = None
cachekey = None
valid_referral_list = self.valid_referral_list(referrals)
if valid_referral_list:
preferred_referral_list = [referral for referral in valid_referral_list if
referral['ssl'] == self.connection.server.ssl]
selected_referral = choice(preferred_referral_list) if preferred_referral_list else choice(
valid_referral_list)
cachekey = (selected_referral['host'], selected_referral['port'] or self.connection.server.port, selected_referral['ssl'])
if self.connection.use_referral_cache and cachekey in self.referral_cache:
referral_connection = self.referral_cache[cachekey]
else:
referral_server = Server(host=selected_referral['host'],
port=selected_referral['port'] or self.connection.server.port,
use_ssl=selected_referral['ssl'],
get_info=self.connection.server.get_info,
formatter=self.connection.server.custom_formatter,
connect_timeout=self.connection.server.connect_timeout,
mode=self.connection.server.mode,
allowed_referral_hosts=self.connection.server.allowed_referral_hosts,
tls=Tls(local_private_key_file=self.connection.server.tls.private_key_file,
local_certificate_file=self.connection.server.tls.certificate_file,
validate=self.connection.server.tls.validate,
version=self.connection.server.tls.version,
ca_certs_file=self.connection.server.tls.ca_certs_file) if
selected_referral['ssl'] else None)
from ..core.connection import Connection
referral_connection = Connection(server=referral_server,
user=self.connection.user if not selected_referral['anonymousBindOnly'] else None,
password=self.connection.password if not selected_referral['anonymousBindOnly'] else None,
version=self.connection.version,
authentication=self.connection.authentication if not selected_referral['anonymousBindOnly'] else ANONYMOUS,
client_strategy=SYNC,
auto_referrals=True,
read_only=self.connection.read_only,
check_names=self.connection.check_names,
raise_exceptions=self.connection.raise_exceptions,
fast_decoder=self.connection.fast_decoder,
receive_timeout=self.connection.receive_timeout,
sasl_mechanism=self.connection.sasl_mechanism,
sasl_credentials=self.connection.sasl_credentials)
if self.connection.usage:
self.connection._usage.referrals_connections += 1
referral_connection.open()
referral_connection.strategy._referrals = self._referrals
if self.connection.tls_started and not referral_server.ssl: # if the original server was in start_tls mode and the referral server is not in ssl then start_tls on the referral connection
if not referral_connection.start_tls():
error = 'start_tls in referral not successful' + (' - ' + referral_connection.last_error if referral_connection.last_error else '')
if log_enabled(ERROR):
log(ERROR, '%s for <%s>', error, self)
self.unbind()
raise LDAPStartTLSError(error)
if self.connection.bound:
referral_connection.bind()
if self.connection.usage:
self.connection._usage.referrals_followed += 1
return selected_referral, referral_connection, cachekey
def do_operation_on_referral(self, request, referrals):
if log_enabled(PROTOCOL):
log(PROTOCOL, 'following referral for <%s>', self.connection)
selected_referral, referral_connection, cachekey = self.create_referral_connection(referrals)
if selected_referral:
if request['type'] == 'searchRequest':
referral_connection.search(selected_referral['base'] or request['base'],
selected_referral['filter'] or request['filter'],
selected_referral['scope'] or request['scope'],
request['dereferenceAlias'],
selected_referral['attributes'] or request['attributes'],
request['sizeLimit'],
request['timeLimit'],
request['typesOnly'],
controls=request['controls'])
elif request['type'] == 'addRequest':
referral_connection.add(selected_referral['base'] or request['entry'],
None,
request['attributes'],
controls=request['controls'])
elif request['type'] == 'compareRequest':
referral_connection.compare(selected_referral['base'] or request['entry'],
request['attribute'],
request['value'],
controls=request['controls'])
elif request['type'] == 'delRequest':
referral_connection.delete(selected_referral['base'] or request['entry'],
controls=request['controls'])
elif request['type'] == 'extendedReq':
referral_connection.extended(request['name'],
request['value'],
controls=request['controls'],
no_encode=True
)
elif request['type'] == 'modifyRequest':
referral_connection.modify(selected_referral['base'] or request['entry'],
prepare_changes_for_request(request['changes']),
controls=request['controls'])
elif request['type'] == 'modDNRequest':
referral_connection.modify_dn(selected_referral['base'] or request['entry'],
request['newRdn'],
request['deleteOldRdn'],
request['newSuperior'],
controls=request['controls'])
else:
self.connection.last_error = 'referral operation not permitted'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPReferralError(self.connection.last_error)
response = referral_connection.response
result = referral_connection.result
if self.connection.use_referral_cache:
self.referral_cache[cachekey] = referral_connection
else:
referral_connection.unbind()
else:
response = None
result = None
return response, result
def sending(self, ldap_message):
if log_enabled(NETWORK):
log(NETWORK, 'sending 1 ldap message for <%s>', self.connection)
try:
encoded_message = encode(ldap_message)
if self.connection.sasl_mechanism == DIGEST_MD5 and self.connection._digest_md5_kic and not self.connection.sasl_in_progress:
# If we are using DIGEST-MD5 and LDAP signing is enabled: add a signature to the message
sec_num = self.connection._digest_md5_sec_num # added underscore GC
kic = self.connection._digest_md5_kic # lowercase GC
# RFC 2831 : encoded_message = sizeOf(encored_message + signature + 0x0001 + secNum) + encoded_message + signature + 0x0001 + secNum
signature = bytes.fromhex(md5_hmac(kic, int(sec_num).to_bytes(4, 'big') + encoded_message)[0:20])
encoded_message = int(len(encoded_message) + 4 + 2 + 10).to_bytes(4, 'big') + encoded_message + signature + int(1).to_bytes(2, 'big') + int(sec_num).to_bytes(4, 'big')
self.connection._digest_md5_sec_num += 1
self.connection.socket.sendall(encoded_message)
if log_enabled(EXTENDED):
log(EXTENDED, 'ldap message sent via <%s>:%s', self.connection, format_ldap_message(ldap_message, '>>'))
if log_enabled(NETWORK):
log(NETWORK, 'sent %d bytes via <%s>', len(encoded_message), self.connection)
except socket.error as e:
self.connection.last_error = 'socket sending error' + str(e)
encoded_message = None
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
# raise communication_exception_factory(LDAPSocketSendError, exc)(self.connection.last_error)
raise communication_exception_factory(LDAPSocketSendError, type(e)(str(e)))(self.connection.last_error)
if self.connection.usage:
self.connection._usage.update_transmitted_message(self.connection.request, len(encoded_message))
def _start_listen(self):
# overridden on strategy class
raise NotImplementedError
def _get_response(self, message_id, timeout):
# overridden in strategy class
raise NotImplementedError
def receiving(self):
# overridden in strategy class
raise NotImplementedError
def post_send_single_response(self, message_id):
# overridden in strategy class
raise NotImplementedError
def post_send_search(self, message_id):
# overridden in strategy class
raise NotImplementedError
def get_stream(self):
raise NotImplementedError
def set_stream(self, value):
raise NotImplementedError
def unbind_referral_cache(self):
while len(self.referral_cache) > 0:
cachekey, referral_connection = self.referral_cache.popitem()
referral_connection.unbind()

View File

@@ -0,0 +1,152 @@
"""
"""
# Created on 2013.07.15
#
# Author: Giovanni Cannata
#
# Copyright 2013 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
from io import StringIO
from os import linesep
import random
from ..core.exceptions import LDAPLDIFError
from ..utils.conv import prepare_for_stream
from ..protocol.rfc4511 import LDAPMessage, MessageID, ProtocolOp, LDAP_MAX_INT
from ..protocol.rfc2849 import operation_to_ldif, add_ldif_header
from ..protocol.convert import build_controls_list
from .base import BaseStrategy
class LdifProducerStrategy(BaseStrategy):
"""
This strategy is used to create the LDIF stream for the Add, Delete, Modify, ModifyDn operations.
You send the request and get the request in the ldif-change representation of the operation.
NO OPERATION IS SENT TO THE LDAP SERVER!
Connection.request will contain the result LDAP message in a dict form
Connection.response will contain the ldif-change format of the requested operation if available
You don't need a real server to connect to for this strategy
"""
def __init__(self, ldap_connection):
BaseStrategy.__init__(self, ldap_connection)
self.sync = True
self.no_real_dsa = True
self.pooled = False
self.can_stream = True
self.line_separator = linesep
self.all_base64 = False
self.stream = None
self.order = dict()
self._header_added = False
random.seed()
def _open_socket(self, address, use_ssl=False, unix_socket=False): # fake open socket
self.connection.socket = NotImplemented # placeholder for a dummy socket
if self.connection.usage:
self.connection._usage.open_sockets += 1
self.connection.closed = False
def _close_socket(self):
if self.connection.usage:
self.connection._usage.closed_sockets += 1
self.connection.socket = None
self.connection.closed = True
def _start_listen(self):
self.connection.listening = True
self.connection.closed = False
self.connection.bound = True
self._header_added = False
if not self.stream or (isinstance(self.stream, StringIO) and self.stream.closed):
self.set_stream(StringIO())
def _stop_listen(self):
self.stream.close()
self.connection.listening = False
self.connection.bound = False
self.connection.closed = True
def receiving(self):
return None
def send(self, message_type, request, controls=None):
"""
Build the LDAPMessage without sending to server
"""
message_id = random.randint(0, LDAP_MAX_INT)
ldap_message = LDAPMessage()
ldap_message['messageID'] = MessageID(message_id)
ldap_message['protocolOp'] = ProtocolOp().setComponentByName(message_type, request)
message_controls = build_controls_list(controls)
if message_controls is not None:
ldap_message['controls'] = message_controls
self.connection.request = BaseStrategy.decode_request(message_type, request, controls)
self.connection.request['controls'] = controls
if self._outstanding is None:
self._outstanding = dict()
self._outstanding[message_id] = self.connection.request
return message_id
def post_send_single_response(self, message_id):
self.connection.response = None
self.connection.result = None
if self._outstanding and message_id in self._outstanding:
request = self._outstanding.pop(message_id)
ldif_lines = operation_to_ldif(self.connection.request['type'], request, self.all_base64, self.order.get(self.connection.request['type']))
if self.stream and ldif_lines and not self.connection.closed:
self.accumulate_stream(self.line_separator.join(ldif_lines))
ldif_lines = add_ldif_header(ldif_lines)
self.connection.response = self.line_separator.join(ldif_lines)
return self.connection.response
return None
def post_send_search(self, message_id):
raise LDAPLDIFError('LDIF-CONTENT cannot be produced for Search operations')
def _get_response(self, message_id, timeout):
pass
def accumulate_stream(self, fragment):
if not self._header_added and self.stream.tell() == 0:
self._header_added = True
header = add_ldif_header(['-'])[0]
self.stream.write(prepare_for_stream(header + self.line_separator + self.line_separator))
self.stream.write(prepare_for_stream(fragment + self.line_separator + self.line_separator))
def get_stream(self):
return self.stream
def set_stream(self, value):
error = False
try:
if not value.writable():
error = True
except (ValueError, AttributeError):
error = True
if error:
raise LDAPLDIFError('stream must be writable')
self.stream = value

View File

@@ -0,0 +1,200 @@
"""
"""
# Created on 2016.04.30
#
# Author: Giovanni Cannata
#
# Copyright 2016 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
from .. import ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, NO_ATTRIBUTES
from .mockBase import MockBaseStrategy
from .asynchronous import AsyncStrategy
from ..operation.search import search_result_done_response_to_dict, search_result_entry_response_to_dict
from ..core.results import DO_NOT_RAISE_EXCEPTIONS
from ..utils.log import log, log_enabled, ERROR, PROTOCOL
from ..core.exceptions import LDAPResponseTimeoutError, LDAPOperationResult
from ..operation.bind import bind_response_to_dict
from ..operation.delete import delete_response_to_dict
from ..operation.add import add_response_to_dict
from ..operation.compare import compare_response_to_dict
from ..operation.modifyDn import modify_dn_response_to_dict
from ..operation.modify import modify_response_to_dict
from ..operation.search import search_result_done_response_to_dict, search_result_entry_response_to_dict
from ..operation.extended import extended_response_to_dict
# LDAPResult ::= SEQUENCE {
# resultCode ENUMERATED {
# success (0),
# operationsError (1),
# protocolError (2),
# timeLimitExceeded (3),
# sizeLimitExceeded (4),
# compareFalse (5),
# compareTrue (6),
# authMethodNotSupported (7),
# strongerAuthRequired (8),
# -- 9 reserved --
# referral (10),
# adminLimitExceeded (11),
# unavailableCriticalExtension (12),
# confidentialityRequired (13),
# saslBindInProgress (14),
# noSuchAttribute (16),
# undefinedAttributeType (17),
# inappropriateMatching (18),
# constraintViolation (19),
# attributeOrValueExists (20),
# invalidAttributeSyntax (21),
# -- 22-31 unused --
# noSuchObject (32),
# aliasProblem (33),
# invalidDNSyntax (34),
# -- 35 reserved for undefined isLeaf --
# aliasDereferencingProblem (36),
# -- 37-47 unused --
# inappropriateAuthentication (48),
# invalidCredentials (49),
# insufficientAccessRights (50),
# busy (51),
# unavailable (52),
# unwillingToPerform (53),
# loopDetect (54),
# -- 55-63 unused --
# namingViolation (64),
# objectClassViolation (65),
# notAllowedOnNonLeaf (66),
# notAllowedOnRDN (67),
# entryAlreadyExists (68),
# objectClassModsProhibited (69),
# -- 70 reserved for CLDAP --
# affectsMultipleDSAs (71),
# -- 72-79 unused --
# other (80),
# ... },
# matchedDN LDAPDN,
# diagnosticMessage LDAPString,
# referral [3] Referral OPTIONAL }
class MockAsyncStrategy(MockBaseStrategy, AsyncStrategy): # class inheritance sequence is important, MockBaseStrategy must be the first one
"""
This strategy create a mock LDAP server, with asynchronous access
It can be useful to test LDAP without accessing a real Server
"""
def __init__(self, ldap_connection):
AsyncStrategy.__init__(self, ldap_connection)
MockBaseStrategy.__init__(self)
#outstanding = dict() # a dictionary with the message id as key and a tuple (result, response) as value
def post_send_search(self, payload):
message_id, message_type, request, controls = payload
async_response = []
async_result = dict()
if message_type == 'searchRequest':
responses, result = self.mock_search(request, controls)
result['type'] = 'searchResDone'
for entry in responses:
response = search_result_entry_response_to_dict(entry, self.connection.server.schema, self.connection.server.custom_formatter, self.connection.check_names)
response['type'] = 'searchResEntry'
if self.connection.empty_attributes:
for attribute_type in request['attributes']:
attribute_name = str(attribute_type)
if attribute_name not in response['raw_attributes'] and attribute_name not in (ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, NO_ATTRIBUTES):
response['raw_attributes'][attribute_name] = list()
response['attributes'][attribute_name] = list()
if log_enabled(PROTOCOL):
log(PROTOCOL, 'attribute set to empty list for missing attribute <%s> in <%s>',
attribute_type, self)
if not self.connection.auto_range:
attrs_to_remove = []
# removes original empty attribute in case a range tag is returned
for attribute_type in response['attributes']:
attribute_name = str(attribute_type)
if ';range' in attribute_name.lower():
orig_attr, _, _ = attribute_name.partition(';')
attrs_to_remove.append(orig_attr)
for attribute_type in attrs_to_remove:
if log_enabled(PROTOCOL):
log(PROTOCOL,
'attribute type <%s> removed in response because of same attribute returned as range by the server in <%s>',
attribute_type, self)
del response['raw_attributes'][attribute_type]
del response['attributes'][attribute_type]
async_response.append(response)
async_result = search_result_done_response_to_dict(result)
async_result['type'] = 'searchResDone'
self._responses[message_id] = (request, async_result, async_response)
return message_id
def post_send_single_response(self, payload): # payload is a tuple sent by self.send() made of message_type, request, controls
message_id, message_type, request, controls = payload
responses = []
result = None
if message_type == 'bindRequest':
result = bind_response_to_dict(self.mock_bind(request, controls))
result['type'] = 'bindResponse'
elif message_type == 'unbindRequest':
self.bound = None
elif message_type == 'abandonRequest':
pass
elif message_type == 'delRequest':
result = delete_response_to_dict(self.mock_delete(request, controls))
result['type'] = 'delResponse'
elif message_type == 'addRequest':
result = add_response_to_dict(self.mock_add(request, controls))
result['type'] = 'addResponse'
elif message_type == 'compareRequest':
result = compare_response_to_dict(self.mock_compare(request, controls))
result['type'] = 'compareResponse'
elif message_type == 'modDNRequest':
result = modify_dn_response_to_dict(self.mock_modify_dn(request, controls))
result['type'] = 'modDNResponse'
elif message_type == 'modifyRequest':
result = modify_response_to_dict(self.mock_modify(request, controls))
result['type'] = 'modifyResponse'
elif message_type == 'extendedReq':
result = extended_response_to_dict(self.mock_extended(request, controls))
result['type'] = 'extendedResp'
responses.append(result)
if self.connection.raise_exceptions and result and result['result'] not in DO_NOT_RAISE_EXCEPTIONS:
if log_enabled(PROTOCOL):
log(PROTOCOL, 'operation result <%s> for <%s>', result, self.connection)
raise LDAPOperationResult(result=result['result'], description=result['description'], dn=result['dn'], message=result['message'], response_type=result['type'])
self._responses[message_id] = (request, result, responses)
return message_id
def get_response(self, message_id, timeout=None, get_request=False):
if message_id in self._responses:
request, result, response = self._responses.pop(message_id)
else:
raise(LDAPResponseTimeoutError('message id not in outstanding queue'))
if self.connection.raise_exceptions and result and result['result'] not in DO_NOT_RAISE_EXCEPTIONS:
if log_enabled(PROTOCOL):
log(PROTOCOL, 'operation result <%s> for <%s>', result, self.connection)
raise LDAPOperationResult(result=result['result'], description=result['description'], dn=result['dn'], message=result['message'], response_type=result['type'])
if get_request:
return response, result, request
else:
return response, result

View File

@@ -0,0 +1,921 @@
"""
"""
# Created on 2016.04.30
#
# Author: Giovanni Cannata
#
# Copyright 2016 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
import json
import re
from random import SystemRandom
from pyasn1.type.univ import OctetString
from .. import SEQUENCE_TYPES, ALL_ATTRIBUTES
from ..operation.bind import bind_request_to_dict
from ..operation.delete import delete_request_to_dict
from ..operation.add import add_request_to_dict
from ..operation.compare import compare_request_to_dict
from ..operation.modifyDn import modify_dn_request_to_dict
from ..operation.modify import modify_request_to_dict
from ..operation.extended import extended_request_to_dict
from ..operation.search import search_request_to_dict, parse_filter, ROOT, AND, OR, NOT, MATCH_APPROX, \
MATCH_GREATER_OR_EQUAL, MATCH_LESS_OR_EQUAL, MATCH_EXTENSIBLE, MATCH_PRESENT,\
MATCH_SUBSTRING, MATCH_EQUAL
from ..utils.conv import json_hook, to_unicode, to_raw
from ..core.exceptions import LDAPDefinitionError, LDAPPasswordIsMandatoryError, LDAPInvalidValueError, LDAPSocketOpenError
from ..core.results import RESULT_SUCCESS, RESULT_OPERATIONS_ERROR, RESULT_UNAVAILABLE_CRITICAL_EXTENSION, \
RESULT_INVALID_CREDENTIALS, RESULT_NO_SUCH_OBJECT, RESULT_ENTRY_ALREADY_EXISTS, RESULT_COMPARE_TRUE, \
RESULT_COMPARE_FALSE, RESULT_NO_SUCH_ATTRIBUTE, RESULT_UNWILLING_TO_PERFORM, RESULT_PROTOCOL_ERROR, RESULT_CONSTRAINT_VIOLATION, RESULT_NOT_ALLOWED_ON_RDN
from ..utils.ciDict import CaseInsensitiveDict
from ..utils.dn import to_dn, safe_dn, safe_rdn
from ..protocol.sasl.sasl import validate_simple_password
from ..protocol.formatters.standard import find_attribute_validator, format_attribute_values
from ..protocol.rfc2696 import paged_search_control
from ..utils.log import log, log_enabled, ERROR, BASIC
from ..utils.asn1 import encode
from ..utils.conv import ldap_escape_to_bytes
from ..strategy.base import BaseStrategy # needed for decode_control() method
from ..protocol.rfc4511 import LDAPMessage, ProtocolOp, MessageID
from ..protocol.convert import build_controls_list
# LDAPResult ::= SEQUENCE {
# resultCode ENUMERATED {
# success (0),
# operationsError (1),
# protocolError (2),
# timeLimitExceeded (3),
# sizeLimitExceeded (4),
# compareFalse (5),
# compareTrue (6),
# authMethodNotSupported (7),
# strongerAuthRequired (8),
# -- 9 reserved --
# referral (10),
# adminLimitExceeded (11),
# unavailableCriticalExtension (12),
# confidentialityRequired (13),
# saslBindInProgress (14),
# noSuchAttribute (16),
# undefinedAttributeType (17),
# inappropriateMatching (18),
# constraintViolation (19),
# attributeOrValueExists (20),
# invalidAttributeSyntax (21),
# -- 22-31 unused --
# noSuchObject (32),
# aliasProblem (33),
# invalidDNSyntax (34),
# -- 35 reserved for undefined isLeaf --
# aliasDereferencingProblem (36),
# -- 37-47 unused --
# inappropriateAuthentication (48),
# invalidCredentials (49),
# insufficientAccessRights (50),
# busy (51),
# unavailable (52),
# unwillingToPerform (53),
# loopDetect (54),
# -- 55-63 unused --
# namingViolation (64),
# objectClassViolation (65),
# notAllowedOnNonLeaf (66),
# notAllowedOnRDN (67),
# entryAlreadyExists (68),
# objectClassModsProhibited (69),
# -- 70 reserved for CLDAP --
# affectsMultipleDSAs (71),
# -- 72-79 unused --
# other (80),
# ... },
# matchedDN LDAPDN,
# diagnosticMessage LDAPString,
# referral [3] Referral OPTIONAL }
# noinspection PyProtectedMember,PyUnresolvedReferences
SEARCH_CONTROLS = ['1.2.840.113556.1.4.319' # simple paged search [RFC 2696]
]
SERVER_ENCODING = 'utf-8'
def random_cookie():
return to_raw(SystemRandom().random())[-6:]
class PagedSearchSet(object):
def __init__(self, response, size, criticality):
self.size = size
self.response = response
self.cookie = None
self.sent = 0
self.done = False
def next(self, size=None):
if size:
self.size=size
message = ''
response = self.response[self.sent: self.sent + self.size]
self.sent += self.size
if self.sent > len(self.response):
self.done = True
self.cookie = ''
else:
self.cookie = random_cookie()
response_control = paged_search_control(False, len(self.response), self.cookie)
result = {'resultCode': RESULT_SUCCESS,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None,
'controls': [BaseStrategy.decode_control(response_control)]
}
return response, result
class MockBaseStrategy(object):
"""
Base class for connection strategy
"""
def __init__(self):
if not hasattr(self.connection.server, 'dit'): # create entries dict if not already present
self.connection.server.dit = CaseInsensitiveDict()
self.entries = self.connection.server.dit # for simpler reference
self.no_real_dsa = True
self.bound = None
self.custom_validators = None
self.operational_attributes = ['entryDN']
self.add_entry('cn=schema', [], validate=False) # add default entry for schema
self._paged_sets = [] # list of paged search in progress
if log_enabled(BASIC):
log(BASIC, 'instantiated <%s>: <%s>', self.__class__.__name__, self)
def _start_listen(self):
self.connection.listening = True
self.connection.closed = False
if self.connection.usage:
self.connection._usage.open_sockets += 1
def _stop_listen(self):
self.connection.listening = False
self.connection.closed = True
if self.connection.usage:
self.connection._usage.closed_sockets += 1
def _prepare_value(self, attribute_type, value, validate=True):
"""
Prepare a value for being stored in the mock DIT
:param value: object to store
:return: raw value to store in the DIT
"""
if validate: # if loading from json dump do not validate values:
validator = find_attribute_validator(self.connection.server.schema, attribute_type, self.custom_validators)
validated = validator(value)
if validated is False:
raise LDAPInvalidValueError('value non valid for attribute \'%s\'' % attribute_type)
elif validated is not True: # a valid LDAP value equivalent to the actual value
value = validated
raw_value = to_raw(value)
if not isinstance(raw_value, bytes):
raise LDAPInvalidValueError('The value "%s" of type %s for "%s" must be bytes or an offline schema needs to be provided when Mock strategy is used.' % (
value,
type(value),
attribute_type,
))
return raw_value
def _update_attribute(self, dn, attribute_type, value):
pass
def add_entry(self, dn, attributes, validate=True):
with self.connection.server.dit_lock:
escaped_dn = safe_dn(dn)
if escaped_dn not in self.connection.server.dit:
new_entry = CaseInsensitiveDict()
for attribute in attributes:
if attribute in self.operational_attributes: # no restore of operational attributes, should be computed at runtime
continue
if not isinstance(attributes[attribute], SEQUENCE_TYPES): # entry attributes are always lists of bytes values
attributes[attribute] = [attributes[attribute]]
if self.connection.server.schema and self.connection.server.schema.attribute_types[attribute].single_value and len(attributes[attribute]) > 1: # multiple values in single-valued attribute
return False
if attribute.lower() == 'objectclass' and self.connection.server.schema: # builds the objectClass hierarchy only if schema is present
class_set = set()
for object_class in attributes[attribute]:
if self.connection.server.schema.object_classes and object_class not in self.connection.server.schema.object_classes:
return False
# walkups the class hierarchy and buils a set of all classes in it
class_set.add(object_class)
class_set_size = 0
while class_set_size != len(class_set):
new_classes = set()
class_set_size = len(class_set)
for class_name in class_set:
if self.connection.server.schema.object_classes[class_name].superior:
new_classes.update(self.connection.server.schema.object_classes[class_name].superior)
class_set.update(new_classes)
new_entry['objectClass'] = [to_raw(value) for value in class_set]
else:
new_entry[attribute] = [self._prepare_value(attribute, value, validate) for value in attributes[attribute]]
for rdn in safe_rdn(escaped_dn, decompose=True): # adds rdns to entry attributes
if rdn[0] not in new_entry: # if rdn attribute is missing adds attribute and its value
new_entry[rdn[0]] = [to_raw(rdn[1])]
else:
raw_rdn = to_raw(rdn[1])
if raw_rdn not in new_entry[rdn[0]]: # add rdn value if rdn attribute is present but value is missing
new_entry[rdn[0]].append(raw_rdn)
new_entry['entryDN'] = [to_raw(escaped_dn)]
self.connection.server.dit[escaped_dn] = new_entry
return True
return False
def remove_entry(self, dn):
with self.connection.server.dit_lock:
escaped_dn = safe_dn(dn)
if escaped_dn in self.connection.server.dit:
del self.connection.server.dit[escaped_dn]
return True
return False
def entries_from_json(self, json_entry_file):
target = open(json_entry_file, 'r')
definition = json.load(target, object_hook=json_hook)
if 'entries' not in definition:
self.connection.last_error = 'invalid JSON definition, missing "entries" section'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPDefinitionError(self.connection.last_error)
if not self.connection.server.dit:
self.connection.server.dit = CaseInsensitiveDict()
for entry in definition['entries']:
if 'raw' not in entry:
self.connection.last_error = 'invalid JSON definition, missing "raw" section'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPDefinitionError(self.connection.last_error)
if 'dn' not in entry:
self.connection.last_error = 'invalid JSON definition, missing "dn" section'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPDefinitionError(self.connection.last_error)
self.add_entry(entry['dn'], entry['raw'], validate=False)
target.close()
def mock_bind(self, request_message, controls):
# BindRequest ::= [APPLICATION 0] SEQUENCE {
# version INTEGER (1 .. 127),
# name LDAPDN,
# authentication AuthenticationChoice }
#
# BindResponse ::= [APPLICATION 1] SEQUENCE {
# COMPONENTS OF LDAPResult,
# serverSaslCreds [7] OCTET STRING OPTIONAL }
#
# request: version, name, authentication
# response: LDAPResult + serverSaslCreds
request = bind_request_to_dict(request_message)
identity = request['name']
if 'simple' in request['authentication']:
try:
password = validate_simple_password(request['authentication']['simple'])
except LDAPPasswordIsMandatoryError:
password = ''
identity = '<anonymous>'
else:
self.connection.last_error = 'only Simple Bind allowed in Mock strategy'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPDefinitionError(self.connection.last_error)
# checks userPassword for password. userPassword must be a text string or a list of text strings
if identity in self.connection.server.dit:
if 'userPassword' in self.connection.server.dit[identity]:
# if self.connection.server.dit[identity]['userPassword'] == password or password in self.connection.server.dit[identity]['userPassword']:
if self.equal(identity, 'userPassword', password):
result_code = RESULT_SUCCESS
message = ''
self.bound = identity
else:
result_code = RESULT_INVALID_CREDENTIALS
message = 'invalid credentials'
else: # no user found, returns invalidCredentials
result_code = RESULT_INVALID_CREDENTIALS
message = 'missing userPassword attribute'
elif identity == '<anonymous>':
result_code = RESULT_SUCCESS
message = ''
self.bound = identity
else:
result_code = RESULT_INVALID_CREDENTIALS
message = 'missing object'
return {'resultCode': result_code,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None,
'serverSaslCreds': None
}
def mock_delete(self, request_message, controls):
# DelRequest ::= [APPLICATION 10] LDAPDN
#
# DelResponse ::= [APPLICATION 11] LDAPResult
#
# request: entry
# response: LDAPResult
request = delete_request_to_dict(request_message)
dn = safe_dn(request['entry'])
if dn in self.connection.server.dit:
del self.connection.server.dit[dn]
result_code = RESULT_SUCCESS
message = ''
else:
result_code = RESULT_NO_SUCH_OBJECT
message = 'object not found'
return {'resultCode': result_code,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None
}
def mock_add(self, request_message, controls):
# AddRequest ::= [APPLICATION 8] SEQUENCE {
# entry LDAPDN,
# attributes AttributeList }
#
# AddResponse ::= [APPLICATION 9] LDAPResult
#
# request: entry, attributes
# response: LDAPResult
request = add_request_to_dict(request_message)
dn = safe_dn(request['entry'])
attributes = request['attributes']
# converts attributes values to bytes
if dn not in self.connection.server.dit:
if self.add_entry(dn, attributes):
result_code = RESULT_SUCCESS
message = ''
else:
result_code = RESULT_OPERATIONS_ERROR
message = 'error adding entry'
else:
result_code = RESULT_ENTRY_ALREADY_EXISTS
message = 'entry already exist'
return {'resultCode': result_code,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None
}
def mock_compare(self, request_message, controls):
# CompareRequest ::= [APPLICATION 14] SEQUENCE {
# entry LDAPDN,
# ava AttributeValueAssertion }
#
# CompareResponse ::= [APPLICATION 15] LDAPResult
#
# request: entry, attribute, value
# response: LDAPResult
request = compare_request_to_dict(request_message)
dn = safe_dn(request['entry'])
attribute = request['attribute']
value = to_raw(request['value'])
if dn in self.connection.server.dit:
if attribute in self.connection.server.dit[dn]:
if self.equal(dn, attribute, value):
result_code = RESULT_COMPARE_TRUE
message = ''
else:
result_code = RESULT_COMPARE_FALSE
message = ''
else:
result_code = RESULT_NO_SUCH_ATTRIBUTE
message = 'attribute not found'
else:
result_code = RESULT_NO_SUCH_OBJECT
message = 'object not found'
return {'resultCode': result_code,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None
}
def mock_modify_dn(self, request_message, controls):
# ModifyDNRequest ::= [APPLICATION 12] SEQUENCE {
# entry LDAPDN,
# newrdn RelativeLDAPDN,
# deleteoldrdn BOOLEAN,
# newSuperior [0] LDAPDN OPTIONAL }
#
# ModifyDNResponse ::= [APPLICATION 13] LDAPResult
#
# request: entry, newRdn, deleteOldRdn, newSuperior
# response: LDAPResult
request = modify_dn_request_to_dict(request_message)
dn = safe_dn(request['entry'])
new_rdn = request['newRdn']
delete_old_rdn = request['deleteOldRdn']
new_superior = safe_dn(request['newSuperior']) if request['newSuperior'] else ''
dn_components = to_dn(dn)
if dn in self.connection.server.dit:
if new_superior and new_rdn: # performs move in the DIT
new_dn = safe_dn(dn_components[0] + ',' + new_superior)
self.connection.server.dit[new_dn] = self.connection.server.dit[dn].copy()
moved_entry = self.connection.server.dit[new_dn]
if delete_old_rdn:
del self.connection.server.dit[dn]
result_code = RESULT_SUCCESS
message = 'entry moved'
moved_entry['entryDN'] = [to_raw(new_dn)]
elif new_rdn and not new_superior: # performs rename
new_dn = safe_dn(new_rdn + ',' + safe_dn(dn_components[1:]))
self.connection.server.dit[new_dn] = self.connection.server.dit[dn].copy()
renamed_entry = self.connection.server.dit[new_dn]
del self.connection.server.dit[dn]
renamed_entry['entryDN'] = [to_raw(new_dn)]
for rdn in safe_rdn(new_dn, decompose=True): # adds rdns to entry attributes
renamed_entry[rdn[0]] = [to_raw(rdn[1])]
result_code = RESULT_SUCCESS
message = 'entry rdn renamed'
else:
result_code = RESULT_UNWILLING_TO_PERFORM
message = 'newRdn or newSuperior missing'
else:
result_code = RESULT_NO_SUCH_OBJECT
message = 'object not found'
return {'resultCode': result_code,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None
}
def mock_modify(self, request_message, controls):
# ModifyRequest ::= [APPLICATION 6] SEQUENCE {
# object LDAPDN,
# changes SEQUENCE OF change SEQUENCE {
# operation ENUMERATED {
# add (0),
# delete (1),
# replace (2),
# ... },
# modification PartialAttribute } }
#
# ModifyResponse ::= [APPLICATION 7] LDAPResult
#
# request: entry, changes
# response: LDAPResult
#
# changes is a dictionary in the form {'attribute': [(operation, [val1, ...]), ...], ...}
# operation is 0 (add), 1 (delete), 2 (replace), 3 (increment)
request = modify_request_to_dict(request_message)
dn = safe_dn(request['entry'])
changes = request['changes']
result_code = 0
message = ''
rdns = [rdn[0] for rdn in safe_rdn(dn, decompose=True)]
if dn in self.connection.server.dit:
entry = self.connection.server.dit[dn]
original_entry = entry.copy() # to preserve atomicity of operation
for modification in changes:
operation = modification['operation']
attribute = modification['attribute']['type']
elements = modification['attribute']['value']
if operation == 0: # add
if attribute not in entry and elements: # attribute not present, creates the new attribute and add elements
if self.connection.server.schema and self.connection.server.schema.attribute_types and self.connection.server.schema.attribute_types[attribute].single_value and len(elements) > 1: # multiple values in single-valued attribute
result_code = RESULT_CONSTRAINT_VIOLATION
message = 'attribute is single-valued'
else:
entry[attribute] = [to_raw(element) for element in elements]
else: # attribute present, adds elements to current values
if self.connection.server.schema and self.connection.server.schema.attribute_types and self.connection.server.schema.attribute_types[attribute].single_value: # multiple values in single-valued attribute
result_code = RESULT_CONSTRAINT_VIOLATION
message = 'attribute is single-valued'
else:
entry[attribute].extend([to_raw(element) for element in elements])
elif operation == 1: # delete
if attribute not in entry: # attribute must exist
result_code = RESULT_NO_SUCH_ATTRIBUTE
message = 'attribute must exists for deleting its values'
elif attribute in rdns: # attribute can't be used in dn
result_code = RESULT_NOT_ALLOWED_ON_RDN
message = 'cannot delete an rdn'
else:
if not elements: # deletes whole attribute if element list is empty
del entry[attribute]
else:
for element in elements:
raw_element = to_raw(element)
if self.equal(dn, attribute, raw_element): # removes single element
entry[attribute].remove(raw_element)
else:
result_code = 1
message = 'value to delete not found'
if not entry[attribute]: # removes the whole attribute if no elements remained
del entry[attribute]
elif operation == 2: # replace
if attribute not in entry and elements: # attribute not present, creates the new attribute and add elements
if self.connection.server.schema and self.connection.server.schema.attribute_types and self.connection.server.schema.attribute_types[attribute].single_value and len(elements) > 1: # multiple values in single-valued attribute
result_code = RESULT_CONSTRAINT_VIOLATION
message = 'attribute is single-valued'
else:
entry[attribute] = [to_raw(element) for element in elements]
elif not elements and attribute in rdns: # attribute can't be used in dn
result_code = RESULT_NOT_ALLOWED_ON_RDN
message = 'cannot replace an rdn'
elif not elements: # deletes whole attribute if element list is empty
if attribute in entry:
del entry[attribute]
else: # substitutes elements
entry[attribute] = [to_raw(element) for element in elements]
elif operation == 3: # increment
if attribute not in entry: # attribute must exist
result_code = RESULT_NO_SUCH_ATTRIBUTE
message = 'attribute must exists for incrementing its values'
else:
if len(elements) != 1:
result_code = RESULT_PROTOCOL_ERROR
message = 'only one increment value is allowed'
else:
try:
entry[attribute] = [bytes(str(int(value) + int(elements[0])), encoding='utf-8') for value in entry[attribute]]
except:
result_code = RESULT_UNWILLING_TO_PERFORM
message = 'unable to increment value'
if result_code: # an error has happened, restores the original dn
self.connection.server.dit[dn] = original_entry
else:
result_code = RESULT_NO_SUCH_OBJECT
message = 'object not found'
return {'resultCode': result_code,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None
}
def mock_search(self, request_message, controls):
# SearchRequest ::= [APPLICATION 3] SEQUENCE {
# baseObject LDAPDN,
# scope ENUMERATED {
# baseObject (0),
# singleLevel (1),
# wholeSubtree (2),
# ... },
# derefAliases ENUMERATED {
# neverDerefAliases (0),
# derefInSearching (1),
# derefFindingBaseObj (2),
# derefAlways (3) },
# sizeLimit INTEGER (0 .. maxInt),
# timeLimit INTEGER (0 .. maxInt),
# typesOnly BOOLEAN,
# filter Filter,
# attributes AttributeSelection }
#
# SearchResultEntry ::= [APPLICATION 4] SEQUENCE {
# objectName LDAPDN,
# attributes PartialAttributeList }
#
#
# SearchResultReference ::= [APPLICATION 19] SEQUENCE
# SIZE (1..MAX) OF uri URI
#
# SearchResultDone ::= [APPLICATION 5] LDAPResult
#
# request: base, scope, dereferenceAlias, sizeLimit, timeLimit, typesOnly, filter, attributes
# response_entry: object, attributes
# response_done: LDAPResult
request = search_request_to_dict(request_message)
if controls:
decoded_controls = [self.decode_control(control) for control in controls if control]
for decoded_control in decoded_controls:
if decoded_control[1]['criticality'] and decoded_control[0] not in SEARCH_CONTROLS:
message = 'Critical requested control ' + str(decoded_control[0]) + ' not available'
result = {'resultCode': RESULT_UNAVAILABLE_CRITICAL_EXTENSION,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None
}
return [], result
elif decoded_control[0] == '1.2.840.113556.1.4.319': # Simple paged search
if not decoded_control[1]['value']['cookie']: # new paged search
response, result = self._execute_search(request)
if result['resultCode'] == RESULT_SUCCESS: # success
paged_set = PagedSearchSet(response, int(decoded_control[1]['value']['size']), decoded_control[1]['criticality'])
response, result = paged_set.next()
if paged_set.done: # paged search already completed, no need to store the set
del paged_set
else:
self._paged_sets.append(paged_set)
return response, result
else:
return [], result
else:
for paged_set in self._paged_sets:
if paged_set.cookie == decoded_control[1]['value']['cookie']: # existing paged set
response, result = paged_set.next() # returns next bunch of entries as per paged set specifications
if paged_set.done:
self._paged_sets.remove(paged_set)
return response, result
# paged set not found
message = 'Invalid cookie in simple paged search'
result = {'resultCode': RESULT_OPERATIONS_ERROR,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None
}
return [], result
else:
return self._execute_search(request)
def _execute_search(self, request):
responses = []
base = safe_dn(request['base'])
scope = request['scope']
attributes = request['attributes']
if '+' in attributes: # operational attributes requested
attributes.extend(self.operational_attributes)
attributes.remove('+')
attributes = [attr.lower() for attr in request['attributes']]
filter_root = parse_filter(request['filter'], self.connection.server.schema, auto_escape=True, auto_encode=False, validator=self.connection.server.custom_validator, check_names=self.connection.check_names)
candidates = []
if scope == 0: # base object
if base in self.connection.server.dit or base.lower() == 'cn=schema':
candidates.append(base)
elif scope == 1: # single level
for entry in self.connection.server.dit:
if entry.lower().endswith(base.lower()) and ',' not in entry[:-len(base) - 1]: # only leafs without commas in the remaining dn
candidates.append(entry)
elif scope == 2: # whole subtree
for entry in self.connection.server.dit:
if entry.lower().endswith(base.lower()):
candidates.append(entry)
if not candidates: # incorrect base
result_code = RESULT_NO_SUCH_OBJECT
message = 'incorrect base object'
else:
matched = self.evaluate_filter_node(filter_root, candidates)
if self.connection.raise_exceptions and 0 < request['sizeLimit'] < len(matched):
result_code = 4
message = 'size limit exceeded'
else:
for match in matched:
responses.append({
'object': match,
'attributes': [{'type': attribute,
'vals': [] if request['typesOnly'] else self.connection.server.dit[match][attribute]}
for attribute in self.connection.server.dit[match]
if attribute.lower() in attributes or ALL_ATTRIBUTES in attributes]
})
if '+' not in attributes: # remove operational attributes
for op_attr in self.operational_attributes:
if op_attr.lower() in attributes:
# if the op_attr was explicitly requested, then keep it
continue
for i, attr in enumerate(responses[len(responses)-1]['attributes']):
if attr['type'] == op_attr:
del responses[len(responses)-1]['attributes'][i]
result_code = 0
message = ''
result = {'resultCode': result_code,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None
}
return responses[:request['sizeLimit']] if request['sizeLimit'] > 0 else responses, result
def mock_extended(self, request_message, controls):
# ExtendedRequest ::= [APPLICATION 23] SEQUENCE {
# requestName [0] LDAPOID,
# requestValue [1] OCTET STRING OPTIONAL }
#
# ExtendedResponse ::= [APPLICATION 24] SEQUENCE {
# COMPONENTS OF LDAPResult,
# responseName [10] LDAPOID OPTIONAL,
# responseValue [11] OCTET STRING OPTIONAL }
#
# IntermediateResponse ::= [APPLICATION 25] SEQUENCE {
# responseName [0] LDAPOID OPTIONAL,
# responseValue [1] OCTET STRING OPTIONAL }
request = extended_request_to_dict(request_message)
result_code = RESULT_UNWILLING_TO_PERFORM
message = 'not implemented'
response_name = None
response_value = None
if self.connection.server.info:
for extension in self.connection.server.info.supported_extensions:
if request['name'] == extension[0]: # server can answer the extended request
if extension[0] == '2.16.840.1.113719.1.27.100.31': # getBindDNRequest [NOVELL]
result_code = 0
message = ''
response_name = OctetString('2.16.840.1.113719.1.27.100.32') # getBindDNResponse [NOVELL]
response_value = OctetString(self.bound)
elif extension[0] == '1.3.6.1.4.1.4203.1.11.3': # WhoAmI [RFC4532]
result_code = 0
message = ''
response_name = OctetString('1.3.6.1.4.1.4203.1.11.3') # WhoAmI [RFC4532]
response_value = OctetString(self.bound)
break
return {'resultCode': result_code,
'matchedDN': '',
'diagnosticMessage': to_unicode(message, SERVER_ENCODING),
'referral': None,
'responseName': response_name,
'responseValue': response_value
}
def evaluate_filter_node(self, node, candidates):
"""After evaluation each 2 sets are added to each MATCH node, one for the matched object and one for unmatched object.
The unmatched object set is needed if a superior node is a NOT that reverts the evaluation. The BOOLEAN nodes mix the sets
returned by the MATCH nodes"""
node.matched = set()
node.unmatched = set()
if node.elements:
for element in node.elements:
self.evaluate_filter_node(element, candidates)
if node.tag == ROOT:
return node.elements[0].matched
elif node.tag == AND:
first_element = node.elements[0]
node.matched.update(first_element.matched)
node.unmatched.update(first_element.unmatched)
for element in node.elements[1:]:
node.matched.intersection_update(element.matched)
node.unmatched.intersection_update(element.unmatched)
elif node.tag == OR:
for element in node.elements:
node.matched.update(element.matched)
node.unmatched.update(element.unmatched)
elif node.tag == NOT:
node.matched = node.elements[0].unmatched
node.unmatched = node.elements[0].matched
elif node.tag == MATCH_GREATER_OR_EQUAL:
attr_name = node.assertion['attr']
attr_value = node.assertion['value']
for candidate in candidates:
if attr_name in self.connection.server.dit[candidate]:
for value in self.connection.server.dit[candidate][attr_name]:
if value.isdigit() and attr_value.isdigit(): # int comparison
if int(value) >= int(attr_value):
node.matched.add(candidate)
else:
node.unmatched.add(candidate)
else:
if to_unicode(value, SERVER_ENCODING).lower() >= to_unicode(attr_value, SERVER_ENCODING).lower(): # case insensitive string comparison
node.matched.add(candidate)
else:
node.unmatched.add(candidate)
elif node.tag == MATCH_LESS_OR_EQUAL:
attr_name = node.assertion['attr']
attr_value = node.assertion['value']
for candidate in candidates:
if attr_name in self.connection.server.dit[candidate]:
for value in self.connection.server.dit[candidate][attr_name]:
if value.isdigit() and attr_value.isdigit(): # int comparison
if int(value) <= int(attr_value):
node.matched.add(candidate)
else:
node.unmatched.add(candidate)
else:
if to_unicode(value, SERVER_ENCODING).lower() <= to_unicode(attr_value, SERVER_ENCODING).lower(): # case insentive string comparison
node.matched.add(candidate)
else:
node.unmatched.add(candidate)
elif node.tag == MATCH_EXTENSIBLE:
self.connection.last_error = 'Extensible match not allowed in Mock strategy'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPDefinitionError(self.connection.last_error)
elif node.tag == MATCH_PRESENT:
attr_name = node.assertion['attr']
for candidate in candidates:
if attr_name in self.connection.server.dit[candidate]:
node.matched.add(candidate)
else:
node.unmatched.add(candidate)
elif node.tag == MATCH_SUBSTRING:
attr_name = node.assertion['attr']
# rebuild the original substring filter
if 'initial' in node.assertion and node.assertion['initial'] is not None:
substring_filter = re.escape(to_unicode(node.assertion['initial'], SERVER_ENCODING))
else:
substring_filter = ''
if 'any' in node.assertion and node.assertion['any'] is not None:
for middle in node.assertion['any']:
substring_filter += '.*' + re.escape(to_unicode(middle, SERVER_ENCODING))
if 'final' in node.assertion and node.assertion['final'] is not None:
substring_filter += '.*' + re.escape(to_unicode(node.assertion['final'], SERVER_ENCODING))
if substring_filter and not node.assertion.get('any', None) and not node.assertion.get('final', None): # only initial, adds .*
substring_filter += '.*'
regex_filter = re.compile(substring_filter, flags=re.UNICODE | re.IGNORECASE) # unicode AND ignorecase
for candidate in candidates:
if attr_name in self.connection.server.dit[candidate]:
for value in self.connection.server.dit[candidate][attr_name]:
if regex_filter.match(to_unicode(value, SERVER_ENCODING)):
node.matched.add(candidate)
else:
node.unmatched.add(candidate)
else:
node.unmatched.add(candidate)
elif node.tag == MATCH_EQUAL or node.tag == MATCH_APPROX:
attr_name = node.assertion['attr']
attr_value = node.assertion['value']
for candidate in candidates:
if attr_name in self.connection.server.dit[candidate] and self.equal(candidate, attr_name, attr_value):
node.matched.add(candidate)
else:
node.unmatched.add(candidate)
def equal(self, dn, attribute_type, value_to_check):
# value is the value to match
attribute_values = self.connection.server.dit[dn][attribute_type]
if not isinstance(attribute_values, SEQUENCE_TYPES):
attribute_values = [attribute_values]
escaped_value_to_check = ldap_escape_to_bytes(value_to_check)
for attribute_value in attribute_values:
if self._check_equality(escaped_value_to_check, attribute_value):
return True
if self._check_equality(self._prepare_value(attribute_type, value_to_check), attribute_value):
return True
return False
@staticmethod
def _check_equality(value1, value2):
if value1 == value2: # exact matching
return True
if str(value1).isdigit() and str(value2).isdigit():
if int(value1) == int(value2): # int comparison
return True
try:
if to_unicode(value1, SERVER_ENCODING).lower() == to_unicode(value2, SERVER_ENCODING).lower(): # case insensitive comparison
return True
except UnicodeError:
pass
return False
def send(self, message_type, request, controls=None):
self.connection.request = self.decode_request(message_type, request, controls)
if self.connection.listening:
message_id = self.connection.server.next_message_id()
if self.connection.usage: # ldap message is built for updating metrics only
ldap_message = LDAPMessage()
ldap_message['messageID'] = MessageID(message_id)
ldap_message['protocolOp'] = ProtocolOp().setComponentByName(message_type, request)
message_controls = build_controls_list(controls)
if message_controls is not None:
ldap_message['controls'] = message_controls
asn1_request = BaseStrategy.decode_request(message_type, request, controls)
self.connection._usage.update_transmitted_message(asn1_request, len(encode(ldap_message)))
return message_id, message_type, request, controls
else:
self.connection.last_error = 'unable to send message, connection is not open'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSocketOpenError(self.connection.last_error)

View File

@@ -0,0 +1,133 @@
"""
"""
# Created on 2014.11.17
#
# Author: Giovanni Cannata
#
# Copyright 2014 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
from ..core.results import DO_NOT_RAISE_EXCEPTIONS
from .mockBase import MockBaseStrategy
from .. import ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, NO_ATTRIBUTES
from .sync import SyncStrategy
from ..operation.bind import bind_response_to_dict
from ..operation.delete import delete_response_to_dict
from ..operation.add import add_response_to_dict
from ..operation.compare import compare_response_to_dict
from ..operation.modifyDn import modify_dn_response_to_dict
from ..operation.modify import modify_response_to_dict
from ..operation.search import search_result_done_response_to_dict, search_result_entry_response_to_dict
from ..operation.extended import extended_response_to_dict
from ..core.exceptions import LDAPSocketOpenError, LDAPOperationResult
from ..utils.log import log, log_enabled, ERROR, PROTOCOL
class MockSyncStrategy(MockBaseStrategy, SyncStrategy): # class inheritance sequence is important, MockBaseStrategy must be the first one
"""
This strategy create a mock LDAP server, with synchronous access
It can be useful to test LDAP without accessing a real Server
"""
def __init__(self, ldap_connection):
SyncStrategy.__init__(self, ldap_connection)
MockBaseStrategy.__init__(self)
def post_send_search(self, payload):
message_id, message_type, request, controls = payload
self.connection.response = []
self.connection.result = dict()
if message_type == 'searchRequest':
responses, result = self.mock_search(request, controls)
for entry in responses:
response = search_result_entry_response_to_dict(entry, self.connection.server.schema, self.connection.server.custom_formatter, self.connection.check_names)
response['type'] = 'searchResEntry'
###
if self.connection.empty_attributes:
for attribute_type in request['attributes']:
attribute_name = str(attribute_type)
if attribute_name not in response['raw_attributes'] and attribute_name not in (ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, NO_ATTRIBUTES):
response['raw_attributes'][attribute_name] = list()
response['attributes'][attribute_name] = list()
if log_enabled(PROTOCOL):
log(PROTOCOL, 'attribute set to empty list for missing attribute <%s> in <%s>',
attribute_type, self)
if not self.connection.auto_range:
attrs_to_remove = []
# removes original empty attribute in case a range tag is returned
for attribute_type in response['attributes']:
attribute_name = str(attribute_type)
if ';range' in attribute_name.lower():
orig_attr, _, _ = attribute_name.partition(';')
attrs_to_remove.append(orig_attr)
for attribute_type in attrs_to_remove:
if log_enabled(PROTOCOL):
log(PROTOCOL,
'attribute type <%s> removed in response because of same attribute returned as range by the server in <%s>',
attribute_type, self)
del response['raw_attributes'][attribute_type]
del response['attributes'][attribute_type]
###
self.connection.response.append(response)
result = search_result_done_response_to_dict(result)
result['type'] = 'searchResDone'
self.connection.result = result
if self.connection.raise_exceptions and result and result['result'] not in DO_NOT_RAISE_EXCEPTIONS:
if log_enabled(PROTOCOL):
log(PROTOCOL, 'operation result <%s> for <%s>', result, self.connection)
raise LDAPOperationResult(result=result['result'], description=result['description'], dn=result['dn'], message=result['message'], response_type=result['type'])
return self.connection.response
def post_send_single_response(self, payload): # payload is a tuple sent by self.send() made of message_type, request, controls
message_id, message_type, request, controls = payload
responses = []
result = None
if message_type == 'bindRequest':
result = bind_response_to_dict(self.mock_bind(request, controls))
result['type'] = 'bindResponse'
elif message_type == 'unbindRequest':
self.bound = None
elif message_type == 'abandonRequest':
pass
elif message_type == 'delRequest':
result = delete_response_to_dict(self.mock_delete(request, controls))
result['type'] = 'delResponse'
elif message_type == 'addRequest':
result = add_response_to_dict(self.mock_add(request, controls))
result['type'] = 'addResponse'
elif message_type == 'compareRequest':
result = compare_response_to_dict(self.mock_compare(request, controls))
result['type'] = 'compareResponse'
elif message_type == 'modDNRequest':
result = modify_dn_response_to_dict(self.mock_modify_dn(request, controls))
result['type'] = 'modDNResponse'
elif message_type == 'modifyRequest':
result = modify_response_to_dict(self.mock_modify(request, controls))
result['type'] = 'modifyResponse'
elif message_type == 'extendedReq':
result = extended_response_to_dict(self.mock_extended(request, controls))
result['type'] = 'extendedResp'
self.connection.result = result
responses.append(result)
if self.connection.raise_exceptions and result and result['result'] not in DO_NOT_RAISE_EXCEPTIONS:
if log_enabled(PROTOCOL):
log(PROTOCOL, 'operation result <%s> for <%s>', result, self.connection)
raise LDAPOperationResult(result=result['result'], description=result['description'], dn=result['dn'], message=result['message'], response_type=result['type'])
return responses

View File

@@ -0,0 +1,260 @@
"""
"""
# Created on 2014.03.04
#
# Author: Giovanni Cannata
#
# Copyright 2014 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
from time import sleep
import socket
from .. import get_config_parameter
from .sync import SyncStrategy
from ..core.exceptions import LDAPSocketOpenError, LDAPOperationResult, LDAPMaximumRetriesError, LDAPStartTLSError
from ..utils.log import log, log_enabled, ERROR, BASIC
# noinspection PyBroadException,PyProtectedMember
class RestartableStrategy(SyncStrategy):
def __init__(self, ldap_connection):
SyncStrategy.__init__(self, ldap_connection)
self.sync = True
self.no_real_dsa = False
self.pooled = False
self.can_stream = False
self.restartable_sleep_time = get_config_parameter('RESTARTABLE_SLEEPTIME')
self.restartable_tries = get_config_parameter('RESTARTABLE_TRIES')
self._restarting = False
self._last_bind_controls = None
self._current_message_type = None
self._current_request = None
self._current_controls = None
self._restart_tls = None
self.exception_history = []
def open(self, reset_usage=False, read_server_info=True):
SyncStrategy.open(self, reset_usage, read_server_info)
def _open_socket(self, address, use_ssl=False, unix_socket=False):
"""
Try to open and connect a socket to a Server
raise LDAPExceptionError if unable to open or connect socket
if connection is restartable tries for the number of restarting requested or forever
"""
try:
SyncStrategy._open_socket(self, address, use_ssl, unix_socket) # try to open socket using SyncWait
self._reset_exception_history()
return
except Exception as e: # machinery for restartable connection
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
if not self._restarting: # if not already performing a restart
self._restarting = True
counter = self.restartable_tries
while counter > 0: # includes restartable_tries == True
if log_enabled(BASIC):
log(BASIC, 'try #%d to open Restartable connection <%s>', self.restartable_tries - counter, self.connection)
sleep(self.restartable_sleep_time)
if not self.connection.closed:
try: # resetting connection
self.connection.unbind()
except (socket.error, LDAPSocketOpenError): # don't trace catch socket errors because socket could already be closed
pass
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
try: # reissuing same operation
if self.connection.server_pool:
new_server = self.connection.server_pool.get_server(self.connection) # get a server from the server_pool if available
if self.connection.server != new_server:
self.connection.server = new_server
if self.connection.usage:
self.connection._usage.servers_from_pool += 1
SyncStrategy._open_socket(self, address, use_ssl, unix_socket) # calls super (not restartable) _open_socket()
if self.connection.usage:
self.connection._usage.restartable_successes += 1
self.connection.closed = False
self._restarting = False
self._reset_exception_history()
return
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
if self.connection.usage:
self.connection._usage.restartable_failures += 1
if not isinstance(self.restartable_tries, bool):
counter -= 1
self._restarting = False
self.connection.last_error = 'restartable connection strategy failed while opening socket'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPMaximumRetriesError(self.connection.last_error, self.exception_history, self.restartable_tries)
def send(self, message_type, request, controls=None):
self._current_message_type = message_type
self._current_request = request
self._current_controls = controls
if not self._restart_tls: # RFCs doesn't define how to stop tls once started
self._restart_tls = self.connection.tls_started
if message_type == 'bindRequest': # stores controls used in bind operation to be used again when restarting the connection
self._last_bind_controls = controls
try:
message_id = SyncStrategy.send(self, message_type, request, controls) # tries to send using SyncWait
self._reset_exception_history()
return message_id
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
if not self._restarting: # machinery for restartable connection
self._restarting = True
counter = self.restartable_tries
while counter > 0:
if log_enabled(BASIC):
log(BASIC, 'try #%d to send in Restartable connection <%s>', self.restartable_tries - counter, self.connection)
sleep(self.restartable_sleep_time)
if not self.connection.closed:
try: # resetting connection
self.connection.unbind()
except (socket.error, LDAPSocketOpenError): # don't trace socket errors because socket could already be closed
pass
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
failure = False
try: # reopening connection
self.connection.open(reset_usage=False, read_server_info=False)
if self._restart_tls: # restart tls if start_tls was previously used
if not self.connection.start_tls(read_server_info=False):
error = 'restart tls in restartable not successful' + (' - ' + self.connection.last_error if self.connection.last_error else '')
if log_enabled(ERROR):
log(ERROR, '%s for <%s>', error, self)
self.connection.unbind()
raise LDAPStartTLSError(error)
if message_type != 'bindRequest':
self.connection.bind(read_server_info=False, controls=self._last_bind_controls) # binds with previously used controls unless the request is already a bindRequest
if not self.connection.server.schema and not self.connection.server.info:
self.connection.refresh_server_info()
else:
self.connection._fire_deferred(read_info=False) # in case of lazy connection, not open by the refresh_server_info
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
failure = True
if not failure:
try: # reissuing same operation
ret_value = self.connection.send(message_type, request, controls)
if self.connection.usage:
self.connection._usage.restartable_successes += 1
self._restarting = False
self._reset_exception_history()
return ret_value # successful send
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
failure = True
if failure and self.connection.usage:
self.connection._usage.restartable_failures += 1
if not isinstance(self.restartable_tries, bool):
counter -= 1
self._restarting = False
self.connection.last_error = 'restartable connection failed to send'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPMaximumRetriesError(self.connection.last_error, self.exception_history, self.restartable_tries)
def post_send_single_response(self, message_id):
try:
ret_value = SyncStrategy.post_send_single_response(self, message_id)
self._reset_exception_history()
return ret_value
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
# if an LDAPExceptionError is raised then resend the request
try:
ret_value = SyncStrategy.post_send_single_response(self, self.send(self._current_message_type, self._current_request, self._current_controls))
self._reset_exception_history()
return ret_value
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
if not isinstance(e, LDAPOperationResult):
self.connection.last_error = 'restartable connection strategy failed in post_send_single_response'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise
def post_send_search(self, message_id):
try:
ret_value = SyncStrategy.post_send_search(self, message_id)
self._reset_exception_history()
return ret_value
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
# if an LDAPExceptionError is raised then resend the request
try:
ret_value = SyncStrategy.post_send_search(self, self.connection.send(self._current_message_type, self._current_request, self._current_controls))
self._reset_exception_history()
return ret_value
except Exception as e:
if log_enabled(ERROR):
log(ERROR, '<%s> while restarting <%s>', e, self.connection)
self._add_exception_to_history(type(e)(str(e)))
if not isinstance(e, LDAPOperationResult):
self.connection.last_error = e.args
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise e
def _add_exception_to_history(self, exc):
if not isinstance(self.restartable_tries, bool): # doesn't accumulate when restarting forever
if not isinstance(exc, LDAPMaximumRetriesError): # doesn't add the LDAPMaximumRetriesError exception
self.exception_history.append(exc)
def _reset_exception_history(self):
if self.exception_history:
self.exception_history = []
def get_stream(self):
raise NotImplementedError
def set_stream(self, value):
raise NotImplementedError

View File

@@ -0,0 +1,495 @@
"""
"""
# Created on 2014.03.23
#
# Author: Giovanni Cannata
#
# Copyright 2014 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime
from os import linesep
from threading import Thread, Lock
from time import sleep
from .. import RESTARTABLE, get_config_parameter, AUTO_BIND_DEFAULT, AUTO_BIND_NONE, AUTO_BIND_NO_TLS, AUTO_BIND_TLS_AFTER_BIND, AUTO_BIND_TLS_BEFORE_BIND
from .base import BaseStrategy
from ..core.usage import ConnectionUsage
from ..core.exceptions import LDAPConnectionPoolNameIsMandatoryError, LDAPConnectionPoolNotStartedError, LDAPOperationResult, LDAPExceptionError, LDAPResponseTimeoutError
from ..utils.log import log, log_enabled, ERROR, BASIC
from ..protocol.rfc4511 import LDAP_MAX_INT
TERMINATE_REUSABLE = 'TERMINATE_REUSABLE_CONNECTION'
BOGUS_BIND = -1
BOGUS_UNBIND = -2
BOGUS_EXTENDED = -3
BOGUS_ABANDON = -4
try:
from queue import Queue, Empty
except ImportError: # Python 2
# noinspection PyUnresolvedReferences
from Queue import Queue, Empty
# noinspection PyProtectedMember
class ReusableStrategy(BaseStrategy):
"""
A pool of reusable SyncWaitRestartable connections with lazy behaviour and limited lifetime.
The connection using this strategy presents itself as a normal connection, but internally the strategy has a pool of
connections that can be used as needed. Each connection lives in its own thread and has a busy/available status.
The strategy performs the requested operation on the first available connection.
The pool of connections is instantiated at strategy initialization.
Strategy has two customizable properties, the total number of connections in the pool and the lifetime of each connection.
When lifetime is expired the connection is closed and will be open again when needed.
"""
pools = dict()
def receiving(self):
raise NotImplementedError
def _start_listen(self):
raise NotImplementedError
def _get_response(self, message_id, timeout):
raise NotImplementedError
def get_stream(self):
raise NotImplementedError
def set_stream(self, value):
raise NotImplementedError
# noinspection PyProtectedMember
class ConnectionPool(object):
"""
Container for the Connection Threads
"""
def __new__(cls, connection):
if connection.pool_name in ReusableStrategy.pools: # returns existing connection pool
pool = ReusableStrategy.pools[connection.pool_name]
if not pool.started: # if pool is not started remove it from the pools singleton and create a new onw
del ReusableStrategy.pools[connection.pool_name]
return object.__new__(cls)
if connection.pool_keepalive and pool.keepalive != connection.pool_keepalive: # change lifetime
pool.keepalive = connection.pool_keepalive
if connection.pool_lifetime and pool.lifetime != connection.pool_lifetime: # change keepalive
pool.lifetime = connection.pool_lifetime
if connection.pool_size and pool.pool_size != connection.pool_size: # if pool size has changed terminate and recreate the connections
pool.terminate_pool()
pool.pool_size = connection.pool_size
return pool
else:
return object.__new__(cls)
def __init__(self, connection):
if not hasattr(self, 'workers'):
self.name = connection.pool_name
self.master_connection = connection
self.workers = []
self.pool_size = connection.pool_size or get_config_parameter('REUSABLE_THREADED_POOL_SIZE')
self.lifetime = connection.pool_lifetime or get_config_parameter('REUSABLE_THREADED_LIFETIME')
self.keepalive = connection.pool_keepalive
self.request_queue = Queue()
self.open_pool = False
self.bind_pool = False
self.tls_pool = False
self._incoming = dict()
self.counter = 0
self.terminated_usage = ConnectionUsage() if connection._usage else None
self.terminated = False
self.pool_lock = Lock()
ReusableStrategy.pools[self.name] = self
self.started = False
if log_enabled(BASIC):
log(BASIC, 'instantiated ConnectionPool: <%r>', self)
def __str__(self):
s = 'POOL: ' + str(self.name) + ' - status: ' + ('started' if self.started else 'terminated')
s += ' - responses in queue: ' + str(len(self._incoming))
s += ' - pool size: ' + str(self.pool_size)
s += ' - lifetime: ' + str(self.lifetime)
s += ' - keepalive: ' + str(self.keepalive)
s += ' - open: ' + str(self.open_pool)
s += ' - bind: ' + str(self.bind_pool)
s += ' - tls: ' + str(self.tls_pool) + linesep
s += 'MASTER CONN: ' + str(self.master_connection) + linesep
s += 'WORKERS:'
if self.workers:
for i, worker in enumerate(self.workers):
s += linesep + str(i).rjust(5) + ': ' + str(worker)
else:
s += linesep + ' no active workers in pool'
return s
def __repr__(self):
return self.__str__()
def get_info_from_server(self):
for worker in self.workers:
with worker.worker_lock:
if not worker.connection.server.schema or not worker.connection.server.info:
worker.get_info_from_server = True
else:
worker.get_info_from_server = False
def rebind_pool(self):
for worker in self.workers:
with worker.worker_lock:
worker.connection.rebind(self.master_connection.user,
self.master_connection.password,
self.master_connection.authentication,
self.master_connection.sasl_mechanism,
self.master_connection.sasl_credentials)
def start_pool(self):
if not self.started:
self.create_pool()
for worker in self.workers:
with worker.worker_lock:
worker.thread.start()
self.started = True
self.terminated = False
if log_enabled(BASIC):
log(BASIC, 'worker started for pool <%s>', self)
return True
return False
def create_pool(self):
if log_enabled(BASIC):
log(BASIC, 'created pool <%s>', self)
self.workers = [ReusableStrategy.PooledConnectionWorker(self.master_connection, self.request_queue) for _ in range(self.pool_size)]
def terminate_pool(self):
if not self.terminated:
if log_enabled(BASIC):
log(BASIC, 'terminating pool <%s>', self)
self.started = False
self.request_queue.join() # waits for all queue pending operations
for _ in range(len([worker for worker in self.workers if worker.thread.is_alive()])): # put a TERMINATE signal on the queue for each active thread
self.request_queue.put((TERMINATE_REUSABLE, None, None, None))
self.request_queue.join() # waits for all queue terminate operations
self.terminated = True
if log_enabled(BASIC):
log(BASIC, 'pool terminated for <%s>', self)
class PooledConnectionThread(Thread):
"""
The thread that holds the Reusable connection and receive operation request via the queue
Result are sent back in the pool._incoming list when ready
"""
def __init__(self, worker, master_connection):
Thread.__init__(self)
self.daemon = True
self.worker = worker
self.master_connection = master_connection
if log_enabled(BASIC):
log(BASIC, 'instantiated PooledConnectionThread: <%r>', self)
# noinspection PyProtectedMember
def run(self):
self.worker.running = True
terminate = False
pool = self.master_connection.strategy.pool
while not terminate:
try:
counter, message_type, request, controls = pool.request_queue.get(block=True, timeout=self.master_connection.strategy.pool.keepalive)
except Empty: # issue an Abandon(0) operation to keep the connection live - Abandon(0) is a harmless operation
if not self.worker.connection.closed:
self.worker.connection.abandon(0)
continue
with self.worker.worker_lock:
self.worker.busy = True
if counter == TERMINATE_REUSABLE:
terminate = True
if self.worker.connection.bound:
try:
self.worker.connection.unbind()
if log_enabled(BASIC):
log(BASIC, 'thread terminated')
except LDAPExceptionError:
pass
else:
if (datetime.now() - self.worker.creation_time).seconds >= self.master_connection.strategy.pool.lifetime: # destroy and create a new connection
try:
self.worker.connection.unbind()
except LDAPExceptionError:
pass
self.worker.new_connection()
if log_enabled(BASIC):
log(BASIC, 'thread respawn')
if message_type not in ['bindRequest', 'unbindRequest']:
try:
if pool.open_pool and self.worker.connection.closed:
self.worker.connection.open(read_server_info=False)
if pool.tls_pool and not self.worker.connection.tls_started:
self.worker.connection.start_tls(read_server_info=False)
if pool.bind_pool and not self.worker.connection.bound:
self.worker.connection.bind(read_server_info=False)
elif pool.open_pool and not self.worker.connection.closed: # connection already open, issues a start_tls
if pool.tls_pool and not self.worker.connection.tls_started:
self.worker.connection.start_tls(read_server_info=False)
if self.worker.get_info_from_server and counter:
self.worker.connection.refresh_server_info()
self.worker.get_info_from_server = False
response = None
result = None
if message_type == 'searchRequest':
response = self.worker.connection.post_send_search(self.worker.connection.send(message_type, request, controls))
else:
response = self.worker.connection.post_send_single_response(self.worker.connection.send(message_type, request, controls))
result = self.worker.connection.result
with pool.pool_lock:
pool._incoming[counter] = (response, result, BaseStrategy.decode_request(message_type, request, controls))
except LDAPOperationResult as e: # raise_exceptions has raised an exception. It must be redirected to the original connection thread
with pool.pool_lock:
pool._incoming[counter] = (e, None, None)
# pool._incoming[counter] = (type(e)(str(e)), None, None)
# except LDAPOperationResult as e: # raise_exceptions has raised an exception. It must be redirected to the original connection thread
# exc = e
# with pool.pool_lock:
# if exc:
# pool._incoming[counter] = (exc, None, None)
# else:
# pool._incoming[counter] = (response, result, BaseStrategy.decode_request(message_type, request, controls))
self.worker.busy = False
pool.request_queue.task_done()
self.worker.task_counter += 1
if log_enabled(BASIC):
log(BASIC, 'thread terminated')
if self.master_connection.usage:
pool.terminated_usage += self.worker.connection.usage
self.worker.running = False
class PooledConnectionWorker(object):
"""
Container for the restartable connection. it includes a thread and a lock to execute the connection in the pool
"""
def __init__(self, connection, request_queue):
self.master_connection = connection
self.request_queue = request_queue
self.running = False
self.busy = False
self.get_info_from_server = False
self.connection = None
self.creation_time = None
self.task_counter = 0
self.new_connection()
self.thread = ReusableStrategy.PooledConnectionThread(self, self.master_connection)
self.worker_lock = Lock()
if log_enabled(BASIC):
log(BASIC, 'instantiated PooledConnectionWorker: <%s>', self)
def __str__(self):
s = 'CONN: ' + str(self.connection) + linesep + ' THREAD: '
s += 'running' if self.running else 'halted'
s += ' - ' + ('busy' if self.busy else 'available')
s += ' - ' + ('created at: ' + self.creation_time.isoformat())
s += ' - time to live: ' + str(self.master_connection.strategy.pool.lifetime - (datetime.now() - self.creation_time).seconds)
s += ' - requests served: ' + str(self.task_counter)
return s
def new_connection(self):
from ..core.connection import Connection
# noinspection PyProtectedMember
self.creation_time = datetime.now()
self.connection = Connection(server=self.master_connection.server_pool if self.master_connection.server_pool else self.master_connection.server,
user=self.master_connection.user,
password=self.master_connection.password,
auto_bind=AUTO_BIND_NONE, # do not perform auto_bind because it reads again the schema
version=self.master_connection.version,
authentication=self.master_connection.authentication,
client_strategy=RESTARTABLE,
auto_referrals=self.master_connection.auto_referrals,
auto_range=self.master_connection.auto_range,
sasl_mechanism=self.master_connection.sasl_mechanism,
sasl_credentials=self.master_connection.sasl_credentials,
check_names=self.master_connection.check_names,
collect_usage=self.master_connection._usage,
read_only=self.master_connection.read_only,
raise_exceptions=self.master_connection.raise_exceptions,
lazy=False,
fast_decoder=self.master_connection.fast_decoder,
receive_timeout=self.master_connection.receive_timeout,
return_empty_attributes=self.master_connection.empty_attributes)
# simulates auto_bind, always with read_server_info=False
if self.master_connection.auto_bind and self.master_connection.auto_bind not in [AUTO_BIND_NONE, AUTO_BIND_DEFAULT]:
if log_enabled(BASIC):
log(BASIC, 'performing automatic bind for <%s>', self.connection)
self.connection.open(read_server_info=False)
if self.master_connection.auto_bind == AUTO_BIND_NO_TLS:
self.connection.bind(read_server_info=False)
elif self.master_connection.auto_bind == AUTO_BIND_TLS_BEFORE_BIND:
self.connection.start_tls(read_server_info=False)
self.connection.bind(read_server_info=False)
elif self.master_connection.auto_bind == AUTO_BIND_TLS_AFTER_BIND:
self.connection.bind(read_server_info=False)
self.connection.start_tls(read_server_info=False)
if self.master_connection.server_pool:
self.connection.server_pool = self.master_connection.server_pool
self.connection.server_pool.initialize(self.connection)
# ReusableStrategy methods
def __init__(self, ldap_connection):
BaseStrategy.__init__(self, ldap_connection)
self.sync = False
self.no_real_dsa = False
self.pooled = True
self.can_stream = False
if hasattr(ldap_connection, 'pool_name') and ldap_connection.pool_name:
self.pool = ReusableStrategy.ConnectionPool(ldap_connection)
else:
if log_enabled(ERROR):
log(ERROR, 'reusable connection must have a pool_name')
raise LDAPConnectionPoolNameIsMandatoryError('reusable connection must have a pool_name')
def open(self, reset_usage=True, read_server_info=True):
# read_server_info not used
self.pool.open_pool = True
self.pool.start_pool()
self.connection.closed = False
if self.connection.usage:
if reset_usage or not self.connection._usage.initial_connection_start_time:
self.connection._usage.start()
def terminate(self):
self.pool.terminate_pool()
self.pool.open_pool = False
self.connection.bound = False
self.connection.closed = True
self.pool.bind_pool = False
self.pool.tls_pool = False
def _close_socket(self):
"""
Doesn't really close the socket
"""
self.connection.closed = True
if self.connection.usage:
self.connection._usage.closed_sockets += 1
def send(self, message_type, request, controls=None):
if self.pool.started:
if message_type == 'bindRequest':
self.pool.bind_pool = True
counter = BOGUS_BIND
elif message_type == 'unbindRequest':
self.pool.bind_pool = False
counter = BOGUS_UNBIND
elif message_type == 'abandonRequest':
counter = BOGUS_ABANDON
elif message_type == 'extendedReq' and self.connection.starting_tls:
self.pool.tls_pool = True
counter = BOGUS_EXTENDED
else:
with self.pool.pool_lock:
self.pool.counter += 1
if self.pool.counter > LDAP_MAX_INT:
self.pool.counter = 1
counter = self.pool.counter
self.pool.request_queue.put((counter, message_type, request, controls))
return counter
if log_enabled(ERROR):
log(ERROR, 'reusable connection pool not started')
raise LDAPConnectionPoolNotStartedError('reusable connection pool not started')
def validate_bind(self, controls):
# in case of a new connection or different credentials
if (self.connection.user != self.pool.master_connection.user or
self.connection.password != self.pool.master_connection.password or
self.connection.authentication != self.pool.master_connection.authentication or
self.connection.sasl_mechanism != self.pool.master_connection.sasl_mechanism or
self.connection.sasl_credentials != self.pool.master_connection.sasl_credentials):
self.pool.master_connection.user = self.connection.user
self.pool.master_connection.password = self.connection.password
self.pool.master_connection.authentication = self.connection.authentication
self.pool.master_connection.sasl_mechanism = self.connection.sasl_mechanism
self.pool.master_connection.sasl_credentials = self.connection.sasl_credentials
self.pool.rebind_pool()
temp_connection = self.pool.workers[0].connection
old_lazy = temp_connection.lazy
temp_connection.lazy = False
if not self.connection.server.schema or not self.connection.server.info:
result = self.pool.workers[0].connection.bind(controls=controls)
else:
result = self.pool.workers[0].connection.bind(controls=controls, read_server_info=False)
temp_connection.unbind()
temp_connection.lazy = old_lazy
if result:
self.pool.bind_pool = True # bind pool if bind is validated
return result
def get_response(self, counter, timeout=None, get_request=False):
sleeptime = get_config_parameter('RESPONSE_SLEEPTIME')
request=None
if timeout is None:
timeout = get_config_parameter('RESPONSE_WAITING_TIMEOUT')
if counter == BOGUS_BIND: # send a bogus bindResponse
response = list()
result = {'description': 'success', 'referrals': None, 'type': 'bindResponse', 'result': 0, 'dn': '', 'message': '<bogus Bind response>', 'saslCreds': None}
elif counter == BOGUS_UNBIND: # bogus unbind response
response = None
result = None
elif counter == BOGUS_ABANDON: # abandon cannot be executed because of multiple connections
response = list()
result = {'result': 0, 'referrals': None, 'responseName': '1.3.6.1.4.1.1466.20037', 'type': 'extendedResp', 'description': 'success', 'responseValue': 'None', 'dn': '', 'message': '<bogus StartTls response>'}
elif counter == BOGUS_EXTENDED: # bogus startTls extended response
response = list()
result = {'result': 0, 'referrals': None, 'responseName': '1.3.6.1.4.1.1466.20037', 'type': 'extendedResp', 'description': 'success', 'responseValue': 'None', 'dn': '', 'message': '<bogus StartTls response>'}
self.connection.starting_tls = False
else:
response = None
result = None
while timeout >= 0: # waiting for completed message to appear in _incoming
try:
with self.connection.strategy.pool.pool_lock:
response, result, request = self.connection.strategy.pool._incoming.pop(counter)
except KeyError:
sleep(sleeptime)
timeout -= sleeptime
continue
break
if timeout <= 0:
if log_enabled(ERROR):
log(ERROR, 'no response from worker threads in Reusable connection')
raise LDAPResponseTimeoutError('no response from worker threads in Reusable connection')
if isinstance(response, LDAPOperationResult):
raise response # an exception has been raised with raise_exceptions
if get_request:
return response, result, request
return response, result
def post_send_single_response(self, counter):
return counter
def post_send_search(self, counter):
return counter

View File

@@ -0,0 +1,32 @@
"""
"""
# Created on 2020.07.12
#
# Author: Giovanni Cannata
#
# Copyright 2013 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
from .restartable import RestartableStrategy
class SafeRestartableStrategy(RestartableStrategy):
def __init__(self, ldap_connection):
RestartableStrategy.__init__(self, ldap_connection)
self.thread_safe = True

View File

@@ -0,0 +1,32 @@
"""
"""
# Created on 2020.07.12
#
# Author: Giovanni Cannata
#
# Copyright 2013 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
from .sync import SyncStrategy
class SafeSyncStrategy(SyncStrategy):
def __init__(self, ldap_connection):
SyncStrategy.__init__(self, ldap_connection)
self.thread_safe = True

View File

@@ -0,0 +1,251 @@
"""
"""
# Created on 2013.07.15
#
# Author: Giovanni Cannata
#
# Copyright 2013 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.
import socket
from .. import SEQUENCE_TYPES, get_config_parameter, DIGEST_MD5
from ..core.exceptions import LDAPSocketReceiveError, communication_exception_factory, LDAPExceptionError, LDAPExtensionError, LDAPOperationResult, LDAPSignatureVerificationFailedError
from ..strategy.base import BaseStrategy, SESSION_TERMINATED_BY_SERVER, RESPONSE_COMPLETE, TRANSACTION_ERROR
from ..protocol.rfc4511 import LDAPMessage
from ..utils.log import log, log_enabled, ERROR, NETWORK, EXTENDED, format_ldap_message
from ..utils.asn1 import decoder, decode_message_fast
from ..protocol.sasl.digestMd5 import md5_hmac
LDAP_MESSAGE_TEMPLATE = LDAPMessage()
# noinspection PyProtectedMember
class SyncStrategy(BaseStrategy):
"""
This strategy is synchronous. You send the request and get the response
Requests return a boolean value to indicate the result of the requested Operation
Connection.response will contain the whole LDAP response for the messageId requested in a dict form
Connection.request will contain the result LDAP message in a dict form
"""
def __init__(self, ldap_connection):
BaseStrategy.__init__(self, ldap_connection)
self.sync = True
self.no_real_dsa = False
self.pooled = False
self.can_stream = False
self.socket_size = get_config_parameter('SOCKET_SIZE')
def open(self, reset_usage=True, read_server_info=True):
BaseStrategy.open(self, reset_usage, read_server_info)
if read_server_info and not self.connection._deferred_open:
try:
self.connection.refresh_server_info()
except LDAPOperationResult: # catch errors from server if raise_exception = True
self.connection.server._dsa_info = None
self.connection.server._schema_info = None
def _start_listen(self):
if not self.connection.listening and not self.connection.closed:
self.connection.listening = True
def receiving(self):
"""
Receives data over the socket
Checks if the socket is closed
"""
messages = []
receiving = True
unprocessed = b''
data = b''
get_more_data = True
# exc = None # not needed here GC
sasl_total_bytes_recieved = 0
sasl_received_data = b'' # used to verify the signature
sasl_next_packet = b''
# sasl_signature = b'' # not needed here? GC
# sasl_sec_num = b'' # used to verify the signature # not needed here, reformatted to lowercase GC
sasl_buffer_length = -1 # added, not initialized? GC
while receiving:
if get_more_data:
try:
data = self.connection.socket.recv(self.socket_size)
except (OSError, socket.error, AttributeError) as e:
self.connection.last_error = 'error receiving data: ' + str(e)
try: # try to close the connection before raising exception
self.close()
except (socket.error, LDAPExceptionError):
pass
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
# raise communication_exception_factory(LDAPSocketReceiveError, exc)(self.connection.last_error)
raise communication_exception_factory(LDAPSocketReceiveError, type(e)(str(e)))(self.connection.last_error)
# If we are using DIGEST-MD5 and LDAP signing is set : verify & remove the signature from the message
if self.connection.sasl_mechanism == DIGEST_MD5 and self.connection._digest_md5_kis and not self.connection.sasl_in_progress:
data = sasl_next_packet + data
if sasl_received_data == b'' or sasl_next_packet:
# Remove the sizeOf(encoded_message + signature + 0x0001 + secNum) from data.
sasl_buffer_length = int.from_bytes(data[0:4], "big")
data = data[4:]
sasl_next_packet = b''
sasl_total_bytes_recieved += len(data)
sasl_received_data += data
if sasl_total_bytes_recieved >= sasl_buffer_length:
# When the LDAP response is splitted accross multiple TCP packets, the SASL buffer length is equal to the MTU of each packet..Which is usually not equal to self.socket_size
# This means that the end of one SASL packet/beginning of one other....could be located in the middle of data
# We are using "sasl_received_data" instead of "data" & "unprocessed" for this reason
# structure of messages when LDAP signing is enabled : sizeOf(encoded_message + signature + 0x0001 + secNum) + encoded_message + signature + 0x0001 + secNum
sasl_signature = sasl_received_data[sasl_buffer_length - 16:sasl_buffer_length - 6]
sasl_sec_num = sasl_received_data[sasl_buffer_length - 4:sasl_buffer_length]
sasl_next_packet = sasl_received_data[sasl_buffer_length:] # the last "data" variable may contain another sasl packet. We'll process it at the next iteration.
sasl_received_data = sasl_received_data[:sasl_buffer_length - 16] # remove signature + 0x0001 + secNum + the next packet if any, from sasl_received_data
kis = self.connection._digest_md5_kis # renamed to lowercase GC
calculated_signature = bytes.fromhex(md5_hmac(kis, sasl_sec_num + sasl_received_data)[0:20])
if sasl_signature != calculated_signature:
raise LDAPSignatureVerificationFailedError("Signature verification failed for the recieved LDAP message number " + str(int.from_bytes(sasl_sec_num, 'big')) + ". Expected signature " + calculated_signature.hex() + " but got " + sasl_signature.hex() + ".")
sasl_total_bytes_recieved = 0
unprocessed += sasl_received_data
sasl_received_data = b''
else:
unprocessed += data
if len(data) > 0:
length = BaseStrategy.compute_ldap_message_size(unprocessed)
if length == -1: # too few data to decode message length
get_more_data = True
continue
if len(unprocessed) < length:
get_more_data = True
else:
if log_enabled(NETWORK):
log(NETWORK, 'received %d bytes via <%s>', len(unprocessed[:length]), self.connection)
messages.append(unprocessed[:length])
unprocessed = unprocessed[length:]
get_more_data = False
if len(unprocessed) == 0:
receiving = False
else:
receiving = False
if log_enabled(NETWORK):
log(NETWORK, 'received %d ldap messages via <%s>', len(messages), self.connection)
return messages
def post_send_single_response(self, message_id):
"""
Executed after an Operation Request (except Search)
Returns the result message or None
"""
responses, result = self.get_response(message_id)
self.connection.result = result
if result['type'] == 'intermediateResponse': # checks that all responses are intermediates (there should be only one)
for response in responses:
if response['type'] != 'intermediateResponse':
self.connection.last_error = 'multiple messages received error'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSocketReceiveError(self.connection.last_error)
responses.append(result)
return responses
def post_send_search(self, message_id):
"""
Executed after a search request
Returns the result message and store in connection.response the objects found
"""
responses, result = self.get_response(message_id)
self.connection.result = result
if isinstance(responses, SEQUENCE_TYPES):
self.connection.response = responses[:] # copy search result entries
return responses
self.connection.last_error = 'error receiving response'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSocketReceiveError(self.connection.last_error)
def _get_response(self, message_id, timeout):
"""
Performs the capture of LDAP response for SyncStrategy
"""
ldap_responses = []
response_complete = False
while not response_complete:
responses = self.receiving()
if responses:
for response in responses:
if len(response) > 0:
if self.connection.usage:
self.connection._usage.update_received_message(len(response))
if self.connection.fast_decoder:
ldap_resp = decode_message_fast(response)
dict_response = self.decode_response_fast(ldap_resp)
else:
ldap_resp, _ = decoder.decode(response, asn1Spec=LDAP_MESSAGE_TEMPLATE) # unprocessed unused because receiving() waits for the whole message
dict_response = self.decode_response(ldap_resp)
if log_enabled(EXTENDED):
log(EXTENDED, 'ldap message received via <%s>:%s', self.connection, format_ldap_message(ldap_resp, '<<'))
if int(ldap_resp['messageID']) == message_id:
ldap_responses.append(dict_response)
if dict_response['type'] not in ['searchResEntry', 'searchResRef', 'intermediateResponse']:
response_complete = True
elif int(ldap_resp['messageID']) == 0: # 0 is reserved for 'Unsolicited Notification' from server as per RFC4511 (paragraph 4.4)
if dict_response['responseName'] == '1.3.6.1.4.1.1466.20036': # Notice of Disconnection as per RFC4511 (paragraph 4.4.1)
return SESSION_TERMINATED_BY_SERVER
elif dict_response['responseName'] == '2.16.840.1.113719.1.27.103.4': # Novell LDAP transaction error unsolicited notification
return TRANSACTION_ERROR
else:
self.connection.last_error = 'unknown unsolicited notification from server'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSocketReceiveError(self.connection.last_error)
elif int(ldap_resp['messageID']) != message_id and dict_response['type'] == 'extendedResp':
self.connection.last_error = 'multiple extended responses to a single extended request'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPExtensionError(self.connection.last_error)
# pass # ignore message with invalid messageId when receiving multiple extendedResp. This is not allowed by RFC4511 but some LDAP server do it
else:
self.connection.last_error = 'invalid messageId received'
if log_enabled(ERROR):
log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
raise LDAPSocketReceiveError(self.connection.last_error)
# response = unprocessed
# if response: # if this statement is removed unprocessed data will be processed as another message
# self.connection.last_error = 'unprocessed substrate error'
# if log_enabled(ERROR):
# log(ERROR, '<%s> for <%s>', self.connection.last_error, self.connection)
# raise LDAPSocketReceiveError(self.connection.last_error)
else:
return SESSION_TERMINATED_BY_SERVER
ldap_responses.append(RESPONSE_COMPLETE)
return ldap_responses
def set_stream(self, value):
raise NotImplementedError
def get_stream(self):
raise NotImplementedError