#!/usr/bin/python

"""Aggregate pwrkap servers into one virtual pwrkap server."""
# (C) Copyright IBM Corp. 2008-2009
# Licensed under the GPLv2.
import asyncore
import socket
import cPickle as pickle
import StringIO
import lazy_log
import traceback
import datetime
import thread
import time
import sys
import pwrkap_data

AGGREGATE_REFRESH = 15

def read_objs_from_string(str):
	"""Unpickle commands from a string."""
	sio = StringIO.StringIO(str)
	objs = []
	pos = 0
	while True:
		try:
			obj = pickle.load(sio)
			pos = sio.tell()
			objs.append(obj)
		except EOFError:
			break
	return (pos, objs)

class listen_dispatcher(asyncore.dispatcher):
	"""Dispatch incoming connections."""

	def __init__(self, socket, controller):
		"""Set up server socket."""
		asyncore.dispatcher.__init__(self, socket)
		self.controller = controller
		self.set_reuse_addr()
		self.listen(5)

	def handle_accept(self):
		"""Dispatch incoming connection attempt."""
		conn, addr = self.accept()
		client_dispatcher(conn, self.controller)

class base_dispatcher(asyncore.dispatcher):
	"""Writable dispatcher."""

	def __init__(self, socket = None, map = None):
		"""Create a writable dispatcher."""
		asyncore.dispatcher.__init__(self, socket, map)
		self.out_buf = ""
		self.in_buf = ""

	def handle_write(self):
		"""Push data to client."""
		sent = self.send(self.out_buf[0])
		self.out_buf = self.out_buf[sent:]

	def writable(self):
		"""Determine if there are data to write to the client."""
		return (len(self.out_buf) > 0)
	
	def write(self, buffer):
		"""Send some data to be written."""
		self.out_buf = self.out_buf + buffer

	def read_objs_from_socket(self):
		"""Read objects from socket."""
		try:
			buf = self.recv(4096)
		except:
			self.handle_close()
			return []

		self.in_buf = self.in_buf + buf
		while len(buf) > 0:
			try:
				buf = self.recv(4096)
				self.in_buf = self.in_buf + buf
			except:
				break

		pos, objs = read_objs_from_string(self.in_buf)
		self.in_buf = self.in_buf[pos:]

		return objs

class client_dispatcher(base_dispatcher):
	"""Talk to a pwrkap client."""

	def __init__(self, socket, controller):
		"""Create a client dispatcher."""
		base_dispatcher.__init__(self, socket)
		self.controller = controller
		self.controller.add_client(self)

	def handle_read(self):
		"""Read and dispatch commands."""
		for obj in self.read_objs_from_socket():
			self.controller.command(obj)

	def handle_close(self):
		"""Remove ourself."""
		self.controller.remove_client(self)
		asyncore.dispatcher.close(self)
		self.connected = False

class server_dispatcher(base_dispatcher):
	"""Talk to a pwrkap server."""

	def __init__(self, controller, sock_func, connect_opts):
		"""Create a server dispatcher."""

		base_dispatcher.__init__(self)
		self.controller = controller
		self.connect_opts = connect_opts
		self.sock_func = sock_func
		self.usock = None
		self.ignore_next = True

	def handle_read(self):
		"""Read and dispatch status."""

		for obj in self.read_objs_from_socket():
			if self.ignore_next:
				self.ignore_next = False
				continue
			try:
				self.controller.status(self, obj)
			except:
				traceback.print_exc()

	def handle_close(self):
		"""Handle connection closing."""
		self.controller.remove_server(self)
		self.close()
		self.usock = None
		self.connected = False
		self.peer = None
	
	def try_connect(self):
		"""Try to connect."""
		assert not self.connected

		if self.usock == None:
			self.usock = self.sock_func()
			self.usock.setblocking(0)
			self.set_socket(self.usock)

		try:
			print "Connecting to %s:%d..." % self.connect_opts
			x = self.socket.connect_ex(self.connect_opts)
		except:
			return False

		return True

	def handle_connect(self):
		"""Handle a successful connection."""
		self.ignore_next = True
		try:
			self.peer = self.socket.getpeername()
		except:
			self.peer = None
			pass
		self.controller.add_server(self)

