# Copyright 2018 Tom SF Haines

import sys
import os
import resource

import types
import copy

import io
from contextlib import redirect_stdout
import re

import json
from collections import defaultdict
import h5py

import inspect
import traceback
import functools
import ast

import numpy
from scipy.spatial import cKDTree as KDTree
import matplotlib



###############################################################################
# helpers...
###############################################################################



def limit_memory(limit = 4):
  """Limits the memory to the given number of gigabytes."""
  l = int(limit * 1024 * 1024 * 1024)
  resource.setrlimit(resource.RLIMIT_AS, (l, l))



###############################################################################
# matplotlib intercept...
###############################################################################

# Tell matplotlib to behave itself. Obviously the hope is that the matplotlib intercept means that matplotlib is never called, but because that can fail it's best to configure matplotlib to not be silly...
matplotlib.use('svg')



class Graph:
  """Represents a graph, as in all of the data fed into matplotlib.pyplot captured at the exact moment .show() is called."""
  
  def __init__(self):
    self._plot = [] # tuples of (x, y)
    self._scatter = [] # kd trees over the points (use tree.data to get original)
    self._hist = [] # Just the data array, sorted.
    self._imshow = [] # numpy array, either (height, width) or (height, width, 3/4 channels)


  def add_plot(self, x, y):
    self._plot.append((x, y))


  def add_scatter(self, x, y):
    x = numpy.asarray(x).ravel()
    y = numpy.asarray(y).ravel()
    kd = KDTree(numpy.concatenate((x[:,None], y[:,None]), axis=1))
    self._scatter.append(kd)
  
  
  def add_hist(self, x):
    try:
      self._hist.append(numpy.sort(x))
      
    except ValueError:
      for arr in x:
        self._hist.append(numpy.sort(arr))
  
  
  def add_imshow(self, image):
    self._imshow.append(numpy.asarray(image))


  def match_scatter(self, kd, threshold = 1e-3):
    """Given a scatter plot (as a KDTree) it returns True if there is something in this graph that's basically the same - measured in terms of everything being close, accounting for reordering. Does it as two seperate 'nearest point tests', so it doesn't mismatch due to different repeats, but it seems unlikelly that would be a problem - you certainly couldn't see it if looking at the graph!"""
    for tree in self._scatter:
      _, indices_a = tree.query(kd.data, distance_upper_bound=threshold)
      _, indices_b = kd.query(tree.data, distance_upper_bound=threshold)
      
      if (indices_a<tree.n).all() and (indices_b<kd.n).all():
        return True

    return False
  
  
  def match_histo(self, conditions):
    """Given a dictionary each key is a threshold, each value a two value tuple of (min, max). It specifies the condition that the number of items that are less than key is within the given range, inclusive. It goes through all histograms and calculates how many conditions are matched, returning the largest count."""
    ret = 0
    
    for hist in self._hist:
      matches = 0
      for threshold, (cmin, cmax) in conditions.items():
        v = (hist<threshold).sum()
        if cmin<=v and v<=cmax:
          matches += 1
      
      if matches>ret:
        ret = matches
    
    return ret



class PyPlotIntercept:
  """Enough of the matplotlib.pyplot interface to fool typical student code. Ignores all beautification and just captures the data passed in, such that it can be checked later."""
  
  def __init__(self):
    self._subplot = 111
    self._graph = dict() # Indexed _subplot -> Graph(); for collection before .show()
    self._collected = [] # List of all Graph() objects, grabbed every time .show() is called.
  
  
  def collect(self):
    """Returns a list of all Graph() objects collected since this method was last called. Will not include any calls that have not been subject to .show()."""
    ret = self._collected
    self._collected = []
    return ret
  
  
  def plot(self, x, y = None, *args, **kw):
    # If y is missing fiddle...
    if y is None:
      y = x
      x = numpy.arange(len(y))
    
    # Create new graph if required...
    if self._subplot not in self._graph:
      self._graph[self._subplot] = Graph()
    
    # Make copies...
    x = copy.deepcopy(x)
    y = copy.deepcopy(y)

    # Record...
    self._graph[self._subplot].add_plot(x, y)


  def scatter(self, x, y, *args, **kw):
    # Create new graph if required...
    if self._subplot not in self._graph:
      self._graph[self._subplot] = Graph()
    
    # Make copies...
    x = copy.deepcopy(x)
    y = copy.deepcopy(y)

    # Record...
    self._graph[self._subplot].add_scatter(x, y)
  
  
  def hist(self, x, *args, **kw):
    # Create new graph if required...
    if self._subplot not in self._graph:
      self._graph[self._subplot] = Graph()
    
    # Record...
    self._graph[self._subplot].add_hist(x)
  
  
  def imshow(self, image, *args, **kw):
    # Create new graph if required...
    if self._subplot not in self._graph:
      self._graph[self._subplot] = Graph()
    
    # Make copy...
    image = copy.deepcopy(image)
    
    # Record...
    self._graph[self._subplot].add_imshow(image)
  
  
  def subplot(self, *args, **kw):
    """Makes sure that each subplot ends up as a seperate Graph object."""
    self._subplot = 0
    for value in args:
      self._subplot *= 10
      self._subplot += value
  
  
  def show(self, *args, **kw):
    """Collect all graphs and reset ready for the next."""
    for graph in self._graph.values():
      self._collected.append(graph)
    
    self._subplot = 111
    self._graph = dict()


  def __getattr__(self, key):
    """An insane hack - if you call any method that is not provided this gets called, and it returns a noop method so that the call is simply ignored."""
    return self.noop


  def noop(self, *args, **kw):
    """A noop method returned whenever an unrecognised function is called; does nothing and hence makes the code just work for all of the functions that have not been replicated."""
    pass



