#!/usr/bin/python2.7
# -*- coding: utf-8 -*-
#
# Univention Samba
#  helper script: kerberize a user account
#
# Copyright 2001-2014 Univention GmbH
#
# http://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <http://www.gnu.org/licenses/>.

import ldap, sys, getopt, subprocess, string, hashlib, struct
import univention.config_registry
import univention.debug
univention.debug.init('/dev/null', 1, 1)
import univention.uldap
from datetime import datetime
from time import strftime

configRegistry = univention.config_registry.ConfigRegistry()
configRegistry.load()
lo = univention.uldap.getAdminConnection(decode_ignorelist = [ 'krb5Key' ])
krbbase='ou=krb5,'+configRegistry['ldap/base']
realm=configRegistry['kerberos/realm']

univentionPWExpiryInterval = 'univentionPWExpiryInterval'
UNIXDAY=3600*24

def __getsmbPWHistory(newpassword, smbpwhistory, smbpwhlen):
	# split the history
	if len(string.strip(smbpwhistory)):
		pwlist = string.split(smbpwhistory, ' ')
	else:
		pwlist = []
		#calculate the password hash & salt
	salt=''
	urandom = open('/dev/urandom', 'r')
	#get 16 bytes from urandom for salting our hash
	rand = urandom.read(16)
	for i in range(0, len(rand)):
		salt = salt + '%.2X' % ord(rand[i])
	#we have to have that in hex
	hexsalt = salt
	#and binary for calculating the md5
	salt = getbytes(salt)
	#we need the ntpwd binary data to
	pwd = getbytes(newpassword)
	#calculating hash. sored as a 32byte hex in sambePasswordHistory,
	#syntax like that: [Salt][MD5(Salt+Hash)]
	#	First 16bytes ^		^ last 16bytes.
	pwdhash = hashlib.md5(salt + pwd).hexdigest().upper()
	smbpwhash = hexsalt+pwdhash
	if len(pwlist) < smbpwhlen:
		#just append
		pwlist.append(smbpwhash)
	else:
		#calc entries to cut out
		cut = 1 + len(pwlist) - smbpwhlen
		pwlist[0:cut] = []
		if smbpwhlen > 1:
			#and append to shortened history
			pwlist.append(smbpwhash)
		else:
			# or replace the history completely
			if len(pwlist) > 0:
				pwlist[0] = smbpwhash
				# just to be sure...
				pwlist[1:] = []
			else:
				pwlist.append(smbpwhash)
	# and build the new history
	res = string.join(pwlist, '')
	return res

def getbytes(string):
	#return byte values of a string (for smbPWHistory)
	bytes = [int(string[i:i+2], 16) for i in xrange(0, len(string), 2)]
	return struct.pack("%iB" % len(bytes), *bytes)

def getPolicyResult(dn, attr):
	# print 'DEBUG: Starting System call univention-policy-result'
	policy = {}
	p1 = subprocess.Popen(['univention-policy-result', '-D', configRegistry['ldap/hostdn'], '-y', '/etc/machine.secret', dn], stdout=subprocess.PIPE)
	p2 = subprocess.Popen(['grep', '-A1', '^Attribute: '+attr], stdin=p1.stdout, stdout=subprocess.PIPE)
	result = p2.communicate()[0]
	# if univention-policy-result fails then quit and do not parse output
	if p1.wait() != 0:
		print 'ERROR: univention-policy-result failed - LDAP server may be down'
		print 'ERROR: no password set'
		sys.exit(0)

	result = result.strip('\n')
	if result:
		for record in result.split('\n--\n'):
			record = record.strip('\t\n\r ')

			lines = record.splitlines()
			if len(lines) != 2:
				print "WARN: cannot parse following lines:"
				print "==> %s" % '\n==> '.join(lines)
			else:
				key = None
				value = None
				if lines[0].startswith('Attribute: '+attr):
					key = lines[0][ len('Attribute: ') : ]
				else:
					print 'WARN: cannot parse key line:', lines[0]

				if lines[1].startswith('Value: '):
					value = lines[1][ len('Value: ') : ]
				else:
					print 'WARN: cannot parse value line:', lines[1]

				policy[key] = value
	return policy

