File size: 2,096 Bytes
2e8ea21
08422e4
121448b
 
08422e4
121448b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e8ea21
08422e4
121448b
 
 
 
 
 
 
 
 
 
 
 
2e8ea21
121448b
2e8ea21
 
121448b
 
2e8ea21
 
121448b
2e8ea21
 
121448b
 
2e8ea21
 
121448b
 
 
2e8ea21
08422e4
121448b
 
 
2e8ea21
 
121448b
 
2e8ea21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Model name and configuration
model_name = "ruslanmv/Medical-Llama3-8B"
device_map = "auto"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    trust_remote_code=True,
    use_cache=False,
    device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Set the chat template
chat_template = """<|im_start|>system
{system}
<|im_end|>
<|im_start|>user
{user}
<|im_end|>
<|im_start|>assistant
"""
tokenizer.chat_template = chat_template

# Define the askme function
def askme(question):
    sys_message = """ 
    You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
    provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
    """   
    # Structure messages for the chat
    messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]
    
    # Apply the chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    # Generate response
    outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)
    
    # Extract and return the generated text
    response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = response_text.split("<|im_start|>assistant")[-1].strip()
    return answer

# Example usage
question = """
I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, 
increased sensitivity to cold, and dry, itchy skin. 
Could these symptoms be related to hypothyroidism? 
If so, what steps should I take to get a proper diagnosis and discuss treatment options?
"""
print(askme(question))