"""
Tutorial demonstrating a simple four-wheel steerable vehicle on deformable
terrain, using the TerrainWheel class for the wheels.

The TerrainWheels are connected with WheelJoint suspension joints and an
Ackermann steering constraint. The terrain is generated with waves and a
Gaussian bump to create uneven ground. The simulation includes keyboard
controls for driving, braking, steering, and adjusting parameters such as
target speed and gravity (Earth, Moon, Mars). An on-screen overlay shows
the active controls, the vehicle speed, etc.
"""

# Python standard library
from __future__ import annotations

import math
from dataclasses import dataclass, field
from enum import Enum

import osg

# AGX libraries
import agx
import agxCollide
import agxOSG
import agxRender
import agxSDK
import agxTerrain
import agxUtil
import agxVehicle

# AGX Python modules
from agxPythonModules.utils.callbacks import StepEventCallback
from agxPythonModules.utils.environment import (
    application,
    init_app,
    root,
    simulation,
)


# -----------------------------------------------------------------------------
# Central Settings
# -----------------------------------------------------------------------------

WHEEL_RADIUS = 0.5  # [m]
WHEEL_WIDTH = 0.4  # [m]
WHEEL_MASS = 10  # [kg]
CHASSIS_WIDTH = 2.4  # [m]
CHASSIS_LENGTH = 4.0  # [m]
CHASSIS_HEIGHT = 1.0  # [m]
CHASSIS_MASS = 1000  # [kg]


# -----------------------------------------------------------------------------
# Additional Settings
# -----------------------------------------------------------------------------

FOCUS_CAMERA_ON_CHASSIS = True  # [-] bool
CHASSIS_CLEARANCE = 0.6  # [m]
FOUR_WHEEL_DRIVE = True  # [-] bool

MAX_STEER_ANGLE_RAD = math.radians(45.0)  # [rad]
DEFAULT_STEER_STEP_RAD = math.radians(1.0)  # [rad]
STEER_STEP_INCREMENT_RAD = math.radians(0.1)  # [rad]

DEFAULT_DRIVE_SPEED_MPS = 2.0  # [m/s]
DRIVE_SPEED_INCREMENT_MPS = 0.1  # [m/s]

SUSPENSION_RANGE_MIN = -0.1  # [m]
SUSPENSION_RANGE_MAX = 0.1  # [m]
SUSPENSION_SPRING_CONSTANT = 1.0e6  # [N/m]
SUSPENSION_DAMPING = 4.0e4  # [Ns/m]

DEFAULT_DRIVE_FORCE = 1.0e5  # [Nm]
DEFAULT_BRAKE_FORCE = 1.0e5  # [Nm]

AUTO_ACCELERATE_BASE_ANGULAR_SPEED = 1.0  # [rad/s]
AUTO_ACCELERATE_MULTIPLIER = 1.25  # [-]

RENDER_HEIGHT_RANGE_MIN = -0.1  # [m]
RENDER_HEIGHT_RANGE_MAX = 0.1  # [m]


# -----------------------------------------------------------------------------
# Configuration and runtime state
# -----------------------------------------------------------------------------

class GravityMode(Enum):
    EARTH = "Earth"
    MOON = "Moon"
    MARS = "Mars"


GRAVITY_BY_MODE = {
    GravityMode.EARTH: agx.Vec3(0.0, 0.0, -9.81),
    GravityMode.MOON: agx.Vec3(0.0, 0.0, -1.61),
    GravityMode.MARS: agx.Vec3(0.0, 0.0, -3.73),
}


@dataclass
class VehicleConfig:
    wheel_radius: float = WHEEL_RADIUS
    wheel_width: float = WHEEL_WIDTH
    wheel_mass: float = WHEEL_MASS
    chassis_width: float = CHASSIS_WIDTH
    chassis_length: float = CHASSIS_LENGTH
    chassis_height: float = CHASSIS_HEIGHT
    chassis_mass: float = CHASSIS_MASS
    chassis_clearance: float = CHASSIS_CLEARANCE
    four_wheel_drive: bool = FOUR_WHEEL_DRIVE

    @property
    def chassis_half_extents(self) -> agx.Vec3:
        return agx.Vec3(
            self.chassis_width * 0.5,
            self.chassis_length * 0.5,
            self.chassis_height * 0.5,
        )


@dataclass
class TerrainConfig:
    start_x: float = -5.0
    end_x: float = 50.0
    element_size: float = 0.15
    use_gaussian_bump: bool = True

    wave_amplitudes_x: tuple[float, ...] = (0.1, 0.12, 0.2)
    wave_wavelengths_x: tuple[float, ...] = (7.0, 12.0, 18.0)
    wave_amplitudes_y: tuple[float, ...] = (0.05, 0.1, 0.15)
    wave_wavelengths_y: tuple[float, ...] = (7.0, 10.0, 15.0)

    bump_amplitude: float = 2.5
    bump_center_x: float = 20.0
    bump_center_y: float = 0.0
    bump_sigma_x: float = 5.0
    bump_sigma_y: float = 5.0
    bump_length_scale_x: float = 1.0
    bump_length_scale_y: float = 2.0
    bump_sharpness: float = 1.0

    @property
    def size(self) -> float:
        return self.end_x - self.start_x

    @property
    def resolution(self) -> int:
        return math.floor(self.size / self.element_size) + 1


@dataclass
class RuntimeState:
    target_drive_speed_mps: float = DEFAULT_DRIVE_SPEED_MPS
    steer_step_rad: float = DEFAULT_STEER_STEP_RAD
    auto_accelerate: bool = True
    show_overlay: bool = True
    gravity_mode: GravityMode = GravityMode.EARTH
    pressed_keys: list[str] = field(default_factory=list)


# -----------------------------------------------------------------------------
# Vehicle rig
# -----------------------------------------------------------------------------

