import gymnasium as gym
import torch

from networks import DQN

env = gym.make("CartPole-v1", render_mode="human")
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to("cpu")
try:
    policy_net.load_state_dict(torch.load("model.pth"))
    print("Modèle chargé.")
except FileNotFoundError:
    print("Modèle non trouvé.")

def select_action(state):
    return policy_net(state).max(1).indices.view(1, 1)

while True:
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device="cpu").unsqueeze(0)

    for t in range(1000):
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        state = torch.tensor(observation, dtype=torch.float32, device="cpu").unsqueeze(0)
        if terminated or truncated:
            print(f"Épisode terminé après {t+1} étapes.")
            break
    
    if input("Appuyez sur q pour quitter, ou une autre touche pour continuer :") == "q":
        break

env.close()