Update README.md
Browse files
README.md
CHANGED
|
@@ -94,6 +94,40 @@ model = AutoModel.from_pretrained(
|
|
| 94 |
trust_remote_code=True,
|
| 95 |
).eval().cuda()
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
```
|
| 98 |
|
| 99 |
## Citation
|
|
|
|
| 94 |
trust_remote_code=True,
|
| 95 |
).eval().cuda()
|
| 96 |
|
| 97 |
+
#!wget https://huggingface.co/5CD-AI/ColVintern-1B-v1/resolve/main/ex1.jpg
|
| 98 |
+
#!wget https://huggingface.co/5CD-AI/ColVintern-1B-v1/resolve/main/ex2.jpg
|
| 99 |
+
|
| 100 |
+
images = [Image.open("ex1.jpg"),Image.open("ex2.jpg")]
|
| 101 |
+
batch_images = processor.process_images(images)
|
| 102 |
+
|
| 103 |
+
queries = [
|
| 104 |
+
"Cảng Hải Phòng thông báo gì ?",
|
| 105 |
+
"Phí giao hàng bao nhiêu ?",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
batch_queries = processor.process_queries(queries)
|
| 109 |
+
|
| 110 |
+
batch_images["pixel_values"] = batch_images["pixel_values"].cuda().bfloat16()
|
| 111 |
+
batch_images["input_ids"] = batch_images["input_ids"].cuda()
|
| 112 |
+
batch_images["attention_mask"] = batch_images["attention_mask"].cuda().bfloat16()
|
| 113 |
+
batch_queries["input_ids"] = batch_queries["input_ids"].cuda()
|
| 114 |
+
batch_queries["attention_mask"] = batch_queries["attention_mask"].cuda().bfloat16()
|
| 115 |
+
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
image_embeddings = model(**batch_images)
|
| 118 |
+
query_embeddings = model(**batch_queries)
|
| 119 |
+
|
| 120 |
+
scores = processor.score_multi_vector(query_embeddings, image_embeddings)
|
| 121 |
+
|
| 122 |
+
max_scores, max_indices = torch.max(scores, dim=1)
|
| 123 |
+
# In ra kết quả cho mỗi câu hỏi
|
| 124 |
+
for i, query in enumerate(queries):
|
| 125 |
+
image_name = images[max_indices[i]]
|
| 126 |
+
print(f"Câu hỏi: '{query}'")
|
| 127 |
+
print(f"Điểm số: {max_scores[i].item()}\n")
|
| 128 |
+
plt.figure(figsize=(5,5))
|
| 129 |
+
plt.imshow(image_name)
|
| 130 |
+
plt.show()
|
| 131 |
```
|
| 132 |
|
| 133 |
## Citation
|