# Version 1.1-1

import re
import sys
from hostlist import expand_hostlist, collect_hostlist

class NodesException(Exception):
    def __init__(self, msg):
        self.msg = msg
class UnknownName(NodesException): pass
class DuplicatedGroup(NodesException): pass
class BadConfigSyntax(NodesException): pass
class MissingGroup(NodesException): pass
class DuplicateMembership(NodesException): pass
class UnknownLevel(NodesException): pass
class AmbiguousLevel(NodesException): pass

class Entity:
    pass

class Node(Entity):
    def __init__(self, name):
        self.name = name
        self.part_of = {} # level_name -> Group

    def __repr__(self):
        return "<Node: %s>" % (self.name)

    def expand_to_nodes(self):
        return set([self])

    def add_to_group(self, group):
        if group.level in self.part_of:
            raise DuplicateMembership("%s cannot belong to %s %s and %s at the same time" % \
                                          (self.name,
                                           group.level,
                                           self.part_of[group.level].name,
                                           group.name))
            
        self.part_of[group.level] = group

    def get_group(self, level):
        return self.part_of.get(level)

class Group(Entity):
    def __init__(self, hierarchy, level, name, all_nodes_group = False):
        self.hierarchy = hierarchy
        self.level = level
        self.name = name
        self.nodes = set() # Node set
        self.all_nodes_group = all_nodes_group

    def __repr__(self):
        return "<Group %s: %s>" % (self.level, self.name)

    def expand_to_nodes(self):
        return self.nodes

    def add_nodes(self, name_string):
        #print "  NODES", self.name, "-->", name_string
        node_set = set()
        for name in expand_hostlist(name_string):
            entity = self.hierarchy.get_entity(name, auto_add_node = self.all_nodes_group)
            nodes = entity.expand_to_nodes()
            node_set |= nodes
        #print "  NODE_SET", node_set
        self.nodes |= node_set
        for node in node_set:
            node.add_to_group(self)

class Hierarchy:
    def __init__(self):
        self.names = {} # name -> Node or Group
        self.level_names = set() # level names used for abbreviation lookup

    def create_group(self, level, name, all_nodes_group = False):
        #print "GROUP", level, name
        if name in self.names:
            raise DuplicatedGroup("group %s seen twice" % name)
        group = Group(self, level, name, all_nodes_group)
        self.names[name] = group
        return group

    def get_entity(self, name, auto_add_node = False):
        if name not in self.names:
            if auto_add_node:
                self.names[name] = Node(name)
            else:
                raise UnknownName("unknown name " + name)
        return self.names[name]

    def parse_file(self, filename):
        f = open(filename)
        current_group = None
        for line in f:
            line = re.sub(r' *#.*', '', line)
            line = line.rstrip()
            if line.strip() == "": continue

            # nodes: <nodes>
            m = re.match(r'^nodes:\s*(.*)', line)
            if m:
                rest = m.group(1)
                current_group = self.create_group("all", "nodes", all_nodes_group = True)
                if rest:
                    current_group.add_nodes(rest)
                continue
            
            # <level> <name>: <parts>
            m = re.match(r'^([a-z]+)\s+([a-z0-9]+):\s*(.*)', line)
            if m: 
                (level, name, rest) = m.group(1,2,3)
                current_group = self.create_group(level, name)
                self.level_names.add(level)
                if rest:
                    current_group.add_nodes(rest)
                continue

            # <parts> (indented)
            if current_group is not None:
                m = re.match(r'^\s+(.*)', line)
                if m: 
                    rest = m.group(1)
                    current_group.add_nodes(rest)
                    continue

            # Fail
            raise BadConfigSyntax(line)

    def expand_abbreviated_level(self, level):
        matching = [l for l in self.level_names if l.startswith(level)]
        if len(matching) == 1:
            return matching[0]
        elif len(matching) == 0:
            raise UnknownLevel("unknown level %s" % level)
        else:
            raise AmbiguousLevel("ambiguous level %s matching %s" % (level,
                                                                     ", ".join(matching)))

    def expand_to_node_set(self, name_string):
        res = set()
        for name in expand_hostlist(name_string):
            entity = self.get_entity(name)
            res |= entity.expand_to_nodes()
        return res

    def expand_to_nodes(self, name_string):
        return collect_hostlist([node.name for node in self.expand_to_node_set(name_string)])

    def up_set(self, level, name_string):
        res = set()
        missing = set()
        for node in self.expand_to_node_set(name_string):
            group = node.get_group(level)
            if group is None:
                missing.add(node)
            else:
                res.add(group)
        return res, missing

    def up(self, level, name_string):
        level = self.expand_abbreviated_level(level)
        groups, missing = self.up_set(level, name_string)
        if missing:
            raise MissingGroup("missing %s for %s" % \
                                   (level,
                                    collect_hostlist([node.name for node in missing])))

        return collect_hostlist([group.name for group in groups])


    def gather(self, level, name_string):
        level = self.expand_abbreviated_level(level)
        groups = set()
        missing = set()
        not_groupable = set()
        node_set = self.expand_to_node_set(name_string)
        for node in node_set:
            group = node.get_group(level)
            if group is None:
                missing.add(node)
            elif group.expand_to_nodes() <= node_set:
                groups.add(group)
            else:
                not_groupable.add(node)
        if missing:
            raise MissingGroup("missing %s for %s" % \
                                   (level,
                                    collect_hostlist([node.name for node in missing])))

        return collect_hostlist([group.name for group in groups|not_groupable])

    def fill(self, level, name_string):
        level = self.expand_abbreviated_level(level)
        groups, res = self.up_set(level, name_string)
        if res:
            sys.stderr.write("Warning: missing %s for %s\n" % \
                                 (level,
                                  collect_hostlist([node.name for node in res])))
        for group in groups:
            res |= group.expand_to_nodes()
        return collect_hostlist([node.name for node in res])