class FourWheelRig(agxSDK.StepEventListener):
    """
    Vehicle rig containing the chassis, wheels, wheel joints, and steering setup.
    """

    def __init__(self, config: VehicleConfig):
        super().__init__()
        self.config = config

        self.vehicle: agxSDK.Assembly | None = None
        self.chassis: agx.RigidBody | None = None
        self.chassis_geometry: agxCollide.Geometry | None = None

        self.wheels: dict[str, agxTerrain.TerrainWheel] = {}
        self.wheel_joints: dict[str, agxVehicle.WheelJoint] = {}

        self.steering_constraint: agxVehicle.Ackermann | None = None
        self.steering_angle_rad: float = 0.0
        self.contact_material: agx.ContactMaterial | None = None

        self.build()
        simulation().add(self)

    def build(self) -> None:
        self.vehicle = self._create_vehicle_assembly(simulation(), root())
        self.chassis = self.vehicle.getRigidBody("chassis")

        self._configure_initial_pose()
        self._configure_wheel_joints()
        self._configure_steering()

    def _create_vehicle_assembly(
        self,
        sim: agxSDK.Simulation,
        scene_root,
    ) -> agxSDK.Assembly:
        assembly = agxSDK.Assembly()

        chassis_body = self._create_chassis_body(scene_root)
        chassis_frame = chassis_body.getFrame()

        sim.add(chassis_body)
        assembly.add(chassis_body)

        self.chassis = chassis_body

        steering_axis_world = chassis_frame.transformVectorToWorld(agx.Vec3.Z_AXIS())
        wheel_axis_world = chassis_frame.transformVectorToWorld(-agx.Vec3.X_AXIS())

        self.wheel_material = agx.Material("wheel_material")

        for wheel_name, local_wheel_position in self._iter_wheel_mounts():
            wheel = self._create_wheel(scene_root)
            wheel_body = wheel.getRigidBodies()[0].get()

            wheel.setPosition(chassis_frame.transformPointToWorld(local_wheel_position))
            wheel.setRotation(agx.Quat(agx.Vec3.Y_AXIS(), agx.Vec3.X_AXIS()))
            wheel_body.getMassProperties().setMass(self.config.wheel_mass)

            joint_anchor_local = agx.Vec3(
                math.copysign(self.config.chassis_half_extents[0], local_wheel_position.x()),
                math.copysign(self.config.chassis_half_extents[1], local_wheel_position.y()),
                local_wheel_position.z(),
            )
            joint_anchor_world = chassis_frame.transformPointToWorld(joint_anchor_local)

            wheel_joint_frame = agxVehicle.WheelJointFrame(
                joint_anchor_world,
                wheel_axis_world,
                steering_axis_world,
            )
            wheel_joint = agxVehicle.WheelJoint(
                wheel_joint_frame,
                wheel_body,
                chassis_body,
            )

            sim.add(wheel)
            sim.add(wheel_joint)

            assembly.add(wheel)
            assembly.add(wheel_joint)

            self.wheels[wheel_name] = wheel
            self.wheel_joints[wheel_name] = wheel_joint

        return assembly

    def _create_chassis_body(self, scene_root) -> agx.RigidBody:
        chassis_body = agx.RigidBody("chassis")
        chassis_geom = agxCollide.Geometry(agxCollide.Box(self.config.chassis_half_extents))
        chassis_body.add(chassis_geom)
        chassis_body.getMassProperties().setMass(self.config.chassis_mass)

        self.chassis_geometry = chassis_geom

        chassis_node = agxOSG.createVisual(chassis_body, scene_root)
        agxOSG.setDiffuseColor(chassis_node, agxRender.Color.Blue())
        agxOSG.setAlpha(chassis_node, 0.25)

        return chassis_body

    def _create_wheel(self, scene_root) -> agxTerrain.TerrainWheel:
        wheel = agxTerrain.TerrainWheel(
            self.config.wheel_radius,
            self.config.wheel_width,
        )

        wheel.setMaterial(self.wheel_material)

        wheel_node = agxOSG.createVisual(wheel, scene_root)
        agxOSG.setTexture(wheel_node, "checkboard_mini.png")

        return wheel

    def _iter_wheel_mounts(self) -> list[tuple[str, agx.Vec3]]:
        """
        Returns wheel positions in chassis-local coordinates.
        AGX uses the current chassis frame to transform these to world.
        """
        half = self.config.chassis_half_extents
        z = -half[2] - self.config.chassis_clearance + self.config.wheel_radius
        x_offset = half[0] + self.config.wheel_width
        y_offset = half[1]

        return [
            ("front_left", agx.Vec3(+x_offset, +y_offset, z)),
            ("front_right", agx.Vec3(-x_offset, +y_offset, z)),
            ("rear_left", agx.Vec3(+x_offset, -y_offset, z)),
            ("rear_right", agx.Vec3(-x_offset, -y_offset, z)),
        ]

    def _configure_initial_pose(self) -> None:
        assert self.vehicle is not None
        self.vehicle.setPosition(agx.Vec3(0.0, -8.0, 1.5))
        self.vehicle.setRotation(agx.EulerAngles(agx.Vec3(0.0, 0.0, -0.5 * agx.PI)))

    def _configure_steering(self) -> None:
        front_left = self.wheel_joints["front_left"]
        front_right = self.wheel_joints["front_right"]

        self.steering_constraint = agxVehicle.Ackermann(front_right, front_left)
        simulation().add(self.steering_constraint)

    def _configure_wheel_joints(self) -> None:
        assert self.chassis is not None

        driven_wheels = (
            {"front_left", "front_right", "rear_left", "rear_right"}
            if self.config.four_wheel_drive
            else {"rear_left", "rear_right"}
        )

        steerable_wheels = {"front_left", "front_right"}

        for wheel_name, joint in self.wheel_joints.items():
            is_steerable = wheel_name in steerable_wheels
            is_driven = wheel_name in driven_wheels

            # Lock steering on non-steerable wheels
            joint.getLock1D(agxVehicle.WheelJoint.STEERING).setEnable(not is_steerable)

            # Disable chassis-wheel collisions
            agxUtil.setEnableCollisions(self.chassis, joint.getWheelRigidBody(), False)

            # Suspension range
            suspension_range = joint.getRange1D(agxVehicle.WheelJoint.SUSPENSION)
            suspension_range.setEnable(True)
            suspension_range.setRange(SUSPENSION_RANGE_MIN, SUSPENSION_RANGE_MAX)

            # Suspension spring/damper settings
            suspension_lock = joint.getLock1D(agxVehicle.WheelJoint.SUSPENSION)
            suspension_lock.setEnable(True)
            suspension_lock.setCompliance(
                agxUtil.convertSpringConstantToCompliance(SUSPENSION_SPRING_CONSTANT)
            )
            suspension_lock.setDamping(
                agxUtil.convertDampingCoefficientToSpookDamping(
                    SUSPENSION_DAMPING,
                    SUSPENSION_SPRING_CONSTANT,
                )
            )

            # Enable wheel motor only for driven wheels
            wheel_motor = joint.getMotor1D(agxVehicle.WheelJoint.WHEEL)
            joint.setEnableComputeForces(True)
            wheel_motor.setEnable(is_driven)
            wheel_motor.setForceRange(agx.RangeReal(DEFAULT_DRIVE_FORCE))

            # Steering motors are not used when Ackermann is active
            steering_motor = joint.getMotor1D(agxVehicle.WheelJoint.STEERING)
            steering_motor.setEnable(False)

    def get_driven_joints(self) -> list[agxVehicle.WheelJoint]:
        if self.config.four_wheel_drive:
            driven_names = ("front_left", "front_right", "rear_left", "rear_right")
        else:
            driven_names = ("rear_left", "rear_right")

        return [self.wheel_joints[name] for name in driven_names]

    def get_chassis_speed(self) -> float:
        assert self.chassis is not None
        return self.chassis.getVelocity().length()

    def get_chassis_velocity(self) -> agx.Vec3:
        assert self.chassis is not None
        return self.chassis.getVelocity()