###############################################################################
# Loading...
###############################################################################

def victim():
  """Returns the filename from the command line; handles errors."""
  if len(sys.argv)!=2:
    print('Critical: File to read not provided')
    print()
    sys.exit(1)
  
  return sys.argv[1]



class Namespace:
  """Dummy class, used as a namespace for code when it is loaded. Declared empty so that loaded code can use it as storage for its global variables."""
  pass



def clone_namespace(ns):
  """Clones a Namespace object; seperate to avoid contaminating said namespace. Does it's best to be smart about how it does the copy, so it avoids copying things that make no sense to copy and edits some to preserve expected behaviour."""
  ret = Namespace()
  
  for key, value in ns.__dict__.items():
    if isinstance(value, (types.CodeType, types.MethodType, types.BuiltinFunctionType, types.BuiltinMethodType, types.ModuleType, types.TracebackType, h5py._hl.files.File, io.TextIOBase)):
      ret.__dict__[key] = value
    
    elif isinstance(value, (types.FunctionType, types.LambdaType, types.GeneratorType)):
      func = type(value)(value.__code__, ret.__dict__, value.__name__, value.__defaults__)
      
      func.__annotations__ = value.__annotations__
      func.__doc__ = value.__doc__
      func.__kwdefaults__ = value.__kwdefaults__
      func.__module__ = value.__module__
      
      ret.__dict__[key] = func
      
    else:
      try:
        ret.__dict__[key] = copy.deepcopy(value)
          
      except TypeError: # Some types can't be deep copied
        ret.__dict__[key] = value
  
  return ret



class Cell:
  """Represents a single cell of a Jupyter workbook - all of its code and outputs from execution in other words."""
  DeleteMe = 'The Black Knight'
  
  def __init__(self, index, code = None):
    """Constructed with the cell index and it's code, as a string. Alternatively, if index is another cell it will be cloned, which is primarily used so the state can be modified before use."""
    if code is None:
      # Clone...
      other = index
      
      self._index = other._index
    
      self._code = other._code
      self._ast = other._ast
      self._state = copy.copy(other._state)
    
      self._output = other._output
      self._graph = other._graph
    
      self._error = other._error
      self._tb = other._tb
      
      self._stack = list(self._stack)
    
    else:
      # Actually new...
      self._index = index

      # Variables to record everything that may be added...
      self._code = code # Code, as a string.
      self._ast = None
      self._state = None # Complete copy of all state immediately after the cell ran
    
      self._output = None # Text output from running cell - all calls to print()
      self._graph = [] # Graph() objects, from matplotlib.pyplot intercept
    
      self._error = None # Name of error object, or None if no error occured
      self._tb = None # Traceback of error, as a sting, if one occured
      
      self._stack = []


  def execute(self, namespace, ppi, reactions = []):
    """Given a Namespace() and a PyPlotIntercept() this runs the code and records everything in the cell. Should typically be called once only, immediately after construction. The namespace will be updated with the state after the cell, noting that it takes a complete copy for itself so future edits do not break judges who need to see intermediate state. You can optionally provide reactions, a list of functions that are called on the state (only parameter) after each cell to make whatever changes they feel like."""
    
    # Actual execution, with error and output intercept...
    buf = io.StringIO()
    error = None
      
    with redirect_stdout(buf):
      try:
        comp = compile(self._code, '<cell {}>'.format(self._index), 'exec')
        exec(comp, namespace.__dict__)
    
      except Exception as e:
        self._error = type(e).__name__
        self._tb = traceback.format_exc(4)
    
    # Swap in the matplotlib intercept as required...
    for plt in ['plt', 'pyplot']:
      if plt in namespace.__dict__ and matplotlib.pyplot is namespace.__dict__[plt]:
        namespace.__dict__[plt] = ppi

    # Run all reactions...
    for reaction in reactions:
      reaction(namespace.__dict__)
    
    # Record post-execution state, taking a copy so the state at each step is avaliable...
    self._state = clone_namespace(namespace)
    
    # Record output...
    self._output = buf.getvalue()
    self._graph = ppi.collect()
    
    # Report any errors...
    if self._error is not None:
      print('Severe: Error executing code block {}: {}'.format(self._index, self._error))
      for line in self._tb.split('\n'):
        print('Info: {}'.format(line))
      print()
  
  
  def index(self):
    """Returns the index of the cell; note that because of markdown cells these are not always contiguous."""
    return self._index


  def code(self):
    """Returns the cell's code, as a string."""
    return self._code
  
  
  def ast(self):
    """Returns the parse tree of the code, as generated by the ast module."""
    if self._ast is None:
      self._ast = ast.parse(self._code)
    return self._ast
  
  
  def state(self):
    """Returns a dictionary containing the exact state of the global variables immediately after execution of the code."""
    return self._state.__dict__
  
  
  def output(self):
    """Returns the output, as generated by any calls to print within the code."""
    return self._output
  
  
  def graph(self):
    """Returns a list of Graph() objects, one for each graph output by the code."""
    return self._graph
  
  
  def error(self):
    """Returns True if there was an error running the code of this cell."""
    return self._error is not None
  
  
  def push(self, delta):
    """Given a dictionary of variable name : new value this updates the state accordingly; noirmnally you then pop(0 to revert to the previous state afterwards.."""
    previous = {}
    for key, value in delta.items():
      previous[key] = self._state.__dict__.get(key, self.DeleteMe)
      self._state.__dict__[key] = value

    self._stack.append(previous)
  
  
  def pop(self):
    """Reverses a push()"""
    delta = self._stack.pop()
    
    for key, value in delta.items():
      if id(value)==id(self.DeleteMe) and key in self._state.__dict__:
        del self._state.__dict__[key]
      
      else:
        self._state.__dict__[key] = value



