sanpdy commited on
Commit
b485e09
·
1 Parent(s): 012982e

Adding app files

Browse files
Files changed (3) hide show
  1. .gradio/certificate.pem +31 -0
  2. app.py +144 -0
  3. requirements.txt +91 -0
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import torchvision.transforms as T
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import gradio as gr
9
+ from ultralytics import YOLO
10
+ from transformers import ResNetModel
11
+ import cv2
12
+
13
+ class FlakeLayerClassifier(nn.Module):
14
+ def __init__(self, num_materials, material_dim, num_classes=4, dropout_prob=0.1, freeze_cnn=False):
15
+ super().__init__()
16
+ self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18")
17
+ if freeze_cnn:
18
+ for p in self.cnn.parameters():
19
+ p.requires_grad = False
20
+
21
+ img_feat_dim = self.cnn.config.hidden_sizes[-1]
22
+ self.material_embedding = nn.Embedding(num_materials, material_dim)
23
+ self.dropout = nn.Dropout(dropout_prob)
24
+
25
+ self.fc_img = nn.Sequential(
26
+ nn.Linear(img_feat_dim, img_feat_dim),
27
+ nn.ReLU(inplace=True),
28
+ self.dropout,
29
+ nn.Linear(img_feat_dim, num_classes)
30
+ )
31
+
32
+ combined_dim = img_feat_dim + material_dim
33
+ self.fc_comb = nn.Sequential(
34
+ nn.Linear(combined_dim, combined_dim),
35
+ nn.ReLU(inplace=True),
36
+ self.dropout,
37
+ nn.Linear(combined_dim, num_classes)
38
+ )
39
+
40
+ def forward(self, pixel_values, material=None):
41
+ outputs = self.cnn(pixel_values=pixel_values)
42
+ img_feats = outputs.pooler_output.view(outputs.pooler_output.size(0), -1)
43
+
44
+ if material is None:
45
+ return self.fc_img(img_feats)
46
+
47
+ mat_emb = self.material_embedding(material)
48
+ combined = torch.cat([img_feats, mat_emb], dim=1)
49
+ return self.fc_comb(combined)
50
+
51
+ def calibration(source_img, target_img):
52
+ source_lab = cv2.cvtColor(source_img, cv2.COLOR_BGR2LAB)
53
+ target_lab = cv2.cvtColor(target_img, cv2.COLOR_BGR2LAB)
54
+
55
+ for i in range(3):
56
+ src_mean, src_std = cv2.meanStdDev(source_lab[:, :, i])
57
+ tgt_mean, tgt_std = cv2.meanStdDev(target_lab[:, :, i])
58
+
59
+ target_lab[:, :, i] = (
60
+ (target_lab[:, :, i] - tgt_mean) * (src_std / tgt_std) + src_mean
61
+ ).clip(0, 255)
62
+
63
+ corrected_img = cv2.cvtColor(target_lab, cv2.COLOR_LAB2BGR)
64
+ return corrected_img.astype(np.uint8)
65
+
66
+
67
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
68
+ print(f"Using device: {device}")
69
+
70
+ # Load YOLO detector
71
+ #yolo = YOLO("/home/sankalp/flake_classification/models/best.pt")
72
+ #yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt")
73
+ yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt")
74
+ yolo.conf = 0.5
75
+
76
+ # Load classifier weights
77
+ ckpt = torch.load(
78
+ "/home/sankalp/flake_classification/models/flake_classifier.pth",
79
+ map_location=device
80
+ )
81
+ num_classes = len(ckpt["class_to_idx"])
82
+ classifier = FlakeLayerClassifier(
83
+ num_materials=num_classes,
84
+ material_dim=64,
85
+ num_classes=num_classes,
86
+ dropout_prob=0.1,
87
+ freeze_cnn=False
88
+ ).to(device)
89
+ classifier.load_state_dict(ckpt["model_state_dict"])
90
+ classifier.eval()
91
+
92
+ # Image processing transforms
93
+ clf_tf = T.Compose([
94
+ T.Resize((224, 224)),
95
+ T.ToTensor(),
96
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
97
+ ])
98
+
99
+ try:
100
+ FONT = ImageFont.truetype("arial.ttf", 20)
101
+ except IOError:
102
+ FONT = ImageFont.load_default()
103
+
104
+ # Inference + drawing
105
+ def detect_and_classify(image: Image.Image):
106
+ #image = calibration(
107
+ # np.array(Image.open("/home/sankalp/gradio_flake_app/quantum-flake-pipeline/template/image.png")),
108
+ #np.array(image.convert("RGB")),
109
+ #)
110
+ #image = Image.fromarray(image)
111
+ img_rgb = np.array(image.convert("RGB"))
112
+ img_bgr = img_rgb[:, :, ::-1]
113
+ results = yolo(img_bgr, device=str(device))
114
+ boxes = results[0].boxes.xyxy.cpu().numpy()
115
+ scores = results[0].boxes.conf.cpu().numpy()
116
+
117
+ draw = ImageDraw.Draw(image)
118
+ for (x1, y1, x2, y2), conf in zip(boxes, scores):
119
+ crop = image.crop((x1, y1, x2, y2))
120
+ inp = clf_tf(crop).unsqueeze(0).to(device) # (1,C,H,W)
121
+
122
+ with torch.no_grad():
123
+ logits = classifier(pixel_values=inp)
124
+ pred = logits.argmax(1).item()
125
+ prob = F.softmax(logits, dim=1)[0, pred].item()
126
+
127
+ label = f"Layer {pred+1} ({prob:.2f})"
128
+ # draw
129
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
130
+ draw.text((x1, max(0, y1-18)), label, fill="red", font=FONT)
131
+
132
+ return image
133
+
134
+ # Gradio UI
135
+ demo = gr.Interface(
136
+ fn=detect_and_classify,
137
+ inputs=gr.Image(type="pil", label="Upload Flake Image"),
138
+ outputs=gr.Image(type="pil", label="Annotated Output"),
139
+ title="Flake Detection + Layer Classification",
140
+ description="Upload an image → YOLO finds flakes → ResNet-18 head classifies their layer.",
141
+ )
142
+
143
+ if __name__ == "__main__":
144
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ certifi==2025.6.15
5
+ charset-normalizer==3.4.2
6
+ click==8.2.1
7
+ contourpy==1.3.2
8
+ cycler==0.12.1
9
+ fastapi==0.115.13
10
+ ffmpy==0.6.0
11
+ filelock==3.18.0
12
+ fonttools==4.58.4
13
+ fsspec==2025.5.1
14
+ gradio==5.34.2
15
+ gradio_client==1.10.3
16
+ groovy==0.1.2
17
+ h11==0.16.0
18
+ hf-xet==1.1.5
19
+ httpcore==1.0.9
20
+ httpx==0.28.1
21
+ huggingface-hub==0.33.1
22
+ idna==3.10
23
+ Jinja2==3.1.6
24
+ kiwisolver==1.4.8
25
+ markdown-it-py==3.0.0
26
+ MarkupSafe==3.0.2
27
+ matplotlib==3.10.3
28
+ mdurl==0.1.2
29
+ mpmath==1.3.0
30
+ networkx==3.5
31
+ numpy==2.3.1
32
+ nvidia-cublas-cu12==12.6.4.1
33
+ nvidia-cuda-cupti-cu12==12.6.80
34
+ nvidia-cuda-nvrtc-cu12==12.6.77
35
+ nvidia-cuda-runtime-cu12==12.6.77
36
+ nvidia-cudnn-cu12==9.5.1.17
37
+ nvidia-cufft-cu12==11.3.0.4
38
+ nvidia-cufile-cu12==1.11.1.6
39
+ nvidia-curand-cu12==10.3.7.77
40
+ nvidia-cusolver-cu12==11.7.1.2
41
+ nvidia-cusparse-cu12==12.5.4.2
42
+ nvidia-cusparselt-cu12==0.6.3
43
+ nvidia-nccl-cu12==2.26.2
44
+ nvidia-nvjitlink-cu12==12.6.85
45
+ nvidia-nvtx-cu12==12.6.77
46
+ opencv-python==4.11.0.86
47
+ orjson==3.10.18
48
+ packaging==25.0
49
+ pandas==2.3.0
50
+ pillow==11.2.1
51
+ psutil==7.0.0
52
+ py-cpuinfo==9.0.0
53
+ pydantic==2.11.7
54
+ pydantic_core==2.33.2
55
+ pydub==0.25.1
56
+ Pygments==2.19.2
57
+ pyparsing==3.2.3
58
+ python-dateutil==2.9.0.post0
59
+ python-multipart==0.0.20
60
+ pytz==2025.2
61
+ PyYAML==6.0.2
62
+ regex==2024.11.6
63
+ requests==2.32.4
64
+ rich==14.0.0
65
+ ruff==0.12.0
66
+ safehttpx==0.1.6
67
+ safetensors==0.5.3
68
+ scipy==1.16.0
69
+ semantic-version==2.10.0
70
+ shellingham==1.5.4
71
+ six==1.17.0
72
+ sniffio==1.3.1
73
+ starlette==0.46.2
74
+ sympy==1.14.0
75
+ tokenizers==0.21.2
76
+ tomlkit==0.13.3
77
+ torch==2.7.1
78
+ torchaudio==2.7.1+cpu
79
+ torchvision==0.22.1
80
+ tqdm==4.67.1
81
+ transformers==4.52.4
82
+ triton==3.3.1
83
+ typer==0.16.0
84
+ typing-inspection==0.4.1
85
+ typing_extensions==4.14.0
86
+ tzdata==2025.2
87
+ ultralytics==8.3.159
88
+ ultralytics-thop==2.0.14
89
+ urllib3==2.5.0
90
+ uvicorn==0.34.3
91
+ websockets==15.0.1