Gary commited on
Commit
1d656af
·
1 Parent(s): 1b9a516

use larger model

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. indexer.py +20 -6
app.py CHANGED
@@ -1,5 +1,4 @@
1
  from indexer import (
2
- load_raw_dataset,
3
  create_vector_database,
4
  get_llm,
5
  get_prompt_template,
@@ -24,14 +23,15 @@ class CustomRAG:
24
 
25
  def run(self, query):
26
  retriever = self.vector_db.as_retriever(search_kwargs={"k": 3})
27
- contexts = retriever.get_relevant_documents(query)
28
  formatted_context = format_contexts(contexts)
29
  prompt = self.prompt_template.format(context=formatted_context, question=query)
30
  return self.llm.invoke(prompt), contexts
31
 
32
 
33
  def answer_question(query):
34
- llm = get_llm("google/flan-t5-base")
 
35
  vector_database = create_vector_database("sentence-transformers/all-MiniLM-L6-v2")
36
  prompt_template = get_prompt_template()
37
  rag = CustomRAG(
 
1
  from indexer import (
 
2
  create_vector_database,
3
  get_llm,
4
  get_prompt_template,
 
23
 
24
  def run(self, query):
25
  retriever = self.vector_db.as_retriever(search_kwargs={"k": 3})
26
+ contexts = retriever.invoke(query)
27
  formatted_context = format_contexts(contexts)
28
  prompt = self.prompt_template.format(context=formatted_context, question=query)
29
  return self.llm.invoke(prompt), contexts
30
 
31
 
32
  def answer_question(query):
33
+ # llm = get_llm("google/flan-t5-base")
34
+ llm = get_llm("FreedomIntelligence/HuatuoGPT-o1-7B")
35
  vector_database = create_vector_database("sentence-transformers/all-MiniLM-L6-v2")
36
  prompt_template = get_prompt_template()
37
  rag = CustomRAG(
indexer.py CHANGED
@@ -7,6 +7,7 @@ from transformers import (
7
  AutoTokenizer,
8
  pipeline,
9
  AutoModelForSeq2SeqLM,
 
10
  )
11
  from langchain.llms import HuggingFacePipeline
12
  from langchain.prompts import PromptTemplate
@@ -16,6 +17,7 @@ api_key = os.environ["PINECONE_API_KEY"]
16
 
17
  from langchain_pinecone import PineconeVectorStore
18
 
 
19
  def load_raw_dataset():
20
  dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")
21
 
@@ -47,16 +49,28 @@ def create_vector_database(model_name):
47
 
48
  def get_llm(model_name):
49
  tokenizer = AutoTokenizer.from_pretrained(model_name)
50
- model = AutoModelForSeq2SeqLM.from_pretrained(
51
- "google/flan-t5-base", torch_dtype="auto", device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
53
-
54
  pipe = pipeline(
55
- "text2text-generation",
56
  model=model,
57
  tokenizer=tokenizer,
58
- max_new_tokens=512,
59
- temperature=1,
60
  do_sample=True,
61
  )
62
 
 
7
  AutoTokenizer,
8
  pipeline,
9
  AutoModelForSeq2SeqLM,
10
+ AutoModelForCausalLM,
11
  )
12
  from langchain.llms import HuggingFacePipeline
13
  from langchain.prompts import PromptTemplate
 
17
 
18
  from langchain_pinecone import PineconeVectorStore
19
 
20
+
21
  def load_raw_dataset():
22
  dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")
23
 
 
49
 
50
  def get_llm(model_name):
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
52
+ # model = AutoModelForSeq2SeqLM.from_pretrained(
53
+ # "google/flan-t5-base", torch_dtype="auto", device_map="auto"
54
+ # )
55
+
56
+ # pipe = pipeline(
57
+ # "text2text-generation",
58
+ # model=model,
59
+ # tokenizer=tokenizer,
60
+ # max_new_tokens=512,
61
+ # temperature=1,
62
+ # do_sample=True,
63
+ # )
64
+
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ model_name, torch_dtype="auto", device_map="auto"
67
  )
 
68
  pipe = pipeline(
69
+ "text-generation",
70
  model=model,
71
  tokenizer=tokenizer,
72
+ max_new_tokens=1024,
73
+ temperature=0.7,
74
  do_sample=True,
75
  )
76