#!/usr/bin/python2.7
# -*- coding: utf-8 -*-
# pylint: disable-msg=C0103
#
# Univention Directory Manager Modules
"""Check if users are member of their primary group."""
#
# Copyright 2004-2014 Univention GmbH
#
# http://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <http://www.gnu.org/licenses/>.

import ldap
import sys
from univention.config_registry import ConfigRegistry
from optparse import OptionParser


def info(msg, *args, **kwargs):
	"""Print info."""
	print msg % (args or kwargs)


def warn(msg, *args, **kwargs):
	"""Print warning."""
	print >> sys.stderr, 'Warning: ' + (msg % (args or kwargs))
	warn.warnings += 1
warn.warnings = 0


def fatal(msg, *args, **kwargs):
	"""Print error."""
	print >> sys.stderr, 'Error: ' + (msg % (args or kwargs))
	sys.exit(1)


def ldapbind(basedn):
	"""Open local LDAP connection and do bind."""
	binddn = "cn=admin,%s" % (basedn,)
	try:
		bindpw = open('/etc/ldap.secret').read().rstrip('\n')
	except IOError:
		fatal('Could not read /etc/ldap.secret')
	conn = ldap.open('localhost', 7389)
	try:
		conn.simple_bind_s(binddn, bindpw)
	except ldap.LDAPError:
		fatal("Could not bind to %s", binddn)

	return conn


def check_primary(conn, basedn):
	"""Check if users are member of their primary group."""
	info("Checking if users are member of their primary group...")
	try:
		# GID's will only be found in posixAccount
		user_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE,
				'(objectClass=posixAccount)', ['gidNumber', 'uid'])
	except ldap.NO_SUCH_OBJECT:
		fatal("ldap search in %s failed (no such base dn)", basedn)
	count_changes = 0
	for user_dn, account in user_result:
		user_uid = account['uid'][0]
		user_gid = account.get('gidNumber', [])[0]
		if not user_gid:
			warn("posixAccount without gidNumber: %s", user_dn)

		# search corresponding group
		group_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE,
				'(&(objectClass=univentionGroup)(gidNumber=%s))' % (user_gid,),
				['uniqueMember', 'memberUid'])

		# there must be exactly one group with this gid
		if len(group_result) > 1:
			warn("found more than one univentionGroup for gidNumber=%s!",
					user_gid)
		elif len(group_result) < 1 and not user_gid == "0":
			warn("found no univentionGroup for gidNumber=%s!", user_gid)
		# we change them all -- the user needs to delete all but one of them
		for group_dn, group in group_result:
			# look for the needed entry
			group_member_dns = map(lambda dn: dn.lower(), group.get('uniqueMember', []))
			group_member_uids = map(lambda uid: uid.lower(), group.get('memberUid', []))
			modlist = []
			if user_dn.lower() not in group_member_dns:
				modlist.append((ldap.MOD_ADD, 'uniqueMember', user_dn))
			if user_uid.lower() not in group_member_uids:
				modlist.append((ldap.MOD_ADD, 'memberUid', user_uid))
			# no entry found, gonna add one
			if modlist:
				info("Adding uniqueMember and memberUid entry for '%s' in '%s'",
						user_dn, group_dn)
				try:
					conn.modify_s(group_dn, modlist)
					count_changes += 1
				except ldap.LDAPError:
					warn("failed to modify group %s", group_dn)
	info("Checked %d posixAccounts, fixed %d issues.",
			len(user_result), count_changes)


def check_groups(conn, basedn):
	"""Check if members of group exist."""
	info("Checking if group-members exist...")
	try:
		group_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE,
				'(objectClass=posixGroup)', ['uniqueMember', 'memberUid'])
	except ldap.NO_SUCH_OBJECT:
		fatal("ldap search in %s failed (no such base dn)", basedn)

	count_changes = 0
	for group_dn, group in group_result:
		count_changes += check_groups_by_dn(conn, group_dn, group)
		count_changes += check_groups_by_uid(conn, basedn, group_dn, group)

	info("Checked %d posixGroups, fixed %d issues.",
			len(group_result), count_changes)


def check_groups_by_dn(conn, group_dn, group):
	"""Check by 'uniqueMember'."""
	group_member_dns = group.get('uniqueMember', [])
	count_changes = 0
	remmembers = set()
	for member_dn in group_member_dns:
		# Split uid=USER, cn=user,dc=FQDN
		try:
			member_filter, base = member_dn.split(',', 1)
		except ValueError:
			remmembers.add(member_dn)
			continue

		try:
			member_result = conn.search_s(base, ldap.SCOPE_ONELEVEL,
					member_filter, ['objectClass'])
		except ldap.LDAPError:
			warn("Manual: Search for member DN '%s' of group '%s' failed",
					member_dn, group_dn)
		else:
			if len(member_result) > 1:
				warn("Manual: Multiple members for DN '%s' of group '%s'",
						member_dn, group_dn)
			elif len(member_result) < 1:
				warn("No member for DN '%s', will be removed", member_dn)
				remmembers.add(member_dn)
	for member_dn in remmembers:
		info("Removing member DN '%s' from '%s'", member_dn, group_dn)
		modlist = [(ldap.MOD_DELETE, 'uniqueMember', member_dn)]
		try:
			conn.modify_s(group_dn, modlist)
			count_changes += 1
		except ldap.LDAPError:
			warn("failed to remove DN '%s' from group '%s'",
					member_dn, group_dn)
	return count_changes


def check_groups_by_uid(conn, basedn, group_dn, group):
	"""Check by 'memberUid'."""
	group_member_uids = group.get('memberUid', [])
	count_changes = 0
	remmembers = set()
	for member_uid in group_member_uids:
		try:
			member_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE,
					'(uid=%s)' % (member_uid,), ['objectClass'])
		except ldap.LDAPError:
			warn("Manual: Search for member UID '%s' of group '%s' failed",
					member_uid, group_dn)
		else:
			if len(member_result) > 1:
				warn("Manual: Multiple members for UID '%s' of group '%s'",
						member_uid, group_dn)
			elif len(member_result) < 1:
				warn("No member for UID '%s', will be removed", member_uid)
				remmembers.add(member_uid)
	for member_uid in remmembers:
		info("Removing member UID '%s' from '%s'", member_uid, group_dn)
		modlist = [(ldap.MOD_DELETE, 'memberUid', member_uid)]
		try:
			conn.modify_s(group_dn, modlist)
			count_changes += 1
		except ldap.LDAPError:
			warn("Failed to remove UID '%s' from group '%s'",
					member_uid, group_dn)
	return count_changes


def main():
	"""Check group membership."""
	parser = OptionParser()
	parser.add_option("-b", "--base-dn",
			dest="basedn", action="store",
			help="ldap base DN for user search")
	parser.add_option("-c", "--check",
			dest="check", action="store_true",
			help="Only check, do not modify")
	(options, _args) = parser.parse_args()

	ucr = ConfigRegistry()
	ucr.load()
	basedn = ucr['ldap/base']

	conn = ldapbind(basedn)

	if options.basedn:
		basedn = options.basedn
	if options.check:
		conn.modify_s = lambda dn, modlist: None

	check_primary(conn, basedn)
	check_groups(conn, basedn)
	if warn.warnings:
		info("There were %d warning(s)!", warn.warnings)
		sys.exit(2)
	else:
		sys.exit(0)


if __name__ == '__main__':
	main()
