Snake_agents / app.py
SmokeyBandit's picture
Update app.py
141636f verified
from flask import Flask, render_template
from flask_socketio import SocketIO, emit
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from PIL import Image, ImageDraw
import io # Changed this line - io is a built-in Python module
import time
import threading
import random
app = Flask(__name__)
socketio = SocketIO(app)
# Initialize model with lower precision
MODEL_NAME = "Qwen/Qwen-1_8B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Game Constants
GRID_SIZE = 12 # Smaller grid for performance
CELL_SIZE = 40
COLORS = {
'background': 'white',
'grid': 'lightgray',
'snake': 'red',
'agent': 'blue',
'obstacle': 'gray'
}
class GameState:
def __init__(self):
self.snake = [6, 6] # Center
self.agents = [[2, 2], [9, 9], [2, 9]]
self.obstacles = [[4, 4], [7, 7], [4, 7]]
self.scores = {'snake': 0, 'agents': 0}
self.history = []
def get_agent_state(self, agent_idx):
return {
'position': self.agents[agent_idx],
'snake_pos': self.snake,
'other_agents': [pos for i, pos in enumerate(self.agents) if i != agent_idx],
'obstacles': self.obstacles
}
game = GameState()
def get_model_decision(role, state):
"""Get next move from Qwen model."""
if role == "snake":
prompt = f"You are a predator trying to catch prey. Your position is {state['position']}, prey positions are {state['other_agents']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word."
else:
prompt = f"You are prey avoiding a predator. Your position is {state['position']}, predator position is {state['snake_pos']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word."
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=10,
temperature=0.7,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract move from response
moves = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"]
for move in moves:
if move in response.upper():
return move
return "STAY"
def apply_move(position, move):
"""Apply move while respecting grid boundaries."""
x, y = position.copy()
if move == "UP" and y > 0:
y -= 1
elif move == "DOWN" and y < GRID_SIZE - 1:
y += 1
elif move == "LEFT" and x > 0:
x -= 1
elif move == "RIGHT" and x < GRID_SIZE - 1:
x += 1
return [x, y]
def create_game_image():
"""Create game visualization."""
img = Image.new("RGB", (GRID_SIZE * CELL_SIZE, GRID_SIZE * CELL_SIZE), COLORS['background'])
draw = ImageDraw.Draw(img)
# Draw grid
for i in range(GRID_SIZE + 1):
draw.line([(i * CELL_SIZE, 0), (i * CELL_SIZE, GRID_SIZE * CELL_SIZE)], fill=COLORS['grid'])
draw.line([(0, i * CELL_SIZE), (GRID_SIZE * CELL_SIZE, i * CELL_SIZE)], fill=COLORS['grid'])
# Draw obstacles
for pos in game.obstacles:
draw.rectangle([
pos[0] * CELL_SIZE, pos[1] * CELL_SIZE,
(pos[0] + 1) * CELL_SIZE, (pos[1] + 1) * CELL_SIZE
], fill=COLORS['obstacle'])
# Draw agents
for pos in game.agents:
center = ((pos[0] + 0.5) * CELL_SIZE, (pos[1] + 0.5) * CELL_SIZE)
radius = CELL_SIZE // 3
draw.ellipse([
center[0] - radius, center[1] - radius,
center[0] + radius, center[1] + radius
], fill=COLORS['agent'])
# Draw snake
center = ((game.snake[0] + 0.5) * CELL_SIZE, (game.snake[1] + 0.5) * CELL_SIZE)
radius = CELL_SIZE // 3
draw.ellipse([
center[0] - radius, center[1] - radius,
center[0] + radius, center[1] + radius
], fill=COLORS['snake'])
# Add scores
draw.text((10, 10), f"Snake: {game.scores['snake']} | Agents: {game.scores['agents']}", fill="black")
# Convert to bytes
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return img_byte_arr
def update_game():
"""Update game state for one turn."""
# Snake's turn
snake_state = {'position': game.snake, 'other_agents': game.agents}
snake_move = get_model_decision('snake', snake_state)
new_pos = apply_move(game.snake, snake_move)
if new_pos not in game.obstacles:
game.snake = new_pos
# Agents' turns
for i in range(len(game.agents)):
agent_state = game.get_agent_state(i)
agent_move = get_model_decision('agent', agent_state)
new_pos = apply_move(game.agents[i], agent_move)
if new_pos not in game.obstacles:
game.agents[i] = new_pos
# Check captures
for i, agent_pos in enumerate(game.agents):
if agent_pos == game.snake:
game.scores['snake'] += 1
# Respawn agent
while True:
new_pos = [random.randint(0, GRID_SIZE - 1), random.randint(0, GRID_SIZE - 1)]
if new_pos not in game.obstacles and new_pos != game.snake:
game.agents[i] = new_pos
break
def game_loop():
"""Main game loop."""
while True:
update_game()
img_bytes = create_game_image()
socketio.emit('game_update', {
'image': img_bytes.getvalue().hex(),
'scores': game.scores
})
time.sleep(1.0) # Slower updates to reduce resource usage
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('connect')
def handle_connect():
print('Client connected')
if __name__ == '__main__':
threading.Thread(target=game_loop, daemon=True).start()
socketio.run(app, host='0.0.0.0', port=7860)