bishaltwr commited on
Commit
03d0560
·
1 Parent(s): fa923dd
Files changed (1) hide show
  1. app.py +14 -37
app.py CHANGED
@@ -25,48 +25,28 @@ try:
25
  logging.info("Custom M2M100 model loaded successfully")
26
  except Exception as e:
27
  logging.error(f"Error loading custom M2M100 model: {e}")
28
- try:
29
- # Fall back to official model
30
- checkpoint_dir = "facebook/m2m100_418M"
31
- logging.info(f"Attempting to load official M2M100 from {checkpoint_dir}")
32
- tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
33
- model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
34
- logging.info("Official M2M100 model loaded successfully")
35
- m2m_available = True
36
- except Exception as e2:
37
- logging.error(f"Error loading official M2M100 model: {e2}")
38
- m2m_available = False
39
- logging.info("Setting m2m_available to False")
40
 
41
  # Set device after model loading
42
- if m2m_available:
43
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
- logging.info(f"Using device: {device}")
45
- model_m2m.to(device)
46
 
47
  # Initialize ASR model
48
  model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
49
- try:
50
- processor = AutoProcessor.from_pretrained(model_id)
51
- model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)
52
- asr_available = True
53
- except Exception as e:
54
- logging.error(f"Error loading ASR model: {e}")
55
- asr_available = False
56
 
57
  # Initialize X-Transformer model
58
- try:
59
- from inference import translate as xtranslate
60
- xtransformer_available = True
61
- except Exception as e:
62
- logging.error(f"Error loading XTransformer model: {e}")
63
- xtransformer_available = False
64
 
65
  def m2m_translate(text, source_lang, target_lang):
66
  """Translation using M2M100 model"""
67
- if not m2m_available:
68
- return "M2M100 model not available"
69
-
70
  tokenizer.src_lang = source_lang
71
  inputs = tokenizer(text, return_tensors="pt").to(device)
72
  translated_tokens = model_m2m.generate(
@@ -78,9 +58,6 @@ def m2m_translate(text, source_lang, target_lang):
78
 
79
  def transcribe_audio(audio_path, language="npi"):
80
  """Transcribe audio using ASR model"""
81
- if not asr_available:
82
- return "ASR model not available"
83
-
84
  import librosa
85
  audio, sr = librosa.load(audio_path, sr=16000)
86
  processor.tokenizer.set_target_lang(language)
@@ -130,9 +107,9 @@ def translate_text(text, model_choice, source_lang=None, target_lang=None):
130
  target_lang = "ne" if source_lang == "en" else "en"
131
 
132
  # Choose the translation model
133
- if model_choice == "XTransformer" and xtransformer_available:
134
  return xtranslate(text)
135
- elif model_choice == "M2M100" and m2m_available:
136
  return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
137
  else:
138
  return "Selected model is not available"
 
25
  logging.info("Custom M2M100 model loaded successfully")
26
  except Exception as e:
27
  logging.error(f"Error loading custom M2M100 model: {e}")
28
+ # Fall back to official model
29
+ checkpoint_dir = "facebook/m2m100_418M"
30
+ logging.info(f"Attempting to load official M2M100 from {checkpoint_dir}")
31
+ tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
32
+ model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
33
+ logging.info("Official M2M100 model loaded successfully")
 
 
 
 
 
 
34
 
35
  # Set device after model loading
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ logging.info(f"Using device: {device}")
38
+ model_m2m.to(device)
 
39
 
40
  # Initialize ASR model
41
  model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
42
+ processor = AutoProcessor.from_pretrained(model_id)
43
+ model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)
 
 
 
 
 
44
 
45
  # Initialize X-Transformer model
46
+ from inference import translate as xtranslate
 
 
 
 
 
47
 
48
  def m2m_translate(text, source_lang, target_lang):
49
  """Translation using M2M100 model"""
 
 
 
50
  tokenizer.src_lang = source_lang
51
  inputs = tokenizer(text, return_tensors="pt").to(device)
52
  translated_tokens = model_m2m.generate(
 
58
 
59
  def transcribe_audio(audio_path, language="npi"):
60
  """Transcribe audio using ASR model"""
 
 
 
61
  import librosa
62
  audio, sr = librosa.load(audio_path, sr=16000)
63
  processor.tokenizer.set_target_lang(language)
 
107
  target_lang = "ne" if source_lang == "en" else "en"
108
 
109
  # Choose the translation model
110
+ if model_choice == "XTransformer":
111
  return xtranslate(text)
112
+ elif model_choice == "M2M100":
113
  return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
114
  else:
115
  return "Selected model is not available"