bhaskartripathi commited on
Commit
e5fd599
·
1 Parent(s): a1084bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -37
app.py CHANGED
@@ -1,56 +1,156 @@
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from peft import PeftModel
 
 
3
  from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
 
 
 
 
 
 
 
4
 
5
- tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-13b-hf")
6
  model = LlamaForCausalLM.from_pretrained(
7
- "decapoda-research/llama-13b-hf",
8
- load_in_8bit=True,
9
- torch_dtype=torch.float16,
10
- device_map="auto",
 
 
 
11
  )
12
  model = PeftModel.from_pretrained(
13
- model, "baruga/alpaca-lora-13b",
14
- torch_dtype=torch.float16
 
15
  )
16
 
17
- def generate_prompt(instruction, input=None):
18
- if input:
19
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. Answer step by step.
 
20
 
21
- ### Instruction:
22
- {instruction}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- ### Input:
25
- {input}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- ### Response:"""
28
- else:
29
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. Answer step by step.
30
 
 
31
  ### Instruction:
32
- {instruction}
33
-
 
 
34
  ### Response:"""
35
 
36
- generation_config = GenerationConfig(
37
- temperature=0.1,
38
- top_p=0.75,
39
- num_beams=4,
40
  )
41
 
42
- def evaluate(instruction, input=None):
43
- prompt = generate_prompt(instruction, input)
44
- inputs = tokenizer(prompt, return_tensors="pt")
45
- input_ids = inputs["input_ids"].cuda()
46
- generation_output = model.generate(
47
- input_ids=input_ids,
48
- generation_config=generation_config,
49
- return_dict_in_generate=True,
50
- output_scores=True,
51
- max_new_tokens=256
52
- )
53
- for s in generation_output.sequences:
54
- output = tokenizer.decode(s)
55
- print("Response:", output.split("### Response:")[1].strip())
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Python file to serve as the frontend"""
2
+ import streamlit as st
3
+ from streamlit_chat import message
4
+
5
+ from langchain.chains import ConversationChain, LLMChain
6
+ from langchain import PromptTemplate
7
+ from langchain.llms.base import LLM
8
+ from langchain.memory import ConversationBufferWindowMemory
9
+ from typing import Optional, List, Mapping, Any
10
+
11
  import torch
12
  from peft import PeftModel
13
+ import transformers
14
+
15
  from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
16
+ from transformers import BitsAndBytesConfig
17
+
18
+ tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
19
+
20
+
21
+
22
+ quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
23
 
 
24
  model = LlamaForCausalLM.from_pretrained(
25
+ "decapoda-research/llama-7b-hf",
26
+ # load_in_8bit=True,
27
+ # torch_dtype=torch.float16,
28
+ device_map="auto",
29
+ # device_map={"":"cpu"},
30
+ max_memory={"cpu":"15GiB"},
31
+ quantization_config=quantization_config
32
  )
33
  model = PeftModel.from_pretrained(
34
+ model, "tloen/alpaca-lora-7b",
35
+ # torch_dtype=torch.float16,
36
+ device_map={"":"cpu"},
37
  )
38
 
39
+ device = "cpu"
40
+ print("model device :", model.device, flush=True)
41
+ # model.to(device)
42
+ model.eval()
43
 
44
+ def evaluate_raw_prompt(
45
+ prompt:str,
46
+ temperature=0.1,
47
+ top_p=0.75,
48
+ top_k=40,
49
+ num_beams=4,
50
+ **kwargs,
51
+ ):
52
+ inputs = tokenizer(prompt, return_tensors="pt")
53
+ input_ids = inputs["input_ids"].to(device)
54
+ generation_config = GenerationConfig(
55
+ temperature=temperature,
56
+ top_p=top_p,
57
+ top_k=top_k,
58
+ num_beams=num_beams,
59
+ **kwargs,
60
+ )
61
+ with torch.no_grad():
62
+ generation_output = model.generate(
63
+ input_ids=input_ids,
64
+ generation_config=generation_config,
65
+ return_dict_in_generate=True,
66
+ output_scores=True,
67
+ max_new_tokens=256,
68
+ )
69
+ s = generation_output.sequences[0]
70
+ output = tokenizer.decode(s)
71
+ # return output
72
+ return output.split("### Response:")[1].strip()
73
 
74
+ class AlpacaLLM(LLM):
75
+ temperature: float
76
+ top_p: float
77
+ top_k: int
78
+ num_beams: int
79
+ @property
80
+ def _llm_type(self) -> str:
81
+ return "custom"
82
+
83
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
84
+ if stop is not None:
85
+ raise ValueError("stop kwargs are not permitted.")
86
+ answer = evaluate_raw_prompt(prompt,
87
+ top_p= self.top_p,
88
+ top_k= self.top_k,
89
+ num_beams= self.num_beams,
90
+ temperature= self.temperature
91
+ )
92
+ return answer
93
+
94
+ @property
95
+ def _identifying_params(self) -> Mapping[str, Any]:
96
+ """Get the identifying parameters."""
97
+ return {
98
+ "top_p": self.top_p,
99
+ "top_k": self.top_k,
100
+ "num_beams": self.num_beams,
101
+ "temperature": self.temperature
102
+ }
103
 
 
 
 
104
 
105
+ template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
106
  ### Instruction:
107
+ You are a chatbot, you should answer my last question very briefly. You are consistent and non repetitive.
108
+ ### Chat:
109
+ {history}
110
+ Human: {human_input}
111
  ### Response:"""
112
 
113
+ prompt = PromptTemplate(
114
+ input_variables=["history","human_input"],
115
+ template=template,
 
116
  )
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ def load_chain():
120
+ """Logic for loading the chain you want to use should go here."""
121
+ llm = AlpacaLLM(top_p=0.75, top_k=40, num_beams=4, temperature=0.1)
122
+ # chain = ConversationChain(llm=llm)
123
+ chain = LLMChain(llm=llm, prompt=prompt, memory=ConversationBufferWindowMemory(k=2))
124
+ return chain
125
+
126
+ chain = load_chain()
127
+
128
+ # From here down is all the StreamLit UI.
129
+ st.set_page_config(page_title="LangChain Demo", page_icon=":robot:")
130
+ st.header("LangChain Demo")
131
+
132
+ if "generated" not in st.session_state:
133
+ st.session_state["generated"] = []
134
+
135
+ if "past" not in st.session_state:
136
+ st.session_state["past"] = []
137
+
138
+
139
+ def get_text():
140
+ input_text = st.text_input("Human: ", "Hello, how are you?", key="input")
141
+ return input_text
142
+
143
+
144
+ user_input = get_text()
145
+
146
+ if user_input:
147
+ output = chain.predict(human_input=user_input)
148
+
149
+ st.session_state.past.append(user_input)
150
+ st.session_state.generated.append(output)
151
+
152
+ if st.session_state["generated"]:
153
+
154
+ for i in range(len(st.session_state["generated"]) - 1, -1, -1):
155
+ message(st.session_state["generated"][i], key=str(i))
156
+ message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")