File size: 818 Bytes
9ff97fa
2590409
9ff97fa
2590409
 
 
 
9ff97fa
e7d85ce
4e8a01a
e7d85ce
2590409
 
 
 
9ff97fa
 
2590409
9ff97fa
 
 
2590409
9ff97fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
import torch
from config import MODEL_NAME


class RadarDetectionModel:
    def __init__(self):
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(
            "google/paligemma-3b-ft-coco35l-224")
        self.model = AutoModelForObjectDetection.from_pretrained(
            "google/paligemma-3b-ft-coco35l-224")
        self.model.eval()

    @torch.no_grad()
    def detect(self, image):
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        outputs = self.model(**inputs)

        target_sizes = torch.tensor([image.size[::-1]])
        results = self.feature_extractor.post_process_object_detection(
            outputs, threshold=0.5, target_sizes=target_sizes)[0]

        return results