# -----------------------------------------------------------------------------
# Terrain
# -----------------------------------------------------------------------------

class Terrain:
    """
    Terrain generated from a height field.
    """

    def __init__(self, config: TerrainConfig):
        self.config = config
        self.material: agx.Material | None = None
        self.terrain: agxTerrain.Terrain | None = None
        self.renderer = None
        self.create()

    def create(self) -> None:
        height_field = agxCollide.HeightField(
            self.config.resolution,
            self.config.resolution,
            self.config.size,
            self.config.size,
        )

        positions = [self._hf_index_to_position(i) for i in range(self.config.resolution)]

        for i, x in enumerate(positions):
            for j, y in enumerate(positions):
                height_field.setHeight(i, j, self.get_height(x, y))

        terrain = agxTerrain.Terrain.createFromHeightField(height_field, 2)
        terrain_position = agx.Vec3(
            0.5 * (self.config.start_x + self.config.end_x),
            -0.1,
            0.0,
        )
        terrain.setPosition(terrain_position)

        self.material = agx.Material("terrain_material")
        terrain.loadLibraryMaterial("terrain_wheel_sand_1")
        terrain.setMaterial(self.material)
        terrain.getProperties().setEnableAvalanching(True)

        simulation().add(terrain)
        self.terrain = terrain

        self.renderer = agxOSG.TerrainVoxelRenderer(terrain, root())
        self._configure_renderer(self.renderer)
        simulation().add(self.renderer)

    def _configure_renderer(self, renderer) -> None:
        renderer.setRenderHeights(
            True,
            agx.RangeReal(RENDER_HEIGHT_RANGE_MIN, RENDER_HEIGHT_RANGE_MAX),
            True,
            osg.Vec4(0.0, 0.0, 1.0, 1.0),
            osg.Vec4(0.6, 0.6, 0.6, 1.0),
            osg.Vec4(1.0, 0.0, 0.0, 1.0),
        )
        renderer.setRenderCompaction(False, agx.RangeReal(0.85, 1.15))
        renderer.setRenderVoxelSolidMass(False)
        renderer.setRenderVoxelFluidMass(False)
        renderer.setRenderHeightField(True)
        renderer.setRenderVoxelBoundingBox(False)
        renderer.setRenderVelocityField(False)
        renderer.setRenderSoilParticlesMesh(False)
        renderer.setRenderDefaultTerrainMaterial(False)
        renderer.setRenderTerrainMaterials(False)

    def _hf_index_to_position(self, index: int) -> float:
        return (index / float(self.config.resolution) - 0.5) * self.config.size

    def get_height(self, x: float, y: float) -> float:
        """
        Height in terrain-local coordinates.
        """

        height = 0.0

        x_star = x + 0.5 * self.config.size + self.config.start_x

        if self.config.use_gaussian_bump:
            dx = (x_star - self.config.bump_center_x) / (
                self.config.bump_sigma_x * self.config.bump_length_scale_x
            )
            dy = (y - self.config.bump_center_y) / (
                self.config.bump_sigma_y * self.config.bump_length_scale_y
            )
            r2 = dx * dx + dy * dy
            height += self.config.bump_amplitude * math.exp(
                -0.5 * (r2 ** self.config.bump_sharpness)
            )

        height += self.generate_waves(
            self.config.wave_amplitudes_x,
            self.config.wave_wavelengths_x,
            x,
        )
        height += self.generate_waves(
            self.config.wave_amplitudes_y,
            self.config.wave_wavelengths_y,
            y,
        )
        return height

    def generate_waves(
        self,
        amplitudes: tuple[float, ...],
        wavelengths: tuple[float, ...],
        coordinate: float,
    ) -> float:
        wave = 0.0
        for amplitude, wavelength in zip(amplitudes, wavelengths):
            wave += amplitude * math.sin(self.wavelength_to_angular_frequency(wavelength) * coordinate)
        return wave

    @staticmethod
    def wavelength_to_angular_frequency(wavelength: float) -> float:
        return 2.0 * math.pi / wavelength