class Notebook:
  """This reads in a Jupyter Notebook, executes it, and then provides an interface to access all of the cells."""
  
  def __init__(self, fn = None, cwd='.', reactions = []):
    # Use filename from command line if not provided...
    if fn is None:
      fn = victim()
    
    # Load in file...
    with open(fn, 'r') as fin:
      ipynb = json.load(fin)
    
    # State needed for the comming insanity...
    namespace = Namespace()
    ppi = PyPlotIntercept()
    
    # Move to a directory that contains files the script needs to load...
    cwd = os.path.abspath(cwd)
    
    owd = os.getcwd()
    opath = sys.path[:]
    
    os.chdir(cwd)
    sys.path.insert(0, cwd)
    
    # Loop and collect cell objects...
    self._cell = []
    for index, cell in enumerate(ipynb['cells']):
      # Skip if not code...
      if cell['cell_type']!='code':
        continue
      
      # Need code as a straight string...
      code = []
      for line in cell['source']:
        if line.lstrip().startswith('%'):
          continue
        line = line.rstrip()
        
        if len(code)>0 and code[-1].endswith('\\'):
          code[-1] = code[-1][:-1] + line
        
        else:
          code.append(line)
      
      code = '\n'.join(code) + '\n'
      
      # Create and execute cell object...
      c = Cell(index, code)
      c.execute(namespace, ppi, reactions)
      self._cell.append(c)
    
    # Restore original directory...
    sys.path = opath
    os.chdir(owd)
    
    # Print out a statistic...
    print('Info: Ran {} cells'.format(len(self._cell)))
  
  
  def __len__(self):
    """So len gives how many cells have been captured,"""
    return len(self._cell)
  
  
  def cells(self):
    """Iterator over all cells."""
    for c in self._cell:
      yield c
  
  
  def state(self):
    """Returns the dictionary of the state at the end of the notebook, for ease of access."""
    return self._cell[-1].state()
  
  
  def print_cells(self, start = None, stop = None):
    """Prints out everything that was printed by the code when it was executed, with suitable headings/escapes. Can also pass start/stop to limit range."""
    cells = self._cell[start:stop]
    
    for ci, cell in enumerate(cells):
      print('Cell: {}'.format(ci + (start if start is not None else 0)))
      
      for line in cell.output().split('\n'):
        print('Print: {}'.format(line))
      
      print()



###############################################################################
# Question object...
###############################################################################



class Judge:
  """Interface for a judge - represents a single unit test of the code that results in a pass (True), fail (False) or confusion (None)."""
  
  def __call__(self, cell):
    """Given a cell this returns True if the code passes, False if it fails, None if it can't tell."""
    raise NotImplementedError



