#! /usr/bin/env python

import os
import shutil
import math
import random
import string
from collections import defaultdict
import cPickle as pickle

from utils.prog_bar import ProgBar

from p_cat.kde_inc.loo_cov import SubsetPrecisionLOO

from dp_al.pool import Pool
from p_cat.p_cat import ClassifyKDE

from shuttle.shuttle import Shuttle_AL



# Parameters...
methods = ['p_wrong_soft', 'p_wrong_hard', 'entropy', 'outlier', 'random'] 
runs = 32
limit = 200
rk = ''.join(map(lambda _: random.choice(string.ascii_lowercase+string.digits), xrange(8)))
out_dir = 'shuttle_results_' + rk
prec_cache = 'shuttle_prec_cache.pickle'



# Load the problem...
data = Shuttle_AL()
print 'Loaded - %i trainning examples, %i testing examples'%(data.getTrainClasses().shape[0], data.getTestClasses().shape[0])



# This calculates a suitable precision matrix to use...
if os.path.exists(prec_cache):
  f = open(prec_cache,'rb')
  precision = pickle.load(f)
  f.close()
else:
  print 'Calculating loo optimal precision matrix for data set...'
  p = ProgBar()
  loo = SubsetPrecisionLOO()
  for i in xrange(data.getTrainVectors().shape[0]):
    loo.addSample(data.getTrainVectors()[i,:])
  loo.solve(256, limit, p.callback)
  precision = loo.getBest()
  del p

  f = open(prec_cache, 'wb')
  pickle.dump(precision,f,-1)
  f.close()

print 'Optimal standard deviation = %s'%str(math.sqrt(1.0/precision[0,0]))



# Define a function to run a test with a specific algorithm - will return a tuple of arrays of stats...
def do_run(method, callback = None, fout = None):
  # Create a pool and fill it...
  trainV = data.getTrainVectors()
  trainD = data.getTrainClasses()
  testV = data.getTestVectors()
  testD = data.getTestClasses()
  
  pool = Pool()
  for ii in xrange(trainV.shape[0]):
    pool.store(trainV[ii,:], trainD[ii])
    
  # Create a classifier...
  classifier = ClassifyKDE(precision)
  
  for ii in xrange(trainV.shape[0]):
    classifier.priorAdd(trainV[ii,:])
  
  
  # Setup the variables required to store the stats...
  found = [0]
  inlier = [0.0]
  conc = [1.0]
  
  classes = defaultdict(int) # For counting how many of each class has been found.
  
  # Do the active learning loop - select an item, update the model, test the model...
  
  lim = min(limit, pool.size())
  for ii in xrange(lim):
    if callback!=None: callback(ii, lim)
    
    # Update the information in the pool...
    pool.update(classifier)
    
    # Do the selection...
    sample, _, truth = pool.select(method)
    classes[truth] += 1
    
    # Update the model using the oracle...
    classifier.add(sample, truth)
    
    # Test...
    inliers = defaultdict(int)
    cat_count = defaultdict(int)
    
    for jj in xrange(testV.shape[0]):
      truth = testD[jj]
      guess = classifier.getCat(testV[jj,:])
      cat_count[truth] += 1
      if guess==truth: inliers[truth] += 1
    
    mean_rate = 0.0
    for cat in cat_count.iterkeys():
      mean_rate += float(inliers[cat]) / float(cat_count[cat])
    mean_rate /= float(len(cat_count))
    
    # Calculate the concentration as a prior on the probability of something being new...
    c = float(pool.getConcentration())
    c = c / (c + classifier.getSampleTotal())
    
    # Record the statistics...
    found.append(len(classes))
    inlier.append(mean_rate)
    conc.append(c)
    
    if fout!=None:
      fout.write('%i, %i, %.4f, %.3f\n'%(ii+1, len(classes), mean_rate, c))
      fout.flush()
   
  # Return the stats...
  return (found, inlier, conc)



# Create the output directory...
os.makedirs(out_dir)


# Loop, do and record the list of required tests...
for ii in xrange(runs):
  for method in methods:
    print 'Run %i, method %s:'%(ii, method)
    
    # Setup output...
    f = open('%s/%s_%i.csv'%(out_dir, method, ii), 'w')
    f.write('Queries, Classes Found, Inlier Rate, New Prior\n')
    
    # Run it...
    p = ProgBar()
    do_run(method, p.callback, f)
    del p
    
    # Close the output file...
    f.close()
