Spaces:
Running
Running
fancyfeast
commited on
Commit
·
b72bef3
1
Parent(s):
df9e86f
Prepare images correctly
Browse files
app.py
CHANGED
|
@@ -4,15 +4,45 @@ import huggingface_hub
|
|
| 4 |
from PIL import Image
|
| 5 |
import torch.amp.autocast_mode
|
| 6 |
from pathlib import Path
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
MODEL_REPO = "fancyfeast/joytag"
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
@torch.no_grad()
|
| 13 |
def predict(image: Image.Image):
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
tag_preds = preds['tags'].sigmoid().cpu()
|
| 17 |
|
| 18 |
return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
import torch.amp.autocast_mode
|
| 6 |
from pathlib import Path
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision.transforms.functional as TVF
|
| 9 |
|
| 10 |
|
| 11 |
MODEL_REPO = "fancyfeast/joytag"
|
| 12 |
|
| 13 |
|
| 14 |
+
def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
|
| 15 |
+
# Pad image to square
|
| 16 |
+
image_shape = image.size
|
| 17 |
+
max_dim = max(image_shape)
|
| 18 |
+
pad_left = (max_dim - image_shape[0]) // 2
|
| 19 |
+
pad_top = (max_dim - image_shape[1]) // 2
|
| 20 |
+
|
| 21 |
+
padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
|
| 22 |
+
padded_image.paste(image, (pad_left, pad_top))
|
| 23 |
+
|
| 24 |
+
# Resize image
|
| 25 |
+
if max_dim != target_size:
|
| 26 |
+
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
|
| 27 |
+
|
| 28 |
+
# Convert to tensor
|
| 29 |
+
image_tensor = TVF.pil_to_tensor(padded_image) / 255.0
|
| 30 |
+
|
| 31 |
+
# Normalize
|
| 32 |
+
image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
|
| 33 |
+
|
| 34 |
+
return image_tensor
|
| 35 |
+
|
| 36 |
+
|
| 37 |
@torch.no_grad()
|
| 38 |
def predict(image: Image.Image):
|
| 39 |
+
image_tensor = prepare_image(image, model.image_size)
|
| 40 |
+
batch = {
|
| 41 |
+
'image': image_tensor.unsqueeze(0),
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
with torch.amp.autocast_mode.autocast('cpu', enabled=True):
|
| 45 |
+
preds = model(batch)
|
| 46 |
tag_preds = preds['tags'].sigmoid().cpu()
|
| 47 |
|
| 48 |
return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}
|