Omartificial-Intelligence-Space commited on
Commit
41b7038
ยท
verified ยท
1 Parent(s): 5840a14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -99
app.py CHANGED
@@ -1,107 +1,136 @@
1
- import os
2
- import re
3
- import json
4
  import gradio as gr
5
- from PIL import Image
6
-
7
  import torch
8
- from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
9
- import spaces # ๐Ÿ‘ˆ Hugging Face ZeroGPU
 
10
 
11
- MODEL_NAME = os.environ.get("MODEL_NAME", "NAMAA-Space/Qari-OCR-0.1-VL-2B-Instruct")
12
- MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "1024"))
 
 
13
 
14
- # ---- Device selection ----
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
- dtype = torch.float16 if device == "cuda" else torch.float32
17
-
18
- print(f"Device being used: {device}")
19
-
20
- # ---- Load model & processor ----
21
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
22
-
23
- model = Qwen2VLForConditionalGeneration.from_pretrained(
24
- MODEL_NAME,
25
- torch_dtype=dtype,
26
- device_map="auto" if device == "cuda" else {"": "cpu"},
27
- )
28
-
29
- print("Model loaded successfully!")
30
-
31
- def _mk_messages(image_path: str, prompt_info: str):
32
- return [
33
- {
34
- "role": "user",
35
- "content": [
36
- {"type": "image", "image": f"file://{image_path}"},
37
- {"type": "text", "text": f"""
38
- You are an advanced invoice OCR system...
39
- (extractions instructions same as notebook)
40
- Extra hints from user: {prompt_info}
41
- """.strip()},
42
- ],
43
- }
44
- ]
45
-
46
- def _extract_json(text: str):
47
- text = text.strip()
48
- if text.startswith("{") and text.endswith("}"):
49
- try:
50
- return json.loads(text)
51
- except Exception:
52
- pass
53
- m = re.search(r"\{[\s\S]*\}", text)
54
- if m:
55
- block = m.group(0)
56
- try:
57
- return json.loads(block)
58
- except Exception:
59
- pass
60
- return {"other_text": text}
61
-
62
- @spaces.GPU(duration=120) # ๐Ÿ‘ˆ Request ZeroGPU for 2 minutes
63
- def infer(image: Image.Image, prompt_info: str):
64
- if image is None:
65
- return "Please upload an image.", {}
66
-
67
- tmp_path = "input_image.png"
68
- image.save(tmp_path)
69
-
70
- messages = _mk_messages(tmp_path, prompt_info)
71
- chat_text = processor.apply_chat_template(messages, add_generation_prompt=True)
72
-
73
- inputs = processor(
74
- text=[chat_text],
75
- images=[Image.open(tmp_path)],
76
- return_tensors="pt",
77
- )
78
- inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
79
-
80
- with torch.no_grad():
81
- generated_ids = model.generate(
82
- **inputs,
83
- max_new_tokens=MAX_NEW_TOKENS,
84
- do_sample=False,
85
- )
86
-
87
- gen_only = generated_ids[:, inputs["input_ids"].shape[1]:]
88
- text_out = processor.batch_decode(gen_only, skip_special_tokens=True)[0].strip()
89
- parsed = _extract_json(text_out)
90
-
91
- return text_out, parsed
92
-
93
- with gr.Blocks(title="Qari OCR (ZeroGPU)") as demo:
94
- gr.Markdown("# Qari OCR ยท ZeroGPU\nUpload an invoice image and (optionally) add extraction hints.")
95
  with gr.Row():