# -----------------------------------------------------------------------------
# Controls
# -----------------------------------------------------------------------------

class VehicleKeyboardController(agxSDK.GuiEventListener):
    """
    Keyboard listener that keeps track of key state.
    Actual control application happens in VehicleControllerStepListener.
    """

    def __init__(
        self,
        rig: FourWheelRig,
        runtime_state: RuntimeState,
    ):
        super().__init__(agxSDK.GuiEventListener.KEYBOARD)

        self.rig = rig
        self.runtime_state = runtime_state

        self.drive_force_range = agx.RangeReal(DEFAULT_DRIVE_FORCE)
        self.brake_force_range = agx.RangeReal(DEFAULT_BRAKE_FORCE)

        self.drive_speed_step_mps = DRIVE_SPEED_INCREMENT_MPS
        self.steer_step_increment_rad = STEER_STEP_INCREMENT_RAD

        self._pressed = {
            "up": False,
            "down": False,
            "left": False,
            "right": False,
            "brake": False,
        }

    def keyboard(self, key, modifier, x, y, keydown):
        if self._handle_continuous_key(key, keydown):
            self._update_pressed_keys_overlay()
            return True

        if keydown and self._handle_one_shot_key(key):
            self._update_pressed_keys_overlay()
            return True

        return False

    def _handle_continuous_key(self, key, keydown: bool) -> bool:
        if key == self.KEY_Up:
            self._pressed["up"] = bool(keydown)
        elif key == self.KEY_Down:
            self._pressed["down"] = bool(keydown)
        elif key == self.KEY_Left:
            self._pressed["left"] = bool(keydown)
        elif key == self.KEY_Right:
            self._pressed["right"] = bool(keydown)
        elif key in (ord("h"), ord("H")):
            self._pressed["brake"] = bool(keydown)
        else:
            return False

        return True

    def _handle_one_shot_key(self, key) -> bool:
        if key in (ord("j"), ord("J")):
            self.runtime_state.target_drive_speed_mps = max(
                0.0,
                self.runtime_state.target_drive_speed_mps - self.drive_speed_step_mps,
            )
        elif key in (ord("k"), ord("K")):
            self.runtime_state.target_drive_speed_mps += self.drive_speed_step_mps
        elif key in (ord("i"), ord("I")):
            self.runtime_state.steer_step_rad = max(
                0.0,
                self.runtime_state.steer_step_rad - self.steer_step_increment_rad,
            )
        elif key in (ord("o"), ord("O")):
            self.runtime_state.steer_step_rad += self.steer_step_increment_rad
        elif key in (ord("p"), ord("P")):
            self.runtime_state.show_overlay = not self.runtime_state.show_overlay
        elif key in (ord("z"), ord("Z")):
            self.runtime_state.auto_accelerate = not self.runtime_state.auto_accelerate
        elif key in (ord("a"), ord("A")):
            self.runtime_state.gravity_mode = self._next_gravity_mode(self.runtime_state.gravity_mode)
        else:
            return False

        return True

    def _next_gravity_mode(self, current: GravityMode) -> GravityMode:
        modes = [GravityMode.EARTH, GravityMode.MOON, GravityMode.MARS]
        next_index = (modes.index(current) + 1) % len(modes)
        return modes[next_index]

    def _update_pressed_keys_overlay(self) -> None:
        keys = []
        if self._pressed["up"]:
            keys.append("drive forward")
        if self._pressed["down"]:
            keys.append("drive reverse")
        if self._pressed["left"]:
            keys.append("steer left")
        if self._pressed["right"]:
            keys.append("steer right")
        if self._pressed["brake"]:
            keys.append("brake")

        self.runtime_state.pressed_keys = keys

    @property
    def up_pressed(self) -> bool:
        return self._pressed["up"]

    @property
    def down_pressed(self) -> bool:
        return self._pressed["down"]

    @property
    def left_pressed(self) -> bool:
        return self._pressed["left"]

    @property
    def right_pressed(self) -> bool:
        return self._pressed["right"]

    @property
    def brake_pressed(self) -> bool:
        return self._pressed["brake"]


