Somalitts commited on
Commit
f548f48
·
verified ·
1 Parent(s): be7c2a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -1,38 +1,44 @@
 
 
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import torchaudio
4
  import torch
5
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6
- import uvicorn
7
  import io
8
 
9
  app = FastAPI()
10
 
11
- # Allow requests from Flutter (localhost or any domain)
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
15
- allow_credentials=True,
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
- # Load model and processor once at startup
21
  processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
22
  model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
23
 
24
- @app.post("/transcribe")
25
- async def transcribe_audio(file: UploadFile = File(...)):
26
- contents = await file.read()
27
- audio_bytes = io.BytesIO(contents)
28
 
29
- waveform, sample_rate = torchaudio.load(audio_bytes)
 
 
 
 
 
30
 
31
  if sample_rate != 16000:
32
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
33
  waveform = resampler(waveform)
34
 
35
  inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
 
36
  with torch.no_grad():
37
  logits = model(**inputs).logits
38
 
 
1
+ import os
2
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache" # Important for Docker
3
+
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import torchaudio
7
  import torch
8
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
 
9
  import io
10
 
11
  app = FastAPI()
12
 
13
+ # Allow all origins (for Flutter)
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
17
  allow_methods=["*"],
18
  allow_headers=["*"],
19
  )
20
 
21
+ # Load model
22
  processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
23
  model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
24
 
25
+ @app.get("/")
26
+ async def root():
27
+ return {"message": "Somali Speech-to-Text API is running."}
 
28
 
29
+ @app.post("/transcribe")
30
+ async def transcribe(file: UploadFile = File(...)):
31
+ audio_bytes = await file.read()
32
+ audio_stream = io.BytesIO(audio_bytes)
33
+
34
+ waveform, sample_rate = torchaudio.load(audio_stream)
35
 
36
  if sample_rate != 16000:
37
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
38
  waveform = resampler(waveform)
39
 
40
  inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
41
+
42
  with torch.no_grad():
43
  logits = model(**inputs).logits
44