96
- with gr.Column():
97
- img_in = gr.Image(type="pil", label="Invoice Image")
98
- prompt_box = gr.Textbox(label="Extra hints (optional)")
99
- run_btn = gr.Button("Run OCR")
100
- with gr.Column():
101
- txt_out = gr.Textbox(label="Raw Model Output", lines=10)
102
- json_out = gr.JSON(label="Parsed JSON")
103
-
104
- run_btn.click(infer, inputs=[img_in, prompt_box], outputs=[txt_out, json_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
106
  if __name__ == "__main__":
107
- demo.launch()
 
 
 
 
1
  import gradio as gr
 
 
2
  import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import spaces
6
 
7
+ # Load Hugging Face token from the environment variable
8
+ HF_TOKEN = os.getenv("HF_TOKEN")
9
+ if HF_TOKEN is None:
10
+ raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the script.")
11
 
12
+ # Check for GPU support and configure appropriately
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ zero = torch.Tensor([0]).to(device)
15
+ print(f"Device being used: {zero.device}")
16
+
17
+ # Model configurations
18
+ MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT"
19
+ SYRIAN_TO_MSA_MODEL = "Omartificial-Intelligence-Space/SHAMI-MT-2MSA"
20
+
21
+ # Load models and tokenizers
22
+ print("Loading MSA to Syrian model...")
23
+ msa_to_syrian_tokenizer = AutoTokenizer.from_pretrained(MSA_TO_SYRIAN_MODEL)
24
+ msa_to_syrian_model = AutoModelForSeq2SeqLM.from_pretrained(MSA_TO_SYRIAN_MODEL).to(device)
25
+
26
+ print("Loading Syrian to MSA model...")
27
+ syrian_to_msa_tokenizer = AutoTokenizer.from_pretrained(SYRIAN_TO_MSA_MODEL)
28
+ syrian_to_msa_model = AutoModelForSeq2SeqLM.from_pretrained(SYRIAN_TO_MSA_MODEL).to(device)
29
+
30
+ print("Models loaded successfully!")
31
+
32
+ @spaces.GPU(duration=120)
33
+ def translate_msa_to_syrian(text):
34
+ """Translate from Modern Standard Arabic to Syrian dialect"""
35
+ if not text.strip():
36
+ return ""
37
+
38
+ try:
39
+ input_ids = msa_to_syrian_tokenizer(text, return_tensors="pt").input_ids.to(device)
40
+ outputs = msa_to_syrian_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True)
41
+ translated_text = msa_to_syrian_tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+ return translated_text
43
+ except Exception as e:
44
+ return f"Translation error: {str(e)}"
45
+
46
+ @spaces.GPU(duration=120)
47
+ def translate_syrian_to_msa(text):
48
+ """Translate from Syrian dialect to Modern Standard Arabic"""
49
+ if not text.strip():
50
+ return ""
51
+
52
+ try:
53
+ input_ids = syrian_to_msa_tokenizer(text, return_tensors="pt").input_ids.to(device)
54
+ outputs = syrian_to_msa_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True)
55
+ translated_text = syrian_to_msa_tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ return translated_text
57
+ except Exception as e:
58
+ return f"Translation error: {str(e)}"
59
+
60
+ def bidirectional_translate(text, direction):
61
+ """Handle bidirectional translation based on user selection"""
62
+ if direction == "MSA โ†’ Syrian":
63
+ return translate_msa_to_syrian(text)
64
+ elif direction == "Syrian โ†’ MSA":
65
+ return translate_syrian_to_msa(text)
66
+ else:
67
+ return "Please select a translation direction"
68
+
69
+ # Create Gradio interface
70
+ with gr.Blocks(title="SHAMI-MT: Bidirectional Syria Arabic Dialect MT Framework") as demo:
71
+
72
+ gr.HTML("""
73
+ <div style="text-align: center; margin-bottom: 2rem;">
74
+ <h1>๐ŸŒ SHAMI-MT: Bidirectional Arabic Translation</h1>
75
+ <p>Translate between Modern Standard Arabic (MSA) and Syrian Dialect</p>
76
+ <p><strong>Built on AraT5v2-base-1024 architecture</strong></p>
77
+ </div>
78
+ """)
79
+
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  with gr.Row():
81
+ with gr.Column(scale=1):
82
+ gr.HTML("""
83
+ <div style="background: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0;">
84
+ <h3>๐Ÿ“š Model Information</h3>
85
+ <ul>
86
+ <li><strong>Model Type:</strong> Sequence-to-Sequence Translation</li>
87
+ <li><strong>Base Model:</strong> UBC-NLP/AraT5v2-base-1024</li>
88
+ <li><strong>Languages:</strong> Arabic (MSA โ†” Syrian Dialect)</li>
89
+ <li><strong>Device:</strong> GPU/CPU Auto-detection</li>
90
+ </ul>
91
+ </div>
92
+ """)
93
+
94
+ with gr.Column(scale=2):
95
+ direction = gr.Dropdown(
96
+ choices=["MSA โ†’ Syrian", "Syrian โ†’ MSA"],
97
+ value="MSA โ†’ Syrian",
98
+ label="Translation Direction"
99
+ )
100
+
101
+ input_text = gr.Textbox(
102
+ label="Input Text",
103
+ placeholder="Enter Arabic text here...",
104
+ lines=5
105
+ )
106
+
107
+ translate_btn = gr.Button("๐Ÿš€ Translate", variant="primary")
108
+
109
+ output_text = gr.Textbox(
110
+ label="Translation",
111
+ lines=5
112
+ )
113
+
114
+ # Connect the interface
115
+ translate_btn.click(
116
+ fn=bidirectional_translate,
117
+ inputs=[input_text, direction],
118
+ outputs=output_text
119
+ )
120
+
121
+ # Add example inputs
122
+ gr.Examples(
123
+ examples=[
124
+ ["ุฃู†ุง ู„ุง ุฃุนุฑู ุฅุฐุง ูƒุงู† ุณูŠุชู…ูƒู† ู…ู† ุงู„ุญุถูˆุฑ ุงู„ูŠูˆู… ุฃู… ู„ุง.", "MSA โ†’ Syrian"],
125
+ ["ูƒูŠู ุญุงู„ูƒุŸ", "MSA โ†’ Syrian"],
126
+ ["ู…ุง ุจุนุฑู ุฅุฐุง ุฑุญ ูŠู‚ุฏุฑ ูŠุฌูŠ ุงู„ูŠูˆู… ูˆู„ุง ู„ุฃ.", "Syrian โ†’ MSA"],
127
+ ["ุดู„ูˆู†ูƒุŸ", "Syrian โ†’ MSA"]
128
+ ],
129
+ inputs=[input_text, direction],
130
+ outputs=output_text,
131
+ fn=bidirectional_translate
132
+ )
133
 
134
+ # Launch the app
135
  if __name__ == "__main__":
136
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)