Georgii Savin commited on
Commit
e6f8aa9
·
unverified ·
2 Parent(s): 0a18f7d 45fabe9

Merge pull request #4 from DeltaZN/feature/unique-session-ids

Browse files

feat: implement Google API key management and refactor client usage i…

src/agent/llm.py CHANGED
@@ -4,31 +4,17 @@ import logging
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
 
6
  from config import settings
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
- _API_KEYS: list[str] = []
11
- _current_key_idx = 0
12
  MODEL_NAME = "gemini-2.5-flash-preview-05-20"
13
 
14
 
15
  def _get_api_key() -> str:
16
- """Return an API key using round-robin selection."""
17
- global _API_KEYS, _current_key_idx
18
-
19
- if not _API_KEYS:
20
- keys_str = settings.gemini_api_key.get_secret_value()
21
- if keys_str:
22
- _API_KEYS = [k.strip() for k in keys_str.split(",") if k.strip()]
23
- if not _API_KEYS:
24
- msg = "Google API keys are not configured or invalid"
25
- logger.error(msg)
26
- raise ValueError(msg)
27
-
28
- key = _API_KEYS[_current_key_idx]
29
- _current_key_idx = (_current_key_idx + 1) % len(_API_KEYS)
30
- logger.debug("Using Google API key index %s", _current_key_idx)
31
- return key
32
 
33
 
