import torch import os from transformers import AutoModelForCausalLM, AutoTokenizer class ModelHandler: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = None self.tokenizer = None self.initialized = False def initialize(self): """Initialize the model and tokenizer""" if self.initialized: return try: # Load model and tokenizer from the local path model_path = os.path.dirname(os.path.abspath(__file__)) self.model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", torch_dtype=torch.float16 # Use float16 for T4 GPU optimization ) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.initialized = True except Exception as e: raise RuntimeError(f"Error initializing model: {str(e)}") def predict(self, input_data): """ Process the input data and generate an answer from the model. Args: input_data (dict): The input question. Returns: dict: The model's generated answer. """ if not self.initialized: self.initialize() try: # Extract the question from input_data question = input_data.get('question', '') if not question: return {"error": "No question provided."} # Define the prompt with the user's question alpaca_prompt = f""" السؤال: {question} الإجابة: """ formatted_prompt = alpaca_prompt.strip() # Tokenize the input inputs = self.tokenizer([formatted_prompt], return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate with proper error handling and memory management with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=128, temperature=0.7, top_k=50, top_p=0.95, use_cache=True, pad_token_id=self.tokenizer.eos_token_id ) # Decode the output decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) # Clean up the output clean_output = decoded_output[0].replace("السؤال:", "").replace("الإجابة:", "").strip() # Clear CUDA cache if using GPU if self.device == "cuda": torch.cuda.empty_cache() return {"answer": clean_output} except Exception as e: return {"error": f"Prediction error: {str(e)}"} # Create a global handler instance handler = ModelHandler() def predict(input_data): """ Wrapper function for the handler's predict method """ return handler.predict(input_data)