kikikita commited on
Commit
40a9697
·
1 Parent(s): 2909890

Revert "feat: add flow blockers"

Browse files

This reverts commit 29098903dfdf5988a4f7083a15a9fd9f799d5b88.

src/agent/state.py CHANGED
@@ -1,29 +1,24 @@
1
  """Simple in-memory user state storage."""
2
 
3
  from typing import Dict
4
- from threading import Lock
5
 
6
  from agent.models import UserState
7
 
8
  _USER_STATE: Dict[str, UserState] = {}
9
- _STATE_LOCK = Lock()
10
 
11
 
12
  def get_user_state(user_hash: str) -> UserState:
13
  """Return user state for the given id, creating it if necessary."""
14
- with _STATE_LOCK:
15
- if user_hash not in _USER_STATE:
16
- _USER_STATE[user_hash] = UserState()
17
- return _USER_STATE[user_hash]
18
 
19
 
20
  def set_user_state(user_hash: str, state: UserState) -> None:
21
  """Persist updated user state."""
22
- with _STATE_LOCK:
23
- _USER_STATE[user_hash] = state
24
 
25
 
26
  def reset_user_state(user_hash: str) -> None:
27
  """Reset stored state for a user."""
28
- with _STATE_LOCK:
29
- _USER_STATE[user_hash] = UserState()
 
1
  """Simple in-memory user state storage."""
2
 
3
  from typing import Dict
 
4
 
5
  from agent.models import UserState
6
 
7
  _USER_STATE: Dict[str, UserState] = {}
 
8
 
9
 
10
  def get_user_state(user_hash: str) -> UserState:
11
  """Return user state for the given id, creating it if necessary."""
12
+ if user_hash not in _USER_STATE:
13
+ _USER_STATE[user_hash] = UserState()
14
+ return _USER_STATE[user_hash]
 
15
 
16
 
17
  def set_user_state(user_hash: str, state: UserState) -> None:
18
  """Persist updated user state."""
19
+ _USER_STATE[user_hash] = state
 
20
 
21
 
22
  def reset_user_state(user_hash: str) -> None:
23
  """Reset stored state for a user."""
24
+ _USER_STATE[user_hash] = UserState()
 
src/audio/audio_generator.py CHANGED
@@ -7,16 +7,14 @@ import queue
7
  import logging
8
  import io
9
  import time
10
- from threading import Lock
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
  client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
15
 
16
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
17
- with _SESSIONS_LOCK:
18
- if user_hash in sessions:
19
- return
20
  async with (
21
  client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
22
  asyncio.TaskGroup() as tg,
@@ -34,19 +32,15 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
34
  config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
35
  )
36
  await session.play()
37
- logger.info(
38
- f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
39
- )
40
- with _SESSIONS_LOCK:
41
- sessions[user_hash] = {
42
- 'session': session,
43
- 'queue': queue.Queue()
44
- }
45
 
46
  async def change_music_tone(user_hash: str, new_tone):
47
  logger.info(f"Changing music tone to {new_tone}")
48
- with _SESSIONS_LOCK:
49
- session = sessions.get(user_hash, {}).get('session')
50
  if not session:
51
  logger.error(f"No session found for user hash {user_hash}")
52
  return
@@ -75,36 +69,27 @@ async def receive_audio(session, user_hash):
75
  break
76
 
77
  sessions = {}
78
- _SESSIONS_LOCK = Lock()
79
 
80
  async def start_music_generation(user_hash: str, music_tone: str):
81
  """Start the music generation in a separate thread."""
82
  await generate_music(user_hash, music_tone, receive_audio)
83
-
84
  async def cleanup_music_session(user_hash: str):
85
- with _SESSIONS_LOCK:
86
- session_info = sessions.get(user_hash)
87
- if not session_info:
88
- return
89
  logger.info(f"Cleaning up music session for user hash {user_hash}")
90
- session = session_info['session']
91
- try:
92
- await session.stop()
93
- await session.close()
94
- except Exception as exc: # noqa: BLE001
95
- logger.error(f"Error closing music session: {exc}")
96
  del sessions[user_hash]
97
 
98
 
99
  def update_audio(user_hash):
100
  """Continuously stream audio from the queue as WAV bytes."""
101
  while True:
102
- with _SESSIONS_LOCK:
103
- session_info = sessions.get(user_hash)
104
- if not session_info:
105
  time.sleep(0.5)
106
  continue
107
- queue = session_info['queue']
108
  pcm_data = queue.get() # This is raw PCM audio bytes
109
 
110
  if not isinstance(pcm_data, bytes):
@@ -132,4 +117,4 @@ def update_audio(user_hash):
132
  wf.setframerate(SAMPLE_RATE)
133
  wf.writeframes(pcm_data)
134
  wav_bytes = wav_buffer.getvalue()
135
- yield wav_bytes
 
7
  import logging
8
  import io
9
  import time
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
  client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
14
 
15
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
16
+ if user_hash in sessions:
17
+ return
 
18
  async with (
19
  client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
20
  asyncio.TaskGroup() as tg,
 
32
  config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
33
  )
34
  await session.play()
35
+ logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
36
+ sessions[user_hash] = {
37
+ 'session': session,
38
+ 'queue': queue.Queue()
39
+ }
 
 
 
40
 
41
  async def change_music_tone(user_hash: str, new_tone):
42
  logger.info(f"Changing music tone to {new_tone}")
43
+ session = sessions.get(user_hash, {}).get('session')
 
