Spaces:
Sleeping
Sleeping
feat: implement Redis-backed user state management and update state handling
Browse files- app.py +19 -0
- packages.txt +1 -0
- requirements.txt +4 -1
- src/agent/llm_graph.py +2 -2
- src/agent/redis_state.py +51 -0
- src/agent/runner.py +2 -2
- src/agent/tools.py +11 -11
- src/main.py +2 -2
app.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import time
|
3 |
+
import atexit
|
4 |
+
|
5 |
+
# Launch local Redis server
|
6 |
+
redis_process = subprocess.Popen([
|
7 |
+
"redis-server",
|
8 |
+
"--save", "",
|
9 |
+
"--appendonly", "no",
|
10 |
+
])
|
11 |
+
|
12 |
+
# Ensure Redis shuts down on exit
|
13 |
+
atexit.register(redis_process.terminate)
|
14 |
+
|
15 |
+
# Give the server a moment to start
|
16 |
+
time.sleep(0.5)
|
17 |
+
|
18 |
+
# Import and run the Gradio app
|
19 |
+
import src.main # noqa: F401
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
redis-server
|
requirements.txt
CHANGED
@@ -14,4 +14,7 @@ langchain-community==0.3.16
|
|
14 |
langchain-google-genai==2.1.4
|
15 |
pydantic-core==2.23.4
|
16 |
pydantic-settings==2.7.1
|
17 |
-
pydantic==2.9.2
|
|
|
|
|
|
|
|
14 |
langchain-google-genai==2.1.4
|
15 |
pydantic-core==2.23.4
|
16 |
pydantic-settings==2.7.1
|
17 |
+
pydantic==2.9.2
|
18 |
+
redis[hiredis]>=5
|
19 |
+
aioredis>=2
|
20 |
+
msgpack
|
src/agent/llm_graph.py
CHANGED
@@ -14,7 +14,7 @@ from agent.tools import (
|
|
14 |
generate_story_frame,
|
15 |
update_state_with_choice,
|
16 |
)
|
17 |
-
from agent.
|
18 |
from audio.audio_generator import change_music_tone
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
@@ -75,7 +75,7 @@ async def node_init_game(state: GraphState) -> GraphState:
|
|
75 |
|
76 |
async def node_player_step(state: GraphState) -> GraphState:
|
77 |
logger.debug("[Graph] node_player_step state: %s", state)
|
78 |
-
user_state = get_user_state(state.user_hash)
|
79 |
scene_id = user_state.current_scene_id
|
80 |
if state.choice_text:
|
81 |
await update_state_with_choice.ainvoke(
|
|
|
14 |
generate_story_frame,
|
15 |
update_state_with_choice,
|
16 |
)
|
17 |
+
from agent.redis_state import get_user_state
|
18 |
from audio.audio_generator import change_music_tone
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
|
|
75 |
|
76 |
async def node_player_step(state: GraphState) -> GraphState:
|
77 |
logger.debug("[Graph] node_player_step state: %s", state)
|
78 |
+
user_state = await get_user_state(state.user_hash)
|
79 |
scene_id = user_state.current_scene_id
|
80 |
if state.choice_text:
|
81 |
await update_state_with_choice.ainvoke(
|
src/agent/redis_state.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Async Redis-backed user state storage."""
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import json
|
6 |
+
import msgpack
|
7 |
+
import redis.asyncio as redis
|
8 |
+
|
9 |
+
from agent.models import UserState
|
10 |
+
|
11 |
+
|
12 |
+
class UserRepository:
|
13 |
+
"""Repository for storing UserState objects in Redis."""
|
14 |
+
|
15 |
+
def __init__(self, redis_url: str = "redis://localhost") -> None:
|
16 |
+
self.redis = redis.from_url(redis_url)
|
17 |
+
|
18 |
+
async def get(self, user_id: str) -> UserState:
|
19 |
+
"""Return user state for the given id, creating it if absent."""
|
20 |
+
key = f"llmgamehub:{user_id}"
|
21 |
+
data = await self.redis.hget(key, "data")
|
22 |
+
if data is None:
|
23 |
+
return UserState()
|
24 |
+
state_dict = msgpack.unpackb(data, raw=False)
|
25 |
+
return UserState.parse_obj(state_dict)
|
26 |
+
|
27 |
+
async def set(self, user_id: str, state: UserState) -> None:
|
28 |
+
"""Persist updated user state."""
|
29 |
+
key = f"llmgamehub:{user_id}"
|
30 |
+
packed = msgpack.packb(json.loads(state.json()))
|
31 |
+
await self.redis.hset(key, mapping={"data": packed})
|
32 |
+
|
33 |
+
async def reset(self, user_id: str) -> None:
|
34 |
+
"""Remove stored state for a user."""
|
35 |
+
key = f"llmgamehub:{user_id}"
|
36 |
+
await self.redis.delete(key)
|
37 |
+
|
38 |
+
|
39 |
+
_repo = UserRepository()
|
40 |
+
|
41 |
+
|
42 |
+
async def get_user_state(user_hash: str) -> UserState:
|
43 |
+
return await _repo.get(user_hash)
|
44 |
+
|
45 |
+
|
46 |
+
async def set_user_state(user_hash: str, state: UserState) -> None:
|
47 |
+
await _repo.set(user_hash, state)
|
48 |
+
|
49 |
+
|
50 |
+
async def reset_user_state(user_hash: str) -> None:
|
51 |
+
await _repo.reset(user_hash)
|
src/agent/runner.py
CHANGED
@@ -10,7 +10,7 @@ from agent.tools import generate_scene_image
|
|
10 |
|
11 |
from agent.llm_graph import GraphState, llm_game_graph
|
12 |
from agent.models import UserState
|
13 |
-
from agent.
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
@@ -38,7 +38,7 @@ async def process_step(
|
|
38 |
|
39 |
final_state = await llm_game_graph.ainvoke(asdict(graph_state))
|
40 |
|
41 |
-
user_state: UserState = get_user_state(user_hash)
|
42 |
response: Dict = {}
|
43 |
|
44 |
ending = final_state.get("ending")
|
|
|
10 |
|
11 |
from agent.llm_graph import GraphState, llm_game_graph
|
12 |
from agent.models import UserState
|
13 |
+
from agent.redis_state import get_user_state
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
|
|
38 |
|
39 |
final_state = await llm_game_graph.ainvoke(asdict(graph_state))
|
40 |
|
41 |
+
user_state: UserState = await get_user_state(user_hash)
|
42 |
response: Dict = {}
|
43 |
|
44 |
ending = final_state.get("ending")
|
src/agent/tools.py
CHANGED
@@ -17,7 +17,7 @@ from agent.models import (
|
|
17 |
UserChoice,
|
18 |
)
|
19 |
from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT
|
20 |
-
from agent.
|
21 |
from images.image_generator import modify_image, generate_image
|
22 |
from agent.image_agent import ChangeScene
|
23 |
|
@@ -53,9 +53,9 @@ async def generate_story_frame(
|
|
53 |
character=character,
|
54 |
genre=genre,
|
55 |
)
|
56 |
-
state = get_user_state(user_hash)
|
57 |
state.story_frame = story_frame
|
58 |
-
set_user_state(user_hash, state)
|
59 |
return story_frame.dict()
|
60 |
|
61 |
|
@@ -65,7 +65,7 @@ async def generate_scene(
|
|
65 |
last_choice: Annotated[str, "Last user choice"],
|
66 |
) -> Annotated[Dict, "Generated scene"]:
|
67 |
"""Generate a new scene based on the current user state."""
|
68 |
-
state = get_user_state(user_hash)
|
69 |
if not state.story_frame:
|
70 |
return _err("Story frame not initialized")
|
71 |
llm = create_llm().with_structured_output(SceneLLM)
|
@@ -98,7 +98,7 @@ async def generate_scene(
|
|
98 |
)
|
99 |
state.current_scene_id = scene_id
|
100 |
state.scenes[scene_id] = scene
|
101 |
-
set_user_state(user_hash, state)
|
102 |
return scene.dict()
|
103 |
|
104 |
|
@@ -119,10 +119,10 @@ async def generate_scene_image(
|
|
119 |
# for now always modify the image to avoid the generating an update in a completely wrong style
|
120 |
else modify_image(current_image, change_scene.scene_description)
|
121 |
)
|
122 |
-
state = get_user_state(user_hash)
|
123 |
if scene_id in state.scenes:
|
124 |
state.scenes[scene_id].image = image_path
|
125 |
-
set_user_state(user_hash, state)
|
126 |
return image_path
|
127 |
except Exception as exc: # noqa: BLE001
|
128 |
return _err(str(exc))
|
@@ -137,7 +137,7 @@ async def update_state_with_choice(
|
|
137 |
"""Record the player's choice in the state."""
|
138 |
import datetime
|
139 |
|
140 |
-
state = get_user_state(user_hash)
|
141 |
state.user_choices.append(
|
142 |
UserChoice(
|
143 |
scene_id=scene_id,
|
@@ -145,7 +145,7 @@ async def update_state_with_choice(
|
|
145 |
timestamp=datetime.datetime.utcnow().isoformat(),
|
146 |
)
|
147 |
)
|
148 |
-
set_user_state(user_hash, state)
|
149 |
return state.dict()
|
150 |
|
151 |
|
@@ -154,7 +154,7 @@ async def check_ending(
|
|
154 |
user_hash: Annotated[str, "User session ID"],
|
155 |
) -> Annotated[Dict, "Ending check result"]:
|
156 |
"""Check whether an ending has been reached."""
|
157 |
-
state = get_user_state(user_hash)
|
158 |
if not state.story_frame:
|
159 |
return _err("No story frame")
|
160 |
llm = create_llm().with_structured_output(EndingCheckResult)
|
@@ -166,6 +166,6 @@ async def check_ending(
|
|
166 |
resp: EndingCheckResult = await llm.ainvoke(prompt)
|
167 |
if resp.ending_reached and resp.ending:
|
168 |
state.ending = resp.ending
|
169 |
-
set_user_state(user_hash, state)
|
170 |
return {"ending_reached": True, "ending": resp.ending.dict()}
|
171 |
return {"ending_reached": False}
|
|
|
17 |
UserChoice,
|
18 |
)
|
19 |
from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT
|
20 |
+
from agent.redis_state import get_user_state, set_user_state
|
21 |
from images.image_generator import modify_image, generate_image
|
22 |
from agent.image_agent import ChangeScene
|
23 |
|
|
|
53 |
character=character,
|
54 |
genre=genre,
|
55 |
)
|
56 |
+
state = await get_user_state(user_hash)
|
57 |
state.story_frame = story_frame
|
58 |
+
await set_user_state(user_hash, state)
|
59 |
return story_frame.dict()
|
60 |
|
61 |
|
|
|
65 |
last_choice: Annotated[str, "Last user choice"],
|
66 |
) -> Annotated[Dict, "Generated scene"]:
|
67 |
"""Generate a new scene based on the current user state."""
|
68 |
+
state = await get_user_state(user_hash)
|
69 |
if not state.story_frame:
|
70 |
return _err("Story frame not initialized")
|
71 |
llm = create_llm().with_structured_output(SceneLLM)
|
|
|
98 |
)
|
99 |
state.current_scene_id = scene_id
|
100 |
state.scenes[scene_id] = scene
|
101 |
+
await set_user_state(user_hash, state)
|
102 |
return scene.dict()
|
103 |
|
104 |
|
|
|
119 |
# for now always modify the image to avoid the generating an update in a completely wrong style
|
120 |
else modify_image(current_image, change_scene.scene_description)
|
121 |
)
|
122 |
+
state = await get_user_state(user_hash)
|
123 |
if scene_id in state.scenes:
|
124 |
state.scenes[scene_id].image = image_path
|
125 |
+
await set_user_state(user_hash, state)
|
126 |
return image_path
|
127 |
except Exception as exc: # noqa: BLE001
|
128 |
return _err(str(exc))
|
|
|
137 |
"""Record the player's choice in the state."""
|
138 |
import datetime
|
139 |
|
140 |
+
state = await get_user_state(user_hash)
|
141 |
state.user_choices.append(
|
142 |
UserChoice(
|
143 |
scene_id=scene_id,
|
|
|
145 |
timestamp=datetime.datetime.utcnow().isoformat(),
|
146 |
)
|
147 |
)
|
148 |
+
await set_user_state(user_hash, state)
|
149 |
return state.dict()
|
150 |
|
151 |
|
|
|
154 |
user_hash: Annotated[str, "User session ID"],
|
155 |
) -> Annotated[Dict, "Ending check result"]:
|
156 |
"""Check whether an ending has been reached."""
|
157 |
+
state = await get_user_state(user_hash)
|
158 |
if not state.story_frame:
|
159 |
return _err("No story frame")
|
160 |
llm = create_llm().with_structured_output(EndingCheckResult)
|
|
|
166 |
resp: EndingCheckResult = await llm.ainvoke(prompt)
|
167 |
if resp.ending_reached and resp.ending:
|
168 |
state.ending = resp.ending
|
169 |
+
await set_user_state(user_hash, state)
|
170 |
return {"ending_reached": True, "ending": resp.ending.dict()}
|
171 |
return {"ending_reached": False}
|
src/main.py
CHANGED
@@ -26,9 +26,9 @@ logger = logging.getLogger(__name__)
|
|
26 |
|
27 |
async def return_to_constructor(user_hash: str):
|
28 |
"""Return to the constructor and reset user state and audio."""
|
29 |
-
from agent.
|
30 |
|
31 |
-
reset_user_state(user_hash)
|
32 |
await cleanup_music_session(user_hash)
|
33 |
# Generate a new hash to avoid stale state
|
34 |
new_hash = str(uuid.uuid4())
|
|
|
26 |
|
27 |
async def return_to_constructor(user_hash: str):
|
28 |
"""Return to the constructor and reset user state and audio."""
|
29 |
+
from agent.redis_state import reset_user_state
|
30 |
|
31 |
+
await reset_user_state(user_hash)
|
32 |
await cleanup_music_session(user_hash)
|
33 |
# Generate a new hash to avoid stale state
|
34 |
new_hash = str(uuid.uuid4())
|