yakine commited on
Commit
2efd55d
·
verified ·
1 Parent(s): 3ac2eeb

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +18 -0
  2. app.py +143 -0
  3. requirements.txt +13 -0
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 as the base image
2
+ FROM python:3.9
3
+
4
+ # Set the working directory inside the container
5
+ WORKDIR /app
6
+
7
+ # Copy the requirements file and install dependencies
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the entire project into the container
12
+ COPY . .
13
+
14
+ # Expose port 7860 (same as FastAPI runs on)
15
+ EXPOSE 7860
16
+
17
+ # Command to run FastAPI on startup
18
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
5
+ import logging
6
+ import re
7
+
8
+ app = FastAPI()
9
+
10
+ # Enable CORS if needed
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"], # In production, restrict this to your frontend URL
15
+ allow_credentials=True,
16
+ allow_methods=["POST"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.INFO)
22
+
23
+ ####################################
24
+ # Text Generation Endpoint
25
+ ####################################
26
+
27
+ TEXT_MODEL_NAME = "aubmindlab/aragpt2-medium"
28
+ text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
29
+ text_model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_NAME)
30
+
31
+ general_prompt_template = """
32
+ أنت الآن نموذج لغة مخصص لتوليد نصوص عربية تعليمية بناءً على المادة والمستوى التعليمي. سيتم إعطاؤك مادة تعليمية ومستوى تعليمي، وعليك إنشاء نص مناسب بناءً على ذلك. النص يجب أن يكون:
33
+ 1. واضحًا وسهل الفهم.
34
+ 2. مناسبًا للمستوى التعليمي المحدد.
35
+ 3. مرتبطًا بالمادة التعليمية المطلوبة.
36
+ 4. قصيرًا ومباشرًا.
37
+ ### أمثلة:
38
+ 1. المادة: العلوم
39
+ المستوى: الابتدائي
40
+ النص: النباتات كائنات حية تحتاج إلى الماء والهواء وضوء الشمس لتنمو. بعض النباتات تنتج أزهارًا وفواكه. النباتات تساعدنا في الحصول على الأكسجين.
41
+ 2. المادة: التاريخ
42
+ المستوى: المتوسط
43
+ النص: التاريخ هو دراسة الماضي وأحداثه المهمة. من خلال التاريخ، نتعلم عن الحضارات القديمة مثل الحضارة الفرعونية والحضارة الإسلامية. التاريخ يساعدنا على فهم تطور البشرية.
44
+ 3. المادة: الجغرافيا
45
+ المستوى: المتوسط
46
+ النص: الجغرافيا هي دراسة الأرض وخصائصها. نتعلم عن القارات والمحيطات والجبال. الجغرافيا تساعدنا على فهم العالم من حولنا.
47
+ ---
48
+ المادة: {المادة}
49
+ المستوى: {المستوى}
50
+ اكتب نصًا مناسبًا بناءً على المادة والمستوى المحددين. ركّز على جعل النص بسيطًا وواضحًا للمستوى المطلوب.
51
+ """
52
+
53
+ class GenerateTextRequest(BaseModel):
54
+ المادة: str
55
+ المستوى: str
56
+
57
+ @app.post("/generate-text")
58
+ def generate_text(request: GenerateTextRequest):
59
+ if not request.المادة or not request.المستوى:
60
+ raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
61
+
62
+ try:
63
+ prompt = general_prompt_template.format(المادة=request.المادة, المستوى=request.المستوى)
64
+ inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
65
+
66
+ with torch.no_grad():
67
+ outputs = text_model.generate(
68
+ inputs.input_ids,
69
+ max_length=300,
70
+ num_return_sequences=1,
71
+ temperature=0.7,
72
+ top_p=0.95,
73
+ do_sample=True,
74
+ )
75
+
76
+ generated_text = text_tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
77
+ logger.info(f"Generated text: {generated_text}")
78
+ return {"generated_text": generated_text}
79
+
80
+ except Exception as e:
81
+ logger.error(f"Error during text generation: {str(e)}")
82
+ raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
83
+
84
+ ####################################
85
+ # Question & Answer Generation Model
86
+ ####################################
87
+ QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
88
+ qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
89
+ qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
90
+
91
+ def extract_answer(context: str) -> str:
92
+ """Extract the first sentence (or a key phrase) from the context."""
93
+ sentences = re.split(r'[.!؟]', context)
94
+ sentences = [s.strip() for s in sentences if s.strip()]
95
+ return sentences[0] if sentences else context
96
+
97
+ def get_question(context: str, answer: str) -> str:
98
+ """Generate a question based on the context and the candidate answer."""
99
+ text = f"النص: {context} الإجابة: {answer} </s>"
100
+ text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt")
101
+
102
+ with torch.no_grad():
103
+ generated_ids = qa_model.generate(
104
+ input_ids=text_encoding['input_ids'],
105
+ attention_mask=text_encoding['attention_mask'],
106
+ max_length=64,
107
+ num_beams=5,
108
+ num_return_sequences=1
109
+ )
110
+
111
+ question = qa_tokenizer.decode(generated_ids[0], skip_special_tokens=True).replace("question:", "").strip()
112
+ return question
113
+
114
+ class GenerateQARequest(BaseModel):
115
+ text: str
116
+
117
+ @app.post("/generate-qa")
118
+ def generate_qa(request: GenerateQARequest):
119
+ if not request.text:
120
+ raise HTTPException(status_code=400, detail="Text is required.")
121
+
122
+ try:
123
+ question, answer = get_question(request.text, extract_answer(request.text))
124
+ logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
125
+ return {"question": question, "answer": answer}
126
+
127
+ except Exception as e:
128
+ logger.error(f"Error during QA generation: {str(e)}")
129
+ raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")
130
+
131
+ ####################################
132
+ # Root Endpoint
133
+ ####################################
134
+ @app.get("/")
135
+ def read_root():
136
+ return {"message": "Welcome to the Arabic Text Generation API!"}
137
+
138
+ ####################################
139
+ # Running the FastAPI Server
140
+ ####################################
141
+ if __name__ == "__main__":
142
+ import uvicorn
143
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ transformers
3
+ torch
4
+ numpy
5
+ pandas
6
+ fastapi
7
+ uvicorn[standard]
8
+ pandas
9
+ transformers
10
+ torch
11
+ accelerate
12
+ huggingface-hub
13
+ tqdm