Rausda6 commited on
Commit
4211f84
·
verified ·
1 Parent(s): 428399e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -184
app.py CHANGED
@@ -10,211 +10,133 @@ import uuid
10
  import json
11
  from typing import List, Dict
12
 
13
-
14
- from transformers import pipeline
15
  import torch
16
-
17
- # Configuration: local model path or remote repo ID and HF token secret
18
- LOCAL_MODEL_PATH = os.getenv("GEMMA_MODEL_PATH") # e.g. "./models/gemma-2-l6b"
19
- REMOTE_MODEL_ID = "meta-llama/Llama-3.1-8B"
20
- HF_TOKEN = os.getenv("Tokentest") # your secret token name
21
-
22
- # Determine model source and auth
23
- if LOCAL_MODEL_PATH and os.path.isdir(LOCAL_MODEL_PATH):
24
- model_source = LOCAL_MODEL_PATH
25
- auth_token = None
26
- else:
27
- model_source = REMOTE_MODEL_ID
28
- auth_token = HF_TOKEN or os.getenv("HUGGINGFACE_HUB_TOKEN")
29
-
30
- # Initialize Gemma text-generation pipeline
31
- _device = 0 if torch.cuda.is_available() else -1
32
- pipeline_kwargs = {
33
- "model": model_source,
34
- "device": _device,
35
- "torch_dtype": "auto"
36
- }
37
- if auth_token:
38
- pipeline_kwargs["use_auth_token"] = auth_token
39
-
40
- text_generator = pipeline(
41
- "text-generation",
42
- **pipeline_kwargs
43
  )
44
-
45
- # Constants
46
- MAX_FILE_SIZE_MB = 20
47
- MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
 
 
 
 
 
 
 
48
 
49
  class PodcastGenerator:
 
 
 
50
  def __init__(self):
51
  pass
52
 
53
  async def generate_script(self, prompt: str, language: str, file_obj=None, progress=None) -> Dict:
54
- example = """
55
- {
56
- "topic": "AGI",
57
- "podcast": [ ... ]
58
- }
59
- """
60
- if language == "Auto Detect":
61
- language_instruction = "- The podcast MUST be in the same language as the user input."
62
- else:
63
- language_instruction = f"- The podcast MUST be in {language} language"
64
-
65
- system_prompt = f"""
66
- You are a professional podcast generator. Your task is to generate a professional podcast script based on the user input.
67
- {language_instruction}
68
- - The podcast should have 2 speakers.
69
- - The podcast should be long.
70
- - Do not use names for the speakers.
71
- - The podcast should be interesting, lively, and engaging, and hook the listener from the start.
72
- - The input text might be disorganized or unformatted, originating from sources like PDFs or text files. Ignore any formatting inconsistencies or irrelevant details; your task is to distill the essential points, identify key definitions, and highlight intriguing facts that would be suitable for discussion in a podcast.
73
- - The script must be in JSON format.
74
- Follow this example structure:
75
- {example}
76
- """
77
-
78
  if prompt and file_obj:
79
- user_prompt = f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}"
80
  elif prompt:
81
- user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
82
  else:
83
- user_prompt = "Please generate a podcast script based on the uploaded file."
84
 
85
  full_prompt = system_prompt + "\n\n" + user_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  loop = asyncio.get_event_loop()
87
- result = await loop.run_in_executor(
88
- None,
89
- lambda: text_generator(full_prompt, max_new_tokens=512, do_sample=True)
90
- )
91
- gen_text = result[0]["generated_text"]
92
- return json.loads(gen_text)
93
 
94
  async def _read_file_bytes(self, file_obj) -> bytes:
