"""
This script is shows how GPU offloading can be used for the Granular Solver
"""

import math

# Import libraries to access the AGX API
import agx
import agxCollide
import agxSDK

# Import useful utilities to access the current simulation, graphics root and application
from agxPythonModules.utils.environment import application, init_app, simulation


# Optional: We make a GuiEventListener so we can toggle warmstarting on/off and number
#           of solver iterations in an easy way
# - Warmstarting can be toggled on/off with Home
# - Number of solver iterations can be changed with PageUp / PageDown
class WarmstartToggle(agxSDK.GuiEventListener):
    def __init__(self, app):
        super().__init__(
            agxSDK.GuiEventListener.KEYBOARD | agxSDK.GuiEventListener.UPDATE
        )
        self.app = app

    def update(self, x, y):
        sim = self.app.getSimulation()
        num_iter = sim.getSolver().getNumRestingIterations()
        warm = sim.getSolver().useGranularWarmStarting()

        sd = self.app.getSceneDecorator()

        sd.setText(2, f"Warmstarting: {'ON' if warm else 'OFF'}")
        sd.setText(3, f"Solver iters: {num_iter}")

    def keyboard(self, key, x, y, alt, down):
        sim = self.app.getSimulation()
        solver = sim.getSolver()
        ni = solver.getNumRestingIterations()

        if key == agxSDK.GuiEventListener.KEY_Home:
            if down:
                new_w = not solver.useGranularWarmStarting()
                solver.setUseGranularWarmStarting(new_w)
            return True

        if key == agxSDK.GuiEventListener.KEY_Page_Up:
            if down:
                ni = ni + 10
                solver.setNumRestingIterations(ni)
                solver.setNumPPGSRestingIterations(ni)
            return True

        if key == agxSDK.GuiEventListener.KEY_Page_Down:
            if down:
                ni = max(ni - 10, 10)
                solver.setNumRestingIterations(ni)
                solver.setNumPPGSRestingIterations(ni)
            return True

        return False


# Construct a scene, particles in a pyramid like pile placed on a plane
#
def sample_scene():
    radius = 0.05
    levels = 40
    iterations = 100

    # Solver setting
    solver = simulation().getSolver()
    solver.setUse32bitGranularBodySolver(True)
    solver.setUseGranularWarmStarting(True)
    solver.setNumRestingIterations(iterations)
    solver.setNumPPGSRestingIterations(iterations)

    # Num threads, one less then number of physical cores (assuming SMT)
    agx.setNumThreads(0)
    num_threads = agx.getNumThreads()
    agx.setNumThreads(max(1, int(num_threads / 2 - 1)))

    # Material and contact material handling
    mat = agx.Material("Mat")
    simulation().add(mat)

    mm = simulation().getMaterialManager()
    cm = mm.getOrCreateContactMaterial(mat, mat)
    cm.setRestitution(0.0)
    cm.setYoungsModulus(7e8)
    cm.setFrictionCoefficient(0.9)
    cm.setRollingResistanceCoefficient(1.4 * radius)

    # Ground plane
    plane = agxCollide.Geometry(agxCollide.Plane())
    plane.setMaterial(mat)
    plane.setPosition(0, 0, -0.5 * radius)
    simulation().add(plane)

    # Granular bodies
    gbs = agx.GranularBodySystem()
    gbs.setParticleMass(0.2)
    gbs.setParticleRadius(radius)
    gbs.setMaterial(mat)

    axis_x = agx.Vec3.X_AXIS()
    axis_y = agx.Vec3.Y_AXIS()
    startpos = agx.Vec3(0, 0, 0.5 * radius)
    offset = radius * levels - radius
    for lvl in range(0, levels):
        _z = radius * math.sqrt(2) * lvl
        _offset = offset - lvl * radius
        num_particles = levels - lvl
        if num_particles == 1:
            continue
        for i in range(0, num_particles):
            for j in range(0, num_particles):
                p = gbs.createParticle()
                pos = (
                    axis_x * (i * radius * 2 - _offset)
                    + axis_y * (j * radius * 2 - _offset)
                    + agx.Vec3(0, 0, _z)
                )
                p.setPosition(startpos + pos)

    simulation().add(gbs)

    # Position camera and configure app
    eye = agx.Vec3(5.85, -10.35, 4.05)
    center = agx.Vec3(0.165, 0.0, 1.0)
    up = agx.Vec3(-0.0620, 0.2537, 0.9653)
    application().setCameraHome(eye, center, up)
    application().setEnableDebugRenderer(False)
    application().setEnableDebugRenderer(True)

    # Keyboard listener for settings
    simulation().add(WarmstartToggle(application()))


def buildScene1():
    """
    Granular material on plane, GPU offloading enabled
    """
    sample_scene()
    device_index = 0

    solver = simulation().getSolver()

    if hasattr(solver, "setUseGpu"):
        status = simulation().getSolver().setUseGpu(True, device_index)
        if status:
            application().getSceneDecorator().setText(1, "GPU offloading enabled")
        else:
            application().getSceneDecorator().setText(
                1, "GPU offloading could NOT be enabled"
            )

    else:
        application().getSceneDecorator().setText(
            1, "GPU offloading not available in current AGX build"
        )


def buildScene2():
    """
    Granular material on plane, CPU based parallel PGS solver
    """
    sample_scene()
    simulation().getSolver().setUseParallelPgs(True)

    application().getSceneDecorator().setText(1, "Parallel PGS")


def addRemainingScenes(app):
    """
    Add more scenes so that they can be started using the keyboard.
    """
    scriptFileName = application().getArguments().getArgumentName(1)
    scriptFileName = scriptFileName.replace("agxscene:", "")

    def addScene(name):
        sceneKey = application().getNumScenes() + 1  # Only works until tutorial 9.
        application().addScene(scriptFileName, name, ord(ascii(sceneKey)), True)

    addScene("buildScene2")


def buildScene():
    """
    Entry point when running this script using agxViewer.
    """
    # Add the other scenes to ExampleApplication from agxViewer.
    if application().getNumScenes() == 1:
        addRemainingScenes(application())

    buildScene1()


# Entry point when this script is started with python executable
init = init_app(
    name=__name__,
    scenes=[(buildScene1, "1"), (buildScene2, "2")],
    onInitialized=lambda app: print("App successfully initialized."),
    onShutdown=lambda app: print("App successfully shut down."),
)
