kikikita commited on
Commit
dcf824e
·
1 Parent(s): 40a9697

feat: implement Redis-backed user state management and update state handling

Browse files
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import time
3
+ import atexit
4
+
5
+ # Launch local Redis server
6
+ redis_process = subprocess.Popen([
7
+ "redis-server",
8
+ "--save", "",
9
+ "--appendonly", "no",
10
+ ])
11
+
12
+ # Ensure Redis shuts down on exit
13
+ atexit.register(redis_process.terminate)
14
+
15
+ # Give the server a moment to start
16
+ time.sleep(0.5)
17
+
18
+ # Import and run the Gradio app
19
+ import src.main # noqa: F401
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ redis-server
requirements.txt CHANGED
@@ -14,4 +14,7 @@ langchain-community==0.3.16
14
  langchain-google-genai==2.1.4
15
  pydantic-core==2.23.4
16
  pydantic-settings==2.7.1
17
- pydantic==2.9.2
 
 
 
 
14
  langchain-google-genai==2.1.4
15
  pydantic-core==2.23.4
16
  pydantic-settings==2.7.1
17
+ pydantic==2.9.2
18
+ redis[hiredis]>=5
19
+ aioredis>=2
20
+ msgpack
src/agent/llm_graph.py CHANGED
@@ -14,7 +14,7 @@ from agent.tools import (
14
  generate_story_frame,
15
  update_state_with_choice,
16
  )
17
- from agent.state import get_user_state
18
  from audio.audio_generator import change_music_tone
19
  logger = logging.getLogger(__name__)
20
 
@@ -75,7 +75,7 @@ async def node_init_game(state: GraphState) -> GraphState:
75
 
76
  async def node_player_step(state: GraphState) -> GraphState:
77
  logger.debug("[Graph] node_player_step state: %s", state)
78
- user_state = get_user_state(state.user_hash)
79
  scene_id = user_state.current_scene_id
80
  if state.choice_text:
81
  await update_state_with_choice.ainvoke(
 
14
  generate_story_frame,
15
  update_state_with_choice,
16
  )
17
+ from agent.redis_state import get_user_state
18
  from audio.audio_generator import change_music_tone
19
  logger = logging.getLogger(__name__)
20
 
 
75
 
76
  async def node_player_step(state: GraphState) -> GraphState:
77
  logger.debug("[Graph] node_player_step state: %s", state)
78
+ user_state = await get_user_state(state.user_hash)
79
  scene_id = user_state.current_scene_id
80
  if state.choice_text:
81
  await update_state_with_choice.ainvoke(
src/agent/redis_state.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Async Redis-backed user state storage."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import msgpack
7
+ import redis.asyncio as redis
8
+
9
+ from agent.models import UserState
10
+
11
+
12
+ class UserRepository:
13
+ """Repository for storing UserState objects in Redis."""
14
+
15
+ def __init__(self, redis_url: str = "redis://localhost") -> None:
16
+ self.redis = redis.from_url(redis_url)
17
+
18
+ async def get(self, user_id: str) -> UserState:
19
+ """Return user state for the given id, creating it if absent."""
20
+ key = f"llmgamehub:{user_id}"
21
+ data = await self.redis.hget(key, "data")
22
+ if data is None:
23
+ return UserState()
24
+ state_dict = msgpack.unpackb(data, raw=False)
25
+ return UserState.parse_obj(state_dict)
26
+
27
+ async def set(self, user_id: str, state: UserState) -> None:
28
+ """Persist updated user state."""
29
+ key = f"llmgamehub:{user_id}"
30
+ packed = msgpack.packb(json.loads(state.json()))
31
+ await self.redis.hset(key, mapping={"data": packed})
32
+
33
+ async def reset(self, user_id: str) -> None:
34
+ """Remove stored state for a user."""
35
+ key = f"llmgamehub:{user_id}"
36
+ await self.redis.delete(key)
37
+
38
+
39
+ _repo = UserRepository()
40
+
41
+
42
+ async def get_user_state(user_hash: str) -> UserState:
43
+ return await _repo.get(user_hash)
44
+
45
+
46
+ async def set_user_state(user_hash: str, state: UserState) -> None:
47
+ await _repo.set(user_hash, state)
48
+
49
+
50
+ async def reset_user_state(user_hash: str) -> None:
51
+ await _repo.reset(user_hash)
src/agent/runner.py CHANGED
@@ -10,7 +10,7 @@ from agent.tools import generate_scene_image
10
 
11
  from agent.llm_graph import GraphState, llm_game_graph
12
  from agent.models import UserState
13
- from agent.state import get_user_state
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -38,7 +38,7 @@ async def process_step(
38
 
39
  final_state = await llm_game_graph.ainvoke(asdict(graph_state))
40
 
41
- user_state: UserState = get_user_state(user_hash)
42
  response: Dict = {}
43
 
44
  ending = final_state.get("ending")
 
10
 
11
  from agent.llm_graph import GraphState, llm_game_graph
12
  from agent.models import UserState
13
+ from agent.redis_state import get_user_state
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
38
 
39
  final_state = await llm_game_graph.ainvoke(asdict(graph_state))
40
 
41
+ user_state: UserState = await get_user_state(user_hash)
42
  response: Dict = {}
43
 
44
  ending = final_state.get("ending")
src/agent/tools.py CHANGED
@@ -17,7 +17,7 @@ from agent.models import (
17
  UserChoice,
18
  )
19
  from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT
20
- from agent.state import get_user_state, set_user_state
21
  from images.image_generator import modify_image, generate_image
22
  from agent.image_agent import ChangeScene
23
 
@@ -53,9 +53,9 @@ async def generate_story_frame(
53
  character=character,
54
  genre=genre,
55
  )
56
- state = get_user_state(user_hash)
57
  state.story_frame = story_frame
58
- set_user_state(user_hash, state)
59
  return story_frame.dict()
60
 
61
 
@@ -65,7 +65,7 @@ async def generate_scene(
65
  last_choice: Annotated[str, "Last user choice"],
66
  ) -> Annotated[Dict, "Generated scene"]:
67
  """Generate a new scene based on the current user state."""
68
- state = get_user_state(user_hash)
69
  if not state.story_frame:
70
  return _err("Story frame not initialized")
71
  llm = create_llm().with_structured_output(SceneLLM)
@@ -98,7 +98,7 @@ async def generate_scene(
98
  )
99
  state.current_scene_id = scene_id
100
  state.scenes[scene_id] = scene
101
- set_user_state(user_hash, state)
102
  return scene.dict()
103
 
104
 
@@ -119,10 +119,10 @@ async def generate_scene_image(
119
  # for now always modify the image to avoid the generating an update in a completely wrong style
120
  else modify_image(current_image, change_scene.scene_description)
121
  )
122
- state = get_user_state(user_hash)
123
  if scene_id in state.scenes:
124
  state.scenes[scene_id].image = image_path
125
- set_user_state(user_hash, state)
126
  return image_path
127
  except Exception as exc: # noqa: BLE001
128
  return _err(str(exc))
@@ -137,7 +137,7 @@ async def update_state_with_choice(
137
  """Record the player's choice in the state."""
138
  import datetime
139
 
140
- state = get_user_state(user_hash)
141
  state.user_choices.append(
142
  UserChoice(
143
  scene_id=scene_id,
@@ -145,7 +145,7 @@ async def update_state_with_choice(
145
  timestamp=datetime.datetime.utcnow().isoformat(),
146
  )
147
  )
148
- set_user_state(user_hash, state)
149
  return state.dict()
150
 
151
 
@@ -154,7 +154,7 @@ async def check_ending(
154
  user_hash: Annotated[str, "User session ID"],
155
  ) -> Annotated[Dict, "Ending check result"]:
156
  """Check whether an ending has been reached."""
157
- state = get_user_state(user_hash)
158
  if not state.story_frame:
159
  return _err("No story frame")
160
  llm = create_llm().with_structured_output(EndingCheckResult)
@@ -166,6 +166,6 @@ async def check_ending(
166
  resp: EndingCheckResult = await llm.ainvoke(prompt)
167
  if resp.ending_reached and resp.ending:
168
  state.ending = resp.ending
169
- set_user_state(user_hash, state)
170
  return {"ending_reached": True, "ending": resp.ending.dict()}
171
  return {"ending_reached": False}
 
17
  UserChoice,
18
  )
19
  from agent.prompts import ENDING_CHECK_PROMPT, SCENE_PROMPT, STORY_FRAME_PROMPT
20
+ from agent.redis_state import get_user_state, set_user_state
21
  from images.image_generator import modify_image, generate_image
22
  from agent.image_agent import ChangeScene
23
 
 
53
  character=character,
54
  genre=genre,
55
  )
56
+ state = await get_user_state(user_hash)
57
  state.story_frame = story_frame
58
+ await set_user_state(user_hash, state)
59
  return story_frame.dict()
60
 
61
 
 
65
  last_choice: Annotated[str, "Last user choice"],
66
  ) -> Annotated[Dict, "Generated scene"]:
67
  """Generate a new scene based on the current user state."""
68
+ state = await get_user_state(user_hash)
69
  if not state.story_frame:
70
  return _err("Story frame not initialized")
71
  llm = create_llm().with_structured_output(SceneLLM)
 
98
  )
99
  state.current_scene_id = scene_id
100
  state.scenes[scene_id] = scene
101
+ await set_user_state(user_hash, state)
102
  return scene.dict()
103
 
104
 
 
119
  # for now always modify the image to avoid the generating an update in a completely wrong style
120
  else modify_image(current_image, change_scene.scene_description)
121
  )
