Source code for aioxmpp.network

"""
:mod:`~aioxmpp.network` --- DNS resolution utilities
####################################################

This module uses DNSPython to resolve SRV records.

Querying SRV records
====================

.. autofunction:: find_xmpp_host_addr

.. autofunction:: lookup_srv

.. autofunction:: repeated_query

Ordering SRV records
====================

.. autofunction:: group_and_order_srv_records


"""

import asyncio
import functools
import itertools
import logging
import random

import dns
import dns.flags
import dns.resolver

logger = logging.getLogger(__name__)


[docs]def repeated_query(qname, rdtype, nattempts=3, resolver=None, require_ad=False): """ Repeatedly fire a DNS query until either the number of allowed attempts (``nattempts``) is excedeed or a result is found. ``qname`` must be the (IDNA encoded, as :class:`bytes`) name to query, ``rdtype`` the record type to query for. If `resolver` is not :data:`None`, it must be a DNSPython :class:`dns.resolver.Resolver` instance; if it is :data:`None`, the current default resolver is used. If `require_ad` is :data:`True`, the peer resolver is asked to do DNSSEC validation and if the AD flag is missing in the response, :class:`ValueError` is raised. The resolution automatically starts using the TCP transport after the first attempt. If no result is received before the number of allowed attempts is exceeded, :class:`TimeoutError` is raised. Return the result set or :data:`None` if the domain does not exist. """ if nattempts <= 0: raise ValueError("Query cannot succeed with zero or less attempts") resolver = resolver or dns.resolver.get_default_resolver() for i in range(nattempts): try: if require_ad: resolver.set_flags(dns.flags.AD | dns.flags.RD) else: resolver.set_flags(None) answer = resolver.query( qname.decode("ascii"), rdtype, tcp=(i > 0), ) if require_ad: if not (answer.response.flags & dns.flags.AD): raise ValueError("DNSSEC validation not available") break except (TimeoutError, dns.resolver.Timeout): if i == 0: logger.warn("DNS is timing out, switching to TCP") except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): return None else: raise TimeoutError("SRV query timed out") return answer
[docs]def lookup_srv(domain, service, transport=b"tcp", **kwargs): """ Look up and format the SRV records for the given ``service`` over ``transport`` at the given ``domain``. Keyword arguments are passed to :func:`repeated_query`. Returns a list of tuples ``(prio, weight, (hostname, port))``, where ``hostname`` is a IDNA-encoded :class:`bytes` object containing the hostname obtained from the SRV record. The other fields are also those obtained from the SRV record. If the query returns an empty result, :data:`None` is returned. If any of the SRV records indicates the ``.`` host name (the root name), the domain indicates that the service is not available and :class:`ValueError` is raised. """ record = b".".join([ b"_" + service, b"_" + transport, domain]) answer = repeated_query( record, dns.rdatatype.SRV, **kwargs) if answer is None: return None items = [ (rec.priority, rec.weight, (str(rec.target), rec.port)) for rec in answer ] for i, (prio, weight, (host, port)) in enumerate(items): if host == ".": raise ValueError("Protocol explicitly not supported") items[i] = (prio, weight, ( host.rstrip(".").encode("ascii").decode("IDNA"), port)) return items
def lookup_tlsa(domain, port, transport=b"tcp", require_ad=True, **kwargs): record = b".".join([ b"_" + str(port).encode("ascii"), b"_" + transport, domain ]) answer = repeated_query( record, dns.rdatatype.TLSA, require_ad=require_ad, **kwargs) if answer is None: return None items = [ (rec.usage, rec.selector, rec.mtype, rec.cert) for rec in answer ] return items
[docs]def group_and_order_srv_records(all_records, rng=None): """ Order a list of SRV record information (as returned by :func:`lookup_srv`) and group and order them as specified by the RFC. Return an iterable, yielding each ``(hostname, port)`` tuple inside the SRV records in the order specified by the RFC. For hosts with the same priority, the given `rng` implementation is used (if none is given, the :mod:`random` module is used). """ rng = rng or random all_records.sort() for priority, records in itertools.groupby( all_records, lambda x: x[0]): records = list(records) total_weight = sum( weight for _, weight, _ in records) while records: if len(records) == 1: yield records[0][-1] break value = rng.randint(0, total_weight) running_weight_sum = 0 for i, (_, weight, addr) in enumerate(records): running_weight_sum += weight if running_weight_sum >= value: yield addr del records[i] total_weight -= weight break
@asyncio.coroutine
[docs]def find_xmpp_host_addr(loop, domain, attempts=3): domain = domain.encode("IDNA") items = yield from loop.run_in_executor( None, functools.partial( lookup_srv, service=b"xmpp-client", domain=domain, nattempts=attempts) ) if items is not None: return items return [(0, 0, (domain, 5222))]
@asyncio.coroutine def find_xmpp_host_tlsa(loop, domain, attempts=3, require_ad=True): domain = domain.encode("IDNA") items = yield from loop.run_in_executor( None, functools.partial( lookup_tlsa, domain=domain, port=5222, nattempts=attempts, require_ad=require_ad) ) if items is not None: return items return []

aioxmpp

Navigation