Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		fancyfeast
		
	commited on
		
		
					Commit 
							
							Β·
						
						2fb728d
	
1
								Parent(s):
							
							0ba5137
								
Initial commit
Browse files- app.py +199 -0
- requirements.txt +6 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,199 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import spaces
         | 
| 2 | 
            +
            import gradio as gr
         | 
| 3 | 
            +
            from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, LlavaForConditionalGeneration, TextIteratorStreamer
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.amp.autocast_mode
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
            import torchvision.transforms.functional as TVF
         | 
| 8 | 
            +
            from threading import Thread
         | 
| 9 | 
            +
            from typing import Generator
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            MODEL_PATH = "fancyfeast/llama-joycaption-alpha-two-vqa-test-1"
         | 
| 13 | 
            +
            TITLE = "<h1><center>JoyCaption Alpha Two - VQA Test - (2024-11-25a)</center></h1>"
         | 
| 14 | 
            +
            DESCRIPTION = """
         | 
| 15 | 
            +
            <div>
         | 
| 16 | 
            +
            <p>π¨π¨π¨ BY USING THIS SPACE YOU AGREE THAT YOUR QUERIES (but not images) <i>MAY</i> BE LOGGED AND COLLECTED ANONYMOUSLY π¨π¨π¨</p>
         | 
| 17 | 
            +
            <p>π§ͺπ§ͺπ§ͺ This an experiment to see how well JoyCaption Alpha Two can learn to answer questions about images and follow instructions.
         | 
| 18 | 
            +
            I've only finetuned it on 600 examples, so it is highly experimental, very weak, broken, and volatile.  But for only training 600 examples,
         | 
| 19 | 
            +
            I thought it was performing surprisingly well and wanted to share. π§ͺπ§ͺπ§ͺ</p>
         | 
| 20 | 
            +
            <p>Unlike JoyCaption Alpha Two, you can ask this finetune questions about the image, like "What is he holding in his hand?", "Where might this be?",
         | 
| 21 | 
            +
            and "What are they doing?".  It can also follow instructions, like "Write me a poem about this image",
         | 
| 22 | 
            +
            "Write a caption but don't use any ambigious language, and make sure you mention that the image is from Instagram.", and
         | 
| 23 | 
            +
            "Output JSON with the following properties: 'skin_tone', 'hair_style', 'hair_length', 'clothing', 'background'." Remember that this was only finetuned on
         | 
| 24 | 
            +
            600 VQA/instruction examples, so it is _very_ limited right now.  Expect it to frequently fallback to its base behavior of just writing image descriptions.
         | 
| 25 | 
            +
            Expect accuracy to be lower.  Expect glitches.  Despite that, I've found that it will follow most queries I've tested it with, even outside its training,
         | 
| 26 | 
            +
            with enough coaxing and re-rolling.</p>
         | 
| 27 | 
            +
            <p>About the π¨π¨π¨ above: this space will log all prompts sent to it.  The only thing this space logs is the text query; no images, no user data, etc.
         | 
| 28 | 
            +
            I cannot see what images you send, and frankly, I don't want to.  But knowing what kinds of instructions and queries users want JoyCaption to handle will
         | 
| 29 | 
            +
            help guide me in building JoyCaption's VQA dataset.  I've found out the hard way that almost all public VQA datasets are garbage and don't do a good job of
         | 
| 30 | 
            +
            training and exercising visual understanding.  Certainly not good enough to handle the complicated instructions that will allow JoyCaption users to guide and
         | 
| 31 | 
            +
            direct how JoyCaption writes descriptions and captions.  So I'm building my own dataset, that will be made public.  So, with peace and love, this space logs the text
         | 
| 32 | 
            +
            queries.  As always, the model itself is completely public and free to use outside of this space.  And, of course, I have no control nor access to what HuggingFace,
         | 
| 33 | 
            +
            which are graciously hosting this space, log.</p>
         | 
| 34 | 
            +
            </div>
         | 
| 35 | 
            +
            """
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            PLACEHOLDER = """
         | 
| 38 | 
            +
            """
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            # Load model
         | 
| 43 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)
         | 
| 44 | 
            +
            assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Expected PreTrainedTokenizer, got {type(tokenizer)}"
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, torch_dtype="bfloat16", device_map=0)
         | 
