from flask import Flask, render_template from flask_socketio import SocketIO, emit import torch from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np from PIL import Image, ImageDraw import io # Changed this line - io is a built-in Python module import time import threading import random app = Flask(__name__) socketio = SocketIO(app) # Initialize model with lower precision MODEL_NAME = "Qwen/Qwen-1_8B-Chat" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) # Game Constants GRID_SIZE = 12 # Smaller grid for performance CELL_SIZE = 40 COLORS = { 'background': 'white', 'grid': 'lightgray', 'snake': 'red', 'agent': 'blue', 'obstacle': 'gray' } class GameState: def __init__(self): self.snake = [6, 6] # Center self.agents = [[2, 2], [9, 9], [2, 9]] self.obstacles = [[4, 4], [7, 7], [4, 7]] self.scores = {'snake': 0, 'agents': 0} self.history = [] def get_agent_state(self, agent_idx): return { 'position': self.agents[agent_idx], 'snake_pos': self.snake, 'other_agents': [pos for i, pos in enumerate(self.agents) if i != agent_idx], 'obstacles': self.obstacles } game = GameState() def get_model_decision(role, state): """Get next move from Qwen model.""" if role == "snake": prompt = f"You are a predator trying to catch prey. Your position is {state['position']}, prey positions are {state['other_agents']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word." else: prompt = f"You are prey avoiding a predator. Your position is {state['position']}, predator position is {state['snake_pos']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word." inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=10, temperature=0.7, do_sample=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract move from response moves = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"] for move in moves: if move in response.upper(): return move return "STAY" def apply_move(position, move): """Apply move while respecting grid boundaries.""" x, y = position.copy() if move == "UP" and y > 0: y -= 1 elif move == "DOWN" and y < GRID_SIZE - 1: y += 1 elif move == "LEFT" and x > 0: x -= 1 elif move == "RIGHT" and x < GRID_SIZE - 1: x += 1 return [x, y] def create_game_image(): """Create game visualization.""" img = Image.new("RGB", (GRID_SIZE * CELL_SIZE, GRID_SIZE * CELL_SIZE), COLORS['background']) draw = ImageDraw.Draw(img) # Draw grid for i in range(GRID_SIZE + 1): draw.line([(i * CELL_SIZE, 0), (i * CELL_SIZE, GRID_SIZE * CELL_SIZE)], fill=COLORS['grid']) draw.line([(0, i * CELL_SIZE), (GRID_SIZE * CELL_SIZE, i * CELL_SIZE)], fill=COLORS['grid']) # Draw obstacles for pos in game.obstacles: draw.rectangle([ pos[0] * CELL_SIZE, pos[1] * CELL_SIZE, (pos[0] + 1) * CELL_SIZE, (pos[1] + 1) * CELL_SIZE ], fill=COLORS['obstacle']) # Draw agents for pos in game.agents: center = ((pos[0] + 0.5) * CELL_SIZE, (pos[1] + 0.5) * CELL_SIZE) radius = CELL_SIZE // 3 draw.ellipse([ center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius ], fill=COLORS['agent']) # Draw snake center = ((game.snake[0] + 0.5) * CELL_SIZE, (game.snake[1] + 0.5) * CELL_SIZE) radius = CELL_SIZE // 3 draw.ellipse([ center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius ], fill=COLORS['snake']) # Add scores draw.text((10, 10), f"Snake: {game.scores['snake']} | Agents: {game.scores['agents']}", fill="black") # Convert to bytes img_byte_arr = io.BytesIO() img.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return img_byte_arr def update_game(): """Update game state for one turn.""" # Snake's turn snake_state = {'position': game.snake, 'other_agents': game.agents} snake_move = get_model_decision('snake', snake_state) new_pos = apply_move(game.snake, snake_move) if new_pos not in game.obstacles: game.snake = new_pos # Agents' turns for i in range(len(game.agents)): agent_state = game.get_agent_state(i) agent_move = get_model_decision('agent', agent_state) new_pos = apply_move(game.agents[i], agent_move) if new_pos not in game.obstacles: game.agents[i] = new_pos # Check captures for i, agent_pos in enumerate(game.agents): if agent_pos == game.snake: game.scores['snake'] += 1 # Respawn agent while True: new_pos = [random.randint(0, GRID_SIZE - 1), random.randint(0, GRID_SIZE - 1)] if new_pos not in game.obstacles and new_pos != game.snake: game.agents[i] = new_pos break def game_loop(): """Main game loop.""" while True: update_game() img_bytes = create_game_image() socketio.emit('game_update', { 'image': img_bytes.getvalue().hex(), 'scores': game.scores }) time.sleep(1.0) # Slower updates to reduce resource usage @app.route('/') def index(): return render_template('index.html') @socketio.on('connect') def handle_connect(): print('Client connected') if __name__ == '__main__': threading.Thread(target=game_loop, daemon=True).start() socketio.run(app, host='0.0.0.0', port=7860)