ponix-generator / llm_wrapper.py
cwhuh's picture
chore : remove output structure, refined prompt
39fee28
raw
history blame
3.85 kB
import logging
from PIL import Image
from io import BytesIO
import requests, os, json, time
from google import genai
prompt_base_path = ""
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
def encode_image(image_source):
"""
์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๊ฐ€ URL์ด๋“  ๋กœ์ปฌ ํŒŒ์ผ์ด๋“  Pillow Image ๊ฐ์ฒด์ด๋“  ๋™์ผํ•˜๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜.
์ด๋ฏธ์ง€๋ฅผ ์—ด์–ด google.genai.types.Part ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Pillow์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์— ๋Œ€ํ•ด์„œ๋Š” ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.
"""
try:
# ์ด๋ฏธ Pillow ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
if isinstance(image_source, Image.Image):
image = image_source
else:
# URL์—์„œ ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
if isinstance(image_source, str) and (
image_source.startswith("http://")
or image_source.startswith("https://")
):
response = requests.get(image_source)
image = Image.open(BytesIO(response.content))
# ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ์—ด๊ธฐ
else:
image = Image.open(image_source)
# ์ด๋ฏธ์ง€ ํฌ๋งท์ด None์ธ ๊ฒฝ์šฐ (๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๋“ฑ)
if image.format is None:
image_format = "JPEG"
else:
image_format = image.format
# ์ด๋ฏธ์ง€ ํฌ๋งท์ด ์ง€์›๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ ์˜ˆ์™ธ ๋ฐœ์ƒ
if image_format not in Image.registered_extensions().values():
raise ValueError(f"Unsupported image format: {image_format}.")
buffered = BytesIO()
# PIL์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์ด๋‚˜ ๋‹ค์–‘ํ•œ ์ฑ„๋„์„ RGB๋กœ ๋ณ€ํ™˜ ํ›„ ์ €์žฅ
if image.mode in ("RGBA", "P", "CMYK"): # RGBA, ํŒ”๋ ˆํŠธ, CMYK ๋“ฑ์€ RGB๋กœ ๋ณ€ํ™˜
image = image.convert("RGB")
image.save(buffered, format="JPEG")
return genai.types.Part.from_bytes(data=buffered.getvalue(), mime_type="image/jpeg")
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to download the image from URL: {e}")
except IOError as e:
raise ValueError(f"Failed to process the image file: {e}")
except ValueError as e:
raise ValueError(e)
def run_gemini(
target_prompt: str,
prompt_in_path: str,
img_in_data: str = None,
model: str = "gemini-2.0-flash",
) -> str:
"""
GEMINI API๋ฅผ ๋™๊ธฐ ๋ฐฉ์‹์œผ๋กœ ํ˜ธ์ถœํ•˜์—ฌ ๋ฌธ์ž์—ด ์‘๋‹ต์„ ๋ฐ›์Šต๋‹ˆ๋‹ค.
retry ๋…ผ๋ฆฌ๋Š” ์ œ๊ฑฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
"""
with open(os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8") as file:
prompt_dict = json.load(file)
system_prompt = prompt_dict["system_prompt"]
user_prompt_head = prompt_dict["user_prompt"]["head"]
user_prompt_tail = prompt_dict["user_prompt"]["tail"]
user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
input_content = [user_prompt_text]
if img_in_data is not None:
encoded_image = encode_image(img_in_data)
input_content.append(encoded_image)
logging.info("Requested API for chat completion response (sync call)...")
start_time = time.time()
# ๋™๊ธฐ ๋ฐฉ์‹: client.models.generate_content(...)
chat_completion = client.models.generate_content(
model=model,
contents=input_content,
)
chat_output = chat_completion.parsed
input_token = chat_completion.usage_metadata.prompt_token_count
output_token = chat_completion.usage_metadata.candidates_token_count
pricing = input_token / 1000000 * 0.1 * 1500 + output_token / 1000000 * 0.7 * 1500
logging.info(
f"[GEMINI] Request completed (sync). Time taken: {time.time()-start_time:.2f}s / Pricing(KRW): {pricing:.2f}"
)
return chat_output, chat_completion