Pranjal2041's picture
Initial demo
970a7a2
raw
history blame
2.15 kB
import sys
sys.path.append('DenseMammogram')
import torch
from models import get_FRCNN_model, Bilateral_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
frcnn_model = get_FRCNN_model().to(device)
bilat_model = Bilateral_model(frcnn_model).to(device)
FRCNN_PATH = 'pretrained_models/frcnn/frcnn_models/frcnn_model.pth'
BILAR_PATH = 'pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth'
frcnn_model.load_state_dict(torch.load(FRCNN_PATH, map_location=device))
bilat_model.load_state_dict(torch.load(BILAR_PATH, map_location=device))
import os
import torchvision.transforms as T
import cv2
from tqdm import tqdm
import detection.transforms as transforms
from dataloaders import get_direction
def predict(left_file, right_file, threshold = 0.80, baseIsLeft = True):
model = bilat_model
with torch.no_grad():
transform = T.Compose([T.ToPILImage(),T.ToTensor()])
model.eval()
# First is left, then right
img1 = cv2.imread(left_file)
img1 = transform(img1)
img2 = cv2.imread(right_file)
img2 = transform(img2)
if baseIsLeft:
img1,_ = transforms.RandomHorizontalFlip(1.0)(img1)
else:
img2,_ = transforms.RandomHorizontalFlip(1.0)(img2)
images = [img1.to(device),img2.to(device)]
output = model([images])[0]
if baseIsLeft:
img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output)
image = cv2.imread(left_file)
for b,s,l in zip(output['boxes'], output['scores'], output['labels']):
# Convert img1 tensor to numpy array
if l == 1 and s > threshold:
# Draw the bounding boxes
b = b.detach().cpu().numpy().astype(int)
# return image, b
cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 255, 0), 2)
# Print the % probability just above the box
cv2.putText(image, 'Cancer: '+str(round(round(s.item(), 2) * 100, 1)) + '%', (b[0], b[1] - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.6, (36,255,12), 6)
return image