34
  def create_llm(
 
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
 
6
  from config import settings
7
+ from services.google import ApiKeyPool
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
+ _pool = ApiKeyPool()
 
12
  MODEL_NAME = "gemini-2.5-flash-preview-05-20"
13
 
14
 
15
  def _get_api_key() -> str:
16
+ """Return an API key using round-robin selection in a thread-safe way."""
17
+ return _pool.get_key_sync()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def create_llm(
src/agent/redis_state.py CHANGED
@@ -5,9 +5,12 @@ from __future__ import annotations
5
  import json
6
  import msgpack
7
  import redis.asyncio as redis
 
8
 
9
  from agent.models import UserState
10
 
 
 
11
 
12
  class UserRepository:
13
  """Repository for storing UserState objects in Redis."""
@@ -18,6 +21,7 @@ class UserRepository:
18
  async def get(self, user_id: str) -> UserState:
19
  """Return user state for the given id, creating it if absent."""
20
  key = f"llmgamehub:{user_id}"
 
21
  data = await self.redis.hget(key, "data")
22
  if data is None:
23
  return UserState()
@@ -27,12 +31,14 @@ class UserRepository:
27
  async def set(self, user_id: str, state: UserState) -> None:
28
  """Persist updated user state."""
29
  key = f"llmgamehub:{user_id}"
 
30
  packed = msgpack.packb(json.loads(state.json()))
31
  await self.redis.hset(key, mapping={"data": packed})
32
 
33
  async def reset(self, user_id: str) -> None:
34
  """Remove stored state for a user."""
35
  key = f"llmgamehub:{user_id}"
 
36
  await self.redis.delete(key)
37
 
38
 
@@ -40,12 +46,15 @@ _repo = UserRepository()
40
 
41
 
42
  async def get_user_state(user_hash: str) -> UserState:
 
43
  return await _repo.get(user_hash)
44
 
45
 
46
  async def set_user_state(user_hash: str, state: UserState) -> None:
 
47
  await _repo.set(user_hash, state)
48
 
49
 
50
  async def reset_user_state(user_hash: str) -> None:
 
51
  await _repo.reset(user_hash)
 
5
  import json
6
  import msgpack
7
  import redis.asyncio as redis
8
+ import logging
9
 
10
  from agent.models import UserState
11
 
12
+ logger = logging.getLogger(__name__)
13
+
14
 
15
  class UserRepository:
16
  """Repository for storing UserState objects in Redis."""
 
21
  async def get(self, user_id: str) -> UserState:
22
  """Return user state for the given id, creating it if absent."""
23
  key = f"llmgamehub:{user_id}"
24
+ logger.debug("Fetching state for %s", user_id)
25
  data = await self.redis.hget(key, "data")
26
  if data is None:
27
  return UserState()
 
31
  async def set(self, user_id: str, state: UserState) -> None:
32
  """Persist updated user state."""
33
  key = f"llmgamehub:{user_id}"
34
+ logger.debug("Saving state for %s", user_id)
35
  packed = msgpack.packb(json.loads(state.json()))
36
  await self.redis.hset(key, mapping={"data": packed})
37
 
38
  async def reset(self, user_id: str) -> None:
39
  """Remove stored state for a user."""
40
  key = f"llmgamehub:{user_id}"
41
+ logger.debug("Resetting state for %s", user_id)
42
  await self.redis.delete(key)
43
 
44
 
 
46
 
47
 
48
  async def get_user_state(user_hash: str) -> UserState:
49
+ logger.debug("get_user_state for %s", user_hash)
50
  return await _repo.get(user_hash)
51
 
52
 
53
  async def set_user_state(user_hash: str, state: UserState) -> None:
54
+ logger.debug("set_user_state for %s", user_hash)
55
  await _repo.set(user_hash, state)
56
 
57
 
58
  async def reset_user_state(user_hash: str) -> None:
59
+ logger.debug("reset_user_state for %s", user_hash)
60
  await _repo.reset(user_hash)
src/audio/audio_generator.py CHANGED
@@ -1,59 +1,71 @@
1
  import asyncio
2
- from google import genai
3
  from google.genai import types
4
- from config import settings
5
  import wave
6
  import queue
7
  import logging
8
  import io
9
  import time
 
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
- client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
 
14
 
15
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
16
  if user_hash in sessions:
17
- logger.info(f"Music generation already started for user hash {user_hash}, skipping new generation")
18
- return
19
- async with (
20
- client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
21
- asyncio.TaskGroup() as tg,
22
- ):
23
- # Set up task to receive server messages.
24
- tg.create_task(receive_audio(session, user_hash))
25
-
26
- # Send initial prompts and config
27
- await session.set_weighted_prompts(
28
- prompts=[
29
- types.WeightedPrompt(text=music_tone, weight=1.0),
30
- ]
31
- )
32
- await session.set_music_generation_config(
33
- config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
34
  )
35
- await session.play()
36
- logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
37
- sessions[user_hash] = {
38
- 'session': session,
39
- 'queue': queue.Queue()
40
- }
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  async def change_music_tone(user_hash: str, new_tone):
43
  logger.info(f"Changing music tone to {new_tone}")
44
- session = sessions.get(user_hash, {}).get('session')
45
  if not session:
46
  logger.error(f"No session found for user hash {user_hash}")
47
  return
48
- await session.set_weighted_prompts(
49
- prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
 
 
 
50
  )
51
-
52
 
53
  SAMPLE_RATE = 48000
54
  NUM_CHANNELS = 2 # Stereo
55
  SAMPLE_WIDTH = 2 # 16-bit audio -> 2 bytes per sample
56
 
 
57
  async def receive_audio(session, user_hash):
58
  """Process incoming audio from the music generation."""
59
  while True:
@@ -61,7 +73,7 @@ async def receive_audio(session, user_hash):
61
  async for message in session.receive():
62
  if message.server_content and message.server_content.audio_chunks:
63
  audio_data = message.server_content.audio_chunks[0].data
64
- queue = sessions[user_hash]['queue']
65
  # audio_data is already bytes (raw PCM)
66
  await asyncio.to_thread(queue.put, audio_data)
67
  await asyncio.sleep(10**-12)
@@ -69,42 +81,47 @@ async def receive_audio(session, user_hash):
69
  logger.error(f"Error in receive_audio: {e}")
70
  break
71
 
 
72
  sessions = {}
73
 
 
74
  async def start_music_generation(user_hash: str, music_tone: str):
75
  """Start the music generation in a separate thread."""
76
  await generate_music(user_hash, music_tone, receive_audio)
77
-
 
78
  async def cleanup_music_session(user_hash: str):
79
  if user_hash in sessions:
80
  logger.info(f"Cleaning up music session for user hash {user_hash}")
81
- session = sessions[user_hash]['session']
82
- await session.stop()
83
- await session.close()
84
  del sessions[user_hash]
85
-
86
 
87
  def update_audio(user_hash):
88
  """Continuously stream audio from the queue as WAV bytes."""
89
  if user_hash == "":
90
  return
91
-
92
  logger.info(f"Starting audio update loop for user hash: {user_hash}")
93
  while True:
94
  if user_hash not in sessions:
95
  time.sleep(0.5)
96
  continue
97
- queue = sessions[user_hash]['queue']
98
- pcm_data = queue.get() # This is raw PCM audio bytes
99
-
100
  if not isinstance(pcm_data, bytes):
101
- logger.warning(f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping.")
 
 
102
  continue
103
 
104
  # Lyria provides stereo, 16-bit PCM at 48kHz.
105
  # Ensure the number of bytes is consistent with stereo 16-bit audio.
106
  # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
107
- # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
108
  # it might indicate an incomplete chunk or an issue.
109
  bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
110
  if len(pcm_data) % bytes_per_frame != 0:
@@ -113,13 +130,13 @@ def update_audio(user_hash):
113
  f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
114
  )
115
  # Depending on strictness, you might want to skip this chunk:
116
- # continue
117
 
118
  wav_buffer = io.BytesIO()
119
- with wave.open(wav_buffer, 'wb') as wf:
120
  wf.setnchannels(NUM_CHANNELS)
121
- wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
122
  wf.setframerate(SAMPLE_RATE)
123
  wf.writeframes(pcm_data)
124
  wav_bytes = wav_buffer.getvalue()
125
- yield wav_bytes
 
1
  import asyncio
 
2
  from google.genai import types
 
3
  import wave
4
  import queue
5
  import logging
6
  import io
7
  import time
8
+ from config import settings
9
+ from services.google import GoogleClientFactory
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+
14
+
15
 
16
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
17
  if user_hash in sessions:
18
+ logger.info(
19
+ f"Music generation already started for user hash {user_hash}, skipping new generation"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
+ return
22
+ async with GoogleClientFactory.audio() as client:
23
+ async with (
24
+ client.live.music.connect(model="models/lyria-realtime-exp") as session,
25
+ asyncio.TaskGroup() as tg,
26
+ ):
27
+ # Set up task to receive server messages.
28
+ tg.create_task(receive_audio(session, user_hash))
29
+
30
+ # Send initial prompts and config
31
+ await asyncio.wait_for(
32
+ session.set_weighted_prompts(
33
+ prompts=[types.WeightedPrompt(text=music_tone, weight=1.0)]
34
+ ),
35
+ settings.request_timeout,
36
+ )
37
+ await asyncio.wait_for(
38
+ session.set_music_generation_config(
39
+ config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
40
+ ),
41
+ settings.request_timeout,
42
+ )
43
+ await asyncio.wait_for(session.play(), settings.request_timeout)
44
+ logger.info(
45
+ f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
46
+ )
47
+ sessions[user_hash] = {"session": session, "queue": queue.Queue()}
48
+
49
+
50
  async def change_music_tone(user_hash: str, new_tone):
51
  logger.info(f"Changing music tone to {new_tone}")
52
+ session = sessions.get(user_hash, {}).get("session")
53
  if not session:
54
  logger.error(f"No session found for user hash {user_hash}")
55
  return
56
+ await asyncio.wait_for(
57
+ session.set_weighted_prompts(
58
+ prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
59
+ ),
60
+ settings.request_timeout,
61
  )
62
+
63
 
64
  SAMPLE_RATE = 48000
65
  NUM_CHANNELS = 2 # Stereo
66
  SAMPLE_WIDTH = 2 # 16-bit audio -> 2 bytes per sample
67
 
68
+
69
  async def receive_audio(session, user_hash):
70
  """Process incoming audio from the music generation."""
71
  while True:
 
73
  async for message in session.receive():
74
  if message.server_content and message.server_content.audio_chunks:
75
  audio_data = message.server_content.audio_chunks[0].data
76
+ queue = sessions[user_hash]["queue"]
77
  # audio_data is already bytes (raw PCM)
78
  await asyncio.to_thread(queue.put, audio_data)
79
  await asyncio.sleep(10**-12)
 
81
  logger.error(f"Error in receive_audio: {e}")
82
  break
83
 
84
+
85
  sessions = {}
86
 
87
+
88
  async def start_music_generation(user_hash: str, music_tone: str):
89
  """Start the music generation in a separate thread."""
90
  await generate_music(user_hash, music_tone, receive_audio)
91
+
92
+
93
  async def cleanup_music_session(user_hash: str):
94
  if user_hash in sessions:
95
  logger.info(f"Cleaning up music session for user hash {user_hash}")
96
+ session = sessions[user_hash]["session"]
97
+ await asyncio.wait_for(session.stop(), settings.request_timeout)
98
+ await asyncio.wait_for(session.close(), settings.request_timeout)
99
  del sessions[user_hash]
100
+
101
 
102
  def update_audio(user_hash):
103
  """Continuously stream audio from the queue as WAV bytes."""
104
  if user_hash == "":
105
  return
106
+
107
  logger.info(f"Starting audio update loop for user hash: {user_hash}")
108
  while True:
109
  if user_hash not in sessions:
110
  time.sleep(0.5)
111
  continue
112
+ queue = sessions[user_hash]["queue"]
113
+ pcm_data = queue.get() # This is raw PCM audio bytes
114
+
115
  if not isinstance(pcm_data, bytes):
116
+ logger.warning(
117
+ f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
118
+ )
119
  continue
120
 
121
  # Lyria provides stereo, 16-bit PCM at 48kHz.
122
  # Ensure the number of bytes is consistent with stereo 16-bit audio.
123
  # Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
124
+ # If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
125
  # it might indicate an incomplete chunk or an issue.
126
  bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
127
  if len(pcm_data) % bytes_per_frame != 0:
 
130
  f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
131
  )
132
  # Depending on strictness, you might want to skip this chunk:
133
+ # continue
134
 
135
  wav_buffer = io.BytesIO()
136
+ with wave.open(wav_buffer, "wb") as wf:
137
  wf.setnchannels(NUM_CHANNELS)
138
+ wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
139
  wf.setframerate(SAMPLE_RATE)
140
  wf.writeframes(pcm_data)
141
  wav_bytes = wav_buffer.getvalue()
142
+ yield wav_bytes
src/config.py CHANGED
@@ -29,6 +29,6 @@ class AppSettings(BaseAppSettings):
29
  top_p: float = 0.95
30
  temperature: float = 0.5
31
  pregenerate_next_scene: bool = True
32
-
33
 
34
  settings = AppSettings()
 
29
  top_p: float = 0.95
30
  temperature: float = 0.5
31
  pregenerate_next_scene: bool = True
32
+ request_timeout: int = 20
33
 
34
  settings = AppSettings()
src/images/image_generator.py CHANGED
@@ -1,18 +1,16 @@
1
- from google import genai
2
  from google.genai import types
3
  import os
4
  from PIL import Image
5
  from io import BytesIO
6
  from datetime import datetime
7
- from config import settings
8
  import logging
9
  import asyncio
10
  import gradio as gr
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
- client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
15
-
16
  safety_settings = [
17
  types.SafetySetting(
18
  category="HARM_CATEGORY_HARASSMENT",
@@ -50,14 +48,18 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
50
  logger.info(f"Generating image with prompt: {prompt}")
51
 
52
  try:
53
- response = await client.models.generate_content(
54
- model="gemini-2.0-flash-preview-image-generation",
55
- contents=prompt,
56
- config=types.GenerateContentConfig(
57
- response_modalities=["TEXT", "IMAGE"],
58
- safety_settings=safety_settings,
59
- ),
60
- )
 
 
 
 
61
 
62
  # Process the response parts
63
  image_saved = False
@@ -108,23 +110,23 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
108
  logger.error(f"Error: Image file not found at {image_path}")
109
  return None
110
 
111
- key = settings.gemini_api_key.get_secret_value()
112
-
113
- client = genai.Client(api_key=key).aio
114
-
115
  try:
116
- # Load the input image
117
- input_image = Image.open(image_path)
118
-
119
- # Make the API call with both text and image
120
- response = await client.models.generate_content(
121
- model="gemini-2.0-flash-preview-image-generation",
122
- contents=[modification_prompt, input_image],
123
- config=types.GenerateContentConfig(
124
- response_modalities=["TEXT", "IMAGE"],
125
- safety_settings=safety_settings,
126
- ),
127
- )
 
 
 
 
128
 
129
  # Process the response parts
130
  image_saved = False
 
 
1
  from google.genai import types
2
  import os
3
  from PIL import Image
4
  from io import BytesIO
5
  from datetime import datetime
 
6
  import logging
7
  import asyncio
8
  import gradio as gr
9
+ from config import settings
10
+ from services.google import GoogleClientFactory
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
 
14
  safety_settings = [
15
  types.SafetySetting(
16
  category="HARM_CATEGORY_HARASSMENT",
 
48
  logger.info(f"Generating image with prompt: {prompt}")
49
 
50
  try:
51
+ async with GoogleClientFactory.image() as client:
52
+ response = await asyncio.wait_for(
53
+ client.models.generate_content(
54
+ model="gemini-2.0-flash-preview-image-generation",
55
+ contents=prompt,
56
+ config=types.GenerateContentConfig(
57
+ response_modalities=["TEXT", "IMAGE"],
58
+ safety_settings=safety_settings,
59
+ ),
60
+ ),
61
+ settings.request_timeout,
62
+ )
63
 
64
  # Process the response parts
65
  image_saved = False
 
110
  logger.error(f"Error: Image file not found at {image_path}")
111
  return None
112
 
 
 
 
 
113
  try:
114
+ async with GoogleClientFactory.image() as client:
115
+ # Load the input image
116
+ input_image = Image.open(image_path)
117
+
118
+ # Make the API call with both text and image
119
+ response = await asyncio.wait_for(
120
+ client.models.generate_content(
121
+ model="gemini-2.0-flash-preview-image-generation",
122
+ contents=[modification_prompt, input_image],
123
+ config=types.GenerateContentConfig(
124
+ response_modalities=["TEXT", "IMAGE"],
125
+ safety_settings=safety_settings,
126
+ ),
127
+ ),
128
+ settings.request_timeout,
129
+ )
130
 
131
  # Process the response parts
132
  image_saved = False
src/main.py CHANGED
@@ -136,7 +136,7 @@ with gr.Blocks(
136
  with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
137
  gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
138
 
139
- local_storage = gr.BrowserState("", "user_hash")
140
 
141
  # Constructor Interface (visible by default)
142
  with gr.Column(
@@ -313,7 +313,7 @@ with gr.Blocks(
313
  start_btn.click(
314
  fn=start_game_with_music,
315
  inputs=[
316
- local_storage,
317
  setting_description,
318
  char_name,
319
  char_age,
@@ -330,13 +330,14 @@ with gr.Blocks(
330
  game_image,
331
  game_choices,
332
  custom_choice,
 
333
  ],
334
  concurrency_limit=CONCURRENCY_LIMIT,
335
  )
336
 
337
  back_btn.click(
338
  fn=return_to_constructor,
339
- inputs=[local_storage],
340
  outputs=[
341
  loading_indicator,
342
  constructor_interface,
@@ -345,16 +346,9 @@ with gr.Blocks(
345
  ],
346
  )
347
 
348
- game_choices.change(
349
- fn=update_scene,
350
- inputs=[local_storage, game_choices],
351
- outputs=[game_text, game_image, game_choices, custom_choice],
352
- concurrency_limit=CONCURRENCY_LIMIT,
353
- )
354
-
355
  custom_choice.submit(
356
  fn=update_scene,
357
- inputs=[local_storage, custom_choice],
358
  outputs=[game_text, game_image, game_choices, custom_choice],
359
  concurrency_limit=CONCURRENCY_LIMIT,
360
  )
@@ -363,13 +357,15 @@ with gr.Blocks(
363
  demo.load(
364
  fn=generate_user_hash,
365
  inputs=[],
366
- outputs=[local_storage],
367
  )
368
- local_storage.change(
369
  fn=update_audio,
370
- inputs=[local_storage],
371
  outputs=[audio_out],
372
  concurrency_limit=CONCURRENCY_LIMIT,
373
  )
374
 
 
 
375
  demo.launch(ssr_mode=False)
 
136
  with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
137
  gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
138
 
139
+ ls_user_hash = gr.BrowserState("", "user_hash")
140
 
141
  # Constructor Interface (visible by default)
142
  with gr.Column(
 
313
  start_btn.click(
314
  fn=start_game_with_music,
315
  inputs=[
316
+ ls_user_hash,
317
  setting_description,
318
  char_name,
319
  char_age,
 
330
  game_image,
331
  game_choices,
332
  custom_choice,
333
+ ls_user_hash,
334
  ],
335
  concurrency_limit=CONCURRENCY_LIMIT,
336
  )
337
 
338
  back_btn.click(
339
  fn=return_to_constructor,
340
+ inputs=[ls_user_hash],
341
  outputs=[
342
  loading_indicator,
343
  constructor_interface,
 
346
  ],
347
  )
348
 
 
 
 
 
 
 
 
349
  custom_choice.submit(
350
  fn=update_scene,
351
+ inputs=[ls_user_hash, custom_choice],
352
  outputs=[game_text, game_image, game_choices, custom_choice],
353
  concurrency_limit=CONCURRENCY_LIMIT,
354
  )
 
357
  demo.load(
358
  fn=generate_user_hash,
359
  inputs=[],
360
+ outputs=[ls_user_hash],
361
  )
362
+ ls_user_hash.change(
363
  fn=update_audio,
364
+ inputs=[ls_user_hash],
365
  outputs=[audio_out],
366
  concurrency_limit=CONCURRENCY_LIMIT,
367
  )
368
 
369
+
370
+ demo.queue()
371
  demo.launch(ssr_mode=False)
src/services/google.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
+ from google import genai
5
+ import threading
6
+
7
+ from config import settings
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ApiKeyPool:
13
+ """Manage Google API keys with round-robin selection."""
14
+
15
+ def __init__(self) -> None:
16
+ self._keys: list[str] | None = None
17
+ self._index = 0
18
+ self._lock = asyncio.Lock()
19
+ self._sync_lock = threading.Lock()
20
+
21
+ def _load_keys(self) -> None:
22
+ keys_raw = (
23
+ getattr(settings, "gemini_api_keys", None) or settings.gemini_api_key
24
+ )
25
+ keys_str = keys_raw.get_secret_value()
26
+ keys = [k.strip() for k in keys_str.split(',') if k.strip()] if keys_str else []
27
+ if not keys:
28
+ msg = "Google API keys are not configured or invalid"
29
+ logger.error(msg)
30
+ raise ValueError(msg)
31
+ self._keys = keys
32
+
33
+ async def get_key(self) -> str:
34
+ async with self._lock:
35
+ if self._keys is None:
36
+ self._load_keys()
37
+ key = self._keys[self._index]
38
+ self._index = (self._index + 1) % len(self._keys)
39
+ logger.debug("Using Google API key index %s", self._index)
40
+ return key
41
+
42
+ def get_key_sync(self) -> str:
43
+ """Synchronous helper for environments without an event loop."""
44
+ with self._sync_lock:
45
+ if self._keys is None:
46
+ self._load_keys()
47
+ key = self._keys[self._index]
48
+ self._index = (self._index + 1) % len(self._keys)
49
+ logger.debug("Using Google API key index %s", self._index)
50
+ return key
51
+
52
+
53
+ class GoogleClientFactory:
54
+ """Factory for thread-safe creation of Google GenAI clients."""
55
+
56
+ _pool = ApiKeyPool()
57
+
58
+ @classmethod
59
+ @asynccontextmanager
60
+ async def image(cls):
61
+ key = await cls._pool.get_key()
62
+ client = genai.Client(api_key=key)
63
+ try:
64
+ yield client.aio
65
+ finally:
66
+ pass
67
+
68
+ @classmethod
69
+ @asynccontextmanager
70
+ async def audio(cls):
71
+ key = await cls._pool.get_key()
72
+ client = genai.Client(api_key=key, http_options={"api_version": "v1alpha"})
73
+ try:
74
+ yield client.aio
75
+ finally:
76
+ pass