from fastapi import FastAPI, File, UploadFile, Form, HTTPException # Keep these if you use them elsewhere in your app (HTML, static files) from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi.responses import FileResponse # Removed 'requests' as we are using gradio_client for captioning # import requests # Not needed if using gradio_client for captioning import base64 # Keep if needed elsewhere import os import random # Removed unused Iterator import as we are not streaming with llama.cpp # from typing import Iterator # Removed unused IO import # from typing import IO # Import necessary classes from transformers (Keeping only AutoTokenizer) from transformers import AutoTokenizer # Import necessary modules for llama-cpp-python and downloading from Hub from llama_cpp import Llama # The core Llama class from huggingface_hub import hf_hub_download # For downloading GGUF files # Import the Gradio Client and handle_file for captioning from gradio_client import Client, handle_file # Import necessary modules for temporary file handling (for gradio_client) import tempfile # shutil is not strictly necessary for this version # import shutil from deep_translator import GoogleTranslator from deep_translator.exceptions import InvalidSourceOrTargetLanguage app = FastAPI() # --- Llama.cpp Language Model Setup (Local CPU Inference) --- # Repository on Hugging Face Hub for Qwen1.5 0.5B Chat GGUF (verified from image) LLM_MODEL_REPO = "Qwen/Qwen1.5-0.5B-Chat-GGUF" # Specific filename for Q4_K_M quantization in that repo (verified from image) LLM_MODEL_FILE = "qwen1_5-0_5b-chat-q4_k_m.gguf" # Original model name for the tokenizer (needed by transformers for chat templating) ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" tokenizer = None # Using transformers tokenizer llm_model = None # This will hold the llama_cpp.Llama instance # --- Hugging Face Gradio Space Client Setup (For External Image Captioning) --- # Global Gradio Client for Captioning caption_client = None # The URL of the external Gradio Space for image captioning CAPTION_SPACE_URL = "Makhinur/Image-to-Text-Salesforce-blip-image-captioning-base" # --- Translation Setup (Keeping original global translators, only _ar is used in endpoint) --- translator_to_en = GoogleTranslator(source='arabic', target='english') # Keep for parity with original code translator_to_ar = GoogleTranslator(source='english', target='arabic') # Used to translate final English story to Arabic # Function to load the language model (GGUF via llama.cpp) and its tokenizer def load_language_model(): global tokenizer, llm_model print(f"Loading language model: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...") try: # --- Load Tokenizer (using transformers) --- print(f"Loading tokenizer from original model repo: {ORIGINAL_MODEL_NAME}...") tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL_NAME) # Set pad_token if needed (Qwen tokenizers usually have one, but robust check) if tokenizer.pad_token is None: if tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token elif tokenizer.unk_token is not None: tokenizer.pad_token = tokenizer.unk_token else: print("Warning: Neither EOS nor UNK token found for tokenizer. Setting pad_token to None.") tokenizer.pad_token = None # --- Download GGUF model file (using huggingface_hub) --- print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...") model_path = hf_hub_download( repo_id=LLM_MODEL_REPO, filename=LLM_MODEL_FILE, # cache_dir="/tmp/hf_cache" # Optional: specify a custom cache directory ) print(f"GGUF model downloaded to: {model_path}") # --- Load the GGUF model (using llama-cpp-python) --- print(f"Loading GGUF model into llama_cpp...") llm_model = Llama( model_path=model_path, n_gpu_layers=0, # Explicitly use CPU n_ctx=4096, # Context window size (4096 is a common safe value) n_threads=2 # Use 2 CPU threads on your basic tier ) print("Llama.cpp model loaded successfully.") except Exception as e: print(f"Error loading language model {LLM_MODEL_REPO}/{LLM_MODEL_FILE}: {e}") tokenizer = None llm_model = None # Ensure model is None if loading fails # Function to initialize the Gradio Client for the captioning Space (Corrected version) def initialize_caption_client(): global caption_client print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...") try: # If the target Gradio Space requires authentication (e.g., private) # store HF_TOKEN as a Space Secret and uncomment these lines. # HF_TOKEN = os.environ.get("HF_TOKEN") # if HF_TOKEN: # print("Using HF_TOKEN for Gradio client.") # caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN) # else: # print("HF_TOKEN not found. Initializing public Gradio client.") # caption_client = Client(CAPTION_SPACE_URL) # Assuming the caption space is public caption_client = Client(CAPTION_SPACE_URL) print("Gradio client initialized successfully.") except Exception as e: print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}") # Set client to None so the endpoint can check and return an error caption_client = None # Load models and initialize clients when the app starts @app.on_event("startup") async def startup_event(): # Load the language model (Qwen 0.5B GGUF via llama.cpp) load_language_model() # Initialize the Gradio client for captioning initialize_caption_client() # --- Image Captioning Function (Using gradio_client and temporary file) --- # This function is the correctly implemented version using gradio_client and temporary files. def generate_image_caption(image_file: UploadFile): """ Generates a caption for the uploaded image using the external Gradio Space API. Reads the uploaded file's content, saves it to a temporary file, and uses the temporary file's path with handle_file for the API call. """ if caption_client is None: # If the client failed to initialize at startup error_msg = "Gradio caption client not initialized. Cannot generate caption." print(error_msg) return f"Error: {error_msg}" temp_file_path = None # Variable to store the path of the temporary file try: print(f"Attempting to generate caption for file: {image_file.filename}") # Read the content of the uploaded file # Seek to the beginning just in case the file-like object's pointer was moved image_file.file.seek(0) image_bytes = image_file.file.read() # Create a temporary file on the local filesystem # delete=False ensures the file persists after closing the handle # suffix helps hint at the file type for the Gradio API temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image_file.filename)[1] or '.jpg') temp_file.write(image_bytes) temp_file.close() # Close the file handle so gradio_client can access the file temp_file_path = temp_file.name # Get the full path to the temporary file print(f"Saved uploaded file temporarily to: {temp_file_path}") # Use handle_file() with the path string to the temporary file. # This correctly prepares the file for the Gradio API input. prepared_input = handle_file(temp_file_path) # Call the predict method on the initialized gradio_client # api_name="/predict" matches the endpoint specified in the Gradio API docs caption = caption_client.predict(img=prepared_input, api_name="/predict") print(f"Caption generated successfully.") # Return the caption string received from the API return caption except Exception as e: # Catch any exceptions that occur during reading, writing, or the API call print(f"Error during caption generation API call: {e}") # Log the error details server-side # Return a structured error string including the exception type and message return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}" finally: # Clean up the temporary file regardless of whether the process succeeded or failed if temp_file_path and os.path.exists(temp_file_path): print(f"Cleaning up temporary file: {temp_file_path}") try: os.remove(temp_file_path) # Delete the file using its path except OSError as e: print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors # Removed original get_prompt function (CodeLlama specific) # def get_prompt(...): ... # --- Story Generation Function (Qwen 0.5B via llama.cpp) --- # This replaces the original 'run' function and does NOT stream. # It generates the story in English first. def generate_story_qwen_0_5b_english(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str: """ Generates the story in English using the loaded Qwen1.5 0.5B model via llama.cpp. Calls llama.cpp's create_chat_completion. """ # Check if the language model was loaded successfully if tokenizer is None or llm_model is None: raise RuntimeError("Language model (llama.cpp) or tokenizer not loaded.") # Construct the messages list for Qwen1.5 Chat # Note: Original CodeLlama system_prompt is now part of the main prompt_text messages = [ {"role": "user", "content": prompt_text} ] try: print("Calling llama.cpp create_chat_completion for Qwen 0.5B (English generation)...") response = llm_model.create_chat_completion( messages=messages, max_tokens=max_new_tokens, temperature=temperature, top_p=top_p, # top_k is sometimes supported by llama-cpp-python, but less standard for chat completions # top_k=top_k, stream=False # Request full response ) print("Llama.cpp completion received for Qwen 0.5B.") # Parse the response to get the generated text content if response and response.get('choices') and len(response['choices']) > 0: story_english = response['choices'][0].get('message', {}).get('content', '') else: print("Warning: Llama.cpp Qwen 0.5B response structure unexpected or content missing.") story_english = "" except Exception as e: print(f"Llama.cpp Qwen 0.5B inference failed: {e}") raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}") return story_english.strip() # --- FastAPI Endpoint for Story Generation --- # Keeping the original endpoint signature (no 'language' Form parameter) # The story will ALWAYS be translated to Arabic at the end. @app.post("/generate-story/") async def generate_story(image_file: UploadFile = File(...)): # No 'language' parameter as per original code # Original code had system_prompt here, let's integrate the random theme into the prompt_text story_theme = random.choice([ 'an adventurous journey', 'a mysterious encounter', 'a heroic quest', 'a magical adventure', 'a thrilling escape', 'an unexpected discovery', 'a dangerous mission', 'a romantic escapade', 'an epic battle', 'a journey into the unknown' ]) # Step 1: Get image caption using the external API via gradio_client (corrected method) # Pass the UploadFile object directly caption = generate_image_caption(image_file) # Check if caption generation failed if caption.startswith("Error:"): print(f"Caption generation failed: {caption}") raise HTTPException(status_code=500, detail=caption) # Step 2: Construct the full prompt text for the language model # Incorporate the random theme/instruction and the caption. prompt_text = f"Write a detailed story that is approximately 300 words long. Ensure the story has a clear beginning, middle, and end about {story_theme}. Incorporate the following image description: {caption}\n\nStory:" # Step 3: Generate the story in English using the local language model (Qwen via llama.cpp) try: story_english = generate_story_qwen_0_5b_english( # Generate English story prompt_text, max_new_tokens=350, # Request ~300 new tokens temperature=0.7, top_p=0.9, top_k=50 ) # Basic cleanup story_english = story_english.strip() # Check if the generated English story is empty if not story_english: print("Language model generated empty story.") raise HTTPException(status_code=500, detail="Story generation failed: Language model returned empty response.") except RuntimeError as e: # Catch specific RuntimeError raised if LLM loading or inference fails print(f"Language model generation error: {e}") raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}") except Exception as e: # Catch any other unexpected errors during story generation print(f"An unexpected error occurred during story generation: {e}") raise HTTPException(status_code=500, detail=f"An unexpected error occurred during story generation: {type(e).__name__}: {e}") # Step 4: Translate the generated English story to Arabic (hardcoded as per original code) print("Translating English story to Arabic...") try: # Use the global translator_to_ar (English to Arabic) translated_story_arabic = translator_to_ar.translate(story_english) # Check if translation returned None or an empty string if translated_story_arabic is None or translated_story_arabic == "": print("Arabic translation returned None or empty string.") # If translation fails, raise an error raise HTTPException(status_code=500, detail="Arabic translation failed.") story_final = translated_story_arabic except Exception as e: # Catch any errors during translation print(f"Arabic translation failed: {e}") raise HTTPException(status_code=500, detail=f"Arabic translation failed: {type(e).__name__}: {e}") # Step 5: Return the final Arabic story as a JSON response return {"story": story_final} # --- Optional: Serve a simple HTML form for testing --- # To use this, uncomment the imports related to HTMLResponse, StaticFiles, Jinja2Templates, FileResponse # at the top of the file, and create a 'templates' directory with an 'index.html' file. # from fastapi import Request # from fastapi.templating import Jinja2Templates # from fastapi.staticfiles import StaticFiles # templates = Jinja2Templates(directory="templates") # app.mount("/static", StaticFiles(directory="static"), name="static") # @app.get("/", response_class=HTMLResponse) # async def read_root(request: Request): # # Example HTML form - this version expects no language input from the user # html_content = """ # # #