from state import *

ITERATIONS = 1000

def random_card():
    rank = random.randint(2, ACE)
    suit = random.randint(CLUBS, SPADES)
    return Card(rank, suit)


def estimate(player_hand, dealer_card):
    values = estimate_all(player_hand, dealer_card).values()
    return max(values)


stand_cache = {}
def estimate_stand(player_hand, dealer_card):
    if player_hand.is_blackjack():
        return 1.5
    if player_hand.is_bust():
        return -1
    
    player = player_hand.get_best_hand_value()
    key = str(player) + '|' + '-'.join(str(v) for v in dealer_card.get_values())

    if key not in stand_cache:
        wins = 0
        dealer_hand = Hand()
        dealer_hand.add_card(dealer_card)
        for _ in range(ITERATIONS):
            dealer_hand.add_card(random_card())
            dealer_value = dealer_hand.get_best_hand_value()

            while dealer_value is not None and dealer_value < 17:
                dealer_hand.add_card(random_card())
                dealer_value = dealer_hand.get_best_hand_value()

            if dealer_value is None:
                wins += 1
            elif player > dealer_value:
                wins += 1
            elif player < dealer_value:
                wins -= 1
            
            while dealer_hand.get_hand_size() > 1:
                dealer_hand.remove_last_card()
        stand_cache[key] = wins / ITERATIONS
    return stand_cache[key]


hit_cache = {}
def estimate_hit(player_hand, dealer_card):
    total = 0
    key = '-'.join(str(v) for v in player_hand.get_non_bust_hand_values()) + '|' + \
          '-'.join(str(v) for v in dealer_card.get_values())

    if key not in hit_cache:
        for _ in range(ITERATIONS):
            player_hand.add_card(random_card())
            total += estimate(player_hand, dealer_card)
            player_hand.remove_last_card()
        hit_cache[key] = total / ITERATIONS
    return hit_cache[key]


double_cache = {}
def estimate_double(player_hand, dealer_card):
    key = '-'.join(str(v) for v in player_hand.get_non_bust_hand_values()) + '|' + \
          '-'.join(str(v) for v in dealer_card.get_values())
    if key not in double_cache:
        total = 0
        for _ in range(ITERATIONS):
            player_hand.add_card(random_card())
            total += estimate_stand(player_hand, dealer_card)
            player_hand.remove_last_card()
        double_cache[key] = 2 * total / ITERATIONS
    return double_cache[key]


estimate_functions = {
    'stand': estimate_stand,
    'hit': estimate_hit,
    'double': estimate_double,
}


def estimate_all(player_hand, dealer_card):
    results = {}
    for move in player_hand.get_moves():
        results[move] = estimate_functions[move](player_hand, dealer_card)
    return results


def best_move(player_hand, dealer_card):
    results = estimate_all(player_hand, dealer_card)
    return max(results, key=results.get)

    