95
- if hasattr(file_obj, 'size'):
96
- file_size = file_obj.size
97
- else:
98
- file_size = os.path.getsize(file_obj.name)
99
- if file_size > MAX_FILE_SIZE_BYTES:
100
- raise Exception(f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file.")
101
- if hasattr(file_obj, 'read'):
102
- return file_obj.read()
103
- else:
104
- async with aiofiles.open(file_obj.name, 'rb') as f:
105
- return await f.read()
106
-
107
- def _get_mime_type(self, filename: str) -> str:
108
- ext = os.path.splitext(filename)[1].lower()
109
- if ext == '.pdf':
110
- return "application/pdf"
111
- elif ext == '.txt':
112
- return "text/plain"
113
- mime_type, _ = mimetypes.guess_type(filename)
114
- return mime_type or "application/octet-stream"
115
-
116
- async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
117
- voice = speaker1 if speaker == 1 else speaker2
118
- speech = edge_tts.Communicate(text, voice)
119
- temp_filename = f"temp_{uuid.uuid4()}.wav"
120
- try:
121
- await asyncio.wait_for(speech.save(temp_filename), timeout=30)
122
- return temp_filename
123
- except Exception:
124
- if os.path.exists(temp_filename):
125
- os.remove(temp_filename)
126
- raise
127
-
128
- async def combine_audio_files(self, audio_files: List[str], progress=None) -> str:
129
- if progress:
130
- progress(0.9, "Combining audio files...")
131
- combined_audio = AudioSegment.empty()
132
- for audio_file in audio_files:
133
- combined_audio += AudioSegment.from_file(audio_file)
134
- os.remove(audio_file)
135
- output_filename = f"output_{uuid.uuid4()}.wav"
136
- combined_audio.export(output_filename, format="wav")
137
- if progress:
138
- progress(1.0, "Podcast generated successfully!")
139
- return output_filename
140
-
141
- async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, file_obj=None, progress=None) -> str:
142
- return await asyncio.wait_for(
143
- self._generate_podcast_internal(input_text, language, speaker1, speaker2, file_obj, progress),
144
- timeout=600
145
- )
146
-
147
- async def _generate_podcast_internal(self, input_text: str, language: str, speaker1: str, speaker2: str, file_obj=None, progress=None) -> str:
148
- if progress:
149
- progress(0.2, "Generating podcast script...")
150
- podcast_json = await self.generate_script(input_text, language, file_obj, progress)
151
- if progress:
152
- progress(0.5, "Converting text to speech...")
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  audio_files = []
155
- total_lines = len(podcast_json['podcast'])
156
- batch_size = 10
157
- for batch_start in range(0, total_lines, batch_size):
158
- batch = podcast_json['podcast'][batch_start:batch_start+batch_size]
159
- tasks = [self.tts_generate(item['line'], item['speaker'], speaker1, speaker2) for item in batch]
160
- results = await asyncio.gather(*tasks)
161
- audio_files.extend(results)
162
- if progress:
163
- progress(0.5 + 0.4 * ((batch_start+len(batch)) / total_lines), f"Processed {batch_start+len(batch)}/{total_lines} segments...")
164
-
165
- combined = await self.combine_audio_files(audio_files, progress)
166
- return combined
167
-
168
- async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, progress=None) -> str:
169
- generator = PodcastGenerator()
170
- return await generator.generate_podcast(input_text, language, speaker1, speaker2, input_file, progress)
171
 
172
  # Gradio UI
173
 
174
- def generate_podcast_gradio(input_text, input_file, language, speaker1, speaker2, progress=gr.Progress()):
175
- def progress_callback(val, msg):
176
- progress(val, msg)
177
- result = asyncio.run(process_input(
178
- input_text,
179
- input_file,
180
- language,
181
- speaker1,
182
- speaker2,
183
- progress_callback
184
- ))
185
- return result
186
-
187
-
188
- def main():
189
- language_options = ["Auto Detect", "English", "German", "French"]
190
- voice_options = [
191
- "Andrew - English (United States)",
192
- "Ava - English (United States)",
193
- "Brian - English (United States)",
194
- "Emma - English (United States)",
195
- "Florian - German (Germany)",
196
- "Seraphina - German (Germany)",
197
- "Remy - French (France)",
198
- "Vivienne - French (France)"
199
- ]
200
- with gr.Blocks(title="PodcastGen 🎙️") as demo:
201
- gr.Markdown("# PodcastGen 🎙️")
202
- gr.Markdown("Generate a 2-speaker podcast from text input or documents!")
203
- with gr.Row():
204
- input_text = gr.Textbox(label="Input Text", lines=10)
205
- input_file = gr.File(label="Or Upload a PDF or TXT file", file_types=[".pdf", ".txt"])
206
- with gr.Row():
207
- language = gr.Dropdown(label="Language", choices=language_options, value="Auto Detect")
208
- speaker1 = gr.Dropdown(label="Speaker 1 Voice", choices=voice_options, value="Andrew - English (United States)")
209
- speaker2 = gr.Dropdown(label="Speaker 2 Voice", choices=voice_options, value="Ava - English (United States)")
210
- generate_btn = gr.Button("Generate Podcast", variant="primary")
211
- output_audio = gr.Audio(label="Generated Podcast", type="filepath", format="wav")
212
- generate_btn.click(
213
- fn=generate_podcast_gradio,
214
- inputs=[input_text, input_file, language, speaker1, speaker2],
215
- outputs=[output_audio]
216
  )
217
  demo.launch()
218
 
219
- if __name__ == "__main__":
220
- main()
 
10
  import json
11
  from typing import List, Dict
12
 
13
+ # Model imports
 
14
  import torch
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
16
+
17
+ # Configuration
18
+ # Use this MODEL_ID, adjust if you have a local path instead
19
+ MODEL_ID = os.getenv("GEMMA_MODEL_PATH", "tabularisai/german-gemma-3-1b-it")
20
+ # Hugging Face token secret (optional, for gated/private models)
21
+ HF_TOKEN = os.getenv("Tokentest")
22
+
23
+ # Load tokenizer and model
24
+ print(f"Loading model {MODEL_ID}...")
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ MODEL_ID,
27
+ trust_remote_code=True,
28
+ use_auth_token=HF_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_ID,
32
+ trust_remote_code=True,
33
+ use_auth_token=HF_TOKEN,
34
+ torch_dtype=(torch.bfloat16 if torch.cuda.is_available() else torch.float32),
35
+ device_map="auto"
36
+ ).eval()
37
+
38
+ # Optional: set up a simple stopping criteria on <end_of_turn> token
39
+ PAD = tokenizer.pad_token_id or tokenizer.eos_token_id
40
+ EOT = tokenizer.convert_tokens_to_ids('<end_of_turn>')
41
 
