|
from PIL import Image |
|
import torch |
|
|
|
def predict_turkish(image, turkish_question, tokenizer, model, image_processor, tr_en_tokenizer, tr_en_model, en_tr_tokenizer, en_tr_model, device="cuda"): |
|
inputs = tr_en_tokenizer([turkish_question], return_tensors="pt").to(device) |
|
english_question = tr_en_tokenizer.decode(tr_en_model.generate(**inputs)[0], skip_special_tokens=True) |
|
|
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
else: |
|
image = image.convert("RGB") |
|
|
|
from llava.conversation import conv_templates |
|
from llava.constants import DEFAULT_IMAGE_TOKEN |
|
|
|
conv = conv_templates["llava-v1"].copy() |
|
conv.messages = [] |
|
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + english_question) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(device) |
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids=input_ids, |
|
images=image_tensor, |
|
do_sample=False, |
|
temperature=0.2, |
|
max_new_tokens=512 |
|
) |
|
|
|
english_answer = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True) |
|
answer_inputs = en_tr_tokenizer([english_answer], return_tensors="pt").to(device) |
|
return en_tr_tokenizer.decode(en_tr_model.generate(**answer_inputs)[0], skip_special_tokens=True) |
|
|