class Question:
  """Combines many Judge's to determine a final mark for a single question, with support for various rules. A functor, so create, configure, then call with a notebook object and it will print out the mark for the question."""

  def __init__(self, name, maximum):
    """Initalised with the name of the question, which is usually the question number, and the maximum number of marks it is worth). The maximum mark count is only used as a check that the groups sum to this, to avoid mistakes."""
    self._name = name
    self._maximum = maximum
    
    self._worth = defaultdict(lambda: (1.0, 0.5)) # marks, resolution
    self._mode = defaultdict(lambda: 'down')
    self._panel = defaultdict(list) # Always loop this dicts keys.


  def worth(self, group, worth, resolution = 0.5):
    """Sets how many marks a group is worth. If not set for any group it defaults to 1. There is also the resolution value, which indicates how fine grained the marking is. Defaults to 0.5, so it will award no more precisly than half marks."""
    self._worth[group] = (worth, resolution)


  def mode(self, group, mode):
    """Sets the mode of a given group. By default they are all set to 'down', which means the mark is given as the ratio of passed tests multiplied by the group mark and then rounded down to the nearest quantisation step. Other options are: 'up', scaled ratio but rounded up; 'any', full marks if any test is passed; 'all', no marks if any test fails, otherwise all."""
    assert(mode in ['all', 'down', 'up', 'any'])
    self._mode[group] = mode
  
  
  def add(self, group, judge, evidence = 'all'):
    """Adds a Judge to a group. There is nothing to stop you adding the same Judge to multiple groups with several calls (and it's smart enough to only evaluate the Judge once). The optional evidence parameter tells it which Cell's to run the judge on: 'all', the default, means run on all cells and take the best result [True>None>False>error]. 'end' means to only run on the last cell, usually because you only care about the final state. You can also provided a function/functor that takes a cell and returns True/False, with it only running on True (good for speed if the Judge is slow and you can identify the one cell to run it on; treats None as False so that a Judge object also works!). It's smart enough to stop early if it gets a True result."""
    self._panel[group].append((judge, evidence))
  
  
  def __judgement(self, notebook, judge, evidence):
    """Internal wrapper for making a judgement - handles running a Judge on all relevant cells and selecting the best judgement. Also captures any errors, reporting them only if it always errors."""
    results = []
    last_error = None
    last_i = len(notebook) - 1
    
    for i, cell in enumerate(notebook.cells()):
      # Determine if we should run this...
      if evidence!='all':
        if evidence=='end':
          if i<last_i:
            continue
        
        elif evidence(cell)!=True:
          continue

      # Commence judgement!..
      try:
        j = judge(cell)
        results.append(j)
        
        if results[-1]==True: # Passed - no point testing further cells
          break
        
      except:
        last_error = traceback.format_exc(4)

    # Report any errors...
    if last_error is not None:
      print('Info: Judge {} only produced errors'.format(str(judge)))
      for line in last_error.split('\n'):
        print('Info: {}'.format(line))
      return None
    
    # Summarise what we have learned...
    if len(results)>0:
      # Return best result...
      if True in results:
        return True
      if None in results:
        return None
      return False
    
    # Total failure...
    print('Info: Judge {} received no evidence'.format(str(judge)))
    return None


  def __call__(self, notebook, decimals = 2):
    """Runs the Judge-s, calculates the final score, and prints it out. Requires a notebook object as a parameter, so it can pass go through the cells and pass them to the judge to be run."""
    print('Question:', self._name)
    
    # Verify that the maximum score matches up with the groups...
    maximum = 0.0
    for group in self._panel.keys():
      maximum += self._worth[group][0]
    
    if not numpy.allclose(maximum, self._maximum):
      print('Critical: Scores do not add up.')
      print()
      return
    
    # Cache so any Judge added more than once is only run once...
    cache = dict() #  Key is (id(judge), evidence) where strings are kept as strings but anything else has id() called on it.

    # Loop and sum score of each group...
    total_low = 0.0
    total_high = 0.0
    for group in self._panel.keys():
      # Get array of True/False judgements...
      panel = self._panel[group]
      judgement = [None] * len(panel)
      
      for i, (judge, evidence) in enumerate(panel):
        key = (id(judge), evidence if isinstance(evidence, str) else id(evidence))
        
        if key in cache:
          judgement[i] = cache[key]
        
        else:
          judgement[i] = self.__judgement(notebook, judge, evidence)
          cache[key] = judgement[i]
      
      # Calculate score depending on mode...
      mode = self._mode[group]
      worth, resolution = self._worth[group]
      
      if mode=='all':
        if False in judgement:
          score_low = 0
          score_high = 0
          
        elif None in judgement:
          score_low = 0
          score_high = worth
          
        else:
          score_low = worth
          score_high = worth
      
      elif mode=='down':
        ratio_low = judgement.count(True) / len(judgement)
        score_low = numpy.floor(worth * ratio_low / resolution) * resolution
        
        ratio_high = 1.0 - judgement.count(False) / len(judgement)
        score_high = numpy.floor(worth * ratio_high / resolution) * resolution
      
      elif mode=='up':
        ratio_low = judgement.count(True) / len(judgement)
        score_low = numpy.ceil(worth * ratio_low / resolution) * resolution
        
        ratio_high = 1.0 - judgement.count(False) / len(judgement)
        score_high = numpy.ceil(worth * ratio_high / resolution) * resolution
      
      elif mode=='any':
        if True in judgement:
          score_low = worth
          score_high = worth
          
        elif None in judgement:
          score_low = 0
          score_high = worth
          
        else:
          score_low = 0
          score_high = 0
      
      else:
        print('Critical: Unrecognised group mode for group {}'.format(group))
        print()
        return
      
      # Add in score, and print it out...
      total_low += score_low
      total_high += score_high
      if group is not None:
        if numpy.isclose(score_low, score_high):
          print('Submark: {} = {:g}'.format(group, score_low))
          
        else:
          print('Submark: {} = {:g}--{:g}'.format(group, score_low, score_high))
    
    # Report final mark...
    total_low = numpy.around(total_low, decimals)
    total_high = numpy.around(total_high, decimals)
    if numpy.isclose(total_low, total_high):
      print('Mark: {:g} / {:g}'.format(total_low, maximum))
      
    else:
      print('Mark: {:g}--{:g} / {:g}'.format(total_low, total_high, maximum))

    print()



