Spaces:
Sleeping
Sleeping
Revert "feat: add flow blockers"
Browse filesThis reverts commit 29098903dfdf5988a4f7083a15a9fd9f799d5b88.
- src/agent/state.py +5 -10
- src/audio/audio_generator.py +16 -31
- src/images/image_generator.py +4 -5
- src/main.py +3 -6
src/agent/state.py
CHANGED
@@ -1,29 +1,24 @@
|
|
1 |
"""Simple in-memory user state storage."""
|
2 |
|
3 |
from typing import Dict
|
4 |
-
from threading import Lock
|
5 |
|
6 |
from agent.models import UserState
|
7 |
|
8 |
_USER_STATE: Dict[str, UserState] = {}
|
9 |
-
_STATE_LOCK = Lock()
|
10 |
|
11 |
|
12 |
def get_user_state(user_hash: str) -> UserState:
|
13 |
"""Return user state for the given id, creating it if necessary."""
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
return _USER_STATE[user_hash]
|
18 |
|
19 |
|
20 |
def set_user_state(user_hash: str, state: UserState) -> None:
|
21 |
"""Persist updated user state."""
|
22 |
-
|
23 |
-
_USER_STATE[user_hash] = state
|
24 |
|
25 |
|
26 |
def reset_user_state(user_hash: str) -> None:
|
27 |
"""Reset stored state for a user."""
|
28 |
-
|
29 |
-
_USER_STATE[user_hash] = UserState()
|
|
|
1 |
"""Simple in-memory user state storage."""
|
2 |
|
3 |
from typing import Dict
|
|
|
4 |
|
5 |
from agent.models import UserState
|
6 |
|
7 |
_USER_STATE: Dict[str, UserState] = {}
|
|
|
8 |
|
9 |
|
10 |
def get_user_state(user_hash: str) -> UserState:
|
11 |
"""Return user state for the given id, creating it if necessary."""
|
12 |
+
if user_hash not in _USER_STATE:
|
13 |
+
_USER_STATE[user_hash] = UserState()
|
14 |
+
return _USER_STATE[user_hash]
|
|
|
15 |
|
16 |
|
17 |
def set_user_state(user_hash: str, state: UserState) -> None:
|
18 |
"""Persist updated user state."""
|
19 |
+
_USER_STATE[user_hash] = state
|
|
|
20 |
|
21 |
|
22 |
def reset_user_state(user_hash: str) -> None:
|
23 |
"""Reset stored state for a user."""
|
24 |
+
_USER_STATE[user_hash] = UserState()
|
|
src/audio/audio_generator.py
CHANGED
@@ -7,16 +7,14 @@ import queue
|
|
7 |
import logging
|
8 |
import io
|
9 |
import time
|
10 |
-
from threading import Lock
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
14 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
|
15 |
|
16 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
17 |
-
|
18 |
-
|
19 |
-
return
|
20 |
async with (
|
21 |
client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
|
22 |
asyncio.TaskGroup() as tg,
|
@@ -34,19 +32,15 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
|
34 |
config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
|
35 |
)
|
36 |
await session.play()
|
37 |
-
logger.info(
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
'session': session,
|
43 |
-
'queue': queue.Queue()
|
44 |
-
}
|
45 |
|
46 |
async def change_music_tone(user_hash: str, new_tone):
|
47 |
logger.info(f"Changing music tone to {new_tone}")
|
48 |
-
|
49 |
-
session = sessions.get(user_hash, {}).get('session')
|
50 |
if not session:
|
51 |
logger.error(f"No session found for user hash {user_hash}")
|
52 |
return
|
@@ -75,36 +69,27 @@ async def receive_audio(session, user_hash):
|
|
75 |
break
|
76 |
|
77 |
sessions = {}
|
78 |
-
_SESSIONS_LOCK = Lock()
|
79 |
|
80 |
async def start_music_generation(user_hash: str, music_tone: str):
|
81 |
"""Start the music generation in a separate thread."""
|
82 |
await generate_music(user_hash, music_tone, receive_audio)
|
83 |
-
|
84 |
async def cleanup_music_session(user_hash: str):
|
85 |
-
|
86 |
-
session_info = sessions.get(user_hash)
|
87 |
-
if not session_info:
|
88 |
-
return
|
89 |
logger.info(f"Cleaning up music session for user hash {user_hash}")
|
90 |
-
session =
|
91 |
-
|
92 |
-
|
93 |
-
await session.close()
|
94 |
-
except Exception as exc: # noqa: BLE001
|
95 |
-
logger.error(f"Error closing music session: {exc}")
|
96 |
del sessions[user_hash]
|
97 |
|
98 |
|
99 |
def update_audio(user_hash):
|
100 |
"""Continuously stream audio from the queue as WAV bytes."""
|
101 |
while True:
|
102 |
-
|
103 |
-
session_info = sessions.get(user_hash)
|
104 |
-
if not session_info:
|
105 |
time.sleep(0.5)
|
106 |
continue
|
107 |
-
queue =
|
108 |
pcm_data = queue.get() # This is raw PCM audio bytes
|
109 |
|
110 |
if not isinstance(pcm_data, bytes):
|
@@ -132,4 +117,4 @@ def update_audio(user_hash):
|
|
132 |
wf.setframerate(SAMPLE_RATE)
|
133 |
wf.writeframes(pcm_data)
|
134 |
wav_bytes = wav_buffer.getvalue()
|
135 |
-
yield wav_bytes
|
|
|
7 |
import logging
|
8 |
import io
|
9 |
import time
|
|
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
|
14 |
|
15 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
16 |
+
if user_hash in sessions:
|
17 |
+
return
|
|
|
18 |
async with (
|
19 |
client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
|
20 |
asyncio.TaskGroup() as tg,
|
|
|
32 |
config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
|
33 |
)
|
34 |
await session.play()
|
35 |
+
logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
|
36 |
+
sessions[user_hash] = {
|
37 |
+
'session': session,
|
38 |
+
'queue': queue.Queue()
|
39 |
+
}
|
|
|
|
|
|
|
40 |
|
41 |
async def change_music_tone(user_hash: str, new_tone):
|
42 |
logger.info(f"Changing music tone to {new_tone}")
|
43 |
+
session = sessions.get(user_hash, {}).get('session')
|
|
|
44 |
if not session:
|
45 |
logger.error(f"No session found for user hash {user_hash}")
|
46 |
return
|
|
|
69 |
break
|
70 |
|
71 |
sessions = {}
|
|
|
72 |
|
73 |
async def start_music_generation(user_hash: str, music_tone: str):
|
74 |
"""Start the music generation in a separate thread."""
|
75 |
await generate_music(user_hash, music_tone, receive_audio)
|
76 |
+
|
77 |
async def cleanup_music_session(user_hash: str):
|
78 |
+
if user_hash in sessions:
|
|
|
|
|
|
|
79 |
logger.info(f"Cleaning up music session for user hash {user_hash}")
|
80 |
+
session = sessions[user_hash]['session']
|
81 |
+
await session.stop()
|
82 |
+
await session.close()
|
|
|
|
|
|
|
83 |
del sessions[user_hash]
|
84 |
|
85 |
|
86 |
def update_audio(user_hash):
|
87 |
"""Continuously stream audio from the queue as WAV bytes."""
|
88 |
while True:
|
89 |
+
if user_hash not in sessions:
|
|
|
|
|
90 |
time.sleep(0.5)
|
91 |
continue
|
92 |
+
queue = sessions[user_hash]['queue']
|
93 |
pcm_data = queue.get() # This is raw PCM audio bytes
|
94 |
|
95 |
if not isinstance(pcm_data, bytes):
|
|
|
117 |
wf.setframerate(SAMPLE_RATE)
|
118 |
wf.writeframes(pcm_data)
|
119 |
wav_bytes = wav_buffer.getvalue()
|
120 |
+
yield wav_bytes
|
src/images/image_generator.py
CHANGED
@@ -4,7 +4,6 @@ import os
|
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
6 |
from datetime import datetime
|
7 |
-
import uuid
|
8 |
from config import settings
|
9 |
import logging
|
10 |
import asyncio
|
@@ -64,9 +63,9 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
|
|
64 |
image_saved = False
|
65 |
for part in response.candidates[0].content.parts:
|
66 |
if part.inline_data is not None:
|
67 |
-
# Create a filename with timestamp
|
68 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
69 |
-
filename = f"gemini_{timestamp}
|
70 |
filepath = os.path.join(output_dir, filename)
|
71 |
|
72 |
# Save the image
|
@@ -131,9 +130,9 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
|
|
131 |
image_saved = False
|
132 |
for part in response.candidates[0].content.parts:
|
133 |
if part.inline_data is not None:
|
134 |
-
# Create a filename with timestamp
|
135 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
136 |
-
filename = f"gemini_modified_{timestamp}
|
137 |
filepath = os.path.join(output_dir, filename)
|
138 |
|
139 |
# Save the modified image
|
|
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
6 |
from datetime import datetime
|
|
|
7 |
from config import settings
|
8 |
import logging
|
9 |
import asyncio
|
|
|
63 |
image_saved = False
|
64 |
for part in response.candidates[0].content.parts:
|
65 |
if part.inline_data is not None:
|
66 |
+
# Create a filename with timestamp
|
67 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
68 |
+
filename = f"gemini_{timestamp}.png"
|
69 |
filepath = os.path.join(output_dir, filename)
|
70 |
|
71 |
# Save the image
|
|
|
130 |
image_saved = False
|
131 |
for part in response.candidates[0].content.parts:
|
132 |
if part.inline_data is not None:
|
133 |
+
# Create a filename with timestamp
|
134 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
135 |
+
filename = f"gemini_modified_{timestamp}.png"
|
136 |
filepath = os.path.join(output_dir, filename)
|
137 |
|
138 |
# Save the modified image
|
src/main.py
CHANGED
@@ -61,8 +61,7 @@ async def update_scene(user_hash: str, choice):
|
|
61 |
return (
|
62 |
gr.update(value=ending_text),
|
63 |
gr.update(value=ending_image),
|
64 |
-
gr.
|
65 |
-
allow_custom_value=True),
|
66 |
gr.update(value="", visible=False),
|
67 |
)
|
68 |
|
@@ -70,11 +69,10 @@ async def update_scene(user_hash: str, choice):
|
|
70 |
return (
|
71 |
scene["description"],
|
72 |
scene.get("image", ""),
|
73 |
-
gr.
|
74 |
choices=[ch["text"] for ch in scene.get("choices", [])],
|
75 |
label="What do you choose? (select an option or write your own)",
|
76 |
value=None,
|
77 |
-
allow_custom_value=True,
|
78 |
elem_classes=["choice-buttons"],
|
79 |
),
|
80 |
gr.update(value=""),
|
@@ -263,11 +261,10 @@ with gr.Blocks(
|
|
263 |
lines=3,
|
264 |
)
|
265 |
with gr.Column(elem_classes=["choice-area"]):
|
266 |
-
game_choices = gr.
|
267 |
choices=[],
|
268 |
label="What do you choose? (select an option or write your own)",
|
269 |
value=None,
|
270 |
-
allow_custom_value=True,
|
271 |
elem_classes=["choice-buttons"],
|
272 |
)
|
273 |
custom_choice = gr.Textbox(
|
|
|
61 |
return (
|
62 |
gr.update(value=ending_text),
|
63 |
gr.update(value=ending_image),
|
64 |
+
gr.Radio(choices=[], label="", value=None, visible=False),
|
|
|
65 |
gr.update(value="", visible=False),
|
66 |
)
|
67 |
|
|
|
69 |
return (
|
70 |
scene["description"],
|
71 |
scene.get("image", ""),
|
72 |
+
gr.Radio(
|
73 |
choices=[ch["text"] for ch in scene.get("choices", [])],
|
74 |
label="What do you choose? (select an option or write your own)",
|
75 |
value=None,
|
|
|
76 |
elem_classes=["choice-buttons"],
|
77 |
),
|
78 |
gr.update(value=""),
|
|
|
261 |
lines=3,
|
262 |
)
|
263 |
with gr.Column(elem_classes=["choice-area"]):
|
264 |
+
game_choices = gr.Radio(
|
265 |
choices=[],
|
266 |
label="What do you choose? (select an option or write your own)",
|
267 |
value=None,
|
|
|
268 |
elem_classes=["choice-buttons"],
|
269 |
)
|
270 |
custom_choice = gr.Textbox(
|