class VehicleControllerStepListener(agxSDK.StepEventListener):
    """
    Applies driving, braking, steering, and gravity each simulation step.
    """

    def __init__(
        self,
        rig: FourWheelRig,
        keyboard: VehicleKeyboardController,
        runtime_state: RuntimeState,
    ):
        super().__init__()
        self.rig = rig
        self.keyboard = keyboard
        self.runtime_state = runtime_state

        self.auto_accelerate_added_speed = (
            AUTO_ACCELERATE_BASE_ANGULAR_SPEED * self.rig.config.wheel_radius
        )
        self.auto_accelerate_multiplier = AUTO_ACCELERATE_MULTIPLIER

    def post(self, t: float) -> None:
        drive_dir = self._resolve_drive_direction()
        steer_dir = self._resolve_steer_direction()

        if steer_dir != 0:
            self._apply_steer(steer_dir)

        if self.keyboard.brake_pressed:
            self._apply_brake(True)
        else:
            self._apply_drive(drive_dir)

        if self.runtime_state.auto_accelerate:
            self._update_auto_accelerated_speed()

        self._apply_gravity()

    def _resolve_drive_direction(self) -> int:
        if self.keyboard.up_pressed and not self.keyboard.down_pressed:
            return +1
        if self.keyboard.down_pressed and not self.keyboard.up_pressed:
            return -1
        return 0

    def _resolve_steer_direction(self) -> int:
        if self.keyboard.left_pressed and not self.keyboard.right_pressed:
            return +1
        if self.keyboard.right_pressed and not self.keyboard.left_pressed:
            return -1
        return 0

    def _apply_drive(self, drive_dir: int) -> None:
        target_angular_speed = self.runtime_state.target_drive_speed_mps / self.rig.config.wheel_radius

        if drive_dir == 0 or target_angular_speed <= 0.0:
            for joint in self.rig.get_driven_joints():
                self._wheel_motor(joint).setEnable(False)
            return

        target = float(drive_dir) * float(target_angular_speed)
        for joint in self.rig.get_driven_joints():
            motor = self._wheel_motor(joint)
            motor.setEnable(True)
            motor.setSpeed(target)
            motor.setForceRange(agx.RangeReal(DEFAULT_DRIVE_FORCE))

    def _apply_brake(self, enabled: bool) -> None:
        if not enabled:
            return

        for joint in self.rig.get_driven_joints():
            motor = self._wheel_motor(joint)
            motor.setEnable(True)
            motor.setSpeed(0.0)
            motor.setForceRange(agx.RangeReal(DEFAULT_BRAKE_FORCE))

    def _apply_steer(self, steer_dir: int) -> None:
        self.rig.steering_angle_rad += float(steer_dir) * float(self.runtime_state.steer_step_rad)
        self.rig.steering_angle_rad = max(
            -MAX_STEER_ANGLE_RAD,
            min(MAX_STEER_ANGLE_RAD, self.rig.steering_angle_rad),
        )

        if self.rig.steering_constraint is not None:
            self.rig.steering_constraint.setSteeringAngle(self.rig.steering_angle_rad)

    def _update_auto_accelerated_speed(self) -> None:
        chassis_speed = self.rig.get_chassis_speed()
        self.runtime_state.target_drive_speed_mps = (
            self.auto_accelerate_multiplier * chassis_speed + self.auto_accelerate_added_speed
        )

    def _apply_gravity(self) -> None:
        simulation().setUniformGravity(GRAVITY_BY_MODE[self.runtime_state.gravity_mode])

    @staticmethod
    def _wheel_motor(joint: agxVehicle.WheelJoint):
        return joint.getMotor1D(agxVehicle.WheelJoint.WHEEL)


