ryanzhangfan commited on
Commit
c5cfecb
1 Parent(s): f6e699f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +63 -3
README.md CHANGED
@@ -1,3 +1,63 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ ---
5
+
6
+
7
+ #### Quickstart
8
+
9
+ ```python
10
+ from PIL import Image
11
+ from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
12
+ from transformers.generation.configuration_utils import GenerationConfig
13
+ import torch
14
+
15
+ import sys
16
+ sys.path.append(PATH_TO_BAAI_Emu3-Chat_MODEL)
17
+ from processing_emu3 import Emu3Processor
18
+
19
+
20
+ # model path
21
+ EMU_HUB = "BAAI/Emu3-Chat"
22
+ VQ_HUB = "BAAI/Emu3-VisionTokenizer"
23
+
24
+ # prepare model and processor
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ EMU_HUB,
27
+ device_map="cuda:0",
28
+ torch_dtype=torch.bfloat16,
29
+ attn_implementation="flash_attention_2",
30
+ trust_remote_code=True,
31
+ )
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True)
34
+ image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
35
+ image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
36
+ processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
37
+
38
+ # prepare input
39
+ text = "Please describe the image"
40
+ image = Image.open("assets/demo.png")
41
+
42
+ inputs = processor(
43
+ text=text,
44
+ image=image,
45
+ mode='U',
46
+ padding_side="left",
47
+ padding="longest",
48
+ return_tensors="pt",
49
+ )
50
+
51
+ # prepare hyper parameters
52
+ GENERATION_CONFIG = GenerationConfig(pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
53
+
54
+ # generate
55
+ outputs = model.generate(
56
+ inputs.input_ids.to("cuda:0"),
57
+ GENERATION_CONFIG,
58
+ max_new_tokens=320,
59
+ )
60
+
61
+ outputs = outputs[:, inputs.input_ids.shape[-1]:]
62
+ print(processor.batch_decode(outputs, skip_special_tokens=True)[0])
63
+ ```