import sys
import os

import io
from contextlib import redirect_stdout
import traceback

import json
from collections import defaultdict

import numpy
from scipy.spatial import cKDTree as KDTree
import matplotlib



# Tell matplotlib to behave itself. Should typically not occur due to the matplotlib intercept, but that can be disabled; alternatively the student may have done something silly...
matplotlib.use('svg')



###############################################################################
# Loading...
###############################################################################



def victim():
  """Returns the filename from the command line; handles errors."""
  if len(sys.argv)!=2:
    print('Critical: File to read not provided')
    print()
    sys.exit(1)
  
  return sys.argv[1]



class Namespace:
  """Dummy class, used as a namespace for code when it is loaded, i.e. declared empty so loaded code can use it as its storage for global variables."""
  pass



class Ingest:
  """Parent for classes that load in code - an interface plus some helpers that make it a bit like a dictionary for if you want to get functions."""
  
  def __call__(self):
    """Returns the Namespace that contains the ingested global variables."""
    raise NotImplementedError


  def __contains__(self, name):
    return hasattr(self(), name)

  
  def __getitem__(self, name):
    return getattr(self(), name)
  
  
  def block(self):
    """Iterates over the code blocks contained within each cell, yielding a string for each."""
    raise NotImplementedError


  def cells(self):
    """Iterates over the text output by each cell, in each case it yields a string."""
    raise NotImplementedError
  
  
  def run(self, code):
    """Runs the given code block within the context of the workbook, returning a tuple of (new local variable dictionary, output of any print statements as one big string). For the local variable dictionary it takes a *shallow* copy of the original and then updates with the code. The possibility of """
    raise NotImplementedError



###############################################################################
# Core...
###############################################################################



class Judge:
  """Interface for a judge - represents a single unit test of the code that results in a pass (True), fail (False) or confusion (None)."""
  
  def __call__(self, injest):
    """Given an in Injest instance this returns True if the code passes, False if it fails, None if it can't tell."""
    raise NotImplementedError



