openfree's picture
Update app.py
7a0313b verified
#!/usr/bin/env python
import os
import re
import tempfile
import gc
from collections.abc import Iterator
from threading import Thread
import json
import requests
import gradio as gr
import spaces
import torch
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from peft import PeftModel
# BitsAndBytesConfig는 조건부로 import
try:
from transformers import BitsAndBytesConfig
BITSANDBYTES_AVAILABLE = True
except ImportError:
logger.warning("BitsAndBytesConfig를 import할 수 없습니다. 양자화 기능이 비활성화됩니다.")
BITSANDBYTES_AVAILABLE = False
# CSV/TXT 분석
import pandas as pd
# PDF 텍스트 추출
import PyPDF2
##############################################################################
# 상수 정의
##############################################################################
MAX_CONTENT_CHARS = 2000 # 문서 내용 최대 문자 수
MAX_INPUT_LENGTH = 4096 # 모델 입력 최대 토큰 수
##############################################################################
# 전역 변수
##############################################################################
model = None
tokenizer = None
device = None
##############################################################################
# 메모리 정리 함수 추가
##############################################################################
def clear_cuda_cache():
"""CUDA 캐시를 명시적으로 비웁니다."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
##############################################################################
# SERPHouse API key from environment variable
##############################################################################
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
##############################################################################
# 간단한 키워드 추출 함수 (한글 + 알파벳 + 숫자 + 공백 보존)
##############################################################################
def extract_keywords(text: str, top_k: int = 5) -> str:
"""
1) 한글(가-힣), 영어(a-zA-Z), 숫자(0-9), 공백만 남김
2) 공백 기준 토큰 분리
3) 최대 top_k개만
"""
# 특수문자 제거하되 기본적인 문장 부호는 유지
text = re.sub(r"[^a-zA-Z0-9가-힣\s\.\,\?\!]", "", text)
tokens = text.split()
# 중복 제거하면서 순서 유지
seen = set()
unique_tokens = []
for token in tokens:
if token not in seen and len(token) > 1: # 1글자 단어 제외
seen.add(token)
unique_tokens.append(token)
key_tokens = unique_tokens[:top_k]
return " ".join(key_tokens)
##############################################################################
# SerpHouse Live endpoint 호출
##############################################################################
def do_web_search(query: str) -> str:
"""
상위 20개 'organic' 결과 item 전체(제목, link, snippet 등)를
JSON 문자열 형태로 반환
"""
try:
url = "https://api.serphouse.com/serp/live"
params = {
"q": query,
"domain": "google.com",
"serp_type": "web",
"device": "desktop",
"lang": "en",
"num": "20"
}
headers = {
"Authorization": f"Bearer {SERPHOUSE_API_KEY}"
}
logger.info(f"SerpHouse API 호출 중... 검색어: {query}")
response = requests.get(url, headers=headers, params=params, timeout=60)
response.raise_for_status()
data = response.json()
# 다양한 응답 구조 처리
results = data.get("results", {})
organic = None
if isinstance(results, dict) and "organic" in results:
organic = results["organic"]
elif isinstance(results, dict) and "results" in results:
if isinstance(results["results"], dict) and "organic" in results["results"]:
organic = results["results"]["organic"]
elif "organic" in data:
organic = data["organic"]
if not organic:
logger.warning("응답에서 organic 결과를 찾을 수 없습니다.")
return "No web search results found or unexpected API response structure."
# 결과 수 제한 및 컨텍스트 길이 최적화
max_results = min(20, len(organic))
limited_organic = organic[:max_results]
# 결과 형식 개선 - 마크다운 형식으로 출력
summary_lines = []
for idx, item in enumerate(limited_organic, start=1):
title = item.get("title", "No title")
link = item.get("link", "#")
snippet = item.get("snippet", "No description")
displayed_link = item.get("displayed_link", link)
summary_lines.append(
f"### Result {idx}: {title}\n\n"
f"{snippet}\n\n"
f"**Source**: [{displayed_link}]({link})\n\n"
f"---\n"
)
instructions = """
# Web Search Results
Below are the search results. Use this information when answering the question:
1. Reference the title, content, and source links from each result
2. Explicitly cite relevant sources in your response
3. Include actual source links in your response
4. Synthesize information from multiple sources when answering
5. Provide comprehensive analysis based on the search data
6. Cross-reference multiple sources for accuracy verification
"""
search_results = instructions + "\n".join(summary_lines)
logger.info(f"검색 결과 {len(limited_organic)}개 처리 완료")
return search_results
except Exception as e:
logger.error(f"Web search failed: {e}")
return f"Web search failed: {str(e)}"
##############################################################################
# 모델 및 토크나이저 로드 (Space 환경에서 최적화)
##############################################################################
def load_model(model_name="VIDraft/Gemma-3-R1984-1B", adapter_name="openfree/Gemma-3-R1984-1B-0613"):
global model, tokenizer, device
logger.info(f"모델 로딩 시작: {model_name} (어댑터: {adapter_name})")
clear_cuda_cache() # 캐시 정리
# device 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# 양자화 설정을 시도하되, 실패하면 일반 로드
if BITSANDBYTES_AVAILABLE:
try:
# bitsandbytes가 설치되어 있는지 추가 확인
import bitsandbytes
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# 베이스 모델 로드 (양자화 적용)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=False,
)
logger.info("4-bit 양자화로 모델 로드 완료")
except ImportError:
logger.warning("bitsandbytes가 설치되지 않았습니다. 양자화 없이 모델을 로드합니다.")
# 베이스 모델 로드 (양자화 없이)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # GPU 메모리 절약을 위해 float16 사용
device_map="auto",
trust_remote_code=False,
)
else:
logger.info("BitsAndBytesConfig를 사용할 수 없습니다. 일반 모드로 모델을 로드합니다.")
# 베이스 모델 로드 (양자화 없이)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # GPU 메모리 절약을 위해 float16 사용
device_map="auto",
trust_remote_code=False,
)
# 토크나이저 로드 (베이스 모델과 동일한 토크나이저 사용)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# 한글 처리를 위한 추가 설정
tokenizer.model_max_length = MAX_INPUT_LENGTH
# PEFT 어댑터 로드 및 베이스 모델에 병합
try:
model = PeftModel.from_pretrained(model, adapter_name)
logger.info(f"PEFT 어댑터 로딩 및 병합 완료: {adapter_name}")
except Exception as e:
logger.error(f"PEFT 어댑터 로딩 오류: {e}")
logger.warning("어댑터 로딩에 실패했습니다. 베이스 모델로 진행합니다.")
model.eval() # 추론 모드로 설정
# 모델 설정 로깅
logger.info(f"모델 설정 - device: {device}, dtype: {model.dtype}")
logger.info(f"토크나이저 설정 - vocab_size: {tokenizer.vocab_size}, max_length: {tokenizer.model_max_length}")
logger.info("모델 및 토크나이저 로딩 완료")
return model, tokenizer
##############################################################################
# CSV, TXT, PDF 분석 함수
##############################################################################
def analyze_csv_file(path: str) -> str:
"""CSV 파일을 전체 문자열로 변환. 너무 길 경우 일부만 표시."""
try:
df = pd.read_csv(path)
if df.shape[0] > 50 or df.shape[1] > 10:
df = df.iloc[:50, :10]
df_str = df.to_string()
if len(df_str) > MAX_CONTENT_CHARS:
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
except Exception as e:
return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
def analyze_txt_file(path: str) -> str:
"""TXT 파일 전문 읽기. 너무 길면 일부만 표시."""
try:
with open(path, "r", encoding="utf-8") as f:
text = f.read()
if len(text) > MAX_CONTENT_CHARS:
text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
except Exception as e:
return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
def pdf_to_markdown(pdf_path: str) -> str:
"""PDF 텍스트를 Markdown으로 변환. 페이지별로 간단히 텍스트 추출."""
text_chunks = []
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
max_pages = min(5, len(reader.pages))
for page_num in range(max_pages):
page = reader.pages[page_num]
page_text = page.extract_text() or ""
page_text = page_text.strip()
if page_text:
if len(page_text) > MAX_CONTENT_CHARS // max_pages:
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)"
text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n")
if len(reader.pages) > max_pages:
text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...")
except Exception as e:
return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
full_text = "\n".join(text_chunks)
if len(full_text) > MAX_CONTENT_CHARS:
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
##############################################################################
# 문서 파일 확인
##############################################################################
def is_document_file(file_path: str) -> bool:
return (
file_path.lower().endswith(".pdf")
or file_path.lower().endswith(".csv")
or file_path.lower().endswith(".txt")
)
##############################################################################
# 메시지 처리 (텍스트 및 문서 파일만)
##############################################################################
def process_new_user_message(message: dict) -> str:
"""사용자 메시지와 첨부된 문서 파일들을 처리하여 하나의 텍스트로 결합"""
content_parts = [message["text"]]
if message.get("files"):
csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
for csv_path in csv_files:
csv_analysis = analyze_csv_file(csv_path)
content_parts.append(csv_analysis)
for txt_path in txt_files:
txt_analysis = analyze_txt_file(txt_path)
content_parts.append(txt_analysis)
for pdf_path in pdf_files:
pdf_markdown = pdf_to_markdown(pdf_path)
content_parts.append(pdf_markdown)
return "\n\n".join(content_parts)
##############################################################################
# 대화 히스토리 처리
##############################################################################
def process_history(history: list[dict]) -> str:
"""대화 히스토리를 텍스트 형식으로 변환"""
conversation_text = ""
for item in history:
if item["role"] == "assistant":
conversation_text += f"\nAssistant: {item['content']}\n"
else: # user
content = item["content"]
if isinstance(content, str):
conversation_text += f"\nUser: {content}\n"
elif isinstance(content, list) and len(content) > 0:
# 파일 경로만 표시
file_path = content[0]
conversation_text += f"\nUser: [File: {os.path.basename(file_path)}]\n"
return conversation_text
##############################################################################
# 모델 생성 함수
##############################################################################
def _model_gen_with_oom_catch(**kwargs):
"""별도 스레드에서 OutOfMemoryError를 잡아주기 위해"""
global model
try:
model.generate(**kwargs)
except torch.cuda.OutOfMemoryError:
raise RuntimeError(
"[OutOfMemoryError] GPU memory insufficient. "
"Please reduce Max New Tokens or shorten the prompt length."
)
finally:
clear_cuda_cache()
##############################################################################
# 메인 추론 함수 (텍스트 전용)
##############################################################################
@spaces.GPU(duration=120)
def run(
message: dict,
history: list[dict],
system_prompt: str = "",
max_new_tokens: int = 512,
use_web_search: bool = False,
web_search_query: str = "",
) -> Iterator[str]:
global model, tokenizer
# 모델이 로드되지 않았으면 로드
if model is None or tokenizer is None:
load_model()
try:
# 전체 프롬프트 구성
full_prompt = ""
# 시스템 프롬프트
if system_prompt.strip():
full_prompt += f"System: {system_prompt.strip()}\n\n"
# 웹 검색 수행
if use_web_search:
user_text = message["text"]
ws_query = extract_keywords(user_text, top_k=5)
if ws_query.strip():
logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
ws_result = do_web_search(ws_query)
full_prompt += f"[Web Search Results]\n{ws_result}\n\n"
# 언어에 따른 지시사항
if any(ord('가') <= ord(char) <= ord('힣') for char in user_text):
full_prompt += "[중요: 위 검색결과의 출처를 한글로 인용하여 답변해 주세요.]\n\n"
else:
full_prompt += "[Important: Please cite the sources from the search results above.]\n\n"
# 대화 히스토리
if history:
conversation_history = process_history(history)
full_prompt += conversation_history
# 현재 사용자 메시지
user_content = process_new_user_message(message)
# 언어 감지 및 추가 지시사항
has_korean = any(ord('가') <= ord(char) <= ord('힣') for char in user_content)
if has_korean:
lang_instruction = "\n[중요: 반드시 한글로 답변하세요. 영어로 답변하지 마세요.]\n"
logger.info("한글 질문 감지 - 한글 답변 모드")
else:
lang_instruction = ""
logger.info("영어 질문 감지 - 영어 답변 모드")
full_prompt += f"\nUser: {user_content}{lang_instruction}\nAssistant:"
# 프롬프트 길이 로깅
logger.info(f"프롬프트 길이: {len(full_prompt)} 문자")
# 토큰화
inputs = tokenizer(
full_prompt,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_LENGTH
).to(device=model.device)
# 스트리밍 설정
streamer = TextIteratorStreamer(
tokenizer,
timeout=30.0,
skip_prompt=True,
skip_special_tokens=True
)
gen_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=0.8, # 0.7에서 0.8로 증가
top_p=0.95, # 0.9에서 0.95로 증가
top_k=50, # top_k 추가
repetition_penalty=1.1, # 반복 방지 추가
do_sample=True,
)
# 별도 스레드에서 생성
t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
t.start()
# 스트리밍 출력
output = ""
chunk_count = 0
for new_text in streamer:
output += new_text
chunk_count += 1
# 주기적으로 메모리 정리
if chunk_count % 100 == 0:
gc.collect()
yield output
except Exception as e:
logger.error(f"Error in run: {str(e)}")
yield f"Sorry, an error occurred: {str(e)}"
finally:
# 메모리 정리
try:
del inputs
except:
pass
clear_cuda_cache()
title_html = """
<h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> 🤗 Gemma3-R1984-1B (Text-Only) </h1>
<p align="center" style="font-size:1.1em; color:#555;">
✅Agentic AI Platform ✅Reasoning & Analysis ✅Text Analysis ✅Deep Research & RAG <br>
✅Document Processing (PDF, CSV, TXT) ✅Web Search Integration ✅Korean/English Support<br>
✅Running on Independent Local Server with 'NVIDIA L40s / A100(ZeroGPU) GPU'<br>
@Model Repository: VIDraft/Gemma-3-R1984-1B, @Based on: 'Google Gemma-3-1b'
</p>
"""
with gr.Blocks(title="Gemma3-R1984-1B") as demo:
gr.Markdown(title_html)
with gr.Accordion("Advanced Settings", open=False):
web_search_checkbox = gr.Checkbox(
label="Deep Research (Enable Web Search)",
value=False
)
max_tokens_slider = gr.Slider(
label="Max Tokens (Response Length)",
minimum=100,
maximum=8000,
step=50,
value=2048,
info="Increase this value for longer responses"
)
system_prompt_box = gr.Textbox(
lines=5,
label="System Prompt",
value="""You are an AI assistant that performs deep thinking. Please follow these guidelines:
1. **Language**: If the user asks in Korean, you must answer in Korean. If they ask in English, answer in English.
2. **Response Length**: Provide sufficiently detailed and rich responses. Write responses with at least 3-5 paragraphs.
3. **Analysis Method**: Thoroughly analyze problems and provide accurate solutions through systematic reasoning processes.
4. **Structure**: Organize responses with clear structure, using numbers or bullet points when necessary.
5. **Examples and Explanations**: Include specific examples and detailed explanations whenever possible."""
)
web_search_text = gr.Textbox(
lines=1,
label="(Unused) Web Search Query",
placeholder="No direct input needed",
visible=False
)
chat = gr.ChatInterface(
fn=run,
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1),
textbox=gr.MultimodalTextbox(
file_types=[".csv", ".txt", ".pdf"], # 이미지/비디오 제거
file_count="multiple",
autofocus=True
),
multimodal=True,
additional_inputs=[
system_prompt_box,
max_tokens_slider,
web_search_checkbox,
web_search_text,
],
stop_btn=False,
css_paths=None,
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch()