#! /usr/bin/env python3

import numpy

import automark2 as am



# Load code...
notebook = am.Notebook(cwd='../Machine Learning 1/Labs/04 - Coffee Machine')



# Q1 - Infering distributions...

# Answers...
maxl = {}
maxl['P_he'] = numpy.array([0.9904747, 0.0095253])
maxl['P_fp'] = numpy.array([0.80056381, 0.19943619])
maxl['P_fc'] = numpy.array([0.98004913, 0.01995087])
maxl['P_wr'] = numpy.array([0.89957047, 0.10042953])
maxl['P_gs'] = numpy.array([0.95010757, 0.04989243])
maxl['P_dp'] = numpy.array([0.94968414, 0.05031586])
maxl['P_fh'] = numpy.array([0.97034073, 0.02965927])
maxl['P_pw_he_fp'] = numpy.array([[[0., 1.], [1., 1.]], [[1., 0.], [0., 0.]]])
maxl['P_cb_pw_fc'] = numpy.array([[[1., 1.], [0., 1.]], [[0., 0.], [1., 0.]]])
maxl['P_gw_cb_wr_dp'] = numpy.array([[[[1. , 1. ], [1. , 1. ]], [[0. , 1. ], [0.90243902, 1. ]]], [[[0. , 0. ], [0. , 0. ]], [[1. , 0. ], [0.09756098, 0. ]]]])
maxl['P_ls_he'] = numpy.array([[0.09996649, 1. ], [0.90003351, 0. ]])
maxl['P_vp_pw'] = numpy.array([[1. , 0.01009226], [0. , 0.98990774]])
maxl['P_lo_cb'] = numpy.array([[1. , 0.00103092], [0. , 0.99896908]])
maxl['P_wv_wr'] = numpy.array([[0.19985412, 1. ], [0.80014588, 0. ]])
maxl['P_hp_dp'] = numpy.array([[0.09929144, 1. ], [0.90070856, 0. ]])
maxl['P_me_gw_gs'] = numpy.array([[[1. , 1. ], [0.09907773, 0.89930357]], [[0. , 0. ], [0.90092227, 0.10069643]]])
maxl['P_ta_me_fh'] = numpy.array([[[1. , 1. ], [0.04966186, 1. ]], [[0. , 0. ], [0.95033814, 0. ]]])

prior = {}
prior['P_he'] = numpy.array([0.99047096, 0.00952904])
prior['P_fp'] = numpy.array([0.80056152, 0.19943848])
prior['P_fc'] = numpy.array([0.98004547, 0.01995453])
prior['P_wr'] = numpy.array([0.89956742, 0.10043258])
prior['P_gs'] = numpy.array([0.95010414, 0.04989586])
prior['P_dp'] = numpy.array([0.94968071, 0.05031929])
prior['P_fh'] = numpy.array([0.97033714, 0.02966286])
prior['P_pw_he_fp'] = numpy.array([[[4.81037502e-06, 9.99980683e-01], [9.99495714e-01, 9.98069498e-01]], [[9.99995190e-01, 1.93173257e-05], [5.04286435e-04, 1.93050193e-03]]])
prior['P_cb_pw_fc'] = numpy.array([[[9.99981208e-01, 9.99048525e-01], [4.90910787e-06, 9.99760937e-01]], [[1.87916941e-05, 9.51474786e-04], [9.99995091e-01, 2.39062874e-04]]])
prior['P_gw_cb_wr_dp'] = numpy.array([[[[9.99980048e-01, 9.99611349e-01], [9.99817718e-01, 9.96309963e-01]], [[5.75251529e-06, 9.99892404e-01], [9.02397787e-01, 9.99056604e-01]]], [[[1.99517168e-05, 3.88651380e-04], [1.82282173e-04, 3.69003690e-03]], [[9.99994247e-01, 1.07596299e-04], [9.76022133e-02, 9.43396226e-04]]]])
prior['P_ls_he'] = numpy.array([[9.99695743e-02, 9.99599840e-01], [9.00030426e-01, 4.00160064e-04]])
prior['P_vp_pw'] = numpy.array([[9.99981572e-01, 1.00969772e-02], [1.84284240e-05, 9.89903023e-01]])
prior['P_lo_cb'] = numpy.array([[9.99982890e-01, 1.03582176e-03], [1.71101035e-05, 9.98964178e-01]])
prior['P_wv_wr'] = numpy.array([[1.99856670e-01, 9.99962019e-01], [8.00143330e-01, 3.79809336e-05]])
prior['P_hp_dp'] = numpy.array([[9.92946545e-02, 9.99924196e-01], [9.00705346e-01, 7.58035173e-05]])
prior['P_me_gw_gs'] = numpy.array([[[9.99987818e-01, 9.99768626e-01], [9.90825358e-02, 8.99212419e-01]], [[1.21821969e-05, 2.31374364e-04], [9.00917464e-01, 1.00787581e-01]]])
prior['P_ta_me_fh'] = numpy.array([[[9.99990701e-01, 9.99696233e-01], [4.96679947e-02, 9.99777134e-01]], [[9.29903848e-06, 3.03766707e-04], [9.50332005e-01, 2.22866057e-04]]])


# Equality checks...
vc_maxl = am.VarsClose(maxl)
vc_prior = am.VarsClose(prior)


# Actual question...
q1 = am.Question(1, 4)

q1.worth('base', 3)
q1.add('base', am.Any(vc_maxl, vc_prior))

q1.worth('prior', 1)
q1.add('prior', vc_prior)

q1(notebook)



# Q2 - Implement BP...
q2 = am.Question(2, 10)

q2.worth('unknown', 5)
q2.add('unknown', am.FuncMatch('brute_marginals', 'marginals', {}))

q2.worth('known', 5)
q2.mode('known', 'all')
q2.add('known', am.FuncMatch('brute_marginals', 'marginals', {11 : True}))
q2.add('known', am.FuncMatch('brute_marginals', 'marginals', {11 : False}))
q2.add('known', am.FuncMatch('brute_marginals', 'marginals', {11 : True, 15 : False}))

q2(notebook)



# Q3 - Identify broken part...
q3 = am.Question(3, 1)

q3.add(None, am.PrintRE([r'A\s*:\s*fh',
                         r'B\s*:\s*fp',
                         r'C\s*:\s*wr',
                         r'D\s*:\s*dp',
                         r'E\s*:\s*gs']))

q3(notebook)