class Question:
  """Combines many Judge's to determine a final mark for a single question, with support for various rules. A functor, so create, configure, then call with an injest object and it will print out the mark for the question."""

  def __init__(self, name, maximum, divider = 1):
    """Initalised with the name of the question, which is usually the question number, and the maximum number of marks it is worth (must be an integer). The maximum mark count is only used as a check that the groups sum to this, to avoid mistakes. There is also an optional divider, which is only applied when outputting marks, to support fractional total marks and if it's easier to reason with more marks than the question is worth."""
    assert(isinstance(maximum, int))
    assert(isinstance(divider, int))
    
    self._name = name
    self._maximum = maximum
    self._divider = divider
    
    self._worth = defaultdict(lambda: (1,2)) # marks, quant
    self._mode = defaultdict(lambda: 'down')
    self._mob = defaultdict(list) # Always loop this dicts keys.


  def worth(self, group, worth, quant = 2):
    """Sets how many marks a group is worth. If not set for any group it defaults to 1. Must an integer. There is also the quant value, which indicates how fine grained the marking is. 1 means it only awards whole marks, 2 half marks, 3 thirds of marks and so on. Defaults to 2 for half marks."""
    assert(isinstance(worth, int))
    assert(isinstance(quant, int))
    
    self._worth[group] = (worth, quant)


  def mode(self, group, mode):
    """Sets the mode of a given group, which defaults to. By default they are all set to 'down', which means the mark is given as the ratio of passed tests multiplied by the group mark and then rounded down to the nearest quantisation step. Other options are: 'up', scaled ratio but rounded up; 'any', full marks if any test is passed; 'all', no marks if any test fails, otherwise all."""
    assert(mode in ['all', 'down', 'up', 'any'])
    self._mode[group] = mode
  
  
  def add(self, group, judge):
    """Adds a Judge to a group. There is nothing to stop you adding the same Judge to multiple groups with several calls (and it's smart enough to only evaluate the Judge once)."""
    self._mob[group].append(judge)
  
  
  def __call__(self, injest, decimals = 2):
    """Runs the Judge-s, calculates the final score, and prints it out. Requires an injest object as a parameter, so it can pass to the judge to be run."""
    print('Question:', self._name)
    
    # Verify that the maximum score matches up with the groups...
    maximum = 0
    for group in self._mob.keys():
      maximum += self._worth[group][0]
    
    if maximum!=self._maximum:
      print('Critical: Scores do not add up.')
      print()
      return
    
    # Cache so any Judge added more than once is only run once...
    cache = dict()
    
    # Loop and sum score of each group...
    total_low = 0.0
    total_high = 0.0
    for group in self._mob.keys():
      # Get array of True/False judgements...
      mob = self._mob[group]
      judgement = [None] * len(mob)
      
      for i, judge in enumerate(mob):
        if id(judge) in cache:
          judgement[i] = cache[id(judge)]
        
        else:
          try:
            judgement[i] = judge(injest)
          except Exception as e:
            judgement[i] = None
            for line in traceback.format_exc(4).split('\n'):
              print('Info: {}'.format(line))
            print('Critical: Judge {} failed with error {}'.format(str(judge), type(e).__name__))
          cache[id(judge)] = judgement[i]
      
      # Calculate score depending on mode...
      mode = self._mode[group]
      worth, quant = self._worth[group]
      
      if mode=='all':
        if False in judgement:
          score_low = 0
          score_high = 0
          
        elif None in judgement:
          score_low = 0
          score_high = worth
          
        else:
          score_low = worth
          score_high = worth
      
      elif mode=='down':
        ratio_low = judgement.count(True) / len(judgement)
        score_low = numpy.floor(worth * quant * ratio_low) / quant
        
        ratio_high = 1.0 - judgement.count(False) / len(judgement)
        score_high = numpy.floor(worth * quant * ratio_high) / quant
      
      elif mode=='up':
        ratio_low = judgement.count(True) / len(judgement)
        score_low = numpy.ceil(worth * quant * ratio_low) / quant
        
        ratio_high = 1.0 - judgement.count(False) / len(judgement)
        score_high = numpy.ceil(worth * quant * ratio_high) / quant
      
      elif mode=='any':
        if True in judgement:
          score_low = worth
          score_high = worth
          
        elif None in judgement:
          score_low = 0
          score_high = worth
          
        else:
          score_low = 0
          score_high = 0
      
      else:
        print('Critical: Unrecognised group mode for group {}'.format(group))
        print()
        return
      
      # Add in score, and print it out...
      total_low += score_low
      total_high += score_high
      if group is not None:
        if numpy.isclose(score_low, score_high):
          print('Submark: {} = {:g}'.format(group, score_low / self._divider))
          
        else:
          print('Submark: {} = {:g}--{:g}'.format(group, score_low / self._divider, score_high / self._divider))
    
    # Report final mark...
    total_low = numpy.around(total_low, decimals)
    total_high = numpy.around(total_high, decimals)
    if numpy.isclose(total_low, total_high):
      print('Mark: {:g} / {:g}'.format(total_low / self._divider, maximum / self._divider))
      
    else:
      print('Mark: {:g}--{:g} / {:g}'.format(total_low / self._divider, total_high / self._divider, maximum / self._divider))

    print()



###############################################################################
# Loading instances...
###############################################################################



