|
--- |
|
library_name: transformers |
|
tags: |
|
- deepfake |
|
datasets: |
|
- elsaEU/ELSA_D3 |
|
license: mit |
|
--- |
|
|
|
# Contrasting Deepfakes Diffusion via Contrastive Learning and Global-Local Similarities (ECCV 2024) |
|
|
|
[Project page](https://aimagelab.github.io/CoDE/) |
|
|
|
[Source code](https://github.com/aimagelab/CoDE) |
|
|
|
# Model Card for Model ID |
|
|
|
<!-- Provide the basic links for the model. --> |
|
|
|
```python |
|
import transformers |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
import faiss |
|
import timm |
|
import torch |
|
import torch.nn as nn |
|
import joblib |
|
from torchvision import transforms |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
else: device = torch.device('cpu') |
|
|
|
''' |
|
linear - knn --> 0 Real - 1 Fake |
|
svm --> -1 Real - 1 Fake |
|
''' |
|
class VITContrastiveHF(nn.Module): |
|
def __init__(self, repo_name, classificator_type): |
|
super(VITContrastiveHF, self).__init__() |
|
self.model = transformers.AutoModel.from_pretrained(repo_name) |
|
self.model.pooler= nn.Identity() |
|
|
|
self.processor = transformers.AutoProcessor.from_pretrained(repo_name) |
|
self.processor.do_resize= False |
|
# define the correct classifier /// consider to use the `cache_dir`` parameter |
|
if classificator_type == 'svm': |
|
file_path = hf_hub_download(repo_id=repo_name, filename='sklearn/ocsvm_kernel_poly_gamma_auto_nu_0_1_crop.joblib') |
|
self.classifier = joblib.load(file_path) |
|
|
|
elif classificator_type == 'linear': |
|
file_path = hf_hub_download(repo_id=repo_name, filename='sklearn/linear_tot_classifier_epoch-32.sav') |
|
self.classifier = joblib.load(file_path) |
|
|
|
elif classificator_type == 'knn': |
|
file_path = hf_hub_download(repo_id=repo_name, filename='sklearn/knn_tot_classifier_epoch-32.sav') |
|
self.classifier = joblib.load(file_path) |
|
|
|
else: |
|
raise ValueError('Selected an invalid classifier') |
|
|
|
def forward(self, x, return_feature=False): |
|
features = self.model(x) |
|
if return_feature: |
|
return features |
|
features = features.last_hidden_state[:,0,:].cpu().detach().numpy() |
|
# features.last_hidden_state[:,0,:].shape |
|
predictions = self.classifier.predict(features) |
|
return torch.from_numpy(predictions) |
|
|
|
# HF inference code |
|
classificator_type = 'linear' |
|
model = VITContrastiveHF(repo_name='aimagelab/CoDE', classificator_type=classificator_type) |
|
|
|
transform = transforms.Compose([ |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
model.eval() |
|
model.model.to(device) |
|
y_pred= [] |
|
# Put your image to tes |
|
img = Image.open('206496010652.png').convert('RGB') |
|
|
|
with torch.no_grad(): |
|
# in_tens = model.processor(img, return_tensors='pt')['pixel_values'] |
|
in_tens = transform(img).unsqueeze(0) |
|
|
|
in_tens= in_tens.to(device) |
|
y_pred.extend(model(in_tens).flatten().tolist()) |
|
|
|
# check the correct label of the predict image |
|
for el in y_pred: |
|
if el == 1: |
|
print('Fake') |
|
elif el == 0: |
|
print('Real') |
|
elif el == -1: |
|
print('Real') |
|
else: |
|
print('Error') |
|
|
|
``` |