# SPDX-FileCopyrightText: 2004-2026 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""|UDM| module for |DNS| reverse pointer records (PTR)"""

from __future__ import annotations

import ipaddress

import univention.admin
import univention.admin.handlers
import univention.admin.localization
from univention.admin.filter import conjunction, expression
from univention.admin.handlers.dns import ARPA_IP4, ARPA_IP6, DNSBase, is_dns
from univention.admin.layout import Group, Tab


translation = univention.admin.localization.translation('univention.admin.handlers.dns')
_ = translation.translate


module = 'dns/ptr_record'
operations = ['add', 'edit', 'remove', 'search']
columns = ['ptr_record']
superordinate = 'dns/reverse_zone'
childs = False
short_description = _('DNS: Pointer record')
object_name = _('Pointer record')
object_name_plural = _('Pointer records')
long_description = _('Map IP addresses back to hostnames.')
# fmt: off
options = {
    'default': univention.admin.option(
        short_description=short_description,
        default=True,
        objectClasses=['top', 'dNSZone'],
    ),
}
property_descriptions = {
    'address': univention.admin.property(
        short_description=_('Reverse address'),
        long_description=_('The host part of the IP address in reverse notation (e.g. "172.16.1.2/16" -> "2.1" or "2001:0db8:0100::0007:0008/96" -> "8.0.0.0.7.0.0.0").'),
        syntax=univention.admin.syntax.dnsPTR,
        required=True,
        identifies=True,
    ),
    'ip': univention.admin.property(
        short_description=_('IP Address'),
        long_description='',
        syntax=univention.admin.syntax.ipAddress,
        include_in_default_search=True,
    ),
    'ptr_record': univention.admin.property(
        short_description=_('Pointer record'),
        long_description=_("FQDNs must end with a dot."),
        syntax=univention.admin.syntax.dnsName,
        multivalue=True,
        include_in_default_search=True,
        required=True,
    ),
}

layout = [
    Tab(_('General'), _('Basic settings'), layout=[
        Group(_('General pointer record settings'), layout=[
            ['ip', 'ptr_record'],
        ]),
    ]),
]

mapping = univention.admin.mapping.mapping()
mapping.register('address', 'relativeDomainName', None, univention.admin.mapping.ListToString, encoding='ASCII')
mapping.register('ptr_record', 'pTRRecord', encoding='ASCII')
# fmt: on


def ipv6(string: str) -> str:
    """
    >>> ipv6('0123456789abcdef0123456789abcdef')
    '0123:4567:89ab:cdef:0123:4567:89ab:cdef'
    """
    assert len(string) == 32, string
    return ':'.join(string[i:i + 4] for i in range(0, 32, 4))


def calc_ip(rev: str, subnet: str) -> ipaddress.IPv4Address | ipaddress.IPv6Address:
    """
    >>> calc_ip(rev='8.0.0.0.7.0.0.0.6.0.0.0.5.0.0.0.4.0.0', subnet='0001:0002:0003:0').exploded
    '0001:0002:0003:0004:0005:0006:0007:0008'
    >>> calc_ip(rev='4.3', subnet='1.2').exploded
    '1.2.3.4'
    """
    parts = rev.split('.')
    parts.reverse()
    if ':' in subnet:
        string = ''.join(subnet.split(':') + parts)
        return ipaddress.IPv6Address('%s' % (ipv6(string),))
    else:
        octets = subnet.split('.') + parts
        assert len(octets) == 4, octets
        addr = '.'.join(octets)
        return ipaddress.IPv4Address('%s' % (addr,))