def gen_samba_password_history(dn, password_nt):
	smbpwhlen = int(getPolicyResult(dn, 'univentionPWHistoryLen').get('univentionPWHistoryLen', 0))
	smbpwhistory=[''][0]
	#print 'DEBUG: ntpwd: %s, smbpwhist: %s, smbpwhlen: %s' %(password_nt, smbpwhistory, smbpwhlen)
	newsmbPWHistory = __getsmbPWHistory(password_nt, smbpwhistory, smbpwhlen)
	return newsmbPWHistory

def unixdayToKrb5Date(unixday):
	# print 'DEBUG: Got unixday [',unixday,'] ...'
	strFrmtTime = "%Y%m%d%H%M%S"
	krb5Date = datetime.fromtimestamp(float(unixday)).strftime(strFrmtTime)+'Z'
	# print 'DEBUG: ... calucated to krb5 date results in: [',krb5Date,']'
	return krb5Date

def nt_password_to_arcfour_hmac_md5(nt_password):

	# all arcfour-hmac-md5 keys begin this way
	key='0\x1d\xa1\x1b0\x19\xa0\x03\x02\x01\x17\xa1\x12\x04\x10'

	for i in range(0, 16):
		o=nt_password[2*i:2*i+2]
		key+=chr(int(o, 16))
	return key


