# -*- coding: utf-8 -*- """app.py Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1NU6NHjan4eF9IVHR549tKLRVNUQ_dBD7 """ import os import torch import numpy as np import requests import json import gradio as gr from dotenv import load_dotenv import torch.nn as nn # ---- Load env variables ---- OPENROUTER_KEY = os.getenv("OPENROUTER_KEY") if not OPENROUTER_KEY: raise ValueError("OPENROUTER_KEY not set in environment variables.") # ---- Blackjack Environment ---- import random class BlackjackEnv: def __init__(self): self.dealer = [] self.player = [] self.usable_ace_player = False def draw_card(self): return random.randint(1, 10) def sum_hand(self, hand): total = sum(hand) ace = 1 in hand if ace and total + 10 <= 21: return total + 10, True return total, False def reset(self): self.player = [self.draw_card(), self.draw_card()] self.dealer = [self.draw_card()] total, usable_ace = self.sum_hand(self.player) self.usable_ace_player = usable_ace return (self.dealer[0], total, int(usable_ace)) def step(self, action): if action == 1: self.player.append(self.draw_card()) total, usable_ace = self.sum_hand(self.player) if total > 21: return (self.dealer[0], total, int(usable_ace)), -1, True return (self.dealer[0], total, int(usable_ace)), 0, False else: dealer_hand = self.dealer + [self.draw_card()] dealer_total, _ = self.sum_hand(dealer_hand) player_total, _ = self.sum_hand(self.player) if dealer_total < player_total: return (self.dealer[0], player_total, int(self.usable_ace_player)), 1, True elif dealer_total > player_total: return (self.dealer[0], player_total, int(self.usable_ace_player)), -1, True else: return (self.dealer[0], player_total, int(self.usable_ace_player)), 0, True # ---- QNetwork ---- class QNetwork(nn.Module): def __init__(self, state_size=3, hidden_size=128, action_size=2): super(QNetwork, self).__init__() self.model = nn.Sequential( nn.Linear(state_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, action_size) ) def forward(self, x): return self.model(x) # ---- Load model ---- model = QNetwork() model_path = "qnetwork_blackjack_weights.pth" model.load_state_dict(torch.load(model_path)) model.eval() env = BlackjackEnv() # ---- LLM Explanation ---- def explain_action(state, action): prompt = f""" You are a blackjack strategy explainer. The player has a total of {state[1]}. The dealer is showing {state[0]}. Usable ace: {bool(state[2])}. The DQN model chose to {'Hit' if action == 1 else 'Stick'}. Explain why this action makes sense in 2-3 sentences. """ headers = { "Authorization": f"Bearer {OPENROUTER_KEY}", "Content-Type": "application/json" } data = { "model": "mistralai/mistral-7b-instruct", "messages": [ {"role": "system", "content": "You explain blackjack strategies clearly."}, {"role": "user", "content": prompt} ] } try: response = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, data=json.dumps(data)) if response.status_code == 200: return response.json()['choices'][0]['message']['content'] return f"LLM error: {response.status_code} - {response.text}" except Exception as e: return f"LLM call failed: {str(e)}" # ---- Gradio App ---- def play_hand(): state = env.reset() state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): q_values = model(state_tensor) action = torch.argmax(q_values).item() explanation = explain_action(state, action) action_name = "Hit" if action == 1 else "Stick" dealer_card, player_sum, usable_ace = state return [ str(player_sum), str(dealer_card), str(bool(usable_ace)), action_name, str(q_values.numpy().tolist()), explanation ] demo = gr.Interface( fn=play_hand, inputs=[], outputs=[ gr.Textbox(label="Player Sum"), gr.Textbox(label="Dealer Card"), gr.Textbox(label="Usable Ace"), gr.Textbox(label="DQN Action"), gr.Textbox(label="Q-values"), gr.Textbox(label="LLM Explanation") ], title="🧠 Blackjack Tutor: DQN + LLM", description="Play a hand of blackjack. See how a Deep Q Network plays, and get a natural language explanation from Mistral-7B via OpenRouter." ) if __name__ == "__main__": demo.launch() import os print(os.listdir())