Source code for aioxmpp.sasl

"""
:mod:`~aioxmpp.sasl` -- SASL helpers
####################################

This module is used to implement SASL in :mod:`aioxmpp.security_layer`. It
provides a state machine for use by the different SASL mechanisms and
implementations of some SASL mechansims.

SASL mechansims
===============

.. autoclass:: PLAIN

.. autoclass:: SCRAM

Base class
----------

.. autoclass:: SASLMechanism

SASL state machine and XSOs
===========================

.. autoclass:: SASLStateMachine

.. autoclass:: SASLAuth

.. autoclass:: SASLChallenge

.. autoclass:: SASLResponse

.. autoclass:: SASLFailure

.. autoclass:: SASLSuccess

.. autoclass:: SASLAbort

"""

import abc
import asyncio
import base64
import functools
import hashlib
import hmac
import itertools
import logging
import operator
import random
import time

from .stringprep import saslprep
from . import errors, xso, protocol

from .utils import namespaces

logger = logging.getLogger(__name__)

_system_random = random.SystemRandom()

try:
    from hashlib import pbkdf2_hmac as pbkdf2
except ImportError:
    # this is untested if you have pbkdf2_hmac
    def pbkdf2(hashfun_name, input_data, salt, iterations, dklen=None):
        """
        Derivate a key from a password. `input_data` is taken as the bytes
        object resembling the password (or other input). `hashfun` must be a
        callable returning a :mod:`hashlib`-compatible hash function. `salt` is
        the salt to be used in the PBKDF2 run, `iterations` the count of
        iterations. `dklen` is the length in bytes of the key to be derived.

        Return the derived key as :class:`bytes` object.
        """

        if dklen is not None and dklen <= 0:
            raise ValueError("Invalid length for derived key: {}".format(
                dklen))

        hashfun = lambda: hashlib.new(hashfun_name)

        hlen = hashfun().digest_size
        if dklen is None:
            dklen = hlen

        block_count = (dklen + (hlen - 1)) // hlen

        mac_base = hmac.new(input_data, None, hashfun)

        def do_hmac(data):
            mac = mac_base.copy()
            mac.update(data)
            return mac.digest()

        def calc_block(i):
            u_prev = do_hmac(salt + i.to_bytes(4, "big"))
            u_accum = u_prev
            for k in range(1, iterations):
                u_curr = do_hmac(u_prev)
                u_accum = bytes(itertools.starmap(
                    operator.xor,
                    zip(u_accum, u_curr)))
                u_prev = u_curr

            return u_accum

        result = b"".join(
            calc_block(i)
            for i in range(1, block_count + 1))

        return result[:dklen]