###############################################################################
# Dummy and core judges...
###############################################################################



class Dredd(Judge):
  """Stand in for when a civilised Judge is not coded and you're feeling like a fascist: Always finds the code guilty."""
  
  def __call__(self, cell):
    return False



class MrBean(Judge):
  """Distracts the judge so they forget to pass judgement, leaving it to the meat sack."""
  
  def __call__(self, cell):
    return None



class Picard(Judge):
  """For when the universe needs saving so everyone just gets the marks. Don't tell the university."""
  
  def __call__(self, cell):
    return True



class Any(Judge):
  """Given a list of Judges to the constructor this returns the best result and Judge returns."""
  
  def __init__(self, *judges):
    self._judges = judges
  
  
  def __call__(self, cell):
    judgements = [judge(cell) for judge in self._judges]
    
    if True in judgements:
      return True
    
    if None in judgements:
      return None
    
    return False



class All(Judge):
  """Given a list of Judges to the constructor this returns the best result for which they all agree, on a per cell basis. Lets you combine Judges with a 'same cell' constraint."""
  
  def __init__(self, *judges):
    self._judges = judges
  
  
  def __call__(self, cell):
    judgements = [judge(cell) for judge in self._judges]
    
    if False in judgements:
      return False
    
    if None in judgements:
      return None
    
    return True



class Uncertain(Judge):
  """Wraps a judge and converts any return of False to be None, in effect forcing it to always be uncertain unless it can identify it as correct. For forcing a meat sack to check things if it's not possible to identify all possible right answers."""
  
  def __init__(self, judge):
    self._judge = judge
  
  def __call__(self, cell):
    ret = self._judge(cell)
    if ret==False:
      ret = None
    return ret
  


###############################################################################
# Print judges...
###############################################################################



class PrintRE(Judge):
  """Runs a regular expression on the output of a cell and passes if it matches. Can also provide a list, in which case it passes if all of them are found (ignoring order)."""
  
  def __init__(self, exp):
    if isinstance(exp, list):
      self._patterns = [re.compile(e) for e in exp]
    
    else:
      self._patterns = [re.compile(exp)]


  def __call__(self, cell):
    output = cell.output()
    
    for pattern in self._patterns:
      if pattern.search(output) is None:
        return False
    
    return True



class PrintGreaterThan(Judge):
  """Uses a regular expression to extract a number, and returns True if it's larger than a specific value, False if less than another, None if between."""
  
  def __init__(self, exp, good, bad = None):
    """exp is the regular exprfession; must have a group with the name 'num'. good is the threshold for the number to pass to return True. bad is optional; if provided less than it results in a False and between the two is seen as None (human to check). If not provided defaults to good."""
    self._pattern = re.compile(exp)
    self._good = good
    self._bad = bad if bad is not None else good
  
  
  def __call__(self, cell):
    best = False
    
    for match in self._pattern.finditer(cell.output()):
      num = float(match.group('num'))
      
      if num>=self._good:
        best = True
      
      if num>=self._bad and best!=True:
        best = None
    
    return best



###############################################################################
# Variable judges...
###############################################################################



class VarClose(Judge):
  """Checks if a variable is close to a given value, using numpy.allclose so it handles ndarrays sensibly."""
  
  def __init__(self, name, value):
    """Constructed with name of variable followed by the value is must be close to."""
    self._name = name
    self._value = value
  
  
  def __call__(self, cell):
    # Fetch variable to judge...
    if self._name not in cell.state():
      return False
    v = cell.state()[self._name]
    
    # Compare...
    return numpy.allclose(v, self._value)



class VarsClose(Judge):
  """Same as VarClose except you give it a dictionary where each key is a variable name and each value the corresponding value. Again uses numpy.allclose for sensible numpy behaviour."""
  
  def __init__(self, vd):
    """Given a dictionary of variable name : value pairs."""
    self._vd = vd
  
  
  def __call__(self, cell):
    for name, value in self._vd.items():
      if name not in cell.state():
        return False
      v = cell.state()[name]
      
      if not numpy.allclose(v, value):
        return False
    
    return True



