Nick088's picture
Update app.py
df9a14a verified
raw
history blame
16.7 kB
import os
import subprocess
import random
import numpy as np
import json
from datetime import timedelta
import gradio as gr
from groq import Groq
client = Groq(api_key=os.environ.get("Groq_Api_Key"))
# llms
MAX_SEED = np.iinfo(np.int32).max
def update_max_tokens(model):
if model in ["llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it", "gemma2-9b-it"]:
return gr.update(maximum=8192)
elif model == "mixtral-8x7b-32768":
return gr.update(maximum=32768)
def create_history_messages(history):
history_messages = [{"role": "user", "content": m[0]} for m in history]
history_messages.extend([{"role": "assistant", "content": m[1]} for m in history])
return history_messages
def generate_response(prompt, history, model, temperature, max_tokens, top_p, seed):
messages = create_history_messages(history)
messages.append({"role": "user", "content": prompt})
print(messages)
if seed == 0:
seed = random.randint(1, MAX_SEED)
stream = client.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
seed=seed,
stop=None,
stream=True,
)
response = ""
for chunk in stream:
delta_content = chunk.choices[0].delta.content
if delta_content is not None:
response += delta_content
yield response
return response
# speech to text
ALLOWED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
MAX_FILE_SIZE_MB = 25
LANGUAGE_CODES = {
"English": "en",
"Chinese": "zh",
"German": "de",
"Spanish": "es",
"Russian": "ru",
"Korean": "ko",
"French": "fr",
"Japanese": "ja",
"Portuguese": "pt",
"Turkish": "tr",
"Polish": "pl",
"Catalan": "ca",
"Dutch": "nl",
"Arabic": "ar",
"Swedish": "sv",
"Italian": "it",
"Indonesian": "id",
"Hindi": "hi",
"Finnish": "fi",
"Vietnamese": "vi",
"Hebrew": "he",
"Ukrainian": "uk",
"Greek": "el",
"Malay": "ms",
"Czech": "cs",
"Romanian": "ro",
"Danish": "da",
"Hungarian": "hu",
"Tamil": "ta",
"Norwegian": "no",
"Thai": "th",
"Urdu": "ur",
"Croatian": "hr",
"Bulgarian": "bg",
"Lithuanian": "lt",
"Latin": "la",
"Māori": "mi",
"Malayalam": "ml",
"Welsh": "cy",
"Slovak": "sk",
"Telugu": "te",
"Persian": "fa",
"Latvian": "lv",
"Bengali": "bn",
"Serbian": "sr",
"Azerbaijani": "az",
"Slovenian": "sl",
"Kannada": "kn",
"Estonian": "et",
"Macedonian": "mk",
"Breton": "br",
"Basque": "eu",
"Icelandic": "is",
"Armenian": "hy",
"Nepali": "ne",
"Mongolian": "mn",
"Bosnian": "bs",
"Kazakh": "kk",
"Albanian": "sq",
"Swahili": "sw",
"Galician": "gl",
"Marathi": "mr",
"Panjabi": "pa",
"Sinhala": "si",
"Khmer": "km",
"Shona": "sn",
"Yoruba": "yo",
"Somali": "so",
"Afrikaans": "af",
"Occitan": "oc",
"Georgian": "ka",
"Belarusian": "be",
"Tajik": "tg",
"Sindhi": "sd",
"Gujarati": "gu",
"Amharic": "am",
"Yiddish": "yi",
"Lao": "lo",
"Uzbek": "uz",
"Faroese": "fo",
"Haitian": "ht",
"Pashto": "ps",
"Turkmen": "tk",
"Norwegian Nynorsk": "nn",
"Maltese": "mt",
"Sanskrit": "sa",
"Luxembourgish": "lb",
"Burmese": "my",
"Tibetan": "bo",
"Tagalog": "tl",
"Malagasy": "mg",
"Assamese": "as",
"Tatar": "tt",
"Hawaiian": "haw",
"Lingala": "ln",
"Hausa": "ha",
"Bashkir": "ba",
"jw": "jw",
"Sundanese": "su",
}
# Checks file extension, size, and downsamples if needed.
def check_file(audio_file_path):
if not audio_file_path:
return None, gr.Error("Please upload an audio file.")
file_size_mb = os.path.getsize(audio_file_path) / (1024 * 1024)
file_extension = audio_file_path.split(".")[-1].lower()
if file_extension not in ALLOWED_FILE_EXTENSIONS:
return (
None,
gr.Error(
f"Invalid file type (.{file_extension}). Allowed types: {', '.join(ALLOWED_FILE_EXTENSIONS)}"
),
)
if file_size_mb > MAX_FILE_SIZE_MB:
gr.Warning(
f"File size too large ({file_size_mb:.2f} MB). Attempting to downsample to 16kHz. Maximum allowed: {MAX_FILE_SIZE_MB} MB"
)
output_file_path = os.path.splitext(audio_file_path)[0] + "_downsampled.wav"
try:
subprocess.run(
[
"ffmpeg",
"-i",
audio_file_path,
"-ar",
"16000",
"-ac",
"1",
"-map",
"0:a:",
output_file_path,
],
check=True,
)
# Check size after downsampling
downsampled_size_mb = os.path.getsize(output_file_path) / (1024 * 1024)
if downsampled_size_mb > MAX_FILE_SIZE_MB:
return (
None,
gr.Error(
f"File size still too large after downsampling ({downsampled_size_mb:.2f} MB). Maximum allowed: {MAX_FILE_SIZE_MB} MB"
),
)
return output_file_path, None
except subprocess.CalledProcessError as e:
return None, gr.Error(f"Error during downsampling: {e}")
return audio_file_path, None
def transcribe_audio(audio_file_path, prompt, language, auto_detect_language):
# Check and process the file first
processed_path, error_message = check_file(audio_file_path)
# If there's an error during file check
if error_message:
return error_message
with open(processed_path, "rb") as file:
transcription = client.audio.transcriptions.create(
file=(os.path.basename(processed_path), file.read()),
model="whisper-large-v3",
prompt=prompt,
response_format="json",
language=None if auto_detect_language else language,
temperature=0.0,
)
return transcription.text
def translate_audio(audio_file_path, prompt):
# Check and process the file first
processed_path, error_message = check_file(audio_file_path)
# If there's an error during file check
if error_message:
return error_message
with open(processed_path, "rb") as file:
translation = client.audio.translations.create(
file=(os.path.basename(processed_path), file.read()),
model="whisper-large-v3",
prompt=prompt,
response_format="json",
temperature=0.0,
)
return translation.text
# subtitles maker
def create_srt_from_json(transcription_json):
"""Converts Whisper JSON transcription to SRT format."""
srt_lines = []
for i, segment in enumerate(transcription_json["segments"]):
start_time = timedelta(seconds=segment["start"])
end_time = timedelta(seconds=segment["end"])
text = segment["text"]
# Format SRT timestamp
start_timestamp = f"{start_time.seconds}:{start_time.microseconds // 1000:03}"
end_timestamp = f"{end_time.seconds}:{end_time.microseconds // 1000:03}"
srt_lines.append(f"{i+1}\n{start_timestamp} --> {end_timestamp}\n{text}\n\n")
return "".join(srt_lines)
def generate_subtitles(audio_file_path, prompt, language, auto_detect_language):
# Check and process the file first
processed_path, error_message = check_file(audio_file_path)
# If there's an error during file check
if error_message:
return error_message
with open(processed_path, "rb") as file:
transcription_json = client.audio.transcriptions.create(
file=(os.path.basename(processed_path), file.read()),
model="whisper-large-v3",
prompt=prompt,
response_format="json",
language=None if auto_detect_language else language, # Conditional language parameter
temperature=0.0,
)
transcription_json = json.loads(transcription_json.to_json())
print(transcription_json)
srt_content = create_srt_from_json(transcription_json)
return srt_content
def add_subtitles_to_video(video_file_path, srt_content):
"""Adds subtitles to a video using ffmpeg."""
output_file_path = os.path.splitext(video_file_path)[0] + "_with_subs.mp4"
try:
subprocess.run(
[
"ffmpeg",
"-i",
video_file_path,
"-i",
"-", # Input for subtitles from stdin
"-map",
"0:v",
"-map",
"1:a?",
"-map",
"1:s?", # Map subtitles
"-c:v",
"copy",
"-c:a",
"copy",
"-c:s",
"mov_text", # Subtitle codec
output_file_path,
],
input=srt_content.encode("utf-8"), # Pass SRT content to ffmpeg
check=True,
)
return output_file_path
except subprocess.CalledProcessError as e:
return f"Error during subtitle addition: {e}"
with gr.Blocks() as demo:
gr.Markdown(
"""
# Groq API UI
Inference by Groq. Hugging Face Space by [Nick088](https://linktr.ee/Nick088)
"""
)
with gr.Tabs():
with gr.TabItem("LLMs"):
with gr.Row():
with gr.Column(scale=1, min_width=250):
model = gr.Dropdown(
choices=[
"llama3-70b-8192",
"llama3-8b-8192",
"mixtral-8x7b-32768",
"gemma-7b-it",
"gemma2-9b-it",
],
value="llama3-70b-8192",
label="Model",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5,
label="Temperature",
info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative.",
)
max_tokens = gr.Slider(
minimum=1,
maximum=8192,
step=1,
value=4096,
label="Max Tokens",
info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b.",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5,
label="Top P",
info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p.",
)
seed = gr.Number(
precision=0, value=42, label="Seed", info="A starting point to initiate generation, use 0 for random"
)
model.change(update_max_tokens, inputs=[model], outputs=max_tokens)
with gr.Column(scale=1, min_width=400):
chatbot = gr.ChatInterface(
fn=generate_response,
chatbot=None,
additional_inputs=[
model,
temperature,
max_tokens,
top_p,
seed,
],
)
model.change(update_max_tokens, inputs=[model], outputs=max_tokens)
with gr.TabItem("Speech To Text"):
with gr.Tabs():
with gr.TabItem("Transcription"):
gr.Markdown("Transcript audio from files to text!")
with gr.Row():
audio_input = gr.File(
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS]
)
with gr.Row():
transcribe_prompt = gr.Textbox(
label="Prompt (Optional)",
info="Specify any context or spelling corrections.",
)
with gr.Column():
language = gr.Dropdown(
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()],
value="en",
label="Language",
)
auto_detect_language = gr.Checkbox(label="Auto Detect Language")
transcribe_button = gr.Button("Transcribe")
transcription_output = gr.Textbox(label="Transcription")
transcribe_button.click(
transcribe_audio,
inputs=[audio_input, transcribe_prompt, language, auto_detect_language],
outputs=transcription_output,
)
with gr.TabItem("Translation"):
gr.Markdown("Transcript audio from files and translate them to English text!")
with gr.Row():
audio_input_translate = gr.File(
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS]
)
with gr.Row():
translate_prompt = gr.Textbox(
label="Prompt (Optional)",
info="Specify any context or spelling corrections.",
)
translate_button = gr.Button("Translate")
translation_output = gr.Textbox(label="Translation")
translate_button.click(
translate_audio,
inputs=[audio_input_translate, translate_prompt],
outputs=translation_output,
)
with gr.TabItem("Subtitle Maker"):
with gr.Row():
audio_input_subtitles = gr.File(
label="Upload Audio/Video",
file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS],
)
transcribe_prompt_subtitles = gr.Textbox(
label="Prompt (Optional)",
info="Specify any context or spelling corrections.",
)
with gr.Row():
language_subtitles = gr.Dropdown(
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()],
value="en",
label="Language",
)
auto_detect_language_subtitles = gr.Checkbox(
label="Auto Detect Language"
)
transcribe_button_subtitles = gr.Button("Generate Subtitles")
srt_output = gr.Textbox(label="SRT Output")
video_output = gr.File(label="Output Video (with Subtitles)")
transcribe_button_subtitles.click(
generate_subtitles,
inputs=[
audio_input_subtitles,
transcribe_prompt_subtitles,
language_subtitles,
auto_detect_language_subtitles,
],
outputs=srt_output,
)
srt_output.change(
add_subtitles_to_video,
inputs=[audio_input_subtitles, srt_output],
outputs=video_output,
)
demo.launch()