Nick088's picture
Update app.py
1b66d6b verified
raw
history blame
16.4 kB
import os
import subprocess
import random
import numpy as np
import gradio as gr
from groq import Groq
client = Groq(api_key=os.environ.get("Groq_Api_Key"))
# llms
MAX_SEED = np.iinfo(np.int32).max
def update_max_tokens(model):
if model in ["llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it", "gemma2-9b-it"]:
return gr.update(maximum=8192)
elif model == "mixtral-8x7b-32768":
return gr.update(maximum=32768)
def create_history_messages(history):
history_messages = [{"role": "user", "content": m[0]} for m in history]
history_messages.extend([{"role": "assistant", "content": m[1]} for m in history])
return history_messages
def generate_response(prompt, history, model, temperature, max_tokens, top_p, seed):
messages = create_history_messages(history)
messages.append({"role": "user", "content": prompt})
print(messages)
if seed == 0:
seed = random.randint(1, MAX_SEED)
stream = client.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
seed=seed,
stop=None,
stream=True,
)
response = ""
for chunk in stream:
delta_content = chunk.choices[0].delta.content
if delta_content is not None:
response += delta_content
yield response
return response
# speech to text
ALLOWED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
MAX_FILE_SIZE_MB = 25
LANGUAGE_CODES = {
"English": "en",
"Chinese": "zh",
"German": "de",
"Spanish": "es",
"Russian": "ru",
"Korean": "ko",
"French": "fr",
"Japanese": "ja",
"Portuguese": "pt",
"Turkish": "tr",
"Polish": "pl",
"Catalan": "ca",
"Dutch": "nl",
"Arabic": "ar",
"Swedish": "sv",
"Italian": "it",
"Indonesian": "id",
"Hindi": "hi",
"Finnish": "fi",
"Vietnamese": "vi",
"Hebrew": "he",
"Ukrainian": "uk",
"Greek": "el",
"Malay": "ms",
"Czech": "cs",
"Romanian": "ro",
"Danish": "da",
"Hungarian": "hu",
"Tamil": "ta",
"Norwegian": "no",
"Thai": "th",
"Urdu": "ur",
"Croatian": "hr",
"Bulgarian": "bg",
"Lithuanian": "lt",
"Latin": "la",
"Māori": "mi",
"Malayalam": "ml",
"Welsh": "cy",
"Slovak": "sk",
"Telugu": "te",
"Persian": "fa",
"Latvian": "lv",
"Bengali": "bn",
"Serbian": "sr",
"Azerbaijani": "az",
"Slovenian": "sl",
"Kannada": "kn",
"Estonian": "et",
"Macedonian": "mk",
"Breton": "br",
"Basque": "eu",
"Icelandic": "is",
"Armenian": "hy",
"Nepali": "ne",
"Mongolian": "mn",
"Bosnian": "bs",
"Kazakh": "kk",
"Albanian": "sq",
"Swahili": "sw",
"Galician": "gl",
"Marathi": "mr",
"Panjabi": "pa",
"Sinhala": "si",
"Khmer": "km",
"Shona": "sn",
"Yoruba": "yo",
"Somali": "so",
"Afrikaans": "af",
"Occitan": "oc",
"Georgian": "ka",
"Belarusian": "be",
"Tajik": "tg",
"Sindhi": "sd",
"Gujarati": "gu",
"Amharic": "am",
"Yiddish": "yi",
"Lao": "lo",
"Uzbek": "uz",
"Faroese": "fo",
"Haitian": "ht",
"Pashto": "ps",
"Turkmen": "tk",
"Norwegian Nynorsk": "nn",
"Maltese": "mt",
"Sanskrit": "sa",
"Luxembourgish": "lb",
"Burmese": "my",
"Tibetan": "bo",
"Tagalog": "tl",
"Malagasy": "mg",
"Assamese": "as",
"Tatar": "tt",
"Hawaiian": "haw",
"Lingala": "ln",
"Hausa": "ha",
"Bashkir": "ba",
"jw": "jw",
"Sundanese": "su",
}
# Checks file extension, size, and downsamples if needed.
def check_file(audio_file_path):
if not audio_file_path:
return None, gr.Error("Please upload an audio file.")
file_size_mb = os.path.getsize(audio_file_path) / (1024 * 1024)
file_extension = audio_file_path.split(".")[-1].lower()
if file_extension not in ALLOWED_FILE_EXTENSIONS:
return (
None,
gr.Error(
f"Invalid file type (.{file_extension}). Allowed types: {', '.join(ALLOWED_FILE_EXTENSIONS)}"
),
)
if file_size_mb > MAX_FILE_SIZE_MB:
gr.Warning(
f"File size too large ({file_size_mb:.2f} MB). Attempting to downsample to 16kHz. Maximum allowed: {MAX_FILE_SIZE_MB} MB"
)
output_file_path = os.path.splitext(audio_file_path)[0] + "_downsampled.wav"
try:
subprocess.run(
[
"ffmpeg",
"-i",
audio_file_path,
"-ar",
"16000",
"-ac",
"1",
"-map",
"0:a:",
output_file_path,
],
check=True,
)
# Check size after downsampling
downsampled_size_mb = os.path.getsize(output_file_path) / (1024 * 1024)
if downsampled_size_mb > MAX_FILE_SIZE_MB:
return (
None,
gr.Error(
f"File size still too large after downsampling ({downsampled_size_mb:.2f} MB). Maximum allowed: {MAX_FILE_SIZE_MB} MB"
),
)
return output_file_path, None
except subprocess.CalledProcessError as e:
return None, gr.Error(f"Error during downsampling: {e}")
return audio_file_path, None
def transcribe_audio(audio_file_path, prompt, language, auto_detect_language):
# Check and process the file first
processed_path, error_message = check_file(audio_file_path)
# If there's an error during file check
if error_message:
return error_message
with open(processed_path, "rb") as file:
transcription = client.audio.transcriptions.create(
file=(os.path.basename(processed_path), file.read()),
model="whisper-large-v3",
prompt=prompt,
response_format="json",
language=None if auto_detect_language else language,
temperature=0.0,
)
return transcription.text
def translate_audio(audio_file_path, prompt):
# Check and process the file first
processed_path, error_message = check_file(audio_file_path)
# If there's an error during file check
if error_message:
return error_message
with open(processed_path, "rb") as file:
translation = client.audio.translations.create(
file=(os.path.basename(processed_path), file.read()),
model="whisper-large-v3",
prompt=prompt,
response_format="json",
temperature=0.0,
)
return translation.text
# subtitles maker
def create_srt_from_json(transcription_json):
"""Converts Whisper JSON transcription to SRT format."""
srt_lines = []
for i, segment in enumerate(transcription_json["segments"]):
start_time = timedelta(seconds=segment["start"])
end_time = timedelta(seconds=segment["end"])
text = segment["text"]
# Format SRT timestamp
start_timestamp = f"{start_time.seconds}:{start_time.microseconds // 1000:03}"
end_timestamp = f"{end_time.seconds}:{end_time.microseconds // 1000:03}"
srt_lines.append(f"{i+1}\n{start_timestamp} --> {end_timestamp}\n{text}\n\n")
return "".join(srt_lines)
def generate_subtitles(audio_file_path, prompt, language, auto_detect_language):
# ... (file check and error handling using check_file function) ...
with open(processed_path, "rb") as file:
transcription_json = client.audio.transcriptions.create(
file=(os.path.basename(processed_path), file.read()),
model="whisper-large-v3",
prompt=prompt,
response_format="json",
language=None if auto_detect_language else language, # Conditional language parameter
temperature=0.0,
)
srt_content = create_srt_from_json(transcription_json)
return srt_content
def add_subtitles_to_video(video_file_path, srt_content):
"""Adds subtitles to a video using ffmpeg."""
output_file_path = os.path.splitext(video_file_path)[0] + "_with_subs.mp4"
try:
subprocess.run(
[
"ffmpeg",
"-i",
video_file_path,
"-i",
"-", # Input for subtitles from stdin
"-map",
"0:v",
"-map",
"1:a?",
"-map",
"1:s?", # Map subtitles
"-c:v",
"copy",
"-c:a",
"copy",
"-c:s",
"mov_text", # Subtitle codec
output_file_path,
],
input=srt_content.encode("utf-8"), # Pass SRT content to ffmpeg
check=True,
)
return output_file_path
except subprocess.CalledProcessError as e:
return f"Error during subtitle addition: {e}"
with gr.Blocks() as demo:
gr.Markdown(
"""
# Groq API UI
Inference by Groq. Hugging Face Space by [Nick088](https://linktr.ee/Nick088)
"""
)
with gr.Tabs():
with gr.TabItem("LLMs"):
with gr.Row():
with gr.Column(scale=1, min_width=250):
model = gr.Dropdown(
choices=[
"llama3-70b-8192",
"llama3-8b-8192",
"mixtral-8x7b-32768",
"gemma-7b-it",
"gemma2-9b-it",
],
value="llama3-70b-8192",
label="Model",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5,
label="Temperature",
info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative.",
)
max_tokens = gr.Slider(
minimum=1,
maximum=8192,
step=1,
value=4096,
label="Max Tokens",
info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b.",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5,
label="Top P",
info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p.",
)
seed = gr.Number(
precision=0, value=42, label="Seed", info="A starting point to initiate generation, use 0 for random"
)
model.change(update_max_tokens, inputs=[model], outputs=max_tokens)
with gr.Column(scale=1, min_width=400):
chatbot = gr.ChatInterface(
fn=generate_response,
chatbot=None,
additional_inputs=[
model,
temperature,
max_tokens,
top_p,
seed,
],
)
model.change(update_max_tokens, inputs=[model], outputs=max_tokens)
with gr.TabItem("Speech To Text"):
with gr.Tabs():
with gr.TabItem("Transcription"):
gr.Markdown("Transcript audio from files to text!")
with gr.Row():
audio_input = gr.File(
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS]
)
with gr.Row():
transcribe_prompt = gr.Textbox(
label="Prompt (Optional)",
info="Specify any context or spelling corrections.",
)
with gr.Column():
language = gr.Dropdown(
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()],
value="en",
label="Language",
)
auto_detect_language = gr.Checkbox(label="Auto Detect Language")
transcribe_button = gr.Button("Transcribe")
transcription_output = gr.Textbox(label="Transcription")
transcribe_button.click(
transcribe_audio,
inputs=[audio_input, transcribe_prompt, language, auto_detect_language],
outputs=transcription_output,
)
with gr.TabItem("Translation"):
gr.Markdown("Transcript audio from files and translate them to English text!")
with gr.Row():
audio_input_translate = gr.File(
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS]
)
with gr.Row():
translate_prompt = gr.Textbox(
label="Prompt (Optional)",
info="Specify any context or spelling corrections.",
)
translate_button = gr.Button("Translate")
translation_output = gr.Textbox(label="Translation")
translate_button.click(
translate_audio,
inputs=[audio_input_translate, translate_prompt],
outputs=translation_output,
)
with gr.TabItem("Subtitle Maker"):
with gr.Row():
audio_input_subtitles = gr.File(
label="Upload Audio/Video",
file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS],
)
transcribe_prompt_subtitles = gr.Textbox(
label="Prompt (Optional)",
info="Specify any context or spelling corrections.",
)
with gr.Row():
language_subtitles = gr.Dropdown(
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()],
value="en",
label="Language",
)
auto_detect_language_subtitles = gr.Checkbox(
label="Auto Detect Language"
)
transcribe_button_subtitles = gr.Button("Generate Subtitles")
srt_output = gr.Textbox(label="SRT Output")
video_output = gr.File(label="Output Video (with Subtitles)")
transcribe_button_subtitles.click(
generate_subtitles,
inputs=[
audio_input_subtitles,
transcribe_prompt_subtitles,
language_subtitles,
auto_detect_language_subtitles,
],
outputs=srt_output,
)
srt_output.change(
add_subtitles_to_video,
inputs=[audio_input_subtitles, srt_output],
outputs=video_output,
)
demo.launch()