File size: 8,716 Bytes
67c143c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9dc4dd97-1a2a-409f-8db5-4e0f2f07d49d",
   "metadata": {},
   "source": [
    "# Phi-4-multimodal simple demo\n",
    "\n",
    "Make sure that you must install `gradio, soundfile, and pillow`.\n",
    "\n",
    "- `pip install gradio transformers torch soundfile pillow`\n",
    "- Retrieved from https://www.datacamp.com/tutorial/phi-4-multimodal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ffc47b0-a12b-4b8a-9066-6f15acfc9210",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "import torch\n",
    "import requests\n",
    "import io\n",
    "import os\n",
    "import soundfile as sf\n",
    "from PIL import Image\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig\n",
    "\n",
    "\n",
    "max_new_tokens = 256\n",
    "orig_model_path = \"microsoft/Phi-4-multimodal-instruct\"\n",
    "ft_model_path = \"daekeun-ml/Phi-4-multimodal-finetune-ko-speech\"\n",
    "generation_config = GenerationConfig.from_pretrained(ft_model_path, 'generation_config.json')\n",
    "processor = AutoProcessor.from_pretrained(orig_model_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    ft_model_path,\n",
    "    trust_remote_code=True,\n",
    "    torch_dtype='auto',\n",
    "    _attn_implementation='flash_attention_2',\n",
    ").cuda()\n",
    "\n",
    "user_prompt = '<|user|>'\n",
    "assistant_prompt = '<|assistant|>'\n",
    "prompt_suffix = '<|end|>'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4058364b-d041-4168-b8d7-26813467f454",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def clean_response(response, instruction_keywords):\n",
    "    \"\"\"Removes the prompt text dynamically based on instruction keywords.\"\"\"\n",
    "    for keyword in instruction_keywords:\n",
    "        if response.lower().startswith(keyword.lower()):\n",
    "            response = response[len(keyword):].strip()\n",
    "    return response\n",
    "\n",
    "# task prompt is from technical report\n",
    "asr_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio clip into text.{prompt_suffix}{assistant_prompt}'\n",
    "ast_ko_prompt = f'{user_prompt}<|audio_1|>Translate the audio to Korean.{prompt_suffix}{assistant_prompt}'\n",
    "ast_cot_ko_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to Korean. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}'\n",
    "ast_en_prompt = f'{user_prompt}<|audio_1|>Translate the audio to English.{prompt_suffix}{assistant_prompt}'\n",
    "ast_cot_en_prompt = f'{user_prompt}<|audio_1|>Transcribe the audio to text, and then translate the audio to English. Use <sep> as a separator between the original transcript and the translation.{prompt_suffix}{assistant_prompt}'\n",
    "\n",
    "def process_input(file, input_type, question):\n",
    "    user_prompt = \"<|user|>\"\n",
    "    assistant_prompt = \"<|assistant|>\"\n",
    "    prompt_suffix = \"<|end|>\"\n",
    "    \n",
    "    if input_type == \"Image\":\n",
    "        prompt= f'{user_prompt}<|image_1|>{question}{prompt_suffix}{assistant_prompt}'\n",
    "        image = Image.open(file)\n",
    "        inputs = processor(text=prompt, images=image, return_tensors='pt').to(model.device)\n",
    "    elif input_type == \"Audio\":\n",
    "        prompt= f'{user_prompt}<|audio_1|>{question}{prompt_suffix}{assistant_prompt}'\n",
    "        audio, samplerate = sf.read(file)\n",
    "        inputs = processor(text=prompt, audios=[(audio, samplerate)], return_tensors='pt').to(model.device)\n",
    "    elif input_type == \"Text\":\n",
    "        prompt = f'{user_prompt}{question} \"{file}\"{prompt_suffix}{assistant_prompt}'\n",
    "        inputs = processor(text=prompt, return_tensors='pt').to(model.device)\n",
    "    else:\n",
    "        return \"Invalid input type\"    \n",
    "    \n",
    "    generate_ids = model.generate(**inputs, max_new_tokens=1000, generation_config=generation_config)\n",
    "    response = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]\n",
    "    return clean_response(response, [question])\n",
    "\n",
    "def process_text_translate(text, target_language):\n",
    "    prompt = f'Transcribe the audio to text, and then Translate the following text to {target_language}: \"{text}\"'\n",
    "    return process_input(text, \"Text\", prompt)\n",
    "def process_text_grammar(text):\n",
    "    prompt = f'Check the grammar and provide corrections if needed for the following text: \"{text}\"'\n",
    "    return process_input(text, \"Text\", prompt)\n",
    "\n",
    "def gradio_interface():\n",
    "    with gr.Blocks() as demo:\n",
    "        gr.Markdown(\"# Phi 4 Powered - Multimodal Language Tutor\")        \n",
    "        with gr.Tab(\"Text-Based Learning\"):\n",
    "            text_input = gr.Textbox(label=\"Enter Text\")\n",
    "            language_input = gr.Textbox(label=\"Target Language\", value=\"Korean\")\n",
    "            text_output = gr.Textbox(label=\"Response\")\n",
    "            text_translate_btn = gr.Button(\"Translate\")\n",
    "            text_grammar_btn = gr.Button(\"Check Grammar\")\n",
    "            text_clear_btn = gr.Button(\"Clear\")\n",
    "            text_translate_btn.click(process_text_translate, inputs=[text_input, language_input], outputs=text_output)\n",
    "            text_grammar_btn.click(process_text_grammar, inputs=[text_input], outputs=text_output)\n",
    "            text_clear_btn.click(lambda: (\"\", \"\", \"\"), outputs=[text_input, language_input, text_output])        \n",
    "        with gr.Tab(\"Image-Based Learning\"):\n",
    "            image_input = gr.Image(type=\"filepath\", label=\"Upload Image\")\n",
    "            language_input_image = gr.Textbox(label=\"Target Language for Translation\", value=\"English\")\n",
    "            image_output = gr.Textbox(label=\"Response\")\n",
    "            image_clear_btn = gr.Button(\"Clear\")\n",
    "            image_translate_btn = gr.Button(\"Translate Text in Image\")\n",
    "            image_summarize_btn = gr.Button(\"Summarize Image\")\n",
    "            image_translate_btn.click(process_input, inputs=[image_input, gr.Textbox(value=\"Image\", visible=False), gr.Textbox(value=\"Extract and translate text\", visible=False)], outputs=image_output)\n",
    "            image_summarize_btn.click(process_input, inputs=[image_input, gr.Textbox(value=\"Image\", visible=False), gr.Textbox(value=\"Summarize this image\", visible=False)], outputs=image_output)\n",
    "            image_clear_btn.click(lambda: (None, \"\", \"\"), outputs=[image_input, language_input_image, image_output])\n",
    "        with gr.Tab(\"Audio-Based Learning\"):\n",
    "            audio_input = gr.Audio(type=\"filepath\", label=\"Upload Audio\")\n",
    "            language_input_audio = gr.Textbox(label=\"Target Language for Translation\", value=\"English\")\n",
    "            transcript_output = gr.Textbox(label=\"Transcribed Text\")\n",
    "            translated_output = gr.Textbox(label=\"Translated Text\")\n",
    "            audio_clear_btn = gr.Button(\"Clear\")\n",
    "            audio_transcribe_btn = gr.Button(\"Transcribe & Translate\")\n",
    "            audio_transcribe_btn.click(process_input, inputs=[audio_input, gr.Textbox(value=\"Audio\", visible=False), gr.Textbox(value=\"Transcribe this audio\", visible=False)], outputs=transcript_output)\n",
    "            audio_transcribe_btn.click(process_input, inputs=[audio_input, gr.Textbox(value=\"Audio\", visible=False), language_input_audio], outputs=translated_output)\n",
    "            audio_clear_btn.click(lambda: (None, \"\", \"\", \"\"), outputs=[audio_input, language_input_audio, transcript_output, translated_output])        \n",
    "        demo.launch(debug=True, share=True)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    gradio_interface()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10 - SDK v2",
   "language": "python",
   "name": "python310-sdkv2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}