#!/usr/bin/python3
# SPDX-FileCopyrightText: 2021-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

from __future__ import annotations

import os
import re
import subprocess
import sys
from argparse import SUPPRESS, ArgumentParser, Namespace
from glob import glob
from hashlib import md5
from typing import TYPE_CHECKING

from univention.ldap_cache.cache import Caches, get_cache
from univention.ldap_cache.cache.shard_config import add_shard_to_config, rm_shard_from_config
from univention.uldap import getMachineConnection


if TYPE_CHECKING:
    from collections.abc import Container, Iterable


listener_template = '''#!/usr/bin/python3

from univention.ldap_cache.listener_module import LdapCacheHandler

name = {name!r}

class QueryHandler(LdapCacheHandler):
    class Configuration(LdapCacheHandler.Configuration):
        name = {name!r}
        def get_description(self):
            return 'Automatically created listener module for the univention-group-membership-member cache (%s)' % ({desc!r},)
        def get_ldap_filter(self):
            return {filter!r}
        def get_attributes(self):
            return {attrs!r}
'''


def _md5sum(value: str) -> str:
    m = md5()
    m.update(value.encode('utf-8'))
    return m.hexdigest()


def _query_objects(query: str, attrs: Iterable[bytes | str]) -> list[tuple[str, dict[str, list[bytes]]]]:
    lo = getMachineConnection()
    attrs = [attr.decode('utf-8') if isinstance(attr, bytes) else attr for attr in attrs]
    print('Querying', query, 'with', attrs)
    return lo.search(query, attr=attrs)


def _get_queries(caches: Caches, cache_names: Container[str] | None = None) -> dict[str, tuple[list, set]]:
    queries: dict[str, tuple[list, set]] = {}
    for name, cache in caches:
        if cache_names is not None and name not in cache_names:
            continue
        for shard in cache.shards:
            shards, attrs = queries.setdefault(shard.ldap_filter, ([], set()))
            shards.append(shard)
            attrs.add(shard.key)
            attrs.add(shard.value)
            attrs.update(shard.attributes)
    return queries


def add_cache(args: Namespace) -> None:
    add_shard_to_config(args.db_name, args.single_value, args.reverse, args.key, args.value, args.ldap_filter)


def rm_cache(args: Namespace) -> None:
    rm_shard_from_config(args.db_name, args.single_value, args.reverse, args.key, args.value, args.ldap_filter)


def cleanup(args: Namespace) -> None:
    for _name, db in get_cache():
        db.cleanup()


def rebuild(args: Namespace) -> None:
    caches = get_cache()
    cache_names = args.cache_name
    if not cache_names:
        cache_names = sorted([x[0] for x in caches])
    print('Rebuilding', cache_names)
    for name, cache in caches:
        if name in cache_names:
            cache.clear()
    for query, (_caches, attrs) in _get_queries(caches, cache_names).items():
        print('Searching for', query)
        attrs.discard('dn')
        i = 0
        for obj in _query_objects(query, attrs):
            i += 1
            if i % 1000 == 0:
                print('\rProcessing object #', i, end='')
                sys.stdout.flush()
                cleanup(args)
            for shard in _caches:
                shard.add_object(obj)
        if i >= 1000:
            print()
        print('Added', i, 'objects')


def create_listener_modules(args: Namespace) -> None:
    listener_dir = '/usr/lib/univention-directory-listener/system/'
    existing_listeners = set(glob(os.path.join(listener_dir, 'ldap-cache-*')))
    for query, (_caches, attrs) in _get_queries(get_cache()).items():
        attrs.discard('dn')
        attrs.discard('entryUUID')
        listener_name = 'ldap-cache-%s' % _md5sum(query)
        fname = os.path.join(listener_dir, '%s.py' % listener_name)
        existing_listeners.discard(fname)
        print('Writing', fname, 'for', query)
        with open(fname, 'w') as fd:
            fd.write(listener_template.format(
                name=listener_name,
                desc=', '.join(sorted({cache.db_name for cache in _caches})),
                filter=query,
                attrs=sorted(attrs),
            ))
    for fname in existing_listeners:
        print('Removing', fname)
        os.unlink(fname)
    subprocess.call(['service', 'univention-directory-listener', 'restart'])
    for fname in existing_listeners:
        # also remove status "initialized" so that the listener may be added again in the future
        # needs to be done after the listener is running again
        listener_name = os.path.basename(fname)[:-3]
        try:
            os.unlink(os.path.join('/var/lib/univention-directory-listener/handlers/', listener_name))
        except OSError:
            pass


