File size: 2,154 Bytes
970a7a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
sys.path.append('DenseMammogram')

import torch

from models import get_FRCNN_model, Bilateral_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
frcnn_model = get_FRCNN_model().to(device)
bilat_model = Bilateral_model(frcnn_model).to(device)

FRCNN_PATH = 'pretrained_models/frcnn/frcnn_models/frcnn_model.pth'
BILAR_PATH = 'pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth'

frcnn_model.load_state_dict(torch.load(FRCNN_PATH, map_location=device))
bilat_model.load_state_dict(torch.load(BILAR_PATH, map_location=device))

import os
import torchvision.transforms as T
import cv2
from tqdm import tqdm
import detection.transforms as transforms
from dataloaders import get_direction

def predict(left_file, right_file, threshold = 0.80, baseIsLeft = True):
    model = bilat_model
    with torch.no_grad():
        transform = T.Compose([T.ToPILImage(),T.ToTensor()])
        model.eval()
        # First is left, then right
        img1 = cv2.imread(left_file)
        img1 = transform(img1)
        img2 = cv2.imread(right_file)
        img2 = transform(img2)

        if baseIsLeft:
            img1,_ = transforms.RandomHorizontalFlip(1.0)(img1)
        else:
            img2,_ = transforms.RandomHorizontalFlip(1.0)(img2)


        images = [img1.to(device),img2.to(device)]
        output = model([images])[0]
        if baseIsLeft:
            img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output)
        
        image =  cv2.imread(left_file)
        for b,s,l in zip(output['boxes'], output['scores'], output['labels']):
            # Convert img1 tensor to numpy array
            if l == 1 and s > threshold:
                # Draw the bounding boxes
                b = b.detach().cpu().numpy().astype(int)   
                # return image, b           
                cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 255, 0), 2)
                # Print the % probability just above the box
                cv2.putText(image, 'Cancer: '+str(round(round(s.item(), 2) * 100, 1)) + '%', (b[0], b[1] - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.6, (36,255,12), 6)
        return image