# -----------------------------------------------------------------------------
# Scene helpers
# -----------------------------------------------------------------------------

def apply_camera_data(app, geometry: agxCollide.Geometry | None) -> None:
    camera_data = app.getCameraData()
    camera_data.eye = agx.Vec3(0.0, -25, 10)
    camera_data.center = agx.Vec3(0.0, 0.0, 0.0)
    camera_data.up = agx.Vec3(0.0, 0.0, 0.0)
    camera_data.nearClippingPlane = 0.1
    camera_data.farClippingPlane = 5000
    app.applyCameraData(camera_data)

    if FOCUS_CAMERA_ON_CHASSIS and geometry is not None:
        camera_node = agxOSG.findGeometryNode(geometry, root())
        if camera_node:
            app.setOrbitCamera(
                camera_node,
                -camera_data.center,
                -camera_data.eye,
                -camera_data.center,
            )


def set_number_of_threads(
    num_threads: int | None = None,
    use_half_of_available_threads: bool = False,
) -> None:
    if use_half_of_available_threads:
        agx.setNumThreads(0)
        available = agx.getNumThreads()
        agx.setNumThreads(max(1, available // 2))
        return

    if num_threads is not None:
        agx.setNumThreads(max(1, num_threads))


def create_and_setup_contact_material(
    rig: FourWheelRig,
    terrain: Terrain,
):
    if terrain.material is None:
        raise RuntimeError("Terrain material is not initialized.")

    if rig.wheel_material is None:
        raise RuntimeError("Wheel material is not initialized.")
    material_manager = simulation().getMaterialManager()

    contact_material = material_manager.getOrCreateContactMaterial(
        terrain.material,
        rig.wheel_material,
    )

    agxTerrain.TerrainWheel.configureContactMaterial(contact_material)
    contact_material.setRestitution(0.0)

    return contact_material


def install_overlay(
    rig: FourWheelRig,
    runtime_state: RuntimeState,
) -> None:
    def show_on_screen_info(t: float) -> None:
        scene_decorator = application().getSceneDecorator()
        scene_decorator.setFontSize(0.01)

        if runtime_state.show_overlay:
            chassis_velocity = rig.get_chassis_velocity()
            chassis_speed = chassis_velocity.length()
            chassis_velocity_str = "[{: 4.2f}, {: 4.2f}, {: 4.2f}]".format(
                chassis_velocity.x(),
                chassis_velocity.y(),
                chassis_velocity.z(),
            )

            scene_decorator.setText(0, f"Time: {t:4.2f} s", agxRender.Color.White())
            scene_decorator.setText(
                1,
                f"Chassis speed: {chassis_speed: 4.2f} (m/s), vel: {chassis_velocity_str} (m/s)",
                agxRender.Color.White(),
            )
            scene_decorator.setText(2, "Arrow keys: drive and steer", agxRender.Color.White())
            scene_decorator.setText(3, "h: brake", agxRender.Color.White())
            scene_decorator.setText(
                4,
                f"[j, k]: -/+ max drive speed (max: {runtime_state.target_drive_speed_mps:5.2f} m/s)",
                agxRender.Color.White(),
            )
            scene_decorator.setText(
                5,
                f"[i, o]: -/+ steer step (step: {math.degrees(runtime_state.steer_step_rad):5.1f} deg)",
                agxRender.Color.White(),
            )
            scene_decorator.setText(
                6,
                f"z: auto-accelerate on/off. Mode: {'on' if runtime_state.auto_accelerate else 'off'}",
                agxRender.Color.White(),
            )

            gravity_magnitude = simulation().getUniformGravity().length()
            gravity_text = f"{runtime_state.gravity_mode.value}: {gravity_magnitude:.2f} m/s^2"
            scene_decorator.setText(
                7,
                f"a: toggle gravity mode ({gravity_text})",
                agxRender.Color.White(),
            )

            scene_decorator.setText(8, "p: hide on-screen info", agxRender.Color.White())
            scene_decorator.setText(
                9,
                "Active inputs: " + ", ".join(runtime_state.pressed_keys),
                agxRender.Color.Orange(),
            )
        else:
            scene_decorator.clearText()
            scene_decorator.setText(0, "p: show on-screen info", agxRender.Color.White())

    scene_decorator = application().getSceneDecorator()
    scene_decorator.setBackgroundColor(agx.Vec3(0.0, 0.0, 0.0), agx.Vec3(0.0, 0.0, 0.0))
    StepEventCallback.postCallback(show_on_screen_info)


# -----------------------------------------------------------------------------
# Scene builder
# -----------------------------------------------------------------------------

def build_four_wheel_rig() -> None:
    vehicle_config = VehicleConfig()
    terrain_config = TerrainConfig()
    runtime_state = RuntimeState()

    rig = FourWheelRig(vehicle_config)
    terrain = Terrain(terrain_config)

    rig.contact_material = create_and_setup_contact_material(
        rig=rig,
        terrain=terrain,
    )

    set_number_of_threads(num_threads=4, use_half_of_available_threads=True)
    apply_camera_data(application(), rig.chassis_geometry)

    keyboard_controller = VehicleKeyboardController(rig, runtime_state)
    step_controller = VehicleControllerStepListener(rig, keyboard_controller, runtime_state)

    simulation().add(keyboard_controller)
    simulation().add(step_controller)

    install_overlay(rig, runtime_state)


def buildScene():
    build_four_wheel_rig()


init = init_app(
    name=__name__,
    scenes=[(build_four_wheel_rig, "1")],
    autoStepping=True,
)