def list_caches(args: Namespace) -> None:
    caches = get_cache()
    for name, cache in caches:
        print(name)
        print(' The following objects store data:')
        for shard in cache.shards:
            print('  ', shard.ldap_filter)
            if shard.reverse:
                key = shard.value
                value = f'[{shard.key}]'
            elif shard.single_value:
                key = shard.key
                value = shard.value
            else:
                key = shard.key
                value = f'[{shard.value}]'
            print('    ', key, '=>', value)


def query(args: Namespace) -> None:
    caches = get_cache()
    cache = caches.get_sub_cache(args.cache_name)
    if not cache:
        print('No cache named', args.cache_name)
        return
    data = cache.load()
    regex = None
    if args.pattern:
        try:
            regex = re.compile(args.pattern)
        except re.error:
            print('Broken pattern')
            return
    for key in sorted(data):
        if regex is None or regex.search(key):
            print(key, '=>', data[key])


def main() -> None:
    description = 'The LDAP cache stores some portions of the LDAP database so that it can be accessed fast and reliably.'
    parser = ArgumentParser(description=description)
    subparsers = parser.add_subparsers(
        description='type %(prog)s <action> --help for further help and possible arguments',
        metavar='action',
        required=True,
    )

    subparser = subparsers.add_parser('query', description='Queries a sub cache and shows the value(s). Mainly for test purposes, this tool is not meant for real applications', help='Query the cache')
    subparser.add_argument('cache_name', help='The name of the sub cache. See "list"')
    subparser.add_argument('pattern', nargs='?', help='Queries the key with this regexp')
    subparser.set_defaults(func=query)

    subparser = subparsers.add_parser('list', description='Lists all sub caches of the cache. Each sub cache may be fed from multiple sources', help='List all parts of the cache')
    subparser.set_defaults(func=list_caches)

    subparser = subparsers.add_parser('rebuild', description='Rebuild the cache completely, retrieve the objects and overwrite all previous data', help='Rebuild the cache')
    subparser.add_argument('cache_name', nargs='*', help='The cache consists of different parts. You can only rebuild certain parts of the cache. See "list"')
    subparser.set_defaults(func=rebuild)

    subparser = subparsers.add_parser('create-listener-modules', description='Automatically creates listener modules that will eventually fill the cache (and removes unnecessary); restarts the univention-directory-listener. May be needed after shards are added to /removed from the cache', help='Create listener modules')
    subparser.set_defaults(func=create_listener_modules)

    subparser = subparsers.add_parser('add-cache', description=SUPPRESS)
    subparser.add_argument('db_name')
    subparser.add_argument('--single-value', action='store_true')
    subparser.add_argument('--reverse', action='store_true')
    subparser.add_argument('key')
    subparser.add_argument('value')
    subparser.add_argument('ldap_filter')
    subparser.set_defaults(func=add_cache)

    subparser = subparsers.add_parser('rm-cache', description=SUPPRESS)
    subparser.add_argument('db_name')
    subparser.add_argument('--single-value', action='store_true')
    subparser.add_argument('--reverse', action='store_true')
    subparser.add_argument('key')
    subparser.add_argument('value')
    subparser.add_argument('ldap_filter')
    subparser.set_defaults(func=rm_cache)

    subparser = subparsers.add_parser('cleanup', description='Over the time, the database may become polluted. Shrink all caches to their optimal size if possible.', help='Clean up cache')
    subparser.set_defaults(func=cleanup)

    args = parser.parse_args()
    func = args.func
    func(args)


if __name__ == '__main__':
    main()
