Spaces:
Running
Running
Commit
·
fdc104d
1
Parent(s):
90d2e58
1) added model id sanitization i.e. removing invaid character as per hugging face \
Browse files2) Added proper validation using huggingface_hub model_info to check if requested model is available on hugging face
- ASR_Server.py +15 -3
- requirements.txt +2 -1
- utils/model_validity.py +16 -0
ASR_Server.py
CHANGED
@@ -3,11 +3,13 @@ from flask_cors import CORS
|
|
3 |
from datasets import load_dataset, Audio
|
4 |
import pandas as pd
|
5 |
import os
|
|
|
6 |
import threading
|
7 |
from dotenv import load_dotenv
|
8 |
from utils.load_csv import upload_csv, download_csv
|
9 |
from utils.generate_results import generateResults
|
10 |
from utils.generate_box_plot import box_plot_data
|
|
|
11 |
|
12 |
# Set the cache directory for Hugging Face datasets
|
13 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
@@ -132,8 +134,14 @@ def generateTranscript(ASR_model):
|
|
132 |
df["transcript"] = transcripts
|
133 |
df["rtfx"] = rtfx_score
|
134 |
|
135 |
-
job_status
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
# df.to_csv(csv_result, index=False)
|
138 |
upload_csv(df, csv_transcript)
|
139 |
print(f"\n📄 Transcripts saved to: {csv_transcript}")
|
@@ -187,10 +195,14 @@ def get_status():
|
|
187 |
@app.route('/api', methods=['GET'])
|
188 |
def api():
|
189 |
model = request.args.get('ASR_model', default="", type=str)
|
|
|
|
|
190 |
csv_transcript = f'test_with_{model.replace("/","_")}.csv'
|
191 |
csv_result = f'test_with_{model.replace("/","_")}_WER.csv'
|
192 |
if not model:
|
193 |
-
return jsonify({'error': 'ASR_model parameter is required'})
|
|
|
|
|
194 |
elif (download_csv(csv_transcript) is not None):
|
195 |
# Load the CSV file from the Hugging Face Hub
|
196 |
Results = generateResults(model)
|
|
|
3 |
from datasets import load_dataset, Audio
|
4 |
import pandas as pd
|
5 |
import os
|
6 |
+
import re
|
7 |
import threading
|
8 |
from dotenv import load_dotenv
|
9 |
from utils.load_csv import upload_csv, download_csv
|
10 |
from utils.generate_results import generateResults
|
11 |
from utils.generate_box_plot import box_plot_data
|
12 |
+
from utils.model_validity import is_valid_asr_model
|
13 |
|
14 |
# Set the cache directory for Hugging Face datasets
|
15 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
|
|
134 |
df["transcript"] = transcripts
|
135 |
df["rtfx"] = rtfx_score
|
136 |
|
137 |
+
job_status.update({
|
138 |
+
"running": False,
|
139 |
+
"model": None,
|
140 |
+
"completed": None,
|
141 |
+
"%_completed" : None,
|
142 |
+
"message": "No Transcription in progress",
|
143 |
+
"total": None
|
144 |
+
})
|
145 |
# df.to_csv(csv_result, index=False)
|
146 |
upload_csv(df, csv_transcript)
|
147 |
print(f"\n📄 Transcripts saved to: {csv_transcript}")
|
|
|
195 |
@app.route('/api', methods=['GET'])
|
196 |
def api():
|
197 |
model = request.args.get('ASR_model', default="", type=str)
|
198 |
+
# model = re.sub(r"\s+", "", model)
|
199 |
+
model = re.sub(r"[^a-zA-Z0-9/_\-.]", "", model) # sanitize the model ID
|
200 |
csv_transcript = f'test_with_{model.replace("/","_")}.csv'
|
201 |
csv_result = f'test_with_{model.replace("/","_")}_WER.csv'
|
202 |
if not model:
|
203 |
+
return jsonify({'error': 'ASR_model parameter is required'})
|
204 |
+
elif not is_valid_asr_model(model):
|
205 |
+
return jsonify({'message': 'Invalid ASR model ID, please check if your model is available on Hugging Face'}), 400 # Return 400 if model is invalid
|
206 |
elif (download_csv(csv_transcript) is not None):
|
207 |
# Load the CSV file from the Hugging Face Hub
|
208 |
Results = generateResults(model)
|
requirements.txt
CHANGED
@@ -15,4 +15,5 @@ pymongo
|
|
15 |
flask-cors
|
16 |
pandas
|
17 |
tqdm
|
18 |
-
dotenv
|
|
|
|
15 |
flask-cors
|
16 |
pandas
|
17 |
tqdm
|
18 |
+
dotenv
|
19 |
+
huggingface_hub
|
utils/model_validity.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import model_info
|
2 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
3 |
+
import re
|
4 |
+
|
5 |
+
def is_valid_asr_model(model_id: str) -> bool:
|
6 |
+
try:
|
7 |
+
model_id = re.sub(r"[^a-zA-Z0-9/_\-.]", "", model_id) # Sanitize the model ID
|
8 |
+
info = model_info(model_id)
|
9 |
+
# Optionally check if it's an ASR model (i.e., "automatic-speech-recognition" in the tags)
|
10 |
+
return "automatic-speech-recognition" in info.tags
|
11 |
+
except RepositoryNotFoundError:
|
12 |
+
return False
|
13 |
+
|
14 |
+
# Test examples
|
15 |
+
# print(is_valid_asr_model("facebook/hubert-large-ls960-ft")) # True
|
16 |
+
# print(is_valid_asr_model("facebook/hubert-largeXX-ls960-ft")) # False
|