[docs]class SASLAuth(xso.XSO): """ Start SASL authentication. .. attribute:: mechanism The mechanism to authenticate with. .. attribute:: payload For mechanisms which use an initial client-supplied payload, this can be a string. It is automatically encoded as base64 according to the XMPP SASL specification. """ TAG = (namespaces.sasl, "auth") mechanism = xso.Attr("mechanism") payload = xso.Text(type_=xso.Base64Binary(empty_as_equal=True)) def __init__(self, mechanism=None, payload=None): super().__init__() self.mechanism = mechanism self.payload = payload
[docs]class SASLChallenge(xso.XSO): """ A SASL challenge sent by the server. .. attribute:: payload The (decoded) SASL payload. Base64 en/decoding is handled by the XSO stack. """ TAG = (namespaces.sasl, "challenge") payload = xso.Text(type_=xso.Base64Binary(empty_as_equal=True)) def __init__(self, payload=None): super().__init__() self.payload = payload
[docs]class SASLResponse(xso.XSO): """ A SASL challenge sent by the client. .. attribute:: payload The (decoded) SASL payload. Base64 en/decoding is handled by the XSO stack. """ TAG = (namespaces.sasl, "response") payload = xso.Text(type_=xso.Base64Binary(empty_as_equal=True)) def __init__(self, payload=None): super().__init__() self.payload = payload
[docs]class SASLFailure(xso.XSO): """ Indication of SASL failure. .. attribute:: condition The condition which caused the authentication to fail. .. attribute:: text Optional human-readable text. """ TAG = (namespaces.sasl, "failure") condition = xso.ChildTag( tags=[ "aborted", "account-disabled", "credentials-expired", "encryption-required", "incorrect-encoding", "invalid-authzid", "invalid-mechanism", "malformed-request", "mechanism-too-weak", "not-authorized", "temporary-auth-failure", ], default_ns=namespaces.sasl, allow_none=False, declare_prefix=None, ) text = xso.ChildText( tag=(namespaces.sasl, "text"), attr_policy=xso.UnknownAttrPolicy.DROP, default=None, declare_prefix=None) def __init__(self, condition=(namespaces.sasl, "temporary-auth-failure")): super().__init__() self.condition = condition
[docs]class SASLSuccess(xso.XSO): """ Indication of SASL success, with optional final payload supplied by the server. .. attribute:: payload The (decoded) SASL payload. Base64 en/decoding is handled by the XSO stack. """ TAG = (namespaces.sasl, "success") payload = xso.Text(type_=xso.Base64Binary(empty_as_equal=True))
[docs]class SASLAbort(xso.XSO): """ Request to abort the SASL authentication. """ TAG = (namespaces.sasl, "abort")
[docs]class SASLStateMachine: """ A state machine to reduce code duplication during SASL handshake. The state methods change the state and return the next client state of the SASL handshake, optionally with server-supplied payload. Valid next states are: * ``('challenge', payload)`` (with `payload` being a :class:`bytes` object obtained from base64-decoding the servers challenge) * ``('success', None)`` – after successful authentication * ``('failure', None)`` – after failed authentication (e.g. after a call to :meth:`abort`) Note that, with the notable exception of :meth:`abort`, ``failure`` states are never returned but thrown as :class:`errors.SASLFailure` instead. The initial state is never returned. """ def __init__(self, xmlstream): super().__init__() self.xmlstream = xmlstream self._state = "initial" self.timeout = None @asyncio.coroutine def _send_sasl_node_and_wait_for(self, node): node = yield from protocol.send_and_wait_for( self.xmlstream, [node], [ SASLChallenge, SASLFailure, SASLSuccess ], timeout=self.timeout ) state = node.TAG[1] self._state = state if state == "failure": xmpp_error = node.condition[1] text = node.text raise errors.SASLFailure(xmpp_error, text=text) if hasattr(node, "payload"): payload = node.payload else: payload = None return state, payload @asyncio.coroutine def initiate(self, mechanism, payload=None): """ Initiate the SASL handshake and advertise the use of the given `mechanism`. If `payload` is not :data:`None`, it will be base64 encoded and sent as initial client response along with the ``<auth />`` element. Return the next state of the state machine as tuple (see :class:`SASLStateMachine` for details). """ if self._state != "initial": raise RuntimeError("initiate has already been called") result = yield from self._send_sasl_node_and_wait_for( SASLAuth(mechanism=mechanism, payload=payload)) return result @asyncio.coroutine def response(self, payload): """ Send a response to the previously received challenge, with the given `payload`. The payload is encoded using base64 and transmitted to the server. Return the next state of the state machine as tuple (see :class:`SASLStateMachine` for details). """ if self._state != "challenge": raise RuntimeError( "no challenge has been made or negotiation failed") result = yield from self._send_sasl_node_and_wait_for( SASLResponse(payload=payload) ) return result @asyncio.coroutine def abort(self): """ Abort an initiated SASL authentication process. The expected result state is ``failure``. """ if self._state == "initial": raise RuntimeError("SASL authentication hasn't started yet") try: next_state, payload = yield from self._send_sasl_node_and_wait_for( SASLAbort() ) except errors.SASLFailure as err: self._state = "failure" if err.xmpp_error != "aborted": raise return "failure", None else: raise errors.SASLFailure( "aborted", text="unexpected non-failure after abort: {}".format(self._state) )
[docs]class SASLMechanism(metaclass=abc.ABCMeta): """ Implementation of a SASL mechanism. Each SASLMechanism `class` must have a `class` attribute :attr:`handled_mechanisms`, which must be a container of strings holding the SASL mechanism names supported by that class. """ @abc.abstractclassmethod def any_supported(cls, mechanisms): """ Return the argument to be passed to :meth:`authenticate`, if any of the `mechanisms` (which is a :cls:`set`) is supported. Return :data:`None` otherwise. """ @asyncio.coroutine @abc.abstractmethod def authenticate(self, sm, token): """ Execute the mechanism identified by `token` (the value which has been returned by :meth:`any_supported` before) using the given :class:`SASLStateMachine` `sm`. If authentication fails, an appropriate exception is raised (:class:`~.errors.AuthenticationFailure`). """
[docs]class PLAIN(SASLMechanism): """ The ``PLAIN`` SASL mechanism (see RFC 4616). `credential_provider` must be coroutine which returns a ``(user, password)`` tuple. """ def __init__(self, credential_provider): super().__init__() self._credential_provider = credential_provider @classmethod def any_supported(cls, mechanisms): if "PLAIN" in mechanisms: return "PLAIN" return None @asyncio.coroutine def authenticate(self, sm, mechanism): logger.info("attempting PLAIN mechanism") username, password = yield from self._credential_provider() username = saslprep(username).encode("utf8") password = saslprep(password).encode("utf8") state, _ = yield from sm.initiate( mechanism="PLAIN", payload=b"\0" + username + b"\0" + password) if state != "success": raise errors.SASLFailure( "malformed-request", text="SASL protocol violation") return True
[docs]class SCRAM(SASLMechanism): """ The SCRAM SASL mechanism (see RFC 5802). `credential_provider` must be coroutine which returns a ``(user, password)`` tuple. """ def __init__(self, credential_provider): super().__init__() self._credential_provider = credential_provider self.nonce_length = 15 _supported_hashalgos = { # the second argument is for preference ordering (highest first) # if anyone has a better hash ordering suggestion, I’m open for it # a value of 1 is added if the -PLUS variant is used # -- JWI "SHA-1": ("sha1", 1), "SHA-224": ("sha224", 224), "SHA-512": ("sha512", 512), "SHA-384": ("sha384", 384), "SHA-256": ("sha256", 256), } @classmethod def any_supported(cls, mechanisms): supported = [] for mechanism in mechanisms: if not mechanism.startswith("SCRAM-"): continue if mechanism.endswith("-PLUS"): # channel binding is not supported continue hashfun_key = mechanism[6:] try: hashfun_name, quality = cls._supported_hashalgos[hashfun_key] except KeyError: continue supported.append(((1, quality), (mechanism, hashfun_name,))) if not supported: return None supported.sort() return supported.pop()[1] @classmethod def parse_message(cls, msg): parts = ( part for part in msg.split(b",") if part) for part in parts: key, _, value = part.partition(b"=") if len(key) > 1 or key == b"m": raise Exception("SCRAM protocol violation / unknown " "future extension") if key == b"n" or key == b"a": value = value.replace(b"=2C", b",").replace(b"=3D", b"=") yield key, value @asyncio.coroutine def authenticate(self, sm, token): mechanism, hashfun_name, = token logger.info("attempting %s mechanism (using %s hashfun)", mechanism, hashfun_name) # this is pretty much a verbatim implementation of RFC 5802. hashfun_factory = functools.partial(hashlib.new, hashfun_name) digest_size = hashfun_factory().digest_size # we don’t support channel binding gs2_header = b"n,," username, password = yield from self._credential_provider() username = saslprep(username).encode("utf8") password = saslprep(password).encode("utf8") our_nonce = base64.b64encode(_system_random.getrandbits( self.nonce_length * 8 ).to_bytes( self.nonce_length, "little" )) auth_message = b"n=" + username + b",r=" + our_nonce _, payload = yield from sm.initiate( mechanism, gs2_header + auth_message) auth_message += b"," + payload payload = dict(self.parse_message(payload)) try: iteration_count = int(payload[b"i"]) nonce = payload[b"r"] salt = base64.b64decode(payload[b"s"]) except (ValueError, KeyError): yield from sm.abort() raise errors.SASLFailure( "Malformed server message: {!r}".format(payload)) if not nonce.startswith(our_nonce): yield from sm.abort() raise errors.SASLFailure( "Server nonce doesn't fit our nonce") t0 = time.time() salted_password = pbkdf2( hashfun_name, password, salt, iteration_count) logger.debug("pbkdf2 timing: %f seconds", time.time() - t0) client_key = hmac.new( salted_password, b"Client Key", hashfun_factory).digest() stored_key = hashfun_factory(client_key).digest() reply = b"c=" + base64.b64encode(b"n,,") + b",r=" + nonce auth_message += b"," + reply client_proof = ( int.from_bytes( hmac.new( stored_key, auth_message, hashfun_factory).digest(), "big") ^ int.from_bytes(client_key, "big")).to_bytes(digest_size, "big") logger.debug("response generation time: %f seconds", time.time() - t0) try: state, payload = yield from sm.response( reply + b",p=" + base64.b64encode(client_proof) ) except errors.SASLFailure as err: raise err.promote_to_authentication_failure() from None if state != "success": raise errors.SASLFailure( "malformed-request", text="SCRAM protocol violation") server_signature = hmac.new( hmac.new( salted_password, b"Server Key", hashfun_factory).digest(), auth_message, hashfun_factory).digest() payload = dict(self.parse_message(payload)) if base64.b64decode(payload[b"v"]) != server_signature: raise errors.SASLFailure( "Authentication successful, but server signature invalid") return True

aioxmpp

Navigation