import math

from pandac.PandaModules import *
from direct.showbase import DirectObject
import direct.directbase.DirectStart


class InitODE(DirectObject.DirectObject):
  """This creates the various ODE core objects, and exposes them to other plugins. Should be called ode."""
  def __init__(self,manager,xml):
    # Setup the physics world...
    self.world = OdeWorld()
    self.world.setGravity(0.0,0.0,-9.81) # This should really come from the config file.
    self.world.setErp(0.8)
    self.world.setAutoDisableFlag(True)

    # Create a surface table - contains interactions between different surface types - loaded from config file...
    surElem = [x for x in xml.findall('surface')]
    self.world.initSurfaceTable(len(surElem))
    self.surFromName = dict()
    for a in xrange(len(surElem)):
      self.surFromName[surElem[a].get('name')] = a

      # Maths used below is obviously wrong - should probably work out something better.

      # Interaction with same surface...
      mu = float(surElem[a].get('mu'))
      bounce = float(surElem[a].get('bounce'))
      absorb = float(surElem[a].get('absorb'))
      self.world.setSurfaceEntry(a,a,mu,bounce,absorb,0.8,1e-3,0.0,0.1)

      # Interaction with other surfaces...
      for b in xrange(a+1,len(surElem)):
        mu = float(surElem[a].get('mu')) * float(surElem[b].get('mu'))
        bounce = float(surElem[a].get('bounce')) * float(surElem[b].get('bounce'))
        absorb = float(surElem[a].get('absorb')) + float(surElem[b].get('absorb'))
        self.world.setSurfaceEntry(a,b,mu,bounce,absorb,0.8,1e-3,0.0,0.1)

    # Create a space to manage collisions...
    self.space = OdeHashSpace()
    self.space.setAutoCollideWorld(self.world)

    # Setup a contact group to handle collision events...
    self.contactGroup = OdeJointGroup()
    self.space.setAutoCollideJointGroup(self.contactGroup)

    # Create the synch database - this is a database of NodePath and ODEBodys - each frame the NodePaths have their positions synched with the ODEBodys...
    self.synch = dict() # dict of tuples (node,body), indexed by an integer that is written to the NodePath as a integer using setPythonTag into 'ode_key'
    self.nextKey = 0

    # Create the extra function databases - pre- and post- functions for before and after each collision step...
    self.preCollide = dict() # id(func) -> func
    self.postCollide = dict()

    # Create the damping database - damps objects so that they slow down over time, which is very good for stability...
    self.damping = dict() # id(body) -> (body,amount)
    

    # Arrange for the physics simulation to run on automatic - start and stop are used to enable/disable it however...
    self.runSim = False
    self.timeRem = 0.0
    self.step = 1.0/50.0
    
    def simulationTask(task):
      if self.runSim:
        # Step the simulation and set the new positions - fixed time step...
        self.timeRem += globalClock.getDt()
        while self.timeRem>self.step:
          # Call the pre-collision functions...
          for ident,func in self.preCollide.iteritems():
            func()

          # Apply damping to all objects in damping db...
          for key,data in self.damping.iteritems():
            if data[0].isEnabled():
              vel = data[0].getLinearVel()
              vel *= -data[1]
              data[0].addForce(vel)
              rot = data[0].getAngularVel()
              rot *= -data[2]
              data[0].addTorque(rot)
          
          # A single step of collision detection...
          self.space.autoCollide() # Setup the contact joints
          self.world.quickStep(self.step)
          self.timeRem -= self.step
          self.contactGroup.empty() # Clear the contact joints

          # Call the post-collision functions...
          for ident,func in self.postCollide.iteritems():
            func()

        # Update all objects registered with this class to have their positions updated...
        for key, data in self.synch.items():
          node, body = data
          node.setPosQuat(render,body.getPosition(),Quat(body.getQuaternion()))

      return task.cont

    taskMgr.add(simulationTask,'Physics Sim',sort=100)

    # Arrange callback for collisions - we have one callback function which then checks a database to see if either geom in question has a specific callback function, in which case its called...
    self.collCB = dict() # OdeGeom to func(entry,flag), where flag is False if its in 1, true if its in 2.

    def onCollision(entry):
      geom1 = entry.getGeom1()
      geom2 = entry.getGeom2()

      for geom,func in self.collCB.iteritems(): # bad way of doing this - needs to be hashed, but doesn't seem to be possible currently.
        if geom1==geom:
          func(entry,False)
        if geom2==geom:
          func(entry,True)

    self.space.setCollisionEvent("collision")
    self.accept("collision",onCollision)


  def start(self):
    self.runSim = True

  def stop(self):
    self.runSim = False
    self.timeRem = 0.0


  def getWorld(self):
    """Retuns the ODE world"""
    return self.world

  def getSpace(self):
    """Returns the ODE space used for automatic collisions."""
    return self.space

  def getSurface(self,name):
    """This returns the surface number given the surface name. If it doesn't exist it prints a warning and returns 0 instead of failing."""
    if self.surFromName.has_key(name):
      return self.surFromName[name]
    else:
      print 'Warning: Surface %s does not exist'%name
      return 0


  def regBodySynch(self,node,body):
    """Given a NodePath and a Body this arranges that the NodePath tracks the Body."""
    if node.hasTag('ode_key'):
      key = node.getTag('ode_key')
    else:
      key = self.nextKey
      self.nextKey += 1
      node.setPythonTag('ode_key',key)

    self.synch[key] = (node,body)

  def unregBodySynch(self,node):
    """Removes a NodePath/Body pair from the synchronisation database, so the NodePath will stop automatically tracking the Body."""
    if node.hasTag('ode_key'):
      key = node.getTag('ode_key')
      if self.synch.has_key(key):
        del self.synch[key]

  def regPreFunc(self,name,func):
    """Registers a function under a unique name to be called before every step of the physics simulation - this is different from every frame, being entirly regular."""
    self.preCollide[name] = func

  def unregPreFunc(self,name):
    """Unregisters a function to be called every step, by name."""
    if self.preCollide.has_key(name):
      del self.preCollide[name]

  def regPostFunc(self,name,func):
    """Registers a function under a unique name to be called after every step of the physics simulation - this is different from every frame, being entirly regular."""
    self.postCollide[name] = func

  def unregPostFunc(self,name):
    """Unregisters a function to be called every step, by name."""
    if self.postCollide.has_key(name):
      del self.postCollide[name]

  def regCollisionCB(self,geom,func):
    """Registers a callback that will be called whenever the given geom collides. The function must take an OdeCollisionEntry followed by a flag, which will be False if geom1 is the given geom, True if its geom2."""
    self.collCB[geom] = func

  def unregCollisionCB(self,geom):
    """Unregisters the collision callback for a given geom."""
    if self.collCB.has_key(geom):
      del self.collCB[geom]

  def regDamping(self,body,linear,angular):
    """Given a body this applies a damping force, such that the velocity and rotation will be reduced in time. If the body is already registered this will update the current setting."""
    self.damping[str(body)] = (body,linear,angular)

  def unregDampingl(self,body):
    """Unregisters a body from damping."""
    key = str(body)
    if self.damping.has_key(key):
      del self.air_resist[key]
