Spaces:
Runtime error
Runtime error
File size: 4,738 Bytes
294ad84 0788398 294ad84 0788398 294ad84 0788398 294ad84 0788398 294ad84 0788398 294ad84 0788398 294ad84 0788398 294ad84 3e95aa0 294ad84 03d0755 3f5c67e 294ad84 3f5c67e 6703418 3f5c67e 42175aa 294ad84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# %%
import os, json, itertools, bisect, gc
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import transformers
import torch
from accelerate import Accelerator
import accelerate
import time
import os
import gradio as gr
import requests
import random
import googletrans
translator = googletrans.Translator()
model = None
tokenizer = None
generator = None
os.environ["CUDA_VISIBLE_DEVICES"]=""
def load_model(model_name, eight_bit=0, device_map="auto"):
global model, tokenizer, generator
print("Loading "+model_name+"...")
if device_map == "zero":
device_map = "balanced_low_0"
# config
gpu_count = torch.cuda.device_count()
print('gpu_count', gpu_count)
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
print(model_name)
tokenizer = transformers.LLaMATokenizer.from_pretrained(model_name)
model = transformers.LLaMAForCausalLM.from_pretrained(
model_name,
#device_map=device_map,
#device_map="auto",
torch_dtype=torch_dtype,
#max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"},
#load_in_8bit=eight_bit,
#from_tf=True,
low_cpu_mem_usage=True,
load_in_8bit=False,
cache_dir="cache"
)
if torch.cuda.is_available():
model = model.cuda()
else:
model = model.cpu()
generator = model.generate
# chat doctor
def chatdoctor(input, state):
# print('input',input)
# history = history or []
print('state',state)
invitation = "ChatDoctor: "
human_invitation = "Patient: "
fulltext = "If you are a doctor, please answer the medical questions based on the patient's description. \n\n"
for i in range(len(state)):
if i % 2:
fulltext += human_invitation + state[i] + "\n\n"
else:
fulltext += invitation + state[i] + "\n\n"
fulltext += human_invitation + input + "\n\n"
fulltext += invitation
print('fulltext: ',fulltext)
generated_text = ""
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids
if torch.cuda.is_available():
gen_in = gen_in.cuda()
else:
gen_in = gen_in.cpu()
in_tokens = len(gen_in)
print('len token',in_tokens)
with torch.no_grad():
generated_ids = generator(
gen_in,
max_new_tokens=200,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1,
do_sample=True,
repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx:
temperature=0.5, # default: 1.0
top_k = 50, # default: 50
top_p = 1.0, # default: 1.0
early_stopping=True,
)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element?
text_without_prompt = generated_text[len(fulltext):]
response = text_without_prompt
response = response.split(human_invitation)[0]
response.strip()
print(invitation + response)
print("")
return response
def predict(input, chatbot, state):
print('predict state: ', state)
# input์ ํ๊ตญ์ด๊ฐ detect ๋๋ฉด ์์ด๋ก ๋ณ๊ฒฝ, ์๋๋ฉด ๊ทธ๋๋ก
is_kor = True
if googletrans.Translator().detect(input).lang == 'ko':
en_input = translator.translate(input, src='ko', dest='en').text
else:
en_input = input
is_kor = False
response = chatdoctor(en_input, state)
if is_kor:
ko_response = translator.translate(response, src='en', dest='ko').text
else:
ko_response = response
state.append(response)
chatbot.append((input, ko_response))
return chatbot, state
load_model("mnc-ai/chatdoctor")
with gr.Blocks() as demo:
gr.Markdown("""<h1><center>์ฑ ๋ฅํฐ์
๋๋ค. ์ด๋๊ฐ ๋ถํธํ์ ๊ฐ์?</center></h1>
""")
chatbot = gr.Chatbot()
state = gr.State([])
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="์ฌ๊ธฐ์ ์ง๋ฌธ์ ์ฐ๊ณ ์ํฐ").style(container=False)
clear = gr.Button("์๋ด ์๋ก ์์")
txt.submit(predict, inputs=[txt, chatbot, state], outputs=[chatbot, state], queue=False )
txt.submit(lambda x: "", txt, txt)
clear.click(lambda: None, None, chatbot, queue=False)
clear.click(lambda x: "", txt, txt)
# clear ํด๋ฆญ ์ state ์ด๊ธฐํ
clear.click(lambda x: [], state, state)
demo.launch()
|