| 47 | 
            +
            assert isinstance(model, LlavaForConditionalGeneration), f"Expected LlavaForConditionalGeneration, got {type(model)}"
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def trim_off_prompt(input_ids: list[int], eoh_id: int, eot_id: int) -> list[int]:
         | 
| 51 | 
            +
            	# Trim off the prompt
         | 
| 52 | 
            +
            	while True:
         | 
| 53 | 
            +
            		try:
         | 
| 54 | 
            +
            			i = input_ids.index(eoh_id)
         | 
| 55 | 
            +
            		except ValueError:
         | 
| 56 | 
            +
            			break
         | 
| 57 | 
            +
            		
         | 
| 58 | 
            +
            		input_ids = input_ids[i + 1:]
         | 
| 59 | 
            +
            	
         | 
| 60 | 
            +
            	# Trim off the end
         | 
| 61 | 
            +
            	try:
         | 
| 62 | 
            +
            		i = input_ids.index(eot_id)
         | 
| 63 | 
            +
            	except ValueError:
         | 
| 64 | 
            +
            		return input_ids
         | 
| 65 | 
            +
            	
         | 
| 66 | 
            +
            	return input_ids[:i]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            end_of_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
         | 
| 69 | 
            +
            end_of_turn_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
         | 
| 70 | 
            +
            assert isinstance(end_of_header_id, int) and isinstance(end_of_turn_id, int)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            @spaces.GPU()
         | 
| 74 | 
            +
            @torch.no_grad()
         | 
| 75 | 
            +
            def chat_joycaption(message: dict, history, temperature: float, max_new_tokens: int) -> Generator[str, None, None]:
         | 
| 76 | 
            +
            	torch.cuda.empty_cache()
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            	# Prompts are always stripped in training for now
         | 
| 79 | 
            +
            	prompt = message['text'].strip()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            	# Load image
         | 
| 82 | 
            +
            	if "files" not in message or len(message["files"]) != 1:
         | 
| 83 | 
            +
            		raise ValueError("This model requires exactly one image as input.")
         | 
| 84 | 
            +
            	
         | 
| 85 | 
            +
            	image = Image.open(message["files"][0])
         | 
| 86 | 
            +
            	
         | 
| 87 | 
            +
            	# Log the prompt
         | 
| 88 | 
            +
            	print(f"Prompt: {prompt}")
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            	# Preprocess image
         | 
| 91 | 
            +
            	# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
         | 
| 92 | 
            +
            	if image.size != (384, 384):
         | 
| 93 | 
            +
            		image = image.resize((384, 384), Image.LANCZOS)
         | 
| 94 | 
            +
            	image = image.convert("RGB")
         | 
| 95 | 
            +
            	pixel_values = TVF.pil_to_tensor(image)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            	convo = [
         | 
| 98 | 
            +
            		{
         | 
| 99 | 
            +
            			"role": "system",
         | 
| 100 | 
            +
            			"content": "You are a helpful image captioner.",
         | 
| 101 | 
            +
            		},
         | 
| 102 | 
            +
            		{
         | 
| 103 | 
            +
            			"role": "user",
         | 
| 104 | 
            +
            			"content": prompt,
         | 
| 105 | 
            +
            		},
         | 
| 106 | 
            +
            	]
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            	# Format the conversation
         | 
| 109 | 
            +
            	convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
         | 
| 110 | 
            +
            	assert isinstance(convo_string, str)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            	# Tokenize the conversation
         | 
| 113 | 
            +
            	convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            	# Repeat the image tokens
         | 
| 116 | 
            +
            	input_tokens = []
         | 
| 117 | 
            +
            	for token in convo_tokens:
         | 
| 118 | 
            +
            		if token == model.config.image_token_index:
         | 
| 119 | 
            +
            			input_tokens.extend([model.config.image_token_index] * model.config.image_seq_length)
         | 
| 120 | 
            +
            		else:
         | 
| 121 | 
            +
            			input_tokens.append(token)
         | 
| 122 | 
            +
            	
         | 
| 123 | 
            +
            	input_ids = torch.tensor(input_tokens, dtype=torch.long)
         | 
| 124 | 
            +
            	attention_mask = torch.ones_like(input_ids)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            	# Move to GPU
         | 
| 127 | 
            +
            	input_ids = input_ids.unsqueeze(0).to("cuda")
         | 
| 128 | 
            +
            	attention_mask = attention_mask.unsqueeze(0).to("cuda")
         | 
