#!/usr/bin/env python
# -*- coding: utf-8 -*-
from livewires import *
from Numeric import *
import ml
import random


class envir :
	def __init__(self, envi) :
		self.space = envi.space
		self.agents=[]
		self.nbr_actions = 0
		self.nbr_states = 0
		self.init ={}
		# Q is a dictionnary of dictionnaries
		# keys are the state name
		# entry is a dictionnary for avalaible actions
		self.Q = {}
		self.V = {}
		self.pi= {}
		self.size = [len(self.space),len(self.space[0])]
		#for a in agents :
			#ici on a le nombre d'actions possible
		#	self.nbr_actions += a.get_nbr_actions()
		#	self.nbr_states += a.get_nbr_states()

	def get_agents(self):
		return self.agents

	def get(self,xy):
		return self.space[xy[0]][xy[1]]

	def reinit(self):
		for a in self.agents :
			a.reinit()
		self.space = zeros(self.size)
		for i in self.init :
			for j in self.init[i] :
				self.put(i, j)

	def put(self, objet,coord) :
		x = coord[0]
		y = coord[1]
		value = self.space[x][y]
		self.space[x][y] = objet
		return value

	def remove(self,coord) :
		x = coord[0]
		y = coord[1]
		value = self.space[x][y]
		self.space[x][y] = 0
		return value

	def addagent(self,objet,positions):
		name = objet.name
		self.init[name] = positions
		for a in positions :
			self.put(name,a)
		new_ag = ml.agent(objet,self.size,positions,self)
		self.agents.append(new_ag)
		# est-ce que l'agent compte dans les actions possibles ?
		if name == 1 :
			self.nbr_actions += new_ag.get_nbr_actions()
		self.nbr_states += new_ag.get_nbr_states()
		return new_ag


	def get_state_name(self):
		name = ''
		for i in self.space :
			for j in i :
				name+=str(j)
		#now we can take advantage of the symetry
		#we will remove all leading and trailing 0
		#this is optional to improve performance A LOT
		#we must always start with the mouse (2 in our example)
		b=name.split('2',1)
		final='2'
		final+=b[1]
		final+=b[0]
		name = final.strip('0')
		#then
		return name

	def distance(self,a,b) :
		k_x = self.size[0]
		k_y = self.size[1]
		#print "distance :"
		#print a,b
		distance = 0
		### X###
		if a[0] < b[0] :
			horiz1 = b[0] - a[0]
			horiz2 = k_x - b[0] + a[0]
			if horiz1 < horiz2 :
				distance += horiz1
			else :
				distance += horiz2
		elif a[0] > b[0] :
			horiz1 = a[0] - b[0]
			horiz2 = k_x - a[0] + b[0]
			if horiz1 < horiz2 :
				distance += horiz1
			else :
				distance += horiz2
		#print "en X"
		#print distance
		### y ####
		if a[1] < b[1] :
			horiz1 = b[1] - a[1]
			horiz2 = k_y - b[1] + a[1]
			if horiz1 < horiz2 :
				distance += horiz1
			else :
				distance += horiz2
		elif a[1] > b[1] :
			horiz1 = a[1] - b[1]
			horiz2 = k_y - a[1] + b[1]
			if horiz1 < horiz2 :
				distance += horiz1
			else :
				distance += horiz2
		#print "total"
		#print distance
		return distance

	def get_q_value(self,state,action):
	# here, Q is virtually filled with 1
		try :
			value = self.Q[state][action]
			return value
		except :
			return 1

	def set_q_value(self,state,action,value):
		etat = str(state)
		try :
			self.Q[etat] 
		except : 
			dic={}
			self.Q[etat] = dic
		self.Q[etat][action] = value
		

	def get_v_value(self,state):
	# here, V is virtually filled with 1
		try :
			value = self.V[state]
			return value
		except :
			return 1

	def set_v_value(self,state,value) :
		self.V[state] = value

	def get_pi_value(self,state,action):
	# here, pi is virtually filled with 1/A
		try :
			value = self.pi[state][action]
			return value
		except :
			if self.nbr_actions == 0 :
				return 0
			else :
				#nbr of 
				return 1./self.nbr_actions

	def set_pi_value(self,state,action,value):
		etat = str(state)
		try :
			self.pi[etat] 
		except : 
			dic={}
			self.pi[etat] = dic
		self.pi[etat][action] = value

	def display(self) :
		return self.space

	def q_infos(self) :
		total = 0
		maxi = 0
		for i in self.Q :
			l = len(self.Q[i])
			total += l
			if l > maxi :
				maxi = l
		print "longueur de Q : %d ( %d actions testees max par etat)"%(total,maxi)
		return maxi
		#print self.Q

	def v_infos(self) :
		print "longueur de V : %d"%len(self.V)
		#print self.init
		#print self.V

	def pi_infos(self) :
		total = 0
		for i in self.pi :
			total += len(self.pi[i])
		print "longueur de pi : %d"%total
		#print self.pi

