File size: 1,552 Bytes
5365632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from PIL import Image
import torch

def predict_turkish(image, turkish_question, tokenizer, model, image_processor, tr_en_tokenizer, tr_en_model, en_tr_tokenizer, en_tr_model, device="cuda"):
    inputs = tr_en_tokenizer([turkish_question], return_tensors="pt").to(device)
    english_question = tr_en_tokenizer.decode(tr_en_model.generate(**inputs)[0], skip_special_tokens=True)

    if isinstance(image, str):
        image = Image.open(image).convert("RGB")
    else:
        image = image.convert("RGB")

    from llava.conversation import conv_templates
    from llava.constants import DEFAULT_IMAGE_TOKEN

    conv = conv_templates["llava-v1"].copy()
    conv.messages = []
    conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + english_question)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(device)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids=input_ids,
            images=image_tensor,
            do_sample=False,
            temperature=0.2,
            max_new_tokens=512
        )

    english_answer = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
    answer_inputs = en_tr_tokenizer([english_answer], return_tensors="pt").to(device)
    return en_tr_tokenizer.decode(en_tr_model.generate(**answer_inputs)[0], skip_special_tokens=True)