class aggregate_controller:
	"""Aggregate a bunch of pwrkap servers to pwrkap clients."""

	def __init__(self, name, server_socket_info, domains):
		"""Create a pwrkap aggregator with given name and a list of (host, port, domain) tuples."""
		assert len(domains) > 0

		self.clients = []
		self.servers = []
		self.name = name
		self.logger = lazy_log.lazy_log(self, 3600, 2)

		self.command_table = {
			"cap": self.cap_command}

		self.ssock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		self.ssock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
		self.ssock.bind(server_socket_info)
		self.listen_dispatcher = listen_dispatcher(self.ssock, self)

		# Contains domain info
		self.domains = {}
		for hpd in domains:
			self.domains[hpd] = {"cap": None, "energy": None,
					     "power": None, "utilization": None}

		self.dispatchers = {}
		self.hostportdispatcher = {}
		hostport_seen = []
		self.inactive_dispatchers = []
		self.active_dispatchers = []
		for host, port, dom in domains:
			if (host, port) in hostport_seen:
				continue
			dispatcher = server_dispatcher(self,
				lambda: socket.socket(socket.AF_INET, socket.SOCK_STREAM),
				(host, port))
			self.dispatchers[dispatcher] = (host, port)
			self.hostportdispatcher[(host, port)] = dispatcher
			self.inactive_dispatchers.append(dispatcher)
			hostport_seen.append((host, port))

		print ("dispatchers ready" , self.dispatchers)
		self.try_to_activate_dispatchers()

	def try_to_activate_dispatchers(self):
		"""Try to connect inactive dispatchers."""
		connected = []
		for d in self.inactive_dispatchers:
			if not d.try_connect():
				continue
			connected.append(d)
		for c in connected:
			self.inactive_dispatchers.remove(c)
		self.active_dispatchers.extend(connected)

	def add_client(self, client_dispatcher):
		"""Add a client."""
		self.clients.append(client_dispatcher)

		# Write fake inventory
		data = [(self.name, {"domains": [], "meter": {"aggregate": {}}})]
		datastr = pickle.dumps(data, pickle.HIGHEST_PROTOCOL)
		client_dispatcher.write(datastr)

		# Write old snapshots
		pickles = self.logger.dump_log()
		for apickle in pickles:
			client_dispatcher.write(apickle)

		# Write live stamp
		data = (datetime.datetime.utcnow(), "live")
		datastr = pickle.dumps(data, pickle.HIGHEST_PROTOCOL)
		client_dispatcher.write(datastr)

	def remove_client(self, client_dispatcher):
		"""Remove a client."""
		self.clients.remove(client_dispatcher)

	def add_server(self, server_dispatcher):
		"""Add a server."""
		self.servers.append(server_dispatcher)

	def remove_server(self, server_dispatcher):
		"""Remove a server."""
		self.servers.remove(server_dispatcher)
		self.active_dispatchers.remove(server_dispatcher)
		self.inactive_dispatchers.append(server_dispatcher)

	def write(self, buffer):
		"""Write data to all clients."""
		for client in self.clients:
			client.write(buffer)

	def cap_command(self, command):
		"""Handle the cap command."""
		new_cap = float(command[2])

		total_cap = 0.0
		for key in self.domains.keys():
			dom = self.domains[key]
			if dom["cap"] != None:
				total_cap = total_cap + dom["cap"]
		
		for key in self.domains.keys():
			(host, port, domname) = key
			dom = self.domains[key]
			command = [domname, "cap", (dom["cap"] / total_cap) * new_cap]
			cmdstr = pickle.dumps(command, pickle.HIGHEST_PROTOCOL)
			dispatcher = self.hostportdispatcher[(host, port)]
			dispatcher.write(cmdstr)

	def command(self, command):
		"""Handle commands."""
		if command[0] != self.name:
			return
		self.command_table[command[1]](command)

	def status(self, dispatcher, status):
		"""Handle status reports."""
		(timestamp, some_data) = status
		if some_data == "live":
			return

		(domname, domstatus) = some_data
		peer = dispatcher.peer
		if peer == None:
			peer = self.dispatchers[dispatcher]
		(host, port) = peer

		# Are we watching this domain?
		domkey = (host, port, domname)
		if not self.domains.has_key(domkey):
			return

		# Collect status data
		dom = self.domains[domkey]
		dom["cap"] = domstatus["cap"]
		dom["power"] = domstatus["power"]
		if domstatus.has_key("energy"):
			energy = domstatus["energy"]	
		else:
			energy = None
		dom["energy"] = energy
		dom["utilization"] = domstatus["utilization"]
		if domstatus.has_key("util_details"):
			ud = {}
			for detail in domstatus["util_details"].keys():
				ud["%s:%d:%s:%s" % (host, port, domname, detail)] = domstatus["util_details"][detail]
		else:
			ud = {domname: dom["utilization"]}
		dom["util_details"] = ud

	def run(self):
		"""Start this controller."""
		thread.start_new_thread(self.do_periodic_updates, ())

	def do_periodic_updates(self):
		"""Periodically aggregate data and send to clients."""
		while True:
			print "Refresh"
			self.try_to_activate_dispatchers()
			self.update_clients()
			time.sleep(AGGREGATE_REFRESH)

	def update_clients(self):
		"""Send status update to clients."""

		total_cap = total_energy = total_power = total_utilization = 0.0
		util_details = {}
		doms_found = len(self.domains)
		for key in self.domains.keys():
			dom = self.domains[key]
			try:
				total_cap = total_cap + dom["cap"]
				total_power = total_power + dom["power"]
				if "energy" in dom.keys():
					total_energy = total_energy + dom["energy"]
				total_utilization = total_utilization + dom["utilization"]
				util_details.update(dom["util_details"])
			except:
				doms_found = doms_found - 1
		if doms_found == 0:
			return
		avg_utilization = pwrkap_data.average_utilization(util_details)
		old_avg_util = total_utilization / doms_found

		data = (self.name, {"domains": [],
				    "cap": total_cap,
				    "power": total_power,
				    "energy": total_energy,
				    "utilization": avg_utilization,
				    "old_avg_util": old_avg_util,
				    "util_details": util_details})

		self.logger.log(data)

