import json
import os
import requests
import re
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Required for AWS Lambda
import matplotlib.pyplot as plt
import seaborn as sns
import boto3
import io

# Golden ratio calculation
golden = (1 + 5 ** 0.5) / 2

def lambda_handler(event, context):
    try:
        # Initialize S3 client
        s3 = boto3.client('s3')
        BUCKET_NAME = 'electionprediction'

        # Matplotlib settings
        plt.rcParams['font.sans-serif'] = 'Liberation Sans'
        plt.rcParams['font.family'] = 'sans-serif'

        url = 'https://electionbettingodds.com'

        # Get the webpage content
        response = requests.get(url)
        response.raise_for_status()
        html_content = response.text

        # Split into lines and find lines starting with "case "
        lines = html_content.split('\n')
        case_lines = [line.strip() for line in lines if line.strip().startswith('case ')]
        description_lines = [line.strip() for line in lines if line.strip().startswith('description ')]

        data_input = [(case_info, description_info) for case_info, description_info in zip(case_lines, description_lines)]

        data = {}

        for case, description in data_input:
            # Extract state abbreviation
            state_abbr = case.split("'")[1]

            # Extract percentages and electoral votes using regex
            rep_percent = float(re.search(r'Republican: ([\d.]+)%', description).group(1)) / 100
            dem_percent = float(re.search(r'Democrat: ([\d.]+)%', description).group(1)) / 100
            votes = int(re.search(r'Electoral votes: (\d+)', description).group(1))

            # Add to dictionary in the specified format
            data[state_abbr] = {
                'Case': state_abbr,
                'Republican': round(rep_percent, 3),
                'Democrat': round(dem_percent, 3),
                'Votes': votes
            }

        # Convert dictionary to DataFrame
        df = pd.DataFrame.from_dict(data, orient='index')

        df.rename(columns={'Case': 'State'}, inplace=True)

        # Calculate Trump's support ratio
        df["Trump"] = df.Republican / (df.Republican + df.Democrat)

        # Set bounds on Trump's support
        round_thresh = 0.03
        df.loc[df.Trump >= 1 - round_thresh, "Trump"] = 0.9999
        df.loc[df.Trump <= round_thresh, "Trump"] = 0.0001

        # Calculate Kamala's support
        df["Kamala"] = 1 - df.Trump
        df.sort_values("Trump", ascending=False)

        # Variance and covariance scaling
        var_scale = 0.4
        cov_scale = var_scale * 0.7

        # Calculate log-odds
        df["LogOdds"] = np.log(df.Trump / df.Kamala)
        mean = df.LogOdds.values

        # Calculate variance
        df["Max"] = df[["Trump", "Kamala"]].max(axis=1)
        df["Var"] = df.Max * (1 - df.Max) * var_scale

        # Construct covariance matrix
        Cov = (1 - df.Max).values.reshape(-1, 1) * (1 - df.Max).values * cov_scale
        np.fill_diagonal(Cov, df.Var)

        # Generate samples
        dp = np.random.multivariate_normal(mean, Cov, size=20_000)
        dp = pd.DataFrame(dp, columns=df.State.values)

        # Convert log-odds to probabilities
        prob = 1 / (1 + np.exp(-dp.values))

        # Generate binomial outcomes
        reals = np.random.binomial(1, prob)
        dp = pd.DataFrame(reals, columns=df.State.values)

        # Calculate electoral votes
        dp = pd.DataFrame(dp.values * df.set_index("State").Votes.values, columns=df.State.values)

        # Sum electoral votes
        e = dp.sum(axis=1)

        # Create plot
        h = 4
        plt.figure(figsize=(golden * h, h))

        # Plot histogram
        sns.histplot(e, binwidth=5, binrange=(175, 375), kde=False, color="grey")
        ax = plt.gca()
        ax.set_box_aspect(1 / golden)

        # Add vertical line
        plt.axvline(270, color="red", linestyle="-")
        plt.xlabel("Electoral Votes", labelpad=5, fontsize=12)
        plt.ylabel(None)
        plt.xlim(170, 370)

        # Customize plot aesthetics
        ax = plt.gca()
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        plt.yticks([])
        plt.xticks(np.arange(220, 370, 50))

        # Set titles
        plt.suptitle(f"{(e >= 270).mean():.0%} chance of Trump win", y=0.98, fontsize=12)
        plt.title(f"{(e >= 319).mean():.0%} chance of Trump win by +100 electoral votes",
                fontsize=8, pad=5)

        # Save to memory buffer
        img_data = io.BytesIO()
        plt.savefig(img_data, format='png', dpi=300)
        img_data.seek(0)

        # Get the image data as bytes
        image_bytes = img_data.getvalue()

        # Save timestamped version to S3
        date = pd.Timestamp.now().strftime("%Y_%m_%d_%H_%M_%S")
        timestamped_key = f"electoral_votes_{date}.png"
        s3.upload_fileobj(io.BytesIO(image_bytes), BUCKET_NAME, timestamped_key)

        # Save latest version to S3
        latest_key = "electoral_votes_latest.png"
        s3.upload_fileobj(io.BytesIO(image_bytes), BUCKET_NAME, latest_key)

        plt.close()

        return {
            'statusCode': 200,
            'body': json.dumps(f'Successfully generated and uploaded plots to {BUCKET_NAME}')
        }

    except Exception as e:
        print(f"Error: {str(e)}")
        return {
            'statusCode': 500,
            'body': json.dumps(f'Error: {str(e)}')
        }