122
+ state = await get_user_state(user_hash)
123
  if scene_id in state.scenes:
124
  state.scenes[scene_id].image = image_path
125
+ await set_user_state(user_hash, state)
126
  return image_path
127
  except Exception as exc: # noqa: BLE001
128
  return _err(str(exc))
 
137
  """Record the player's choice in the state."""
138
  import datetime
139
 
140
+ state = await get_user_state(user_hash)
141
  state.user_choices.append(
142
  UserChoice(
143
  scene_id=scene_id,
 
145
  timestamp=datetime.datetime.utcnow().isoformat(),
146
  )
147
  )
148
+ await set_user_state(user_hash, state)
149
  return state.dict()
150
 
151
 
 
154
  user_hash: Annotated[str, "User session ID"],
155
  ) -> Annotated[Dict, "Ending check result"]:
156
  """Check whether an ending has been reached."""
157
+ state = await get_user_state(user_hash)
158
  if not state.story_frame:
159
  return _err("No story frame")
160
  llm = create_llm().with_structured_output(EndingCheckResult)
 
166
  resp: EndingCheckResult = await llm.ainvoke(prompt)
167
  if resp.ending_reached and resp.ending:
168
  state.ending = resp.ending
169
+ await set_user_state(user_hash, state)
170
  return {"ending_reached": True, "ending": resp.ending.dict()}
171
  return {"ending_reached": False}
src/main.py CHANGED
@@ -26,9 +26,9 @@ logger = logging.getLogger(__name__)
26
 
27
  async def return_to_constructor(user_hash: str):
28
  """Return to the constructor and reset user state and audio."""
29
- from agent.state import reset_user_state
30
 
31
- reset_user_state(user_hash)
32
  await cleanup_music_session(user_hash)
33
  # Generate a new hash to avoid stale state
34
  new_hash = str(uuid.uuid4())
 
26
 
27
  async def return_to_constructor(user_hash: str):
28
  """Return to the constructor and reset user state and audio."""
29
+ from agent.redis_state import reset_user_state
30
 
31
+ await reset_user_state(user_hash)
32
  await cleanup_music_session(user_hash)
33
  # Generate a new hash to avoid stale state
34
  new_hash = str(uuid.uuid4())