| 129 | 
            +
            	pixel_values = pixel_values.unsqueeze(0).to("cuda")
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            	# Normalize the image
         | 
| 132 | 
            +
            	pixel_values = pixel_values / 255.0
         | 
| 133 | 
            +
            	pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
         | 
| 134 | 
            +
            	pixel_values = pixel_values.to(torch.bfloat16)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            	generate_kwargs = dict(
         | 
| 137 | 
            +
            		input_ids=input_ids,
         | 
| 138 | 
            +
            		pixel_values=pixel_values,
         | 
| 139 | 
            +
            		attention_mask=attention_mask,
         | 
| 140 | 
            +
            		max_new_tokens=max_new_tokens,
         | 
| 141 | 
            +
            		do_sample=True,
         | 
| 142 | 
            +
            		suppress_tokens=None,
         | 
| 143 | 
            +
            		use_cache=True,
         | 
| 144 | 
            +
            		temperature=temperature,
         | 
| 145 | 
            +
            		top_k=None,
         | 
| 146 | 
            +
            		top_p=0.9,
         | 
| 147 | 
            +
            		streamer=streamer,
         | 
| 148 | 
            +
            	)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            	if temperature == 0:
         | 
| 151 | 
            +
            		generate_kwargs["do_sample"] = False
         | 
| 152 | 
            +
            	
         | 
| 153 | 
            +
            	streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
         | 
| 154 | 
            +
            	t = Thread(target=model.generate, kwargs=generate_kwargs)
         | 
| 155 | 
            +
            	t.start()
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            	outputs = []
         | 
| 158 | 
            +
            	for text in streamer:
         | 
| 159 | 
            +
            		outputs.append(text)
         | 
| 160 | 
            +
            		yield "".join(outputs)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            with gr.Blocks() as demo:
         | 
| 166 | 
            +
            	gr.HTML(TITLE)
         | 
| 167 | 
            +
            	gr.Markdown(DESCRIPTION)
         | 
| 168 | 
            +
            	gr.ChatInterface(
         | 
| 169 | 
            +
            		fn=chat_joycaption,
         | 
| 170 | 
            +
            		chatbot=chatbot,
         | 
| 171 | 
            +
            		fill_height=True,
         | 
| 172 | 
            +
            		additional_inputs_accordion=gr.Accordion(label="βοΈ Parameters", open=False, render=False),
         | 
| 173 | 
            +
            		additional_inputs=[
         | 
| 174 | 
            +
            			gr.Slider(minimum=0,
         | 
| 175 | 
            +
            						maximum=1, 
         | 
| 176 | 
            +
            						step=0.1,
         | 
| 177 | 
            +
            						value=0.6, 
         | 
| 178 | 
            +
            						label="Temperature", 
         | 
| 179 | 
            +
            						render=False),
         | 
| 180 | 
            +
            			gr.Slider(minimum=128, 
         | 
| 181 | 
            +
            						maximum=4096,
         | 
| 182 | 
            +
            						step=1,
         | 
| 183 | 
            +
            						value=1024, 
         | 
| 184 | 
            +
            						label="Max new tokens", 
         | 
| 185 | 
            +
            						render=False ),
         | 
| 186 | 
            +
            			],
         | 
| 187 | 
            +
            		examples=[
         | 
| 188 | 
            +
            			['How to setup a human base on Mars? Give short answer.'],
         | 
| 189 | 
            +
            			['Explain theory of relativity to me like Iβm 8 years old.'],
         | 
| 190 | 
            +
            			['What is 9,000 * 9,000?'],
         | 
| 191 | 
            +
            			['Write a pun-filled happy birthday message to my friend Alex.'],
         | 
| 192 | 
            +
            			['Justify why a penguin might make a good king of the jungle.']
         | 
| 193 | 
            +
            			],
         | 
| 194 | 
            +
            		cache_examples=False,
         | 
| 195 | 
            +
                )
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            if __name__ == "__main__":
         | 
| 199 | 
            +
                demo.launch()
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            huggingface_hub==0.23.4
         | 
| 2 | 
            +
            accelerate
         | 
| 3 | 
            +
            torch
         | 
| 4 | 
            +
            transformers==4.45.2
         | 
| 5 | 
            +
            sentencepiece
         | 
| 6 | 
            +
            torchvision
         | 