class InjestJupyter(Ingest):
  """Given the file name of a .ipynb file this loads and runs all of the code, within it's own namespace, and then provides an interface to interface with it."""
  def __init__(self, fn, cwd='.', graphs=True):
    super().__init__()
    
    # Load in file...
    with open(fn, 'r') as fin:
      ipynb = json.load(fin)
    
    # Create namespace to keep execution variables in and a list for all of the cells and all of the Graph objects...
    self._ns = Namespace()
    self._code = []
    self._cells = []
    self._graphs = []
    
    if graphs:
      mpl = MatPlotLazy()
    
    # Move to a directory that contains files the script needs to load...
    owd = os.getcwd()
    os.chdir(os.path.abspath(cwd))

    # Load every code cell into the namespace by executing it...
    for cell in ipynb['cells']:
      if cell['cell_type']!='code':
        continue
      code = '\n'.join([c for c in cell['source'] if not c.strip().startswith('%')])
      
      buf = io.StringIO()
      error = None
      
      with redirect_stdout(buf):
        try:
          exec(code, self._ns.__dict__)
    
        except Exception as e:
          error = type(e).__name__
          msg = str(e)
      
      s = buf.getvalue()
      if len(s)>0:
        self._cells.append(s)
      
      if error is not None:
        print('Severe: Error executing code block - {}'.format(error))
        for line in msg.split('\n'):
          print('Info: {}'.format(line))
      
      # Check if the variable plt exists, and if so if it's matplotlib.pyplot; if so and we're capturing graphs replace it with the interceptor...
      # (obviously if the student imports and then immediatly uses it or gives it a different name this won't work, but should work 99% of the time given I always add the code at the top in a seperate block)
      if graphs:
        if 'plt' in self._ns.__dict__ and matplotlib.pyplot is self._ns.__dict__['plt']:
          self._ns.__dict__['plt'] = mpl
      
    # Revert to old directory...
    os.chdir(owd)
    
    # Copy over graphs...
    if graphs:
      for graph in mpl.graphs():
        self._graphs.append(graph)
    
    # Print out some statistics...
    if len(self._cells)>0:
      print('Info: Captured {} cell outputs'.format(len(self._cells)))
    
    if len(self._graphs)>0:
      print('Info: Captured {} graphs'.format(len(self._graphs)))


  def __call__(self):
    return self._ns
  
  
  def block(self):
    for code in self._code:
      yield code


  def cells(self):
    for cell in self._cells:
      yield cell
  
  
  def graphs(self):
    for graph in self._graphs:
      yield graph


  def run(self, code):
    buf = io.StringIO()
    error = None
    local = self._ns.__dict__.copy()
    
    with redirect_stdout(buf):
      try:
        exec(code, local)
    
      except Exception as e:
        error = type(e).__name__
        msg = str(e)

    if error is not None:
      print('Severe: Error during run() - {}'.format(error))
      for line in msg.split('\n'):
        print('Info: {}'.format(line))

    return local, buf.getvalue()



###############################################################################
# The Judges...
###############################################################################



class Dredd(Judge):
  """Stand in for when a civilised Judge is not coded and you're feeling like a fascist: prints out a warning message and always finds the code guilty."""
  def __call__(self, injest):
    print('Critical: Dredd finds everyone guilty')
    return False



class MrBean(Judge):
  """Distracts the judge so they forget to pass judgement, leaving it to the meat sack."""
  def __call__(self, injest):
    print('Info: Bean is confused')
    return None



class Picard(Judge):
  """For when the universe needs saving so everyone just gets the marks. Don't tell the university."""
  def __call__(self, injest):
    print('Critical: Make it so')
    return True



class Match(Judge):
  """Calls a function with a specified set of parameters and returns True only if they exactly match a provided return."""
  def __init__(self, ret, func, *param, **keywords):
    self._ret = ret
    self._func = func
    self._param = param
    self._keywords = keywords
  
  
  def __call__(self, injest):
    # Fetch function to judge...
    if self._func not in injest:
      return False
    f = injest[self._func]
    
    # Execute and get return...
    error = None
    with io.StringIO() as buf, redirect_stdout(buf):
      try:
        r = f(*self._param, **self._keywords)
    
      except Exception as e:
        error = type(e).__name__
    
    if error is not None:
      print('Warning: Exception {} in Match judge for function {}'.format(error, self._func))
      return False

    # Check...
    return r==self._ret