###############################################################################
# Function judges...
###############################################################################



class FuncClose(Judge):
  """Calls a function with a specified set of parameters and returns True only if they are close to the return, as defined by numpy.allclose"""
  def __init__(self, ret, func, *args, **kwds):
    self._ret = ret
    self._func = func
    self._args = args
    self._kwds = kwds
  
  
  def __call__(self, cell):
    # Fetch function to judge...
    if self._func not in cell.state():
      return False
    f = cell.state()[self._func]
    
    # Execute and get return...
    try:
      with io.StringIO() as buf, redirect_stdout(buf):
        r = f(*self._args, **self._kwds)
    
    except:
      for line in traceback.format_exc(4).split('\n'):
        print('Info: {}'.format(line))
      return None

    # Check...
    return numpy.allclose(r, self._ret)



class FuncEqual(Judge):
  """Identical to FuncClose, except you provide the function(lhs, rhs) that decides if the results are a match (True) or not (False)."""
  def __init__(self, equal, ret, func, *args, **kwds):
    self._equal = equal
    self._ret = ret
    self._func = func
    self._args = args
    self._kwds = kwds
  
  
  def __call__(self, cell):
    # Fetch function to judge...
    if self._func not in cell.state():
      return False
    f = cell.state()[self._func]
    
    # Execute and get return...
    try:
      with io.StringIO() as buf, redirect_stdout(buf):
        r = f(*self._args, **self._kwds)
    
    except:
      for line in traceback.format_exc(4).split('\n'):
        print('Info: {}'.format(line))
      return None

    # Check...
    return self._equal(r, self._ret)



class FuncMatch(Judge):
  """Calls two functions with the same parameters and checks that they return the same output."""
  
  def __init__(self, fun1, fun2, *args, **kwds):
    """Created with the name of two functions followed by arguments / keywords."""
    self._fun1 = fun1
    self._fun2 = fun2
    self._args = args
    self._kwds = kwds
  
  
  def __call__(self, cell):
    # Fetch functions to judge...
    if self._fun1 not in cell.state():
      return False
    f1 = cell.state()[self._fun1]
    
    if self._fun2 not in cell.state():
      return False
    f2 = cell.state()[self._fun2]
    
    # Execute and get return for both...
    try:
      with io.StringIO() as buf, redirect_stdout(buf):
        r1 = f1(*self._args, **self._kwds)
        r2 = f2(*self._args, **self._kwds)
    
    except:
      for line in traceback.format_exc(4).split('\n'):
        print('Info: {}'.format(line))
      return None
    
    # Check...
    return numpy.allclose(r1, r2)



###############################################################################
# Wrapping judges...
###############################################################################



class WrapPatch(Judge):
  """Wraps another judge, modifying the cell before handling it over by introducing/replacing some variables."""
  
  def __init__(self, delta, inner):
    """delta is a dictionary of {variable : new/replacement value} while inner is the judge being wrapped."""
    self._delta = delta
    self._inner = inner
  
  
  def __call__(self, cell):
    # Adjust cell...
    cell.push(self._delta)
    
    # Go deeper...
    ret = self._inner(cell)
    
    # Revert cell and return...
    cell.pop()
    return ret



###############################################################################
# Code judges...
###############################################################################