42
  class PodcastGenerator:
43
+ MAX_FILE_MB = 20
44
+ MAX_FILE_BYTES = MAX_FILE_MB * 1024 * 1024
45
+
46
  def __init__(self):
47
  pass
48
 
49
  async def generate_script(self, prompt: str, language: str, file_obj=None, progress=None) -> Dict:
50
+ example = '{"topic": "AGI", "podcast": [ ... ] }'
51
+ lang_inst = (
52
+ "- The podcast MUST be in the same language as the user input."
53
+ if language == "Auto Detect"
54
+ else f"- The podcast MUST be in {language} language"
55
+ )
56
+ system_prompt = (
57
+ "You are a professional podcast generator. Your task is to generate a professional podcast script..."
58
+ f"\n{lang_inst}\n- The podcast should have 2 speakers.\n- The podcast should be long."
59
+ "\n- Do not use names for the speakers.\n- The podcast should be interesting, lively, and engaging..."
60
+ "\n- The script must be in JSON format. Follow this example structure:" + example
61
+ )
 
 
 
 
 
 
 
 
 
 
 
 
62
  if prompt and file_obj:
63
+ user_prompt = f"Generate podcast script based on file and prompt:\n{prompt}"
64
  elif prompt:
65
+ user_prompt = f"Generate podcast script based on prompt:\n{prompt}"
66
  else:
67
+ user_prompt = "Generate podcast script based on uploaded file."
68
 
69
  full_prompt = system_prompt + "\n\n" + user_prompt
70
+
71
+ # sync generation in executor
72
+ def gen_sync():
73
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
74
+ # add stopping criteria
75
+ stop_crit = StoppingCriteriaList([StoppingCriteria(max_length=512)])
76
+ outputs = model.generate(
77
+ **inputs,
78
+ max_new_tokens=512,
79
+ do_sample=True,
80
+ pad_token_id=PAD,
81
+ eos_token_id=EOT
82
+ )
83
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+
85
  loop = asyncio.get_event_loop()
86
+ text = await loop.run_in_executor(None, gen_sync)
87
+ return json.loads(text)
 
 
 
 
88
 
89
  async def _read_file_bytes(self, file_obj) -> bytes:
90
+ size = getattr(file_obj, 'size', os.path.getsize(file_obj.name))
91
+ if size > self.MAX_FILE_BYTES:
92
+ raise Exception(f"File > {self.MAX_FILE_MB}MB")
93
+ return file_obj.read() if hasattr(file_obj, 'read') else await aiofiles.open(file_obj.name, 'rb').read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ async def tts_generate(self, text: str, speaker: int, s1: str, s2: str) -> str:
96
+ voice = s1 if speaker == 1 else s2
97
+ speech = edge_tts.Communicate(text, voice)
98
+ fname = f"tmp_{uuid.uuid4()}.wav"
99
+ await speech.save(fname)
100
+ return fname
101
+
102
+ async def combine_audio_files(self, files: List[str], progress=None) -> str:
103
+ combined = AudioSegment.empty()
104
+ for f in files:
105
+ combined += AudioSegment.from_file(f)
106
+ os.remove(f)
107
+ out = f"out_{uuid.uuid4()}.wav"
108
+ combined.export(out, format="wav")
109
+ return out
110
+
111
+ async def generate_podcast(self, text: str, lang: str, sp1: str, sp2: str, file_obj=None, progress=None) -> str:
112
+ pj = await self.generate_script(text, lang, file_obj, progress)
113
+ parts = pj['podcast']
114
  audio_files = []
115
+ for seg in parts:
116
+ audio_files.append(await self.tts_generate(seg['line'], seg['speaker'], sp1, sp2))
117
+ return await self.combine_audio_files(audio_files)
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # Gradio UI
120
 
121
+ def run_app():
122
+ langs = ["Auto Detect","German","English","French"]
123
+ voices = ["Florian - German (Germany)", "Andrew - English (US)"]
124
+ gen = PodcastGenerator()
125
+
126
+ with gr.Blocks() as demo:
127
+ inp = gr.Textbox(label="Input Text")
128
+ file_u = gr.File(label="Upload PDF/TXT")
129
+ lang_dd = gr.Dropdown(langs, value="Auto Detect", label="Language")
130
+ sp1 = gr.Dropdown(voices, value=voices[0], label="Speaker 1")
131
+ sp2 = gr.Dropdown(voices, value=voices[1], label="Speaker 2")
132
+ out = gr.Audio(label="Podcast", type="filepath")
133
+ btn = gr.Button("Generate")
134
+ btn.click(
135
+ lambda t,f,l,a,b: asyncio.run(gen.generate_podcast(t,l,a,b,f)),
136
+ inputs=[inp, file_u, lang_dd, sp1, sp2],
137
+ outputs=[out]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
  demo.launch()
140
 
141
+ if __name__ == '__main__':
142
+ run_app()