hackergeek commited on
Commit
5292e3e
·
verified ·
1 Parent(s): 4ae98c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -1
app.py CHANGED
@@ -1,10 +1,45 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  with gr.Blocks(fill_height=True) as demo:
4
  with gr.Sidebar():
5
  gr.Markdown("# Inference Provider")
6
  gr.Markdown("This Space showcases the google/gemma-2-2b-it model, served by the nebius API. Sign in with your Hugging Face account to use this API.")
7
  button = gr.LoginButton("Sign in")
8
- gr.load("models/google/gemma-2-2b-it", accept_token=button, provider="nebius")
9
 
 
 
 
 
 
 
 
 
 
10
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
+
6
+ # Load a smaller Wav2Vec model and processor for Persian
7
+ model_name = "facebook/wav2vec2-base" # Smaller model
8
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
9
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
10
+
11
+ def transcribe_audio(audio):
12
+ # Load the audio file and resample to 16kHz
13
+ waveform, sample_rate = torchaudio.load(audio)
14
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
15
+ waveform = resampler(waveform)
16
+
17
+ # Preprocess the audio
18
+ input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
19
+
20
+ # Perform inference
21
+ with torch.no_grad():
22
+ logits = model(input_values).logits
23
+
24
+ # Decode the logits to text
25
+ predicted_ids = torch.argmax(logits, dim=-1)
26
+ transcription = processor.decode(predicted_ids[0])
27
+
28
+ return transcription
29
 
30
  with gr.Blocks(fill_height=True) as demo:
31
  with gr.Sidebar():
32
  gr.Markdown("# Inference Provider")
33
  gr.Markdown("This Space showcases the google/gemma-2-2b-it model, served by the nebius API. Sign in with your Hugging Face account to use this API.")
34
  button = gr.LoginButton("Sign in")
 
35
 
36
+ with gr.Tab("Text Inference"):
37
+ gr.load("models/google/gemma-2-2b-it", accept_token=button, provider="nebius")
38
+
39
+ with gr.Tab("Persian ASR"):
40
+ audio_input = gr.Audio(label="Upload Audio", type="filepath")
41
+ text_output = gr.Textbox(label="Transcription")
42
+ transcribe_button = gr.Button("Transcribe")
43
+ transcribe_button.click(transcribe_audio, inputs=audio_input, outputs=text_output)
44
+
45
  demo.launch()