satyamr196 commited on
Commit
fdc104d
·
1 Parent(s): 90d2e58

1) added model id sanitization i.e. removing invaid character as per hugging face \

Browse files

2) Added proper validation using huggingface_hub model_info to check if requested model is available on hugging face

Files changed (3) hide show
  1. ASR_Server.py +15 -3
  2. requirements.txt +2 -1
  3. utils/model_validity.py +16 -0
ASR_Server.py CHANGED
@@ -3,11 +3,13 @@ from flask_cors import CORS
3
  from datasets import load_dataset, Audio
4
  import pandas as pd
5
  import os
 
6
  import threading
7
  from dotenv import load_dotenv
8
  from utils.load_csv import upload_csv, download_csv
9
  from utils.generate_results import generateResults
10
  from utils.generate_box_plot import box_plot_data
 
11
 
12
  # Set the cache directory for Hugging Face datasets
13
  os.environ["HF_HOME"] = "/tmp/huggingface"
@@ -132,8 +134,14 @@ def generateTranscript(ASR_model):
132
  df["transcript"] = transcripts
133
  df["rtfx"] = rtfx_score
134
 
135
- job_status["running"] = False
136
- job_status["message"] = "Transcription completed."
 
 
 
 
 
 
137
  # df.to_csv(csv_result, index=False)
138
  upload_csv(df, csv_transcript)
139
  print(f"\n📄 Transcripts saved to: {csv_transcript}")
@@ -187,10 +195,14 @@ def get_status():
187
  @app.route('/api', methods=['GET'])
188
  def api():
189
  model = request.args.get('ASR_model', default="", type=str)
 
 
190
  csv_transcript = f'test_with_{model.replace("/","_")}.csv'
191
  csv_result = f'test_with_{model.replace("/","_")}_WER.csv'
192
  if not model:
193
- return jsonify({'error': 'ASR_model parameter is required'}), 400 # Return 400 if model is missing
 
 
194
  elif (download_csv(csv_transcript) is not None):
195
  # Load the CSV file from the Hugging Face Hub
196
  Results = generateResults(model)
 
3
  from datasets import load_dataset, Audio
4
  import pandas as pd
5
  import os
6
+ import re
7
  import threading
8
  from dotenv import load_dotenv
9
  from utils.load_csv import upload_csv, download_csv
10
  from utils.generate_results import generateResults
11
  from utils.generate_box_plot import box_plot_data
12
+ from utils.model_validity import is_valid_asr_model
13
 
14
  # Set the cache directory for Hugging Face datasets
15
  os.environ["HF_HOME"] = "/tmp/huggingface"
 
134
  df["transcript"] = transcripts
135
  df["rtfx"] = rtfx_score
136
 
137
+ job_status.update({
138
+ "running": False,
139
+ "model": None,
140
+ "completed": None,
141
+ "%_completed" : None,
142
+ "message": "No Transcription in progress",
143
+ "total": None
144
+ })
145
  # df.to_csv(csv_result, index=False)
146
  upload_csv(df, csv_transcript)
147
  print(f"\n📄 Transcripts saved to: {csv_transcript}")
 
195
  @app.route('/api', methods=['GET'])
196
  def api():
197
  model = request.args.get('ASR_model', default="", type=str)
198
+ # model = re.sub(r"\s+", "", model)
199
+ model = re.sub(r"[^a-zA-Z0-9/_\-.]", "", model) # sanitize the model ID
200
  csv_transcript = f'test_with_{model.replace("/","_")}.csv'
201
  csv_result = f'test_with_{model.replace("/","_")}_WER.csv'
202
  if not model:
203
+ return jsonify({'error': 'ASR_model parameter is required'})
204
+ elif not is_valid_asr_model(model):
205
+ return jsonify({'message': 'Invalid ASR model ID, please check if your model is available on Hugging Face'}), 400 # Return 400 if model is invalid
206
  elif (download_csv(csv_transcript) is not None):
207
  # Load the CSV file from the Hugging Face Hub
208
  Results = generateResults(model)
requirements.txt CHANGED
@@ -15,4 +15,5 @@ pymongo
15
  flask-cors
16
  pandas
17
  tqdm
18
- dotenv
 
 
15
  flask-cors
16
  pandas
17
  tqdm
18
+ dotenv
19
+ huggingface_hub
utils/model_validity.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import model_info
2
+ from huggingface_hub.utils import RepositoryNotFoundError
3
+ import re
4
+
5
+ def is_valid_asr_model(model_id: str) -> bool:
6
+ try:
7
+ model_id = re.sub(r"[^a-zA-Z0-9/_\-.]", "", model_id) # Sanitize the model ID
8
+ info = model_info(model_id)
9
+ # Optionally check if it's an ASR model (i.e., "automatic-speech-recognition" in the tags)
10
+ return "automatic-speech-recognition" in info.tags
11
+ except RepositoryNotFoundError:
12
+ return False
13
+
14
+ # Test examples
15
+ # print(is_valid_asr_model("facebook/hubert-large-ls960-ft")) # True
16
+ # print(is_valid_asr_model("facebook/hubert-largeXX-ls960-ft")) # False