44
  if not session:
45
  logger.error(f"No session found for user hash {user_hash}")
46
  return
 
69
  break
70
 
71
  sessions = {}
 
72
 
73
  async def start_music_generation(user_hash: str, music_tone: str):
74
  """Start the music generation in a separate thread."""
75
  await generate_music(user_hash, music_tone, receive_audio)
76
+
77
  async def cleanup_music_session(user_hash: str):
78
+ if user_hash in sessions:
 
 
 
79
  logger.info(f"Cleaning up music session for user hash {user_hash}")
80
+ session = sessions[user_hash]['session']
81
+ await session.stop()
82
+ await session.close()
 
 
 
83
  del sessions[user_hash]
84
 
85
 
86
  def update_audio(user_hash):
87
  """Continuously stream audio from the queue as WAV bytes."""
88
  while True:
89
+ if user_hash not in sessions:
 
 
90
  time.sleep(0.5)
91
  continue
92
+ queue = sessions[user_hash]['queue']
93
  pcm_data = queue.get() # This is raw PCM audio bytes
94
 
95
  if not isinstance(pcm_data, bytes):
 
117
  wf.setframerate(SAMPLE_RATE)
118
  wf.writeframes(pcm_data)
119
  wav_bytes = wav_buffer.getvalue()
120
+ yield wav_bytes
src/images/image_generator.py CHANGED
@@ -4,7 +4,6 @@ import os
4
  from PIL import Image
5
  from io import BytesIO
6
  from datetime import datetime
7
- import uuid
8
  from config import settings
9
  import logging
10
  import asyncio
@@ -64,9 +63,9 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
64
  image_saved = False
65
  for part in response.candidates[0].content.parts:
66
  if part.inline_data is not None:
67
- # Create a filename with timestamp and uuid to avoid collisions
68
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
69
- filename = f"gemini_{timestamp}_{uuid.uuid4().hex}.png"
70
  filepath = os.path.join(output_dir, filename)
71
 
72
  # Save the image
@@ -131,9 +130,9 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
131
  image_saved = False
132
  for part in response.candidates[0].content.parts:
133
  if part.inline_data is not None:
134
- # Create a filename with timestamp and uuid to avoid collisions
135
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
136
- filename = f"gemini_modified_{timestamp}_{uuid.uuid4().hex}.png"
137
  filepath = os.path.join(output_dir, filename)
138
 
139
  # Save the modified image
 
4
  from PIL import Image
5
  from io import BytesIO
6
  from datetime import datetime
 
7
  from config import settings
8
  import logging
9
  import asyncio
 
63
  image_saved = False
64
  for part in response.candidates[0].content.parts:
65
  if part.inline_data is not None:
66
+ # Create a filename with timestamp
67
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
68
+ filename = f"gemini_{timestamp}.png"
69
  filepath = os.path.join(output_dir, filename)
70
 
71
  # Save the image
 
130
  image_saved = False
131
  for part in response.candidates[0].content.parts:
132
  if part.inline_data is not None:
133
+ # Create a filename with timestamp
134
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
135
+ filename = f"gemini_modified_{timestamp}.png"
136
  filepath = os.path.join(output_dir, filename)
137
 
138
  # Save the modified image
src/main.py CHANGED
@@ -61,8 +61,7 @@ async def update_scene(user_hash: str, choice):
61
  return (
62
  gr.update(value=ending_text),
63
  gr.update(value=ending_image),
64
- gr.Dropdown(choices=[], label="", value=None, visible=False,
65
- allow_custom_value=True),
66
  gr.update(value="", visible=False),
67
  )
68
 
@@ -70,11 +69,10 @@ async def update_scene(user_hash: str, choice):
70
  return (
71
  scene["description"],
72
  scene.get("image", ""),
73
- gr.Dropdown(
74
  choices=[ch["text"] for ch in scene.get("choices", [])],
75
  label="What do you choose? (select an option or write your own)",
76
  value=None,
77
- allow_custom_value=True,
78
  elem_classes=["choice-buttons"],
79
  ),
80
  gr.update(value=""),
@@ -263,11 +261,10 @@ with gr.Blocks(
263
  lines=3,
264
  )
265
  with gr.Column(elem_classes=["choice-area"]):
266
- game_choices = gr.Dropdown(
267
  choices=[],
268
  label="What do you choose? (select an option or write your own)",
269
  value=None,
270
- allow_custom_value=True,
271
  elem_classes=["choice-buttons"],
272
  )
273
  custom_choice = gr.Textbox(
 
61
  return (
62
  gr.update(value=ending_text),
63
  gr.update(value=ending_image),
64
+ gr.Radio(choices=[], label="", value=None, visible=False),
 
65
  gr.update(value="", visible=False),
66
  )
67
 
 
69
  return (
70
  scene["description"],
71
  scene.get("image", ""),
72
+ gr.Radio(
73
  choices=[ch["text"] for ch in scene.get("choices", [])],
74
  label="What do you choose? (select an option or write your own)",
75
  value=None,
 
76
  elem_classes=["choice-buttons"],
77
  ),
78
  gr.update(value=""),
 
261
  lines=3,
262
  )
263
  with gr.Column(elem_classes=["choice-area"]):
264
+ game_choices = gr.Radio(
265
  choices=[],
266
  label="What do you choose? (select an option or write your own)",
267
  value=None,
 
268
  elem_classes=["choice-buttons"],
269
  )
270
  custom_choice = gr.Textbox(