import torch import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import os, gc, logging from threading import Thread import random from datasets import load_dataset import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer import pandas as pd from typing import List, Tuple, Iterator import json from datetime import datetime from concurrent.futures import ThreadPoolExecutor from functools import lru_cache import pyarrow.parquet as pq import pypdf from pdfminer.high_level import extract_text from pdfminer.layout import LAParams from tabulate import tabulate from pydantic import BaseModel import unittest # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('app.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # 전역 변수 model = None tokenizer = None current_file_context = None # CSS 스타일 CSS = """ .chat-container { height: 600px !important; margin-bottom: 10px; } .input-container { height: 70px !important; display: flex; align-items: center; gap: 10px; margin-top: 5px; } .input-textbox { height: 70px !important; border-radius: 8px !important; font-size: 1.1em !important; padding: 10px 15px !important; } .custom-button { background: linear-gradient(145deg, #2196f3, #1976d2); color: white; border-radius: 10px; padding: 10px 20px; font-weight: 600; transition: all 0.3s ease; } """ # 설정 클래스 class Config: def __init__(self): self.MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024" self.MAX_HISTORY = 10 self.MAX_TOKENS = 4096 self.DEFAULT_TEMPERATURE = 0.8 self.HF_TOKEN = os.environ.get("HF_TOKEN", None) self.MODELS = os.environ.get("MODELS") config = Config() # 커스텀 예외 클래스 class FileProcessingError(Exception): pass # 응답 모델 class ChatResponse(BaseModel): message: str status: str timestamp: datetime def initialize_model_and_tokenizer(): global model, tokenizer try: model = load_model() tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID) return True except Exception as e: logger.error(f"Initialization error: {str(e)}") return False # 파일 처리 클래스 class FileProcessor: @staticmethod def safe_file_read(file_path): encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1'] for encoding in encodings: try: with open(file_path, 'r', encoding=encoding) as f: return f.read() except UnicodeDecodeError: continue raise FileProcessingError("Unable to read file with supported encodings") @staticmethod def process_pdf(file_path): try: with ThreadPoolExecutor() as executor: pdf_reader = pypdf.PdfReader(file_path) text = extract_text( file_path, laparams=LAParams( line_margin=0.5, word_margin=0.1, char_margin=2.0, all_texts=True ) ) return text except Exception as e: raise FileProcessingError(f"PDF processing error: {str(e)}") @staticmethod def process_csv(file_path): try: return pd.read_csv(file_path) except Exception as e: raise FileProcessingError(f"CSV processing error: {str(e)}") @staticmethod def analyze_file_content(content, file_type): try: if file_type == 'pdf': words = len(content.split()) lines = content.count('\n') + 1 return f"PDF Analysis:\nWords: {words}\nLines: {lines}" elif file_type == 'csv': df = pd.DataFrame(content) return f"CSV Analysis:\nRows: {len(df)}\nColumns: {len(df.columns)}" else: lines = content.split('\n') return f"Text Analysis:\nLines: {len(lines)}" except Exception as e: raise FileProcessingError(f"Content analysis error: {str(e)}") # 메모리 관리 @torch.no_grad() def clear_cuda_memory(): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() if model is not None: model.cpu() # 모델 로드 @spaces.GPU def load_model(): try: model = AutoModelForCausalLM.from_pretrained( config.MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) return model except Exception as e: logger.error(f"Model loading error: {str(e)}") raise # 컨텍스트 검색 @lru_cache(maxsize=100) def find_relevant_context(query, top_k=3): try: query_vector = vectorizer.transform([query]) similarities = (query_vector * question_vectors.T).toarray()[0] top_indices = np.argsort(similarities)[-top_k:][::-1] relevant_contexts = [] for idx in top_indices: if similarities[idx] > 0: relevant_contexts.append({ 'question': questions[idx], 'answer': wiki_dataset['train']['answer'][idx], 'similarity': similarities[idx] }) return relevant_contexts except Exception as e: logger.error(f"Context search error: {str(e)}") return [] # 스트리밍 채팅 @spaces.GPU def stream_chat(message: str, history: list, uploaded_file, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float) -> Iterator[Tuple[str, list]]: """ 스트리밍 채팅 응답을 생성합니다. """ global model, tokenizer, current_file_context try: if model is None or tokenizer is None: if not initialize_model_and_tokenizer(): raise Exception("Model initialization failed") logger.info(f'Processing message: {message}') logger.debug(f'History length: {len(history)}') # 파일 처리 file_context = "" if uploaded_file: try: file_ext = os.path.splitext(uploaded_file.name)[1].lower() if file_ext == '.pdf': content = FileProcessor.process_pdf(uploaded_file.name) elif file_ext == '.csv': content = FileProcessor.process_csv(uploaded_file.name) else: content = FileProcessor.safe_file_read(uploaded_file.name) file_context = FileProcessor.analyze_file_content(content, file_ext.replace('.', '')) current_file_context = file_context except Exception as e: logger.error(f"File processing error: {str(e)}") file_context = f"\n\n❌ File analysis error: {str(e)}" # 컨텍스트 검색 및 프롬프트 구성 relevant_contexts = find_relevant_context(message) wiki_context = "\n\n관련 위키피디아 정보:\n" + "\n".join([ f"Q: {ctx['question']}\nA: {ctx['answer']}\n유사도: {ctx['similarity']:.3f}" for ctx in relevant_contexts ]) # 토큰화 및 생성 conversation = [ {"role": "user" if i % 2 == 0 else "assistant", "content": msg} for hist in history[-config.MAX_HISTORY:] for i, msg in enumerate(hist) ] final_message = f"{file_context}{wiki_context}\n현재 질문: {message}" conversation.append({"role": "user", "content": final_message}) inputs = tokenizer( tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True), return_tensors="pt" ).to("cuda") # 입력 길이 체크 if len(inputs.input_ids[0]) > config.MAX_TOKENS: raise ValueError("Input too long") streamer = TextIteratorStreamer( tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( inputs, streamer=streamer, top_k=top_k, top_p=top_p, repetition_penalty=penalty, max_new_tokens=min(max_new_tokens, 2048), do_sample=True, temperature=temperature, eos_token_id=[255001], ) clear_cuda_memory() with torch.no_grad(): thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield "", history + [[message, buffer]] clear_cuda_memory() except Exception as e: logger.error(f"Stream chat error: {str(e)}") yield "", history + [[message, f"Error: {str(e)}"]] clear_cuda_memory() # UI 생성 def create_demo(): with gr.Blocks(css=CSS) as demo: with gr.Column(elem_classes="markdown-style"): gr.Markdown(""" # 🤖 RAGOndevice #### 📊 RAG: Upload and Analyze Files (TXT, CSV, PDF, Parquet files) Upload your files for data analysis and learning """) chatbot = gr.Chatbot( value=[], height=600, label="AI Assistant", elem_classes="chat-container" ) with gr.Row(elem_classes="input-container"): with gr.Column(scale=1, min_width=70): file_upload = gr.File( type="filepath", elem_classes="file-upload-icon", scale=1, container=True, interactive=True, show_label=False ) with gr.Column(scale=3): msg = gr.Textbox( show_label=False, placeholder="Type your message here... 💭", container=False, elem_classes="input-textbox", scale=1 ) with gr.Column(scale=1, min_width=70): send = gr.Button( "Send", elem_classes="send-button custom-button", scale=1 ) with gr.Column(scale=1, min_width=70): clear = gr.Button( "Clear", elem_classes="clear-button custom-button", scale=1 ) with gr.Accordion("🎮 Advanced Settings", open=False): with gr.Row(): with gr.Column(scale=1): temperature = gr.Slider( minimum=0, maximum=1, step=0.1, value=config.DEFAULT_TEMPERATURE, label="Creativity Level 🎨" ) max_new_tokens = gr.Slider( minimum=128, maximum=8000, step=1, value=4000, label="Maximum Token Count 📝" ) with gr.Column(scale=1): top_p = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="Diversity Control 🎯" ) top_k = gr.Slider( minimum=1, maximum=20, step=1, value=20, label="Selection Range 📊" ) penalty = gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty 🔄" ) # 이벤트 바인딩 msg.submit(stream_chat, [msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], [msg, chatbot]) send.click(stream_chat, [msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], [msg, chatbot]) clear.click(lambda: ([], None, ""), outputs=[chatbot, file_upload, msg]) return demo # 메인 실행 if __name__ == "__main__": try: # 모델 초기화 if not initialize_model_and_tokenizer(): logger.error("Failed to initialize model and tokenizer") exit(1) # 위키피디아 데이터셋 로드 wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna") logger.info("Wikipedia dataset loaded") # TF-IDF 벡터라이저 초기화 questions = wiki_dataset['train']['question'][:10000] vectorizer = TfidfVectorizer(max_features=1000) question_vectors = vectorizer.fit_transform(questions) logger.info("TF-IDF vectorization completed") # UI 실행 demo = create_demo() demo.launch(share=False, server_name="0.0.0.0") except Exception as e: logger.error(f"Application startup error: {str(e)}") exit(1) # 테스트 코드 class TestChatBot(unittest.TestCase): def setUp(self): self.file_processor = FileProcessor() def test_file_processing(self): # 파일 처리 테스트 test_content = "Test content" result = self.file_processor.analyze_file_content(test_content, 'txt') self.assertIsNotNone(result) def test_context_search(self): # 컨텍스트 검색 테스트 test_query = "테스트 질문" result = find_relevant_context(test_query) self.assertIsInstance(result, list) class Config: def __init__(self): # 변경: EleutherAI/polyglot-ko-12.8b 모델 사용 self.MODEL_ID = "EleutherAI/polyglot-ko-12.8b" # 한국어 특화 모델 self.MAX_HISTORY = 10 self.MAX_TOKENS = 4096 self.DEFAULT_TEMPERATURE = 0.8 self.HF_TOKEN = os.environ.get("HF_TOKEN", None) self.MODELS = os.environ.get("MODELS")