jblast94 commited on
Commit
e638507
·
verified ·
1 Parent(s): 67a2f34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -47
app.py CHANGED
@@ -1,30 +1,27 @@
1
  import gradio as gr
2
  import os
3
- import requests # Used for making API calls to your Chatterbox endpoint
4
- from transformers import AutoModel, AutoTokenizer
5
 
6
  # --- Model Loading ---
7
- # This loads the Gemma model you specified. Note that GGUF files are typically
8
- # optimized for CPU inference with libraries like llama-cpp-python, but we'll
9
- # use the `transformers` library as requested. If you encounter errors, you
10
- # may need to switch to a different method for GGUF models.
11
- # The `torch_dtype="auto"` is a good practice to automatically select the best data type.
12
  model_name = "mradermacher/gemma-3n-E2B-GGUF"
 
13
 
14
- # Let's try to load the model and tokenizer
15
  try:
16
- # Use a specific class that supports the model's architecture.
17
- # The `AutoModel` you provided is a general-purpose class.
18
- # We will use `AutoTokenizer` and a simple `AutoModel` here, as requested.
19
- # We add `trust_remote_code=True` in case the model requires it for loading.
20
- model = AutoModel.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- print("Model and tokenizer loaded successfully!")
23
  except Exception as e:
24
- print(f"Error loading model: {e}")
25
- print("Please check the model name or if the model type is compatible with `AutoModel`.")
26
- model = None
27
- tokenizer = None
28
 
29
  # --- Constants & Configuration ---
30
  # To secure your Chatterbox endpoint URL, you should add it to your
@@ -38,9 +35,9 @@ def process_audio_and_generate(audio_file_path):
38
  This function handles the full workflow:
39
  1. Takes the path to a recorded audio file.
40
  2. Sends the audio to your Chatterbox TTS endpoint for transcription.
41
- 3. Passes the transcribed text to the Gemma model.
42
  4. Generates a text response.
43
-
44
  Args:
45
  audio_file_path (str): The file path of the recorded audio.
46
 
@@ -51,46 +48,48 @@ def process_audio_and_generate(audio_file_path):
51
  return "Please provide an audio recording.", "No audio input received."
52
 
53
  # --- Step 1: Speech-to-Text (using your Chatterbox endpoint) ---
54
- # This is where you will make the API call to your Chatterbox TTS endpoint.
55
- # You'll need to read the audio file and send it as a POST request.
56
  try:
57
  with open(audio_file_path, "rb") as audio_file:
58
- # We'll assume the API expects the audio file in the request body.
59
- headers = {"Content-Type": "audio/x-wav"} # Adjust the MIME type as needed
60
- response = requests.post(CHATTERBOX_ENDPOINT, data=audio_file, headers=headers)
 
61
  transcription = response.json().get("transcription", "Transcription failed.")
62
- except Exception as e:
63
  transcription = f"Error calling Chatterbox API: {e}"
 
 
 
 
 
64
  return transcription, "Transcription failed."
65
 
66
- # --- Step 2: Generate Response with Gemma ---
67
- # This is a placeholder for how you would pass the transcription to the model.
68
- # The actual implementation will depend on the model's specific API.
69
- # We'll use a simple text generation approach.
70
- if model and tokenizer:
71
  try:
72
- # Tokenize the input text
73
- inputs = tokenizer(transcription, return_tensors="pt")
74
-
75
- # Generate a response. You may need to adjust generation parameters.
76
- # Using `max_new_tokens` to limit the response length for efficiency.
77
- outputs = model.generate(**inputs, max_new_tokens=100)
78
 
79
- # Decode the generated tokens to a string
80
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
81
 
82
- # The model will likely repeat the input, so we'll clean it up.
83
- clean_response = response_text.replace(transcription, "", 1).strip()
84
 
85
  except Exception as e:
86
- clean_response = f"Error generating response: {e}"
87
- else:
88
- clean_response = "Gemma model not loaded. Please check the logs."
89
 
90
- return transcription, clean_response
91
 
92
  # --- Gradio Interface Setup ---
