import os.path
import numpy
import scipy.io



class Digits_AL:
  """Provides an interface for the minist data set as bastardised by Tim for the purpose of testing an active learning algorithm."""
  def __init__(self):
    directory = os.path.dirname(os.path.abspath(__file__))
    data = scipy.io.loadmat(os.path.join(directory, 'mnist_digits.mat'), struct_as_record=True)

    self.trainVec = numpy.asarray(data['CV']['Xtr'][0,0], dtype=numpy.float32)
    self.trainAns = numpy.asarray(data['CV']['Ytr'][0,0], dtype=numpy.uint8)
    self.trainAns.shape = self.trainAns.shape[:1]
    self.trainIndex = numpy.asarray(data['Y_idx'], dtype=numpy.int32)
    self.trainIndex.shape = self.trainIndex.shape[:1]

    self.testVec = numpy.asarray(data['CV']['Xte'][0,0], dtype=numpy.float32)
    self.testAns = numpy.asarray(data['CV']['Yte'][0,0], dtype=numpy.uint8)
    self.testAns.shape = self.testAns.shape[:1]

  def getTrainVectors(self):
    """Returns the trainning vectors, as created by PCA - a numpy array of float32, where each vector is [i,:]."""
    return self.trainVec

  def getTrainClasses(self):
    """Returns the digit represented by each trainning vector - a numpy array of uint8, indexed by [i]."""
    return self.trainAns

  def getTrainIndex(self):
    """Returns the Image index associated with the trainning vector, so it can be visualised."""
    return self.trainIndex

  def getTestVectors(self):
    """Returns the test vectors, as created by PCA - a numpy array of float32, where each vector is [i,:]."""
    return self.testVec

  def getTestClasses(self):
    """Returns the digit represented by each test vector - a numpy array of uint8, indexed by [i]."""
    return self.testAns