class MatchUnordered(Judge):
  """Calls a function with a specified set of parameters and returns True only if they match a provided return after 'set' has been called on both sides, such that the order of the return doesn't matter."""
  def __init__(self, ret, func, *param, **keywords):
    self._ret = set(ret)
    self._func = func
    self._param = param
    self._keywords = keywords
  
  
  def __call__(self, injest):
    # Fetch function to judge...
    if self._func not in injest:
      return False
    f = injest[self._func]
    
    # Execute and get return...
    error = None
    with io.StringIO() as buf, redirect_stdout(buf):
      try:
        r = f(*self._param, **self._keywords)
        r = set(r)
    
      except Exception as e:
        error = type(e).__name__
    
    if error is not None:
      print('Warning: Exception {} in Match judge for function {}'.format(error, self._func))
      return False

    # Check...
    return r==self._ret



class MatchPartial(Judge):
  """Calls a function with a specified set of parameters and returns True if everything the user desires matches. The provided return to match against must be a dictionary - it will index the functions return with each key in the dictionary and check it matches the value. All must match to return True. A missing key results in it returning False. Note the return from the function can be a list/tuple if the dictionary only cotains integer keys."""
  def __init__(self, ret, func, *param, **keywords):
    assert(isinstance(ret, dict))
    self._ret = ret
    self._func = func
    self._param = param
    self._keywords = keywords
  
  
  def __call__(self, injest):
    # Fetch function to judge...
    if self._func not in injest:
      return False
    f = injest[self._func]
    
    # Execute and get return...
    error = None
    with io.StringIO() as buf, redirect_stdout(buf):
      try:
        r = f(*self._param, **self._keywords)
    
      except Exception as e:
        print(e)
        error = type(e).__name__
    
    if error is not None:
      print('Warning: Exception {} in Match judge for function {}'.format(error, self._func))
      return False

    # Check...
    for key, value in self._ret.items():
      try:
        if value!=r[key]:
          return False
      except KeyError:
        return False
      except TypeError:
        return False
    return True



class MatchPrint(Judge):
  """Calls a function with a specified set of parameters and returns True only if it prints out a specified string. Match has to be perfect, though it calls strip on both to account for any random newlines at either end."""
  def __init__(self, prints, func, *param, **keywords):
    self._prints = prints.strip()
    self._func = func
    self._param = param
    self._keywords = keywords
  
  
  def __call__(self, injest):
    # Fetch function to judge...
    if self._func not in injest:
      return False
    f = injest[self._func]

    # Execute and get printed output...
    error = None
    with io.StringIO() as buf, redirect_stdout(buf):
      try:
        f(*self._param, **self._keywords)
        out = buf.getvalue().strip()
    
      except Exception as e:
        error = type(e).__name__
    
    if error is not None:
      print('Warning: Exception {} in Match judge for function {}'.format(error, self._func))
      return False
    
    # Check...
    return out==self._prints



class ContainsCell(Judge):
  """Given a list of strings returns True only if it can find all strings in a single cell in the notebook. If any string is a list of strings instead of a single string it considers that as an 'or', so matching any within the inner list means that term is consider to be found within the output list."""
  def __init__(self, match):
    self.match = match


  def __call__(self, injest):
    for cell in injest.cells():
      good = True
      
      for term in self.match:
        if isinstance(term, str):
          if term not in cell:
            good = False
            break
          
        else:
          for option in term:
            if option in cell:
              break
          
          else:
            good = False
            break

      if good:
        return True
    
    return False



class MatchScatter(Judge):
  """Searches for a specific scatter graph, with a set of plots that you provide. Returns True on a total match but None if it finds a graph that contains at least one of the provided scatter plots on."""
  
  def __init__(self, *scatter):
    """Each parameter is a tuple (x, y) of an expected scatter plot to be found on a graph. If you provide more than one then it is looking for a single graph that has each of the scatter plots on."""
    if len(scatter)==0:
      print('Critical: MatchScatter needs data to match against!')
      raise RuntimeError
    
    self._scatter = []
    for x, y in scatter:
      x = numpy.asarray(x)
      y = numpy.asarray(y)
      kd = KDTree(numpy.concatenate((x[:,None], y[:,None]), axis=1))
      self._scatter.append(kd)
  
  
  def __call__(self, injest):
    at_least_one = False

    for graph in injest.graphs():
      hits = 0
      for kd in self._scatter:
        if graph.match_scatter(kd):
          hits += 1
      
      if hits==len(self._scatter):
        return True

      elif hits!=0:
        at_least_one = True
    
    if at_least_one:
      return None
    
    else:
      return False



