# Version 1.4-1

import re
import sys
from hostlist import expand_hostlist, collect_hostlist
import commands

# Exceptions

class NodesException(Exception):
    def __init__(self, msg):
        self.msg = msg
class UnknownName(NodesException): pass
class DuplicatedGroup(NodesException): pass
class BadConfigSyntax(NodesException): pass
class MissingGroup(NodesException): pass
class DuplicateMembership(NodesException): pass
class UnknownLevel(NodesException): pass
class AmbiguousLevel(NodesException): pass
class FailedSLURM(NodesException): pass

# Helper functions

def collect_name_hostlist(node_iterable):
    return collect_hostlist([node.name for node in node_iterable])

# Sort a list of entities numerically

def numerically_sorted_entities (l):
    """Sort a list of entities numerically.

    E.g. sorted order should be n1, n2, n10; not n1, n10, n2.
    """

    return sorted(l, key=entitiy_numeric_sort_key)

nsk_re = re.compile("([0-9]+)|([^0-9]+)")
def entitiy_numeric_sort_key(x):
    return [handle_int_nonint(i_ni) for i_ni in nsk_re.findall(x.name)]

def handle_int_nonint(int_nonint_tuple):
    if int_nonint_tuple[0]:
        return int(int_nonint_tuple[0])
    else:
        return int_nonint_tuple[1]


# Classes

class Entity:
    pass

class Node(Entity):
    def __init__(self, name):
        self.name = name
        self.part_of = {} # level_name -> Group

    def __repr__(self):
        return "<Node: %s>" % (self.name)

    def expand_to_node_set(self):
        return set([self])

    def add_to_group(self, group):
        if group.level in self.part_of:
            raise DuplicateMembership("%s cannot belong to %s %s and %s at the same time" % \
                                          (self.name,
                                           group.level,
                                           self.part_of[group.level].name,
                                           group.name))
            
        self.part_of[group.level] = group

    def get_group(self, level):
        return self.part_of.get(level)

class Group(Entity):
    def __init__(self, hierarchy, level, name, all_nodes_group = False):
        self.hierarchy = hierarchy
        self.level = level
        self.name = name
        self.node_set = set() # Node set
        self.representative = None
        self.all_nodes_group = all_nodes_group

    def __repr__(self):
        return "<Group %s: %s>" % (self.level, self.name)

    def expand_to_node_set(self):
        return self.node_set

    def get_representative(self):
        if self.representative is None:
            self.representative = numerically_sorted_entities(self.node_set)[0]
        return self.representative

    def add_nodespec(self, nodespec):
        #print "  NODES", self.name, "-->", nodespec
        node_set = set()
        for name in expand_hostlist(nodespec):
            entity = self.hierarchy.get_entity(name, auto_add_node = self.all_nodes_group)
            nodes = entity.expand_to_node_set()
            node_set |= nodes
        #print "  NODE_SET", node_set
        self.node_set |= node_set
        for node in node_set:
            node.add_to_group(self)
        self.representative = None

