cwhuh commited on
Commit
dc592b9
·
1 Parent(s): eb9cbe1

fix : google-genai -> openai

Browse files
__pycache__/live_preview_helpers.cpython-310.pyc CHANGED
Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ
 
__pycache__/llm_wrapper.cpython-310.pyc CHANGED
Binary files a/__pycache__/llm_wrapper.cpython-310.pyc and b/__pycache__/llm_wrapper.cpython-310.pyc differ
 
llm_wrapper.py CHANGED
@@ -1,105 +1,43 @@
1
- import logging
2
- from PIL import Image
3
- from io import BytesIO
4
- import requests, os, json, time
5
-
6
- from google import genai
7
 
8
  prompt_base_path = ""
9
- client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
10
-
11
-
12
- def encode_image(image_source):
13
- """
14
- 이미지 경로가 URL이든 로컬 파일이든 Pillow Image 객체이든 동일하게 처리하는 함수.
15
- 이미지를 열어 google.genai.types.Part 객체로 변환합니다.
16
- Pillow에서 지원되지 않는 포맷에 대해서는 예외를 발생시킵니다.
17
- """
18
- try:
19
- # 이미 Pillow 이미지 객체인 경우 그대로 사용
20
- if isinstance(image_source, Image.Image):
21
- image = image_source
22
- else:
23
- # URL에서 이미지 다운로드
24
- if isinstance(image_source, str) and (
25
- image_source.startswith("http://")
26
- or image_source.startswith("https://")
27
- ):
28
- response = requests.get(image_source)
29
- image = Image.open(BytesIO(response.content))
30
- # 로컬 파일에서 이미지 열기
31
- else:
32
- image = Image.open(image_source)
33
-
34
- # 이미지 포맷이 None인 경우 (메모리에서 생성된 이미지 등)
35
- if image.format is None:
36
- image_format = "JPEG"
37
- else:
38
- image_format = image.format
39
-
40
- # 이미지 포맷이 지원되지 않는 경우 예외 발생
41
- if image_format not in Image.registered_extensions().values():
42
- raise ValueError(f"Unsupported image format: {image_format}.")
43
-
44
- buffered = BytesIO()
45
- # PIL에서 지원되지 않는 포맷이나 다양한 채널을 RGB로 변환 후 저장
46
- if image.mode in ("RGBA", "P", "CMYK"): # RGBA, 팔레트, CMYK 등은 RGB로 변환
47
- image = image.convert("RGB")
48
- image.save(buffered, format="JPEG")
49
-
50
- return genai.types.Part.from_bytes(data=buffered.getvalue(), mime_type="image/jpeg")
51
 
52
- except requests.exceptions.RequestException as e:
53
- raise ValueError(f"Failed to download the image from URL: {e}")
54
- except IOError as e:
55
- raise ValueError(f"Failed to process the image file: {e}")
56
- except ValueError as e:
57
- raise ValueError(e)
58
 
59
 
60
  def run_gemini(
61
  target_prompt: str,
62
  prompt_in_path: str,
63
- img_in_data: str = None,
64
- model: str = "gemini-2.0-flash",
65
  ) -> str:
66
  """
67
- GEMINI API를 동기 방식으로 호출하여 문자열 응답을 받습니다.
68
- retry 논리는 제거되었습니다.
69
  """
70
- with open(os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8") as file:
 
 
 
 
71
  prompt_dict = json.load(file)
72
 
73
  system_prompt = prompt_dict["system_prompt"]
74
- user_prompt_head = prompt_dict["user_prompt"]["head"]
75
- user_prompt_tail = prompt_dict["user_prompt"]["tail"]
76
-
77
- user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
78
- input_content = [user_prompt_text]
79
-
80
- if img_in_data is not None:
81
- encoded_image = encode_image(img_in_data)
82
- input_content.append(encoded_image)
83
-
84
- logging.info("Requested API for chat completion response (sync call)...")
85
- start_time = time.time()
86
-
87
- # 동기 방식: client.models.generate_content(...)
88
- chat_completion = client.models.generate_content(
89
- model=model,
90
- contents=input_content,
91
- config={
92
- "system_instruction": system_prompt,
93
- }
94
  )
95
- print(f"Chat Completion: {chat_completion}")
96
-
97
- chat_output = chat_completion.candidates[0].content.parts[0].text
98
- input_token = chat_completion.usage_metadata.prompt_token_count
99
- output_token = chat_completion.usage_metadata.candidates_token_count
100
- pricing = input_token / 1000000 * 0.1 * 1500 + output_token / 1000000 * 0.7 * 1500
101
 
102
- logging.info(
103
- f"[GEMINI] Request completed (sync). Time taken: {time.time()-start_time:.2f}s / Pricing(KRW): {pricing:.2f}"
 
 
 
 
 
 
 
104
  )
 
105
  return chat_output
 
1
+ import openai, os, json
 
 
 
 
 
2
 
3
  prompt_base_path = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ client = openai.OpenAI(
6
+ api_key=os.getenv("GEMINI_API_KEY"),
7
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
8
+ )
 
 
9
 
10
 
11
  def run_gemini(
12
  target_prompt: str,
13
  prompt_in_path: str,
14
+ llm_model: str = "gemini-2.0-flash-exp",
 
15
  ) -> str:
16
  """
17
+ gemini 모델 사용 코드
 
18
  """
19
+
20
+ # Load prompt
21
+ with open(
22
+ os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8"
23
+ ) as file:
24
  prompt_dict = json.load(file)
25
 
26
  system_prompt = prompt_dict["system_prompt"]
27
+ user_prompt_head, user_prompt_tail = (
28
+ prompt_dict["user_prompt"]["head"],
29
+ prompt_dict["user_prompt"]["tail"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
 
 
 
 
 
 
31
 
32
+ user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
33
+ input_content = [{"type": "text", "text": user_prompt_text}]
34
+
35
+ chat_completion = client.beta.chat.completions.parse(
36
+ model=llm_model,
37
+ messages=[
38
+ {"role": "system", "content": system_prompt},
39
+ {"role": "user", "content": input_content},
40
+ ],
41
  )
42
+ chat_output = chat_completion.choices[0].message.content
43
  return chat_output
requirements.txt CHANGED
@@ -5,5 +5,5 @@ transformers==4.42.4
5
  xformers
6
  sentencepiece
7
  peft==0.12.0
8
- google-genai
9
- gradio
 
5
  xformers
6
  sentencepiece
7
  peft==0.12.0
8
+ openai
9
+ gradio==4.43.0