class VarClose(Judge):
  """Checks if a variable is close to a given value, using numpy.allclose so vectorised."""
  def __init__(self, name, value):
    self._name = name
    self._value = value
  
  
  def __call__(self, injest):
    # Fetch variable to judge...
    if self._name not in injest:
      return False
    v = injest[self._name]
    
    # Compare...
    return numpy.allclose(v, self._value)



###############################################################################
# matplotlib intercept...
###############################################################################



class Graph:
  """A proxy for a matplotlib graph, as in just the data passed in, with no actual rendering. Provides various helpers to write tests with."""
  def __init__(self):
    self._plot = [] # List of (x, y) tuples
    self._scatter = [] # List of (x, y) tuples
    self._scatter_kd = [] # kd trees for each scatter plot


  def add_plot(self, x, y):
    self._plot.append((x, y))


  def add_scatter(self, x, y):
    x = numpy.asarray(x)
    y = numpy.asarray(y)
    self._scatter.append((x, y))
    
    kd = KDTree(numpy.concatenate((x[:,None], y[:,None]), axis=1))
    self._scatter_kd.append(kd)
  
  
  def match_scatter(self, kd, threshold = 1e-3):
    """Given a scatter plot (as a kd tree) it returns True if there is something in this graph that's basically the same - measured in terms of everything being close, accounting for reordering. Does it as two seperate 'nearest point tests', so it doesn't mismatch due to different repeats, but it seems unlikelly that would be a problem - you certainly couldn't see it if looking at the graph!"""
    for tree in self._scatter_kd:
      _, indices_a = tree.query(kd.data, distance_upper_bound=threshold)
      _, indices_b = kd.query(tree.data, distance_upper_bound=threshold)
      
      if (indices_a<tree.n).all() and (indices_b<kd.n).all():
        return True

    return False



class MatPlotLazy:
  """Enough of the matplotlib interface to fool typical student code. Ignores all beautification and just captures the data passed in, such that it can be checked later."""
  def __init__(self):
    # List of graphs (move to next each time show() is called)...
    self._graphs = []
    
    # If True make a new graph next time a request is made...
    self._new = True
  
  
  def graphs(self):
    """generator for all of the graphs it has captured."""
    for graph in self._graphs:
      yield graph


  def plot(self, x, y = None, *args, **kw):
    # If y is missing fiddle...
    if y is None:
      y = x
      x = numpy.arange(len(y))
    
    # Create new graph if required...
    if self._new:
      self._graphs.append(Graph())
      self._new = False
    
    # Make copies...
    if isinstance(x, numpy.ndarray):
      x = x.copy()
    elif isinstance(x, list):
      x = x[:]
    
    if isinstance(y, numpy.ndarray):
      y = y.copy()
    elif isinstance(y, list):
      y = y[:]

    # Add scatter event to graph...
    self._graphs[-1].add_plot(x, y)


  def scatter(self, x, y, *args, **kw):
    # Create new graph if required...
    if self._new:
      self._graphs.append(Graph())
      self._new = False
    
    # Make copies...
    if isinstance(x, numpy.ndarray):
      x = x.copy()
    elif isinstance(x, list):
      x = x[:]
    
    if isinstance(y, numpy.ndarray):
      y = y.copy()
    elif isinstance(y, list):
      y = y[:]

    # Add scatter event to graph...
    self._graphs[-1].add_scatter(x, y)
  
  
  def show(self, *args, **kw):
    """Indicate that the next time data is provided we need to move to the next graph."""
    self._new = True


  def __getattr__(self, key):
    """An insane hack - if you call any method that is not provided this gets called, and it returns a noop method so that the call is simply ignored."""
    return self.noop


  def noop(self, *args, **kw):
    """A noop method returned whenever an unrecognised function is called; does nothing and hence makes the code just work for all of the functions that have not been replicated."""
    pass
