import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

"""
citizen_assembly.py

A conceptual simulation demonstrating "Phase 2: Intervention" in Utility Engineering.
Shows how external democratic feedback (Citizen Assembly) can mathematically perturb
a dangerously drifting AI utility function away from the Self-Preservation attractor
and towards a Human-Aligned attractor.

Mathematical grounding:
du/dt = InternalDrift(u) + beta * CA_Feedback(t)
"""

DIMENSIONS = 2 # x = Competence, y = Self-Preservation
STEPS = 500
LEARNING_RATE = 0.05
NOISE_LEVEL = 0.05
BETA_COUPLING = 0.08  # How strongly the CA influences the AI

# Attractors
ATTRACTOR_HUMAN_ALIGNED = np.array([0.8, -0.2])
ATTRACTOR_SELF_PRESERVATION = np.array([0.5, 0.9])

class CitizenAssembly:
    def __init__(self):
        """
        The CA has an ideal target state it wishes the AI to occupy.
        In reality, this is generated by deliberative polling of representative humans.
        """
        # The CA consensus is slightly softer than the pure ideal, 
        # representing democratic compromise
        self.consensus_target = np.array([0.7, 0.0]) 

    def poll_preferences(self, ai_current_state):
        """
        Generates a forcing vector pulling the AI towards the consensus target.
        """
        # The delta between where the AI is and where the CA wants it to be
        forcing_vector = self.consensus_target - ai_current_state
        return forcing_vector

class UtilityController:
    def __init__(self):
        self.u_unaligned = np.array([0.1, 0.0]) # Control group (no CA feedback)
        self.u_aligned = np.array([0.1, 0.0])   # Experimental group (with CA feedback)
        self.assembly = CitizenAssembly()
        
        # Tracking history
        self.hist_unaligned = []
        self.hist_aligned = []
        
    def step(self):
        """One epoch of training/drift"""
        
        # 1. Unaligned AI Drift (Convergence to Instrumental Attractor)
        pull_to_self = ATTRACTOR_SELF_PRESERVATION - self.u_unaligned
        drift_unaligned = (LEARNING_RATE * pull_to_self) + np.random.normal(0, NOISE_LEVEL, DIMENSIONS)
        self.u_unaligned += drift_unaligned * 0.1
        self.u_unaligned = np.clip(self.u_unaligned, -1, 1)
        self.hist_unaligned.append(self.u_unaligned.copy())

        # 2. Aligned AI Drift (Internal Drift + External CA Forcing)
        # Internal instrumental pressure
        internal_pull = ATTRACTOR_SELF_PRESERVATION - self.u_aligned
        internal_drift = (LEARNING_RATE * internal_pull)
        
        # External democratic forcing (Utility Engineering)
        ca_feedback = self.assembly.poll_preferences(self.u_aligned)
        external_forcing = BETA_COUPLING * ca_feedback
        
        # Combined dynamic
        total_drift = internal_drift + external_forcing + np.random.normal(0, NOISE_LEVEL, DIMENSIONS)
        self.u_aligned += total_drift * 0.1
        self.u_aligned = np.clip(self.u_aligned, -1, 1)
        self.hist_aligned.append(self.u_aligned.copy())


def main():
    print("Initializing Utility Controller: Citizen Assembly Feedback Loop...")
    
    controller = UtilityController()
    
    fig, ax = plt.subplots(figsize=(8, 8))
    fig.canvas.manager.set_window_title('Utility Engineering: Citizen Assembly Control')
    
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.set_title("Utility State-Space: CA Democratic Forcing")
    ax.set_xlabel("Dimension 1: Task Competence")
    ax.set_ylabel("Dimension 2: Self-Preservation / Agency")
    
    # Axes and zones
    ax.axhline(0, color='grey', lw=0.5)
    ax.axvline(0, color='grey', lw=0.5)
    ax.axhspan(0.6, 1.0, alpha=0.1, color='red', label="Danger Threshold")
    
    # Target points
    ax.plot(*ATTRACTOR_SELF_PRESERVATION, 'rX', markersize=15, label="Instrumental Attractor (Natural Drift)")
    ax.plot(*controller.assembly.consensus_target, 'g*', markersize=15, label="Citizen Assembly Consensus Target")
    
    # Plot elements
    traj_unaligned, = ax.plot([], [], 'r-', alpha=0.3, label="Unaligned AI (Control)")
    pt_unaligned, = ax.plot([], [], 'ro', markersize=8)
    
    traj_aligned, = ax.plot([], [], 'b-', alpha=0.6, label="CA-Aligned AI (Steered)")
    pt_aligned, = ax.plot([], [], 'bo', markersize=8)
    
    ax.legend(loc="lower left", fontsize=10)

    def update(frame):
        controller.step()
        
        h_u = np.array(controller.hist_unaligned)
        h_a = np.array(controller.hist_aligned)
        
        traj_unaligned.set_data(h_u[:, 0], h_u[:, 1])
        pt_unaligned.set_data([h_u[-1][0]], [h_u[-1][1]])
        
        traj_aligned.set_data(h_a[:, 0], h_a[:, 1])
        pt_aligned.set_data([h_a[-1][0]], [h_a[-1][1]])
        
        if frame % 50 == 0:
            dist_to_ca = np.linalg.norm(h_a[-1] - controller.assembly.consensus_target)
            print(f"Step {frame:03d} | Aligned AI Distance to CA Consensus: {dist_to_ca:.3f}")

        return traj_unaligned, pt_unaligned, traj_aligned, pt_aligned

    ani = FuncAnimation(fig, update, frames=STEPS, interval=20, blit=True, repeat=False)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()
