farouk1 commited on
Commit
d4d58c1
·
verified ·
1 Parent(s): 21189d5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch # Import torch for device management
4
+ import os # For file operations
5
+
6
+ # --- Configuration and Model Loading ---
7
+ # You can choose a different model here if you have access to more powerful ones.
8
+ # For larger models, ensure you have sufficient VRAM (GPU memory).
9
+ # For CPU, smaller models might be necessary or use quantization.
10
+ MODEL_NAME = "google/flan-t5-large" # Changed to 'large' for slightly better performance than 'base' and still manageable.
11
+ # If you have a powerful GPU, consider "google/flan-t5-xl" or even "google/flan-t5-xxl"
12
+ # For even larger models, consider using model.to(torch.bfloat16) or bitsandbytes for 4-bit loading if available.
13
+
14
+ try:
15
+ # Determine the device to use (GPU if available, else CPU)
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"Loading model on device: {device}")
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
+ # Load model with half-precision (float16) to save VRAM if on GPU
21
+ # Or load in 8-bit/4-bit if using libraries like bitsandbytes (requires installation)
22
+ if device == "cuda":
23
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
24
+ else:
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
26
+
27
+ model.eval() # Set model to evaluation mode
28
+ print(f"Model '{MODEL_NAME}' loaded successfully.")
29
+
30
+ except Exception as e:
31
+ print(f"Error loading model: {e}")
32
+ print("Please check your internet connection, model name, and available resources (RAM/VRAM).")
33
+ # Exit or handle gracefully if model loading fails
34
+ tokenizer, model = None, None
35
+
36
+ # --- Prompt Engineering Functions (more structured) ---
37
+
38
+ def create_arabic_prompt(topic, style):
39
+ if style == "Blog Post (Descriptive)":
40
+ return f"اكتب مقالاً احترافياً بأسلوب شخصي عن: {topic}. ركز على التفاصيل، الوصف الجذاب، قدم نصائح عملية. اجعل النص منسقاً بفقرات وعناوين فرعية."
41
+ elif style == "Social Media Post (Short & Catchy)":
42
+ return f"اكتب منشوراً قصيراً وجذاباً ومثيراً للتفاعل عن: {topic}. أضف 2-3 إيموجي مناسبة واقترح 4 هاشتاغات شائعة. ابدأ بسؤال أو جملة جذابة."
43
+ else: # Video Script (Storytelling)
44
+ return f"اكتب سيناريو فيديو احترافي ومقنع عن: {topic}. اجعل الأسلوب قصصي وسردي، مقسماً إلى مشاهد رئيسية، مع اقتراح لقطات بصرية (B-roll) وأصوات (SFX) لكل مشهد. ركز على إثارة المشاعر."
45
+
46
+ def create_english_prompt(topic, style):
47
+ if style == "Blog Post (Descriptive)":
48
+ return f"Write a detailed and professional blog post about: {topic}. Focus on personal insights, vivid descriptions, and practical advice. Structure it with clear paragraphs and subheadings."
49
+ elif style == "Social Media Post (Short & Catchy)":
50
+ return f"Write a short, catchy, and engaging social media post about: {topic}. Include 2-3 relevant emojis and suggest 4 trending hashtags. Start with a hook question or statement."
51
+ else: # Video Script (Storytelling)
52
+ return f"Write a professional, compelling video script about: {topic}. Make it emotionally engaging and story-driven, divided into key scenes, with suggested visual shots (B-roll) and sound effects (SFX) for each scene."
53
+
54
+ # --- Content Generation Function ---
55
+
56
+ @torch.no_grad() # Disable gradient calculations for inference to save memory
57
+ def generate_content(topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty):
58
+ if tokenizer is None or model is None:
59
+ return "⚠️ Error: Model not loaded. Please check the console for details."
60
+
61
+ if not topic:
62
+ return "⚠️ Please enter a topic to generate content."
63
+
64
+ # Max length based on desired length and model's context window
65
+ # Flan-T5 has a context window of 512, so max_length should be within this.
66
+ if length_choice == "Short":
67
+ max_new_tokens = 150
68
+ min_new_tokens = 50
69
+ elif length_choice == "Medium":
70
+ max_new_tokens = 300
71
+ min_new_tokens = 100
72
+ else: # Long
73
+ max_new_tokens = 450 # Max for Flan-T5 effectively
74
+ min_new_tokens = 150
75
+
76
+ # Adjust generation parameters based on user input
77
+ temperature = creativity # Direct mapping
78
+ top_p = detail_level # Direct mapping, higher means more detail/diversity
79
+ no_repeat_ngram_size = diversity_penalty # Higher means less repetition
80
+
81
+ # Build the prompt
82
+ if lang_choice == "Arabic":
83
+ prompt = create_arabic_prompt(topic, style_choice)
84
+ else: # English
85
+ prompt = create_english_prompt(topic, style_choice)
86
+
87
+ # Add detail level instruction to prompt if high
88
+ if detail_level > 0.7: # Only if user explicitly wants high detail
89
+ prompt += " Ensure comprehensive coverage and rich descriptions."
90
+ if creativity > 0.8:
91
+ prompt += " Be highly creative and imaginative in your writing."
92
+
93
+ try:
94
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
95
+
96
+ outputs = model.generate(
97
+ **inputs,
98
+ max_new_tokens=max_new_tokens,
99
+ min_new_tokens=min_new_tokens,
100
+ num_beams=5, # Beam search for better quality
101
+ do_sample=True, # Enable sampling for creativity
102
+ temperature=temperature,
103
+ top_p=top_p,
104
+ top_k=50, # Consider top 50 words
105
+ no_repeat_ngram_size=no_repeat_ngram_size,
106
+ length_penalty=1.0, # Adjust to control output length
107
+ early_stopping=True
108
+ )
109
+ content = tokenizer.decode(outputs[0], skip_special_tokens=True)
110
+
111
+ return content
112
+ except RuntimeError as e:
113
+ if "out of memory" in str(e):
114
+ return "⚠️ Generation failed: Out of memory. Try a shorter length, a less complex model, or restart the application if on GPU."
115
+ return f"⚠️ Generation failed due as runtime error: {str(e)}"
116
+ except Exception as e:
117
+ return f"⚠️ An unexpected error occurred during generation: {str(e)}"
118
+
119
+ # --- Gradio Interface ---
120
+
121
+ # Custom CSS for a more polished look
122
+ custom_css = """
123
+ h1, h2, h3 { color: #4B0082; } /* Dark Purple */
124
+ .gradio-container {
125
+ background-color: #F8F0FF; /* Light Lavender */
126
+ font-family: 'Segoe UI', sans-serif;
127
+ }
128
+ .gr-button {
129
+ background-color: #8A2BE2; /* Blue Violet */
130
+ color: white;
131
+ border-radius: 10px;
132
+ padding: 10px 20px;
133
+ font-size: 1.1em;
134
+ }
135
+ .gr-button:hover {
136
+ background-color: #9370DB; /* Medium Purple */
137
+ }
138
+ .gr-text-input, .gr-textarea {
139
+ border: 1px solid #DDA0DD; /* Plum */
140
+ border-radius: 8px;
141
+ padding: 10px;
142
+ }
143
+ .gradio-radio input:checked + label {
144
+ background-color: #DA70D6 !important; /* Orchid */
145
+ color: white !important;
146
+ }
147
+ .gradio-radio label {
148
+ border: 1px solid #DDA0DD;
149
+ border-radius: 8px;
150
+ padding: 8px 15px;
151
+ }
152
+ """
153
+
154
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as iface:
155
+ gr.Markdown("# ✨ AI Content Creation Studio")
156
+ gr.Markdown("## Generate professional blogs, social media posts, or video scripts in seconds!")
157
+
158
+ with gr.Row():
159
+ with gr.Column(scale=2):
160
+ topic = gr.Textbox(
161
+ label="Topic / الموضوع",
162
+ placeholder="e.g., The Future of AI in Healthcare / مثال: مستقبل الذكاء الاصطناعي في الرعاية الصحية",
163
+ lines=2
164
+ )
165
+
166
+ with gr.Accordion("Advanced Settings", open=False):
167
+ with gr.Row():
168
+ creativity = gr.Slider(
169
+ minimum=0.1, maximum=1.0, value=0.7, step=0.1,
170
+ label="Creativity (Temperature)",
171
+ info="Higher values lead to more creative, less predictable text. Lower values are more focused."
172
+ )
173
+ detail_level = gr.Slider(
174
+ minimum=0.1, maximum=1.0, value=0.9, step=0.1,
175
+ label="Detail Level (Top-p Sampling)",
176
+ info="Higher values allow for more diverse and detailed vocabulary. Lower values prune less likely words."
177
+ )
178
+ with gr.Row():
179
+ diversity_penalty = gr.Slider(
180
+ minimum=1, maximum=5, value=2, step=1,
181
+ label="Repetition Penalty (N-gram)",
182
+ info="Higher values reduce the chance of repeating the same phrases or words. Set to 1 for no penalty."
183
+ )
184
+
185
+ with gr.Column(scale=1):
186
+ with gr.Group():
187
+ style_choice = gr.Radio(
188
+ ["Blog Post (Descriptive)", "Social Media Post (Short & Catchy)", "Video Script (Storytelling)"],
189
+ label="Content Style / نوع المحتوى",
190
+ value="Blog Post (Descriptive)",
191
+ interactive=True
192
+ )
193
+ with gr.Group():
194
+ lang_choice = gr.Radio(
195
+ ["English", "Arabic"],
196
+ label="Language / اللغة",
197
+ value="English",
198
+ interactive=True
199
+ )
200
+ with gr.Group():
201
+ length_choice = gr.Radio(
202
+ ["Short", "Medium", "Long"],
203
+ label="Content Length / طول النص",
204
+ value="Medium",
205
+ interactive=True
206
+ )
207
+ gr.Markdown("*(Note: 'Long' is relative to model capabilities, max ~450 words)*")
208
+
209
+ btn = gr.Button("🚀 Generate Content", variant="primary")
210
+
211
+ output = gr.Textbox(label="Generated Content", lines=20, interactive=True)
212
+
213
+ # Download button logic
214
+ def download_file(content):
215
+ if content and not content.startswith("⚠️"): # Only provide file if content is valid
216
+ file_path = "generated_content.txt"
217
+ with open(file_path, "w", encoding="utf-8") as f:
218
+ f.write(content)
219
+ return file_path
220
+ return None # Return None if no valid content to download
221
+
222
+ download_button = gr.DownloadButton("⬇️ Download Content", file_path=None, interactive=False)
223
+
224
+ # Event handlers
225
+ btn.click(
226
+ fn=generate_content,
227
+ inputs=[topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty],
228
+ outputs=output
229
+ )
230
+
231
+ # Enable download button only when there's valid content
232
+ output.change(fn=download_file, inputs=[output], outputs=[download_button])
233
+
234
+ if __name__ == "__main__":
235
+ iface.launch()