File size: 8,716 Bytes
67c143c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
{
"cells": [
{
"cell_type": "markdown",
"id": "9dc4dd97-1a2a-409f-8db5-4e0f2f07d49d",
"metadata": {},
"source": [
"# Phi-4-multimodal simple demo\n",
"\n",
"Make sure that you must install `gradio, soundfile, and pillow`.\n",
"\n",
"- `pip install gradio transformers torch soundfile pillow`\n",
"- Retrieved from https://www.datacamp.com/tutorial/phi-4-multimodal"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ffc47b0-a12b-4b8a-9066-6f15acfc9210",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import gradio as gr\n",
"import torch\n",
"import requests\n",
"import io\n",
"import os\n",
"import soundfile as sf\n",
"from PIL import Image\n",
"from datasets import load_dataset\n",
"from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig\n",
"\n",
"\n",
"max_new_tokens = 256\n",
"orig_model_path = \"microsoft/Phi-4-multimodal-instruct\"\n",
"ft_model_path = \"daekeun-ml/Phi-4-multimodal-finetune-ko-speech\"\n",
"generation_config = GenerationConfig.from_pretrained(ft_model_path, 'generation_config.json')\n",
"processor = AutoProcessor.from_pretrained(orig_model_path, trust_remote_code=True)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" ft_model_path,\n",
" trust_remote_code=True,\n",
" torch_dtype='auto',\n",
" _attn_implementation='flash_attention_2',\n",
").cuda()\n",
"\n",
"user_prompt = '<|user|>'\n",
"assistant_prompt = '<|assistant|>'\n",
"prompt_suffix = '<|end|>'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4058364b-d041-4168-b8d7-26813467f454",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def clean_response(response, instruction_keywords):\n",
" \"\"\"Removes the prompt text dynamically based on instruction keywords.\"\"\"\n",
" for keyword in instruction_keywords:\n",
" if response.lower().startswith(keyword.lower()):\n",
" response = response[len(keyword):].strip()\n",
" return response\n",
"\n",
"# task prompt is from technical report\n",
"asr_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio clip into text.{prompt_suffix}{assistant_prompt}'\n",
"ast_ko_prompt = f'{user_prompt}<|audio_1|>Translate the audio to Korean.{prompt_suffix}{assistant_prompt}'\n",
"ast_cot_ko_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to Korean. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}'\n",
"ast_en_prompt = f'{user_prompt}<|audio_1|>Translate the audio to English.{prompt_suffix}{assistant_prompt}'\n",
"ast_cot_en_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to English. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}'\n",
"\n",
"def process_input(file, input_type, question):\n",
" user_prompt = \"<|user|>\"\n",
" assistant_prompt = \"<|assistant|>\"\n",
" prompt_suffix = \"<|end|>\"\n",
" \n",
" if input_type == \"Image\":\n",
" prompt= f'{user_prompt}<|image_1|>{question}{prompt_suffix}{assistant_prompt}'\n",
" image = Image.open(file)\n",
" inputs = processor(text=prompt, images=image, return_tensors='pt').to(model.device)\n",
" elif input_type == \"Audio\":\n",
" prompt= f'{user_prompt}<|audio_1|>{question}{prompt_suffix}{assistant_prompt}'\n",
" audio, samplerate = sf.read(file)\n",
" inputs = processor(text=prompt, audios=[(audio, samplerate)], return_tensors='pt').to(model.device)\n",
" elif input_type == \"Text\":\n",
" prompt = f'{user_prompt}{question} \"{file}\"{prompt_suffix}{assistant_prompt}'\n",
" inputs = processor(text=prompt, return_tensors='pt').to(model.device)\n",
" else:\n",
" return \"Invalid input type\" \n",
" \n",
" generate_ids = model.generate(**inputs, max_new_tokens=1000, generation_config=generation_config)\n",
" response = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]\n",
" return clean_response(response, [question])\n",
"\n",
"def process_text_translate(text, target_language):\n",
" prompt = f'Transcribe the audio to text, and then Translate the following text to {target_language}: \"{text}\"'\n",
" return process_input(text, \"Text\", prompt)\n",
"def process_text_grammar(text):\n",
" prompt = f'Check the grammar and provide corrections if needed for the following text: \"{text}\"'\n",
" return process_input(text, \"Text\", prompt)\n",
"\n",
"def gradio_interface():\n",
" with gr.Blocks() as demo:\n",
" gr.Markdown(\"# Phi 4 Powered - Multimodal Language Tutor\") \n",
" with gr.Tab(\"Text-Based Learning\"):\n",
" text_input = gr.Textbox(label=\"Enter Text\")\n",
" language_input = gr.Textbox(label=\"Target Language\", value=\"Korean\")\n",
" text_output = gr.Textbox(label=\"Response\")\n",
" text_translate_btn = gr.Button(\"Translate\")\n",
" text_grammar_btn = gr.Button(\"Check Grammar\")\n",
" text_clear_btn = gr.Button(\"Clear\")\n",
" text_translate_btn.click(process_text_translate, inputs=[text_input, language_input], outputs=text_output)\n",
" text_grammar_btn.click(process_text_grammar, inputs=[text_input], outputs=text_output)\n",
" text_clear_btn.click(lambda: (\"\", \"\", \"\"), outputs=[text_input, language_input, text_output]) \n",
" with gr.Tab(\"Image-Based Learning\"):\n",
" image_input = gr.Image(type=\"filepath\", label=\"Upload Image\")\n",
" language_input_image = gr.Textbox(label=\"Target Language for Translation\", value=\"English\")\n",
" image_output = gr.Textbox(label=\"Response\")\n",
" image_clear_btn = gr.Button(\"Clear\")\n",
" image_translate_btn = gr.Button(\"Translate Text in Image\")\n",
" image_summarize_btn = gr.Button(\"Summarize Image\")\n",
" image_translate_btn.click(process_input, inputs=[image_input, gr.Textbox(value=\"Image\", visible=False), gr.Textbox(value=\"Extract and translate text\", visible=False)], outputs=image_output)\n",
" image_summarize_btn.click(process_input, inputs=[image_input, gr.Textbox(value=\"Image\", visible=False), gr.Textbox(value=\"Summarize this image\", visible=False)], outputs=image_output)\n",
" image_clear_btn.click(lambda: (None, \"\", \"\"), outputs=[image_input, language_input_image, image_output])\n",
" with gr.Tab(\"Audio-Based Learning\"):\n",
" audio_input = gr.Audio(type=\"filepath\", label=\"Upload Audio\")\n",
" language_input_audio = gr.Textbox(label=\"Target Language for Translation\", value=\"English\")\n",
" transcript_output = gr.Textbox(label=\"Transcribed Text\")\n",
" translated_output = gr.Textbox(label=\"Translated Text\")\n",
" audio_clear_btn = gr.Button(\"Clear\")\n",
" audio_transcribe_btn = gr.Button(\"Transcribe & Translate\")\n",
" audio_transcribe_btn.click(process_input, inputs=[audio_input, gr.Textbox(value=\"Audio\", visible=False), gr.Textbox(value=\"Transcribe this audio\", visible=False)], outputs=transcript_output)\n",
" audio_transcribe_btn.click(process_input, inputs=[audio_input, gr.Textbox(value=\"Audio\", visible=False), language_input_audio], outputs=translated_output)\n",
" audio_clear_btn.click(lambda: (None, \"\", \"\", \"\"), outputs=[audio_input, language_input_audio, transcript_output, translated_output]) \n",
" demo.launch(debug=True, share=True)\n",
"\n",
"if __name__ == \"__main__\":\n",
" gradio_interface()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10 - SDK v2",
"language": "python",
"name": "python310-sdkv2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|