93
- # This creates the user interface with a microphone input and two text outputs.
94
  iface = gr.Interface(
95
  fn=process_audio_and_generate,
96
  inputs=gr.Audio(sources=["microphone"], type="filepath"),
@@ -105,3 +104,4 @@ iface = gr.Interface(
105
  # Launch the Gradio app
106
  if __name__ == "__main__":
107
  iface.launch()
 
 
1
  import gradio as gr
2
  import os
3
+ import requests
4
+ from llama_cpp import Llama # Import the Llama class from llama-cpp-python
5
 
6
  # --- Model Loading ---
7
+ # The model you selected is in the GGUF format, which is not compatible with
8
+ # the standard Hugging Face AutoModel class. We need to use a dedicated
9
+ # GGUF inference engine, like llama-cpp-python.
10
+
 
11
  model_name = "mradermacher/gemma-3n-E2B-GGUF"
12
+ model_path = gr.mount_model(model_name) # This function will download the GGUF file
13
 
14
+ # Try to initialize the Llama model
15
  try:
16
+ # Initialize the Llama model with the GGUF file path
17
+ # We set `verbose=False` to keep the logs clean.
18
+ llm = Llama(model_path=model_path, n_gpu_layers=1, verbose=False)
19
+ print("Llama model initialized successfully!")
 
 
 
20
  except Exception as e:
21
+ print(f"Error initializing Llama model: {e}")
22
+ llm = None
23
+ print("Please check if the model is compatible with llama-cpp-python.")
24
+
25
 
26
  # --- Constants & Configuration ---
27
  # To secure your Chatterbox endpoint URL, you should add it to your
 
35
  This function handles the full workflow:
36
  1. Takes the path to a recorded audio file.
37
  2. Sends the audio to your Chatterbox TTS endpoint for transcription.
38
+ 3. Passes the transcribed text to the GGUF model.
39
  4. Generates a text response.
40
+
41
  Args:
42
  audio_file_path (str): The file path of the recorded audio.
43
 
 
48
  return "Please provide an audio recording.", "No audio input received."
49
 
50
  # --- Step 1: Speech-to-Text (using your Chatterbox endpoint) ---
51
+ transcription = "Transcription failed." # Default value in case of error
 
52
  try:
53
  with open(audio_file_path, "rb") as audio_file:
54
+ # Assumes the API expects a multipart form data request with the file.
55
+ files = {'file': audio_file}
56
+ response = requests.post(CHATTERBOX_ENDPOINT, files=files)
57
+ response.raise_for_status() # Raise an exception for bad status codes
58
  transcription = response.json().get("transcription", "Transcription failed.")
59
+ except requests.exceptions.RequestException as e:
60
  transcription = f"Error calling Chatterbox API: {e}"
61
+ print(transcription)
62
+ return transcription, "Transcription service is not available."
63
+ except Exception as e:
64
+ transcription = f"Error during transcription: {e}"
65
+ print(transcription)
66
  return transcription, "Transcription failed."
67
 
68
+ # --- Step 2: Generate Response with Gemma (GGUF model) ---
69
+ response_text = "Gemma model is not available." # Default value
70
+ if llm:
 
 
71
  try:
72
+ # We'll use the model's `create_completion` method to generate text.
73
+ # We wrap the transcription in a prompt template that the model expects.
74
+ prompt = f"### User:\n{transcription}\n### Assistant:\n"
 
 
 
75
 
76
+ # Generate the response from the model
77
+ completion = llm.create_completion(
78
+ prompt,
79
+ max_tokens=150, # Limits the length of the response
80
+ stop=["### User:"], # Stops generation when it sees the next user turn
81
+ echo=False, # Don't repeat the input prompt in the output
82
+ )
83
 
84
+ response_text = completion['choices'][0]['text']
 
85
 
86
  except Exception as e:
87
+ response_text = f"Error generating response from model: {e}"
88
+ print(response_text)
 
89
 
90
+ return transcription, response_text.strip()
91
 
92
  # --- Gradio Interface Setup ---
 
93
  iface = gr.Interface(
94
  fn=process_audio_and_generate,
95
  inputs=gr.Audio(sources=["microphone"], type="filepath"),
 
104
  # Launch the Gradio app
105
  if __name__ == "__main__":
106
  iface.launch()
107
+