def calc_rev(ip: str, subnet: str) -> str:
    """
    >>> calc_rev(ip='1.2.3.4', subnet='1.2')
    '4.3'
    >>> calc_rev(ip='0001:0002:0003:0004:0005:0006:0007:0008', subnet='0001:0002:0003:0')
    '8.0.0.0.7.0.0.0.6.0.0.0.5.0.0.0.4.0.0'
    >>> calc_rev(ip='1:2:3:4:5:6:7:8', subnet='0001:0002:0003:0')
    '8.0.0.0.7.0.0.0.6.0.0.0.5.0.0.0.4.0.0'
    """
    if ':' in subnet:
        string = ''.join(subnet.split(':'))
        prefix = len(string)
        assert 1 <= prefix < 32
        string += '0' * (32 - prefix)
        net6 = ipaddress.IPv6Network('%s/%d' % (ipv6(string), 4 * prefix), strict=False)
        addr6 = ipaddress.IPv6Address('%s' % (ip,))
        if addr6 not in net6:
            raise ValueError()
        host6 = ''.join(addr6.exploded.split(':'))
        return '.'.join(reversed(host6[prefix:]))
    else:
        octets = subnet.split('.')
        prefix = len(octets)
        assert 1 <= prefix < 4
        octets += ['0'] * (4 - prefix)
        net4 = ipaddress.IPv4Network('%s/%d' % ('.'.join(octets), 8 * prefix), strict=False)
        addr4 = ipaddress.IPv4Address('%s' % (ip,))
        if addr4 not in net4:
            raise ValueError()
        host4 = addr4.exploded.split('.')
        return '.'.join(reversed(host4[prefix:]))


class object(DNSBase):
    module = module

    def description(self) -> str:
        try:
            if self.superordinate:
                return calc_ip(self.info['address'] or '', self.superordinate.info['subnet'] or '').compressed
        except (LookupError, ValueError, AssertionError) as ex:
            self.log.warning('Failed to parse address/subnet', dn=self.dn, error=ex)
        return super().description()

    def open(self) -> None:
        super().open()
        try:
            self.info['ip'] = calc_ip(self.info['address'], self.superordinate.info['subnet']).compressed
            self.save()
        except (LookupError, ValueError, AssertionError) as ex:
            self.log.warning('Failed to parse address/subnet', dn=self.dn, error=ex)

    def ready(self) -> None:
        old_ip = self.oldinfo.get('ip')
        new_ip = self.info.get('ip')
        if old_ip != new_ip:
            try:
                self.info['address'] = calc_rev(new_ip, self.superordinate.info['subnet'])
            except (LookupError, ValueError, AssertionError) as ex:
                self.log.warning('Failed to handle address', dn=self.dn, ip=new_ip, error=ex)
                raise univention.admin.uexceptions.InvalidDNS_Information(_('Reverse zone and IP address are incompatible.'))
        super().ready()

    @classmethod
    def rewrite_filter(cls, filter, mapping):
        if filter.variable == 'ip':
            filter.variable = 'relativeDomainName'
            if filter.value:
                if ':' in filter.value:
                    raise NotImplementedError('IPv6')
                else:
                    subnets = [ipaddress.IPv4Interface('%s/%d' % (filter.value, netmask)) for netmask in (24, 16, 8)]
                    subnets = [s.network.network_address.compressed.replace('.0', '') for s in subnets]
                filter.transform_to_conjunction(univention.admin.filter.conjunction('|', [
                    rewrite_rev(expression('ip', filter.value), subnet=subnet) for subnet in subnets
                ]))
        else:
            super().rewrite_filter(filter, mapping)

    @classmethod
    def lookup_filter_superordinate(cls, filter: univention.admin.filter.conjunction, superordinate: univention.admin.handlers.simpleLdap) -> univention.admin.filter.conjunction:
        super().lookup_filter_superordinate(filter, superordinate)
        filter = rewrite_rev(filter, superordinate.info['subnet'])
        return filter

    @classmethod
    def unmapped_lookup_filter(cls) -> univention.admin.filter.conjunction:
        return univention.admin.filter.conjunction('&', [
            univention.admin.filter.expression('objectClass', 'dNSZone'),
            univention.admin.filter.expression('pTRRecord', '*', escape=False),
        ])  # fmt: skip