if __name__ == '__main__':
	username = None
	optlist, mail_user = getopt.getopt(sys.argv[1:], 'u:')
	for option, value in optlist:
		if option == '-u':
			username = value

	if not username:
		sys.exit(0)

	for dn, attrs in lo.search(filter='(&(objectClass=sambaSamAccount)(sambaNTPassword=*)(uid=%s)(!(objectClass=univentionWindows)))' % username, attr=['uid', 'krb5Key', 'sambaNTPassword', 'sambaLMPassword', 'sambaAcctFlags', 'objectClass', 'userPassword', 'krb5PasswordEnd', 'sambaPwdCanChange', 'sambaPwdMustChange' ,'sambaPwdLastSet', 'krb5KDCFlags', 'shadowLastChange', 'sambaPasswordHistory', 'shadowMax']):
		if 'univentionHost' in attrs[ 'objectClass' ]:
			continue
		if not attrs['sambaNTPassword'][0] == "NO PASSWORDXXXXXXXXXXXXXXXXXXXXX":

			if attrs['uid'][0] == 'root':
				print 'Skipping user root '
				continue

			# check if the user was disabled
			disabled = False
			sambaAcctFlags = attrs.get('sambaAcctFlags', '')
			if len(sambaAcctFlags) > 0:
				if 'D' in sambaAcctFlags[0]:
					disabled = True

			principal=attrs['uid'][0]+'@'+realm

			ocs=[]
			ml=[]
			if not 'krb5Principal' in attrs['objectClass']:
				ocs.append('krb5Principal')
				ml.append(('krb5PrincipalName', None, principal))

			if disabled:
				flag = '256'
			else:
				flag = '126'
			
			if not sambaAcctFlags:
				ml.append(('sambaAcctFlags', None, '[U          ]'))

			if not attrs.get('sambaPasswordHistory', False):
				ml.append(('sambaPasswordHistory', None, gen_samba_password_history(dn, attrs.get('sambaNTPassword', [''])[0])))

			if attrs.get('sambaPwdLastSet', False):
				usersPWPolicy = getPolicyResult(dn, univentionPWExpiryInterval)
				# print 'DEBUG: usersPWPolicy: ', usersPWPolicy
				usersPWExpireInterval = int(usersPWPolicy.get(univentionPWExpiryInterval, 0))
				# print 'DEBUG: usersPWExpireInterval: ', usersPWExpireInterval
				
				#if not 'krb5PasswordEnd' in attrs['objectClass']:
				#	ocs.append('krb5PasswordEnd')
				if attrs.get('krb5PasswordEnd', False):
					oldkrb5PasswordEndValue = str(attrs[ 'krb5PasswordEnd' ][0])
				else:
					oldkrb5PasswordEndValue = None
				#if not 'sambaPwdCanChange' in attrs['objectClass']:
				#	ocs.append('sambaPwdCanChange')
				if attrs.get('sambaPwdCanChange', False):
					oldsambaPwdCanChangeValue = str(attrs['sambaPwdCanChange'][0])
				else:
					oldsambaPwdCanChangeValue = None
				#if not 'sambaPwdMustChange' in attrs['objectClass']:
				#	ocs.append('sambaPwdMustChange')
				if attrs.get('sambaPwdMustChange', False):
					oldsambaPwdMustChangeValue = str(attrs['sambaPwdMustChange'][0])
				else:
					oldsambaPwdMustChangeValue = None
				if attrs.get('shadowMax', False):
					oldshadowMaxValue = str(attrs['shadowMax'][0])
				else:
					oldshadowMaxValue = None
				if attrs.get('shadowLastChange', False):
					oldshadowLastChangeValue = str(attrs['shadowLastChange'][0])
				else:
					oldshadowLastChangeValue = None
				sambaPwdLastSetValue = int(attrs[('sambaPwdLastSet')][0])
				# Debug # print 'SambaPwdLastSet "%d", "%d", "%d"' %(sambaPwdLastSetTimestamp, pwdlifetime, unixday)
				shadowLastChangeValue = str(sambaPwdLastSetValue/UNIXDAY)
				sambaPwdCanChangeValue = str(sambaPwdLastSetValue+UNIXDAY)
				if usersPWExpireInterval:
					# print 'DEBUG: PWExpireInterval policy valid, calculating and setting expiring dates'
					sambaPwdMustChangeValue = str(sambaPwdLastSetValue+int(usersPWExpireInterval*UNIXDAY))
					krb5PasswordEndValue = str(unixdayToKrb5Date(sambaPwdMustChangeValue))
					shadowMaxValue = str(usersPWExpireInterval)
				else:
					# print 'DEBUG: PWExpireInterval policy not set, removing expire intervals and dates'
					sambaPwdMustChangeValue = None 
					krb5PasswordEndValue = None
					shadowMaxValue = None
				ml.extend([
					('krb5PasswordEnd', oldkrb5PasswordEndValue, krb5PasswordEndValue),
					('sambaPwdCanChange', oldsambaPwdCanChangeValue, sambaPwdCanChangeValue),
					('sambaPwdMustChange', oldsambaPwdMustChangeValue, sambaPwdMustChangeValue),
					('shadowMax', oldshadowMaxValue, shadowMaxValue),
					('shadowLastChange', oldshadowLastChangeValue, shadowLastChangeValue)
				])
			else:
				print 'Could not find attribute "sambaPwdLastSet". Skipping generating of "krb5PasswordEnd".'


			if not 'krb5KDCEntry' in attrs['objectClass']:
				ocs.append('krb5KDCEntry')
				ml.extend([
					('krb5MaxLife', None, '86400'),
					('krb5MaxRenew', None, '604800'),
					('krb5KeyVersionNumber', None, '1'),
				])

			old_flag = attrs.get( 'krb5KDCFlags', [] )
			old_keys = attrs.get( 'krb5Key', [] )

			ml.extend([
				('krb5Key', old_keys, nt_password_to_arcfour_hmac_md5(attrs['sambaNTPassword'][0])),
				('krb5KDCFlags', old_flag, flag)
			])

			if attrs.get( 'sambaLMPassword' ) not in [ "NO PASSWORDXXXXXXXXXXXXXXXXXXXXX", None ]:
				old_password = attrs.get( 'userPassword', [] )
				if disabled:
					ml.extend([
						('userPassword', old_password, [ '{LANMAN}!%s' % attrs['sambaLMPassword'][0]])
					])
				else:
					ml.extend([
						('userPassword', old_password, [ '{LANMAN}%s' % attrs['sambaLMPassword'][0]])
					])


			if ocs:
				print 'Adding Kerberos key for %s...' % repr(dn),
				ml.insert(0, ('objectClass', None, ocs))

			try:
				lo.modify(dn, ml)
			except ldap.ALREADY_EXISTS:
				print 'already exists'
			else:
				print 'done'

		else:
			print 'Can not add Kerberos key for %s...' % repr(dn),
			print 'no password set'