class Hierarchy:
    def __init__(self):
        self.names = {} # name -> Node or Group
        self.level_names = set() # level names used for abbreviation lookup

    def create_group(self, level, name, all_nodes_group = False):
        #print "GROUP", level, name
        self.level_names.add(level)
        if name in self.names:
            raise DuplicatedGroup("group %s seen twice" % name)
        group = Group(self, level, name, all_nodes_group)
        self.names[name] = group
        return group

    def create_or_get_group(self, level, name):
        if name in self.names:
            return self.names[name]
        else:
            return self.create_group(level, name)

    def get_entity(self, name, auto_add_node = False):
        if name not in self.names:
            if auto_add_node:
                self.names[name] = Node(name)
            else:
                raise UnknownName("unknown name " + name)
        return self.names[name]

    def parse_file(self, file_or_filename):
        if type(file_or_filename) == type(""):
            f = open(file_or_filename)
        else:
            f = file_or_filename
        current_group = None
        for line in f:
            line = re.sub(r' *#.*', '', line)
            line = line.rstrip()
            if line.strip() == "": continue

            # nodes: <nodes>
            m = re.match(r'^nodes:\s*(.*)', line)
            if m:
                rest = m.group(1)
                current_group = self.create_group("all", "nodes", all_nodes_group = True)
                if rest:
                    current_group.add_nodespec(rest)
                continue
            
            # <level> <name>: <parts>
            m = re.match(r'^([a-z]+)\s+([a-z0-9]+):\s*(.*)', line)
            if m: 
                (level, name, rest) = m.group(1,2,3)
                current_group = self.create_group(level, name)
                if rest:
                    current_group.add_nodespec(rest)
                continue

            # <parts> (indented)
            if current_group is not None:
                m = re.match(r'^\s+(.*)', line)
                if m: 
                    rest = m.group(1)
                    current_group.add_nodespec(rest)
                    continue

            # Fail
            raise BadConfigSyntax(line)

    def parse_slurm(self):
        rc, slurm_data = commands.getstatusoutput('/usr/bin/squeue -aho "%i %u %N"')
        if rc <> 0:
            raise FailedSLURM("Failed to get data from SLURM")

        for line in slurm_data.split("\n"):
            fields = line.split()
            if len(fields) <> 3:
                continue
            job, user, nodes = fields

            gj = self.create_group("job", "job" + job)
            gj.add_nodespec(nodes)

            gu = self.create_or_get_group("user", user)
            gu.add_nodespec(nodes)

    def expand_abbreviated_level(self, level):
        matching = [l for l in self.level_names if l.startswith(level)]
        if len(matching) == 1:
            return matching[0]
        elif len(matching) == 0:
            raise UnknownLevel("unknown level %s" % level)
        else:
            raise AmbiguousLevel("ambiguous level %s matching %s" % (level,
                                                                     ", ".join(matching)))

    def expand_to_node_set(self, nodespec):
        res = set()
        for name in expand_hostlist(nodespec):
            entity = self.get_entity(name)
            res |= entity.expand_to_node_set()
        return res

    # This is the method invoked by "nodes --to-nodes" (default) 
    def expand(self, nodespec):
        return collect_name_hostlist(self.expand_to_node_set(nodespec))

    # This is the central logic behind up, fill and gather
    # Returns a tuple of sets (groups, leftovers, missing):
    #   groups:    groups that nodes in nodespec belong to
    #   leftovers: nodes that did not fill whole groups (or empty if fill is True)
    #   missing:   nodes that did not belong to any group

    def up_set(self, level, nodespec, fill = True, missing_group = None):
        groups = set()
        leftovers = set()
        missing = set()
        node_set = self.expand_to_node_set(nodespec)
        for node in node_set:
            group = node.get_group(level)
            if group is None:
                if missing_group is None:
                    missing.add(node)
                else:
                    groups.add(self.create_or_get_group(level, missing_group))
            else:
                if fill or group.expand_to_node_set() <= node_set:
                    groups.add(group)
                else:
                    leftovers.add(node)

        return groups, leftovers, missing

    # This is the method invoked by "nodes --up"
    def up(self, level, nodespec, missing_group = None):
        level = self.expand_abbreviated_level(level)
        groups, leftovers, missing = self.up_set(level, nodespec,
                                                 missing_group = missing_group)

        assert len(leftovers) == 0
        if missing:
            raise MissingGroup("missing %s for %s" % \
                                   (level,
                                    collect_name_hostlist(missing)))

        return collect_name_hostlist(groups)


    # This is the method invoked by "nodes --fill"
    def fill(self, level, nodespec):
        level = self.expand_abbreviated_level(level)
        groups, leftovers, missing = self.up_set(level, nodespec)

        assert len(leftovers) == 0
        if missing:
            sys.stderr.write("Warning: missing %s for %s\n" % \
                                 (level,
                                  collect_name_hostlist(missing)))

        # FIXME: Check if it makes more sense to signal error above?
        res = missing

        for group in groups:
            res |= group.expand_to_node_set()
        return collect_name_hostlist(res)

    # This is the method invoked by "nodes --representative"
    def representative(self, level, nodespec):
        level = self.expand_abbreviated_level(level)
        groups, leftovers, missing = self.up_set(level, nodespec)

        assert len(leftovers) == 0
        if missing:
            sys.stderr.write("Warning: missing %s for %s\n" % \
                                 (level,
                                  collect_name_hostlist(missing)))

        # FIXME: Check if it makes more sense to signal error above?
        res = missing

        for group in groups:
            res.add(group.get_representative())
        return collect_name_hostlist(res)

    # This is the method invoked by "nodes --gather"
    def gather(self, level, nodespec):
        level = self.expand_abbreviated_level(level)
        groups, leftovers, missing = self.up_set(level, nodespec, fill=False)

        if missing:
            sys.stderr.write("Warning: missing %s for %s\n" % \
                                 (level,
                                  collect_name_hostlist(missing)))

        return collect_name_hostlist(groups|leftovers|missing)



