Spaces:
Runtime error
Runtime error
from flask import Flask, render_template | |
from flask_socketio import SocketIO, emit | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import io # Changed this line - io is a built-in Python module | |
import time | |
import threading | |
import random | |
app = Flask(__name__) | |
socketio = SocketIO(app) | |
# Initialize model with lower precision | |
MODEL_NAME = "Qwen/Qwen-1_8B-Chat" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
# Game Constants | |
GRID_SIZE = 12 # Smaller grid for performance | |
CELL_SIZE = 40 | |
COLORS = { | |
'background': 'white', | |
'grid': 'lightgray', | |
'snake': 'red', | |
'agent': 'blue', | |
'obstacle': 'gray' | |
} | |
class GameState: | |
def __init__(self): | |
self.snake = [6, 6] # Center | |
self.agents = [[2, 2], [9, 9], [2, 9]] | |
self.obstacles = [[4, 4], [7, 7], [4, 7]] | |
self.scores = {'snake': 0, 'agents': 0} | |
self.history = [] | |
def get_agent_state(self, agent_idx): | |
return { | |
'position': self.agents[agent_idx], | |
'snake_pos': self.snake, | |
'other_agents': [pos for i, pos in enumerate(self.agents) if i != agent_idx], | |
'obstacles': self.obstacles | |
} | |
game = GameState() | |
def get_model_decision(role, state): | |
"""Get next move from Qwen model.""" | |
if role == "snake": | |
prompt = f"You are a predator trying to catch prey. Your position is {state['position']}, prey positions are {state['other_agents']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word." | |
else: | |
prompt = f"You are prey avoiding a predator. Your position is {state['position']}, predator position is {state['snake_pos']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word." | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=10, | |
temperature=0.7, | |
do_sample=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract move from response | |
moves = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"] | |
for move in moves: | |
if move in response.upper(): | |
return move | |
return "STAY" | |
def apply_move(position, move): | |
"""Apply move while respecting grid boundaries.""" | |
x, y = position.copy() | |
if move == "UP" and y > 0: | |
y -= 1 | |
elif move == "DOWN" and y < GRID_SIZE - 1: | |
y += 1 | |
elif move == "LEFT" and x > 0: | |
x -= 1 | |
elif move == "RIGHT" and x < GRID_SIZE - 1: | |
x += 1 | |
return [x, y] | |
def create_game_image(): | |
"""Create game visualization.""" | |
img = Image.new("RGB", (GRID_SIZE * CELL_SIZE, GRID_SIZE * CELL_SIZE), COLORS['background']) | |
draw = ImageDraw.Draw(img) | |
# Draw grid | |
for i in range(GRID_SIZE + 1): | |
draw.line([(i * CELL_SIZE, 0), (i * CELL_SIZE, GRID_SIZE * CELL_SIZE)], fill=COLORS['grid']) | |
draw.line([(0, i * CELL_SIZE), (GRID_SIZE * CELL_SIZE, i * CELL_SIZE)], fill=COLORS['grid']) | |
# Draw obstacles | |
for pos in game.obstacles: | |
draw.rectangle([ | |
pos[0] * CELL_SIZE, pos[1] * CELL_SIZE, | |
(pos[0] + 1) * CELL_SIZE, (pos[1] + 1) * CELL_SIZE | |
], fill=COLORS['obstacle']) | |
# Draw agents | |
for pos in game.agents: | |
center = ((pos[0] + 0.5) * CELL_SIZE, (pos[1] + 0.5) * CELL_SIZE) | |
radius = CELL_SIZE // 3 | |
draw.ellipse([ | |
center[0] - radius, center[1] - radius, | |
center[0] + radius, center[1] + radius | |
], fill=COLORS['agent']) | |
# Draw snake | |
center = ((game.snake[0] + 0.5) * CELL_SIZE, (game.snake[1] + 0.5) * CELL_SIZE) | |
radius = CELL_SIZE // 3 | |
draw.ellipse([ | |
center[0] - radius, center[1] - radius, | |
center[0] + radius, center[1] + radius | |
], fill=COLORS['snake']) | |
# Add scores | |
draw.text((10, 10), f"Snake: {game.scores['snake']} | Agents: {game.scores['agents']}", fill="black") | |
# Convert to bytes | |
img_byte_arr = io.BytesIO() | |
img.save(img_byte_arr, format='PNG') | |
img_byte_arr.seek(0) | |
return img_byte_arr | |
def update_game(): | |
"""Update game state for one turn.""" | |
# Snake's turn | |
snake_state = {'position': game.snake, 'other_agents': game.agents} | |
snake_move = get_model_decision('snake', snake_state) | |
new_pos = apply_move(game.snake, snake_move) | |
if new_pos not in game.obstacles: | |
game.snake = new_pos | |
# Agents' turns | |
for i in range(len(game.agents)): | |
agent_state = game.get_agent_state(i) | |
agent_move = get_model_decision('agent', agent_state) | |
new_pos = apply_move(game.agents[i], agent_move) | |
if new_pos not in game.obstacles: | |
game.agents[i] = new_pos | |
# Check captures | |
for i, agent_pos in enumerate(game.agents): | |
if agent_pos == game.snake: | |
game.scores['snake'] += 1 | |
# Respawn agent | |
while True: | |
new_pos = [random.randint(0, GRID_SIZE - 1), random.randint(0, GRID_SIZE - 1)] | |
if new_pos not in game.obstacles and new_pos != game.snake: | |
game.agents[i] = new_pos | |
break | |
def game_loop(): | |
"""Main game loop.""" | |
while True: | |
update_game() | |
img_bytes = create_game_image() | |
socketio.emit('game_update', { | |
'image': img_bytes.getvalue().hex(), | |
'scores': game.scores | |
}) | |
time.sleep(1.0) # Slower updates to reduce resource usage | |
def index(): | |
return render_template('index.html') | |
def handle_connect(): | |
print('Client connected') | |
if __name__ == '__main__': | |
threading.Thread(target=game_loop, daemon=True).start() | |
socketio.run(app, host='0.0.0.0', port=7860) |