#!/usr/bin/env python
# -*- coding: utf-8 -*-
import random
from ml import *
from ml_envir import *
import values


class obj_stone :
	def __init__(self, pos) :
		global valeurs
		valeurs = values.parameters()
		self.name = 3
	def nbr_actions(self) :
		return 0
	def actions(self,i,agents,env):
		print "I'm a stone"
	def find_best_move(self,env,agents) :
		return 0

#####################################

class obj_chat :
	def __init__(self, pos) :
		global valeurs
		valeurs = values.parameters()
		self.name = 1
		self.nbr = len(pos)
		#nbr mouvement par subagent
		self.base = 5

	def nbr_actions(self): 
		n = self.base**self.nbr
		return n

	def action(self,i,agents,env) :
		previous_state = env.get_state_name()
		depl=[]
		level = 0
		mouve = i
		while level < self.nbr :
			zemv = mouve%self.base
			depl.append(agents[level].move(zemv))
			mouve = mouve/self.base
			level += 1
		# Heart of the Q-Learning methode
		current_state = env.get_state_name()
		#reward
		rew = self.reward(depl,agents,env)		
		#We set the new Q
		q_prime = env.get_q_value(current_state,i)
		v_prime = env.get_v_value(current_state)
		new_q1 = ((1-valeurs.alpha)*q_prime)
		q_tmp = rew+(valeurs.gamma*v_prime)
		new_q2 = valeurs.alpha*q_tmp
		new_q = new_q1 + new_q2
		env.set_q_value(previous_state,i,new_q)
		#we set the new V
		jj = 0
		bestv = 0
		bestmv = []
		while jj < self.nbr_actions() :
			tmp = env.get_q_value(previous_state,jj)
			if tmp >= bestv :
				bestv = tmp
				bestmv.append(jj)
			jj += 1
		#print "bestv = %d"%bestv
		env.set_v_value(previous_state,bestv)
		#we set the new pi
		#jj = 0.
		#total = 0.
		#while jj < self.nbr_actions() :
		#	to_add = env.get_q_value(previous_state,jj)
		#	total += to_add
		#	print "Q value %d : %d"%(jj,to_add)
		#	jj += 1
		jj = 0
		l = len(bestmv)
		while jj < self.nbr_actions() :
			env.set_pi_value(previous_state,jj,0)
			jj += 1
		for kk in bestmv :
			env.set_pi_value(previous_state,kk,1./l)
		#we set the new alpha
		valeurs.alpha = valeurs.alpha*valeurs.decay
		#decrease exploration coeff
		valeurs.exploration_coeff = valeurs.exploration_coeff*valeurs.explore_decrease 

		return depl

	def find_best_move(self,env,agents) :
		state = env.get_state_name()
		#ratio d'exploration !
		expl = valeurs.exploration_coeff
		val = random.random()
		#if we are close of the mouse, take it !
		zem = 0
		while zem < 5 :
			for a in agents :
				if a.test(zem) == 2 :
					return zem
			zem+=1
			
		if val < expl :
			mv = random.randrange(self.nbr_actions())
			return mv
		else :
			j = 0
			r = random.random()
			step = 0.
			while j < self.nbr_actions() :
				#la probabilite de choisir chaque action
				val = env.get_pi_value(state,j)
				#print j,val
				step += val
				#print j, r, val,step
				if r <= step :
					#print "on a choisi %d" %j
					#print "--------------"
					return j
				j += 1
			#soluc = []
			#maxi = 0
			#while j < self.nbr_actions() :
			#	val = env.get_pi_value(state,j)
			#	print j,val
			#	if val > maxi :
			#		maxi = val
			#		soluc = [j]
			#	elif val == maxi :
			#		soluc.append(j)
			#	j+=1
			#r = random.randrange(len(soluc))
			#print soluc
			#print "------------"
			#print "on a choisi %d" %(soluc[r])
			#print "#############"
			#return soluc[r]

	def reward(self,target,agents,env):
		result = target.count(2)
		murs = target.count(3)
		reward = 0
		dist = 0
		if result == None :
			print "Erreur dans le return du result"
		if murs >= 1 :
			reward -= 1000
			#print reward
		if result >= 1 :
			goal = 1
			print "GOAL !"
			reward += 100000
		else :
			#reward is the opposite of the distance

			for a in env.get_agents() :
				#we want to mesure with agents 2
				if a.get_name() == 2 :
					#we must initialize the value to a big number
					# bigger than the size of the board
					tmp = 100000
					ennemis = a.get_subagents()
					for mouse in ennemis :
						tmp2 = 0
						for cat in agents :
							dis = env.distance(cat.get_pos(),mouse)
							if dis == 1 :
								#print "pas loin !"
								#print cat.get_pos(),mouse
								reward += 10000
								#time.sleep(100)
							# "distance = %d"%dis
							#time.sleep(5)
							tmp2 += dis*dis
						if tmp2 < tmp :
							tmp = tmp2
					dist += tmp
		return 0-dist+reward

######################################################

class obj_souris :
	def __init__(self, pos) :
		global valeurs
		valeurs = values.parameters()
		self.name = 2

	def nbr_actions(self): 
		return 5

	def action(self,i,agents,env) :
		if i == 0 :
			a=agents[0].move(0)
			return [a]
		elif i == 1 :
			a=agents[0].move(1)
			return [a]
		elif i == 2 :
			a=agents[0].move(2)
			return [a]
		elif i == 3 :
			a=agents[0].move(3)
			return [a]
		elif i == 4 :
			a=agents[0].move(4)
			return [a]

	def find_best_move(self,env,agents):
		#randomized min-max algorithm for the mouse
		# here it's the "stochastic best escape" algorithm	
		# we want to stay away of agents 1
		dic = {}
		best = [0,0]
		moves = []
		myself = agents[0]
		j = 0
		while j < self.nbr_actions() :
			moves.append(j)
			j+=1
		#retirer le pire mouvement
		for m in moves :
			#print myself.move(m)
			# We will delete obviously bad mouvements
			result = myself.move(m)
			# here we don't allow move that bring us on the cat
			if result == 0 :
				me = agents[0].get_pos()
				bad = 0
				for a in env.get_agents() :
					#we want to escape agents 1
					if a.get_name() == 1 :
						tmp = 0.
						ennemis = a.get_subagents()
						for p in ennemis :
							dis = env.distance(me,p)
							# we don't accept mvt that go
							#to close of a cat
							if dis <= 1 :
								bad = 1
								break
							tmp += dis*dis
						if bad == 0 :
							dic[m] = tmp
						if tmp >= best[1] :
							best[0] = m
							best[1] = tmp
				myself.back(m)
		#print dic
		somme=0
		if len(dic) >= 1 :
			for i in dic :
				somme+= dic[i]
			step = 0
			#now choose a solution randomly
			# according to the weight of each solution
			nbr = random.randrange(somme)
			for i in dic :
				step += dic[i]
				if nbr <= step :
					#print i
					return i
		else :
			#no solution found. We don't move
			return 0



################################################

class obj_env :
	def __init__(self,xy):
		global valeurs
		valeurs = values.parameters()
		#definition de l'espace
		#les zeros sont des espaces vides
		self.space = zeros(xy)