class CodeMatch(Judge):
  """Returns True if it can match a search expression, False if it cannot. The search expression is provided as a list of terms, each a string - it's a match if it can find a [-1] match within a [-2] match within a [-3] match and so on to [0] (very css). Here is a list of supported expressions:
  'class <name>' - matches a class with the given name, or any if name omitted, then searches within the class.
  'def <name>' - matches a function declaration with the given name, or any if omitted, then searches within function definition.
  '<name> =' - matches an assignment to the given variable; searches within the expression being assigned.
  'name(<paramters>)' - matches any call to a function/method with the given name, with optional parameter information. Must be the end of a search chain.
  <parameters> = comma seperated list, of the parameters of the function (the keywords). Provides no constraint except when default values are given with the usual syntax, in which case those 'defaults' become required equalities for a match.
  'if,for,while' - matches any expression that contains the given comma seperated list of keywords. Search continues inside matched term - (inc. else part)
  '!' - swaps things over, so matching everything after is a failure, but only when the terms before are matched.
  Note that most of the above are coded to be as general as possible and are hardly bullet proof - there is a good chance of this judge giving marks in cases of crazy code that's almost right."""
  
  def __init__(self, expr):
    self._expr = [term.strip() for term in expr]
  
  
  def __match_class(self, tree, expr):
    if expr[0]=='class': # Any class!
      for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef):
          for stmt in node.body:
            if self.__match(stmt, expr[1:]):
              return True
    
    else:
      term = expr[0].split()
      if len(term)!=2:
        print('Severe: Code match class term "{}" did not parse'.format(expr[0]))
      
      for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef) and node.name==term[1]:
          for stmt in node.body:
            if self.__match(stmt, expr[1:]):
              return True

    return False


  def __match_def(self, tree, expr):
    if expr[0]=='class': # Any class!
      for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
          for stmt in node.body:
            if self.__match(stmt, expr[1:]):
              return True
            
    else:
      term = expr[0].split()
      if len(term)!=2:
        print('Severe: Code match def term "{}" did not parse'.format(expr[0]))
      
      for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef) and node.name==term[1]:
          for stmt in node.body:
            if self.__match(stmt, expr[1:]):
              return True

    return False


  def __match_assign(self, tree, expr):
    vname = expr[0][:-1].strip()
    
    for node in ast.walk(tree):
      if isinstance(node, ast.Assign):
        found = False
        for lhs in node.targets:
          for child in ast.walk(lhs):
            if isinstance(child, ast.Name) and child.id==vname:
              found = True
              break
            
          if found:
            break
        
        if found and self.__match(node.value, expr[1:]):
          return True

    return False


  def __match_call(self, tree, expr):
    name, params = expr[0][:-1].split('(', 1)
    name = name.strip()
    params = [term.strip() for term in params.split(',')]
    
    for node in ast.walk(tree):
      if isinstance(node, ast.Call) and ((isinstance(node.func, ast.Name) and node.func.id==name) or (isinstance(node.func, ast.Attribute) and node.func.attr==name)):
        match = True
        for index, param in enumerate(params):
          if '=' in param:
            keyword, value = [term.strip() for term in param.split('=', 1)]
            
            # Find relevant expression in tree...
            if index < len(node.args):
              expr = node.args[index]
            
            else:
              good = False
              for kw in node.keywords:
                if kw.arg==keyword:
                  good = True
                  expr = kw.value
            
              if not good:
                print('instance not found')
                match = False
                break
            
            # Use ast to do the match...
            val = ast.parse(value, '<match>', 'eval').body

            if type(val) != type(expr):
              match = False
              break
            
            if isinstance(val, ast.Num) and not numpy.isclose(val.n, expr.n):
              match = False
              break
            
            if isinstance(val, ast.Str) and val.s!=expr.s:
              match = False
              break
        
        if match:
          return True
    
    return False
  
  
  def __match_keywords(self, tree, expr):
    match = []
    for word in expr[0].split(','):
      word = word.strip()
      if word=='if':
        match.append(ast.If)
      
      elif word=='for':
        match.append(ast.For)
      
      elif word=='while':
        match.append(ast.While)
      
      else:
        print('Severe: Code match keyword unrecognised: "{}"'.format(word))
    
    match = tuple(match)
    for node in ast.walk(tree):
      if isinstance(node, match):
        for stmt in node.body:
          if self.__match(stmt, expr[1:]):
            return True
        
        for stmt in node.orelse:
          if self.__match(stmt, expr[1:]):
            return True
    
    return False
  
  
  def __match(self, tree, expr):
    if len(expr)==0:
      return True

    elif expr[0].startswith('class'):
      return self.__match_class(tree, expr)

    elif expr[0].startswith('def'):
      return self.__match_def(tree, expr)

    elif expr[0].endswith('='):
      return self.__match_assign(tree, expr)
    
    elif expr[0].endswith(')'):
      return self.__match_call(tree, expr)
    
    elif expr[0]=='!':
      return not self.__match(tree, expr[1:])
    
    else: # keywords is only option left
      return self.__match_keywords(tree, expr)


  def __call__(self, cell):
    ret = self.__match(cell.ast(), self._expr)
    return ret



class CodeIterCap(Judge):
  """Calls a function and counts how many iterations go on in it due to loops. Only passes if this number comes under a threshold."""
  
  def __init__(self, limit, name, *args, **kwds):
    """Maximum number of iterations (inclusive, for any loop) is the first parameter, name the second, and all further parameters are fed into the function, including keywords."""
    self._limit = limit
    self._name = name
    self._args = args
    self._kwds = kwds
    self.noisy = True
  

  def __call__(self, cell):
    # Check function is defined...
    for node in ast.walk(cell.ast()):
      if isinstance(node, ast.FunctionDef) and node.name==self._name:
        func = node
        break

    else:
      return False

    # Modify to count - this is insane...
    func = instrument_loops(func, cell.state())
    
    # Call...
    loops = defaultdict(int)
    kwds = self._kwds.copy()
    kwds['loops'] = loops
    
    try:
      with io.StringIO() as buf, redirect_stdout(buf):
        func(*self._args, **kwds)
    except:
      for line in traceback.format_exc(4).split('\n'):
        print('Info: {}'.format(line))
      return None

    # Check if the limit has been exceded...
    if self.noisy:
      print('Info: Loop counts = {}'.format('; '.join(['L{} did {} reps'.format(k, v) for k,v in loops.items()])))

    for value in loops.values():
      if value>self._limit:
        return False
    return True