def rewrite_rev(filter: conjunction | expression, subnet: str) -> conjunction | expression:
    """
    Rewrite LDAP filter expression and convert (ip) -> (zone,reversed)

    >>> rewrite_rev(expression('ip', '1.2.3.4'), subnet='1.2')
    conjunction('&', [expression('zoneName', '2.1.in-addr.arpa', '='), expression('relativeDomainName', '4.3', '=')])
    >>> rewrite_rev(expression('ip', '1.2.3.*', escape=False), subnet='1.2')
    conjunction('&', [expression('zoneName', '2.1.in-addr.arpa', '='), expression('relativeDomainName', '*.3', '=')])
    >>> rewrite_rev(expression('ip', '1.2.*.*', escape=False), subnet='1.2')
    conjunction('&', [expression('zoneName', '2.1.in-addr.arpa', '='), expression('relativeDomainName', '*.*', '=')])
    >>> rewrite_rev(expression('ip', '1.2.*.4', escape=False), subnet='1.2')
    conjunction('&', [expression('zoneName', '2.1.in-addr.arpa', '='), expression('relativeDomainName', '4.*', '=')])
    >>> rewrite_rev(expression('ip', '1.2.*', escape=False), subnet='1.2')
    conjunction('&', [expression('zoneName', '2.1.in-addr.arpa', '='), expression('relativeDomainName', '', '=*')])
    >>> rewrite_rev(expression('ip', '1:2:3:4:5:6:7:8'), subnet='0001:0002')
    conjunction('&', [expression('zoneName', '2.0.0.0.1.0.0.0.ip6.arpa', '='), expression('relativeDomainName', '8.0.0.0.7.0.0.0.6.0.0.0.5.0.0.0.4.0.0.0.3.0.0.0', '=')])
    >>> rewrite_rev(expression('ip', '1:2:3:4:5:6:7:*', escape=False), subnet='0001:0002')
    conjunction('&', [expression('zoneName', '2.0.0.0.1.0.0.0.ip6.arpa', '='), expression('relativeDomainName', '*.7.0.0.0.6.0.0.0.5.0.0.0.4.0.0.0.3.0.0.0', '=')])
    >>> rewrite_rev(expression('ip', '1:2:3:4:5:6:*:8', escape=False), subnet='0001:0002')
    conjunction('&', [expression('zoneName', '2.0.0.0.1.0.0.0.ip6.arpa', '='), expression('relativeDomainName', '8.0.0.0.*.6.0.0.0.5.0.0.0.4.0.0.0.3.0.0.0', '=')])
    >>> rewrite_rev(expression('ip', '1:2:3:*', escape=False), subnet='0001:0002')
    conjunction('&', [expression('zoneName', '2.0.0.0.1.0.0.0.ip6.arpa', '='), expression('relativeDomainName', '*.3.0.0.0', '=')])
    """
    if isinstance(filter, conjunction):
        filter.expressions = [rewrite_rev(expr, subnet) for expr in filter.expressions]
    if isinstance(filter, expression) and filter.variable == 'ip':
        if ':' in subnet:
            string = ''.join(subnet.split(':'))
            prefix = len(string)
            assert 1 <= prefix < 32
            addr = ''.join(
                part if '*' in part else part.rjust(4, '0')[-4:]
                for part in filter.value.split(':')
            )
            suffix = ARPA_IP6
        else:
            octets = subnet.split('.')
            prefix = len(octets)
            assert 1 <= prefix < 4
            addr = filter.value.split('.')  # type: ignore[assignment]
            suffix = ARPA_IP4
        addr_net, addr_host = ('.'.join(reversed(_)) for _ in (addr[:prefix], addr[prefix:]))
        filter = conjunction('&', [
            expression('zoneName', addr_net + suffix),
            expression('relativeDomainName', addr_host or '*', escape=False),
        ])  # fmt: skip
    return filter


lookup = object.lookup
lookup_filter = object.lookup_filter


def identify(dn: str, attr: univention.admin.handlers._Attributes) -> bool:
    return bool(
        attr.get('pTRRecord') and is_dns(attr),
    )
