Dreamy0 commited on
Commit
2c19914
·
verified ·
1 Parent(s): cebb1df

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +99 -3
README.md CHANGED
@@ -1,3 +1,99 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ ---
5
+ license: mit
6
+ tags:
7
+ - vision
8
+ - image-segmentation
9
+ - instance-segmentation
10
+ datasets:
11
+ - custom-germination-dataset
12
+ widget:
13
+ - src: https://example.com/path/to/germination-image1.jpg
14
+ example_title: Germination Image 1
15
+ - src: https://example.com/path/to/germination-image2.jpg
16
+ example_title: Germination Image 2
17
+ ---
18
+
19
+ # GermiNet: A MaskFormer Model for Germination Counting
20
+
21
+ GermiNet model trained on a custom germination dataset for instance segmentation (small-sized version, Swin backbone). It is based on the MaskFormer architecture introduced in the paper [MaskFormer: Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) and first released in [this repository](https://github.com/facebookresearch/MaskFormer).
22
+
23
+ Disclaimer: This model card is written by [Your Name/Organization] with assistance from Grok (xAI) for the Hugging Face community.
24
+
25
+ ## Model Description
26
+
27
+ GermiNet is a MaskFormer-based instance segmentation model fine-tuned to detect and segment "normal" and "abnormal" seeds in germination images. It uses the `facebook/maskformer-swin-tiny-coco` pre-trained checkpoint as its backbone, with a Swin-Tiny transformer architecture. The model predicts a set of masks and corresponding labels for three classes: "background," "normal," and "abnormal," with an additional "no object" class handled internally by MaskFormer. The model was trained on a small custom dataset as a proof-of-concept for automating germination counting in agricultural research.
28
+
29
+ ![model image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/maskformer_architecture.png)
30
+
31
+ ## Intended Uses & Limitations
32
+
33
+ GermiNet is intended for instance segmentation tasks in agricultural research, specifically for detecting and segmenting "normal" and "abnormal" seeds in germination images. It can be used with tools like CVAT for automated annotation workflows.
34
+
35
+ ### Limitations
36
+ - The model was trained on a small dataset (18 images), which limits its generalization.
37
+ - Local inference shows a bias toward "no object" predictions, with few "normal" detections and no "abnormal" detections, indicating underfitting.
38
+ - Mask resolution is 56x56 (upscaled to 224x224 or higher for visualization), which may miss fine details.
39
+ - The model requires further training with a larger dataset and more epochs for improved performance.
40
+
41
+ See the [model hub](https://huggingface.co/models?search=germinet) to look for other fine-tuned versions if needed.
42
+
43
+ ## How to Use
44
+
45
+ Here’s how to use this model for instance segmentation:
46
+
47
+ ```python
48
+ import requests
49
+ import torch
50
+ from PIL import Image
51
+ from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation
52
+
53
+ # Load GermiNet fine-tuned on custom germination dataset
54
+ processor = AutoImageProcessor.from_pretrained("your-username/germi-net")
55
+ model = MaskFormerForInstanceSegmentation.from_pretrained("your-username/germi-net")
56
+
57
+ # Load an image (replace with your image URL or local path)
58
+ url = "https://example.com/path/to/germination-image.jpg"
59
+ image = Image.open(requests.get(url, stream=True).raw)
60
+ # Alternatively, use a local image
61
+ # image = Image.open("path/to/your/image.jpg")
62
+ inputs = processor(images=image, return_tensors="pt")
63
+
64
+ # Run inference
65
+ with torch.no_grad():
66
+ outputs = model(**inputs)
67
+
68
+ # Model predicts class_queries_logits and masks_queries_logits
69
+ class_queries_logits = outputs.class_queries_logits # Shape: (batch_size, num_queries, num_classes + 1)
70
+ masks_queries_logits = outputs.masks_queries_logits # Shape: (batch_size, num_queries, height, width)
71
+
72
+ # Post-process predictions
73
+ predicted_classes = class_queries_logits.argmax(-1).cpu().numpy()
74
+ mask_predictions = masks_queries_logits.sigmoid().cpu().numpy()
75
+ binary_masks = (mask_predictions > 0.5).astype(np.uint8)
76
+
77
+ # Map predictions to labels
78
+ id2label = {0: "background", 1: "normal", 2: "abnormal", 3: "no object"}
79
+ predicted_labels = [id2label[cls] for cls in predicted_classes[0]]
80
+ print("Predicted labels:", predicted_labels)
81
+
82
+ # Optional: Visualize (requires matplotlib and cv2)
83
+ import numpy as np
84
+ import matplotlib.pyplot as plt
85
+ import cv2
86
+
87
+ visualization_size = (800, 800)
88
+ resized_masks = np.zeros((binary_masks.shape[1], *visualization_size), dtype=np.uint8)
89
+ for i in range(binary_masks.shape[1]):
90
+ resized_masks[i] = cv2.resize(binary_masks[0, i], visualization_size, interpolation=cv2.INTER_NEAREST)
91
+
92
+ image_np = np.array(image)
93
+ aspect_ratio = image_np.shape[1] / image_np.shape[0]
94
+ new_height = visualization_size[0]
95
+ new_width = int(new_height * aspect_ratio)
96
+ resized_image = cv2.resize(image_np, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
97
+ if new_width != visualization_size[1]:
98
+ start_x = (new_width - visualization_size[1]) // 2
99
+ resized_image = resized_image[:, start_x:start_x + visualization_size[1]]