class CodeDelta(Judge):
  """Given two fragments of code - matching one is considered failure, matching the other is considered success. Matching neither is kicked to the meat sack to check. matches the parse tree, so comments etc. are all ignored - for when the student is meant to edit some code in a very specific way."""
  
  def __init__(self, incorrect, correct):
    """Two blobs of code, noting that they must compile. Can use None if you want to disable a check or a list ot provide multiple options."""
    if incorrect is None:
      self._incorrect = []
      
    elif isinstance(incorrect, list):
      self._incorrect = [ast.parse(frag).body for frag in incorrect]
      
    else:
      self._incorrect = [ast.parse(incorrect).body]
    
    if correct is None:
      self._correct = []
      
    elif isinstance(correct, list):
      self._correct = [ast.parse(frag).body for frag in correct]
      
    else:
      self._correct = [ast.parse(correct).body]


  @staticmethod
  def __match(a, b, tail = False):
    """Internal helper - given two code blobs from the ast parse tree returns True if they match, False otherwise. Does not handle root being a list."""
    if type(a)!=type(b):
      return False
    
    if not isinstance(a, ast.AST):
      return a==b
    
    if set(a._fields)!=set(b._fields):
      return False
    
    for field in a._fields:
      af = getattr(a, field)
      bf = getattr(b, field)
      
      lsta = isinstance(af, list)
      lstb = isinstance(bf, list)
      
      if lsta!=lstb:
        return False
      
      if lsta:
        if not tail and len(af)!=len(bf):
          return False
        
        for al, bl in zip(af, bf):
          if not CodeDelta.__match(al, bl, tail):
            return False
      
      else:
        if not CodeDelta.__match(af, bf, tail):
          return False

    return True
  
  
  @staticmethod
  def __match_within(space, blob):
    """Returns True if it finds the blob within the space. blob must be a list of statements."""
    
    if not isinstance(space, ast.AST):
      return False
    
    for name in space._fields:
      field = getattr(space, name)
      
      if isinstance(field, list):
        # Within statement list...
        to = len(field) - len(blob) + 1
        if to>0:
          for offset in range(to):
            good = True
            for bl in range(len(blob)):
              if not CodeDelta.__match(field[offset+bl], blob[bl], bl+1==len(blob)):
                good = False
                break
            
            if good:
              return True
        
        # Within something within the statement list...
        for line in field:
          if CodeDelta.__match_within(line, blob):
            return True
      
      else:
        if CodeDelta.__match_within(field, blob):
          return True
    
    return False


  def __call__(self, cell):
    for ast in self._correct:
      if CodeDelta.__match_within(cell.ast(), ast):
        return True
    
    for ast in self._incorrect:
      if CodeDelta.__match_within(cell.ast(), ast):
        return False

    return None



###############################################################################
# The crazy, the weird and the ugly...
###############################################################################



def instrument_loops(func, namespace = None):
  """A decorator that adjusts a function to include an extra keyword parameter, 'loops', to which a collections.defaultdict(int) *must* be passed, so it can count how many times it runs through each loop. Keys are line numbers within the function, i.e. it adds loops[line number] += 1 to the code where line number is 1 for the first line of the function etc. Has a very long list of failure modes - defining functions within functions for instance. Also lets you override the namespace it compiles in, instead of using globals(), and you can pass an ast.FunctionDef in instead of the actual function which works in scenarios the actual function does not. Yes, this is madness; it is recommended you bite down on a stick before reading on."""

  # Fetch code and compile half-way, to an abstract syntax tree...
  if not isinstance(func, ast.FunctionDef):
    source = inspect.getsource(func)
    tree = ast.parse(source)
    
    assert(isinstance(tree.body[0], ast.FunctionDef))
    func = tree.body[0]
  
  # Need to remove self from decorator list or we are going to infinite loop (well, actually crash due to the lack of source code second time around, which breaks inspect)...
  if len(func.decorator_list)>0:
    func.decorator_list = func.decorator_list[:-1]
  
  # Drop the 'loops' keyword argument into the function definition...
  args = func.args
  args.kwonlyargs.append(ast.arg(arg='loops', annotation=None))
  args.kw_defaults.append(ast.NameConstant(value=None))
  
  # Go for a walk and mess with every for and while loop...
  for node in ast.walk(func):
    if isinstance(node, (ast.For, ast.While)):
      inc = ast.AugAssign(target=ast.Subscript(value=ast.Name(id='loops', ctx=ast.Load()), 
                                               slice=ast.Index(value=ast.Num(n=node.lineno-2)), ctx=ast.Store()),
                          op=ast.Add(),
                          value=ast.Num(n=1))
      node.body.insert(0, inc)
  
  # Fix line numbers...
  ast.fix_missing_locations(func)
  
  # Finish compilation and return the new function...
  if namespace is None:
    namespace = globals()
  namespace = namespace.copy() # Don't want to overwrite the original function - not this codes job
  code = compile(ast.Module([func]), '<instrumented_loops>', 'exec')
  exec(code, namespace)
  
  return namespace[func.name]
