x_alhdrawi1 / app.py
Alhdrawi's picture
Update app.py
48badb6 verified
raw
history blame
2.87 kB
import gradio as gr
import torch
import os
import urllib.request
from torchvision import transforms
from PIL import Image
import torch.nn as nn
# إعدادات النموذج
REPO_ID = "Alhdrawi/x_alhdrawi"
MODEL_FILE = "best_128_0.0002_original_15000_0.859.pt"
MODEL_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/{MODEL_FILE}"
MODEL_LOCAL_PATH = f"/tmp/{MODEL_FILE}"
# قائمة الأمراض التي يتوقعها النموذج
diseases = [
"Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
"Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
"Pleural_Thickening", "Pneumonia", "Pneumothorax"
]
# تحويل الصورة مثل ما دربت النموذج
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485], std=[0.229])
])
# تعريف بنية النموذج (نفس اللي استخدمته وقت التدريب)
class SimpleCNN(nn.Module):
def __init__(self, num_classes=14):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# تحميل النموذج
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=len(diseases)).to(device)
def download_and_load_model():
if not os.path.exists(MODEL_LOCAL_PATH):
print(f"Downloading model from {MODEL_URL}")
urllib.request.urlretrieve(MODEL_URL, MODEL_LOCAL_PATH)
state_dict = torch.load(MODEL_LOCAL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()
print(f"✅ Model loaded from {MODEL_FILE}")
# دالة التنبؤ
def predict(image):
img = transform(image.convert("L")).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
probs = torch.sigmoid(outputs).cpu().squeeze().numpy()
results = {d: round(float(p), 3) for d, p in zip(diseases, probs)}
return results
# تحميل النموذج عند بدء التشغيل
download_and_load_model()
# واجهة Gradio
with gr.Blocks() as demo:
gr.Markdown(f"## 🧠 CheXzero | النموذج المستخدم: `{MODEL_FILE}`")
with gr.Row():
image_input = gr.Image(type="pil", label="صورة أشعة X-Ray")
output = gr.Label(num_top_classes=5)
image_input.change(fn=predict, inputs=image_input, outputs=output)
demo.launch()