#ac = aggregate_controller("agg0", ('0.0.0.0', 9410), [
#	('9.47.66.63', 9410, 'pwrdom0'),
#	('9.47.66.254', 9410, 'pwrdom0')
#])
#ac.run()

def read_config_file(file):
	"""Read config file and set up aggregates.

config file format:
domain $listen_addr $port $aggregate_name consists of:
system $hostname $port $domain
"""
	listen_addr = None
	port = None
	domain = None
	systems = []
	controllers = []
	for line in file:
		components = line.split()
		if len(components) < 1 or components[0][0] == "#":
			continue
		if port == None and components[0] != "domain":
			continue
		if components[0] == "domain":
			if port != None:
				print ("C", domain, listen_addr, port)
				ac = aggregate_controller(domain, (listen_addr, port), systems)
				controllers.append(ac)
				port = None
				systems = []
			listen_addr = components[1]
			port = int(components[2])
			domain = components[3]
		elif components[0] == "system":
			system_hostname = components[1]
			system_port = int(components[2])
			system_domain = components[3]
			systems.append((system_hostname, system_port, system_domain))
	if port != None:
		print ("D", domain, listen_addr, port)
		ac = aggregate_controller(domain, (listen_addr, port), systems)
		controllers.append(ac)

	return controllers

fname = "./pwrkap_aggregate.conf"
if len(sys.argv) > 1:
	fname = sys.argv[1]
ctrls = read_config_file(file(fname))
for ac in ctrls:
	ac.run()
asyncore.loop()
