Bostoncake commited on
Commit
f24bfa5
·
1 Parent(s): 6eef56f

Delete unused files.

Browse files
Files changed (2) hide show
  1. chat_assistant.py +0 -208
  2. get_paper_from_pdf.py +0 -194
chat_assistant.py DELETED
@@ -1,208 +0,0 @@
1
- import numpy as np
2
- import os
3
- import re
4
- import datetime
5
- import time
6
- import openai, tenacity
7
- import argparse
8
- import configparser
9
- import json
10
- import tiktoken
11
- from get_paper_from_pdf import Paper
12
-
13
- class Assistant:
14
- def __init__(self, args=None):
15
- if args.language == 'en':
16
- self.language = 'English'
17
- elif args.language == 'zh':
18
- self.language = 'Chinese'
19
- else:
20
- self.language = 'Chinese'
21
- self.config = configparser.ConfigParser()
22
- self.config.read('apikey.ini')
23
- self.chat_api_list = self.config.get('OpenAI', 'OPENAI_API_KEYS')[1:-1].replace('\'', '').split(',')
24
- self.chat_api_list = [api.strip() for api in self.chat_api_list if len(api) > 5]
25
- self.cur_api = 0
26
- self.file_format = args.file_format
27
- self.max_token_num = 4096
28
- self.encoding = tiktoken.get_encoding("gpt2")
29
- self.result_backup = ''
30
-
31
- def validateTitle(self, title):
32
- rstr = r"[\/\\\:\*\?\"\<\>\|]"
33
- new_title = re.sub(rstr, "_", title)
34
- return new_title
35
-
36
-
37
- def assist_reading_by_chatgpt(self, paper_list):
38
- htmls = []
39
- for paper_index, paper in enumerate(paper_list):
40
- sections_of_interest = self.extract_paper(paper)
41
- # extract the essential parts of the paper
42
- text = ''
43
- text += 'Title:' + paper.title + '. '
44
- text += 'Abstract: ' + paper.section_texts['Abstract']
45
- intro_title = next((item for item in paper.section_names if 'ntroduction' in item.lower()), None)
46
- if intro_title is not None:
47
- text += 'Introduction: ' + paper.section_texts[intro_title]
48
- # Similar for conclusion section
49
- conclusion_title = next((item for item in paper.section_names if 'onclusion' in item), None)
50
- if conclusion_title is not None:
51
- text += 'Conclusion: ' + paper.section_texts[conclusion_title]
52
- for heading in sections_of_interest:
53
- if heading in paper.section_names:
54
- text += heading + ': ' + paper.section_texts[heading]
55
- chat_review_text = self.chat_assist(text=text)
56
- htmls.append('## Paper:' + str(paper_index+1))
57
- htmls.append('\n\n\n')
58
- htmls.append(chat_review_text)
59
-
60
- # 将问题与回答保存起来
61
- date_str = str(datetime.datetime.now())[:19].replace(' ', '-').replace(':', '-')
62
- try:
63
- export_path = os.path.join('./', 'output_file')
64
- os.makedirs(export_path)
65
- except:
66
- pass
67
- mode = 'w' if paper_index == 0 else 'a'
68
- file_name = os.path.join(export_path, date_str+'-'+self.validateTitle(paper.title)+"."+self.file_format)
69
- self.export_to_markdown("\n".join(htmls), file_name=file_name, mode=mode)
70
- htmls = []
71
-
72
-
73
- def extract_paper(self, paper):
74
- htmls = []
75
- text = ''
76
- text += 'Title: ' + paper.title + '. '
77
- text += 'Abstract: ' + paper.section_texts['Abstract']
78
- text_token = len(self.encoding.encode(text))
79
- if text_token > self.max_token_num/2 - 800:
80
- input_text_index = int(len(text)*((self.max_token_num/2)-800)/text_token)
81
- text = text[:input_text_index]
82
- openai.api_key = self.chat_api_list[self.cur_api]
83
- self.cur_api += 1
84
- self.cur_api = 0 if self.cur_api >= len(self.chat_api_list)-1 else self.cur_api
85
- print("\n\n"+"********"*10)
86
- print("Extracting content from PDF.")
87
- print("********"*10)
88
- messages = [
89
- {"role": "system",
90
- "content": f"You are a professional researcher in the field of {args.research_fields}. You are the mentor of a student who is new to this field. "
91
- f"I will give you a paper. You need to help your student to read this paper by instructing him to read the important sections in this paper and answer his questions towards these sections."
92
- f"Due to the length limitations, I am only allowed to provide you the abstract, introduction, conclusion and at most two sections of this paper."
93
- f"Now I will give you the title and abstract and the headings of potential sections. "
94
- f"You need to reply at most two headings. Then I will further provide you the full information, includes aforementioned sections and at most two sections you called for.\n\n"
95
- f"Title: {paper.title}\n\n"
96
- f"Abstract: {paper.section_texts['Abstract']}\n\n"
97
- f"Potential Sections: {paper.section_names[2:-1]}\n\n"
98
- f"Follow the following format to output your choice of sections:"
99
- f"{{chosen section 1}}, {{chosen section 2}}\n\n"},
100
- {"role": "user", "content": text},
101
- ]
102
- response = openai.ChatCompletion.create(
103
- model="gpt-3.5-turbo",
104
- messages=messages,
105
- )
106
- result = ''
107
- for choice in response.choices:
108
- result += choice.message.content
109
- print("\n\n"+"********"*10)
110
- print("Important sections of this paper:")
111
- print(result)
112
- print("********"*10)
113
- print("prompt_token_used:", response.usage.prompt_tokens)
114
- print("completion_token_used:", response.usage.completion_tokens)
115
- print("total_token_used:", response.usage.total_tokens)
116
- print("response_time:", response.response_ms/1000.0, 's')
117
- return result.split(',')
118
-
119
- @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
120
- stop=tenacity.stop_after_attempt(5),
121
- reraise=True)
122
- def chat_assist(self, text):
123
- openai.api_key = self.chat_api_list[self.cur_api]
124
- self.cur_api += 1
125
- self.cur_api = 0 if self.cur_api >= len(self.chat_api_list)-1 else self.cur_api
126
- review_prompt_token = 1000
127
- text_token = len(self.encoding.encode(text))
128
- input_text_index = int(len(text)*(self.max_token_num-review_prompt_token)/text_token)
129
- input_text = "This is the paper for your review:" + text[:input_text_index] + "\n\n"
130
- input_text_backup = input_text
131
- while True:
132
- print("\n\n"+"********"*10)
133
- print("Ask ChatGPT questions of the important sections. Type \"quit\" to exit the program. To receive better responses, please describe why you ask the question.\nFor example, ask \"Why does the author use residual connections? I want to know how does the residual connections work in the model structure.\" instead of \"Why does the author use residual connections?\"")
134
- print("********"*10)
135
- student_question = input()
136
- if student_question == "quit":
137
- break
138
- input_text = input_text_backup
139
- input_text = input_text + "The question from your student is: " + student_question
140
- messages=[
141
- {"role": "system", "content": "You are a professional researcher in the field of "+args.research_fields+". You are the mentor of a student who is new to this field. Now I will give you a paper. You need to help your student to read this paper by instructing him to read the important sections in this paper and answer his questions towards these sections. Please answer in {}.".format(self.language)},
142
- {"role": "user", "content": input_text},
143
- ]
144
-
145
- response = openai.ChatCompletion.create(
146
- model="gpt-3.5-turbo",
147
- messages=messages,
148
- )
149
- result = ''
150
- for choice in response.choices:
151
- result += choice.message.content
152
- self.result_backup = self.result_backup + "\n\n" + student_question + "\n"
153
- self.result_backup += result
154
- print("\n\n"+"********"*10)
155
- print(result)
156
- print("********"*10)
157
- print("prompt_token_used:", response.usage.prompt_tokens)
158
- print("completion_token_used:", response.usage.completion_tokens)
159
- print("total_token_used:", response.usage.total_tokens)
160
- print("response_time:", response.response_ms/1000.0, 's')
161
- return self.result_backup
162
-
163
- def export_to_markdown(self, text, file_name, mode='w'):
164
- # 使用markdown模块的convert方法,将文本转换为html格式
165
- # html = markdown.markdown(text)
166
- # 打开一个文件,以写入模式
167
- with open(file_name, mode, encoding="utf-8") as f:
168
- # 将html格式的内容写入文件
169
- f.write(text)
170
-
171
- def main(args):
172
-
173
- # Paper reading assistant instructions
174
- print("********"*10)
175
- print("Extracting content from PDF.")
176
- print("********"*10)
177
-
178
-
179
- assistant1 = Assistant(args=args)
180
- # 开始判断是路径还是文件:
181
- paper_list = []
182
- if args.paper_path.endswith(".pdf"):
183
- paper_list.append(Paper(path=args.paper_path))
184
- else:
185
- for root, dirs, files in os.walk(args.paper_path):
186
- print("root:", root, "dirs:", dirs, 'files:', files) #当前目录路径
187
- for filename in files:
188
- # 如果找到PDF文件,则将其复制到目标文件夹中
189
- if filename.endswith(".pdf"):
190
- paper_list.append(Paper(path=os.path.join(root, filename)))
191
- print("------------------paper_num: {}------------------".format(len(paper_list)))
192
- [print(paper_index, paper_name.path.split('\\')[-1]) for paper_index, paper_name in enumerate(paper_list)]
193
- assistant1.assist_reading_by_chatgpt(paper_list=paper_list)
194
-
195
-
196
-
197
- if __name__ == '__main__':
198
- parser = argparse.ArgumentParser()
199
- parser.add_argument("--paper_path", type=str, default='', help="path of papers")
200
- parser.add_argument("--file_format", type=str, default='txt', help="output file format")
201
- parser.add_argument("--research_fields", type=str, default='computer science, artificial intelligence and transfer learning', help="the research fields of paper")
202
- parser.add_argument("--language", type=str, default='en', help="output lauguage, en or zh")
203
-
204
- args = parser.parse_args()
205
- start_time = time.time()
206
- main(args=args)
207
- print("total time:", time.time() - start_time)
208
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
get_paper_from_pdf.py DELETED
@@ -1,194 +0,0 @@
1
- import fitz, io, os
2
- from PIL import Image
3
- from collections import Counter
4
- import json
5
- import re
6
-
7
-
8
- class Paper:
9
- def __init__(self, path, title='', url='', abs='', authors=[]):
10
- # 初始化函数,根据pdf路径初始化Paper对象
11
- self.url = url # 文章链接
12
- self.path = path # pdf路径
13
- self.section_names = [] # 段落标题
14
- self.section_texts = {} # 段落内容
15
- self.abs = abs
16
- self.title_page = 0
17
- if title == '':
18
- self.pdf = fitz.open(self.path) # pdf文档
19
- self.title = self.get_title()
20
- self.parse_pdf()
21
- else:
22
- self.title = title
23
- self.authors = authors
24
- self.roman_num = ["I", "II", 'III', "IV", "V", "VI", "VII", "VIII", "IIX", "IX", "X"]
25
- self.digit_num = [str(d + 1) for d in range(10)]
26
- self.first_image = ''
27
-
28
- def parse_pdf(self):
29
- self.pdf = fitz.open(self.path) # pdf文档
30
- self.text_list = [page.get_text() for page in self.pdf]
31
- self.all_text = ' '.join(self.text_list)
32
- self.extract_section_infomation()
33
- self.section_texts.update({"title": self.title})
34
- self.pdf.close()
35
-
36
- # 定义一个函数,根据字体的大小,识别每个章节名称,并返回一个列表
37
- def get_chapter_names(self, ):
38
- # # 打开一个pdf文件
39
- doc = fitz.open(self.path) # pdf文档
40
- text_list = [page.get_text() for page in doc]
41
- all_text = ''
42
- for text in text_list:
43
- all_text += text
44
- # # 创建一个空列表,用于存储章节名称
45
- chapter_names = []
46
- for line in all_text.split('\n'):
47
- line_list = line.split(' ')
48
- if '.' in line:
49
- point_split_list = line.split('.')
50
- space_split_list = line.split(' ')
51
- if 1 < len(space_split_list) < 5:
52
- if 1 < len(point_split_list) < 5 and (
53
- point_split_list[0] in self.roman_num or point_split_list[0] in self.digit_num):
54
- # print("line:", line)
55
- chapter_names.append(line)
56
-
57
- return chapter_names
58
-
59
- def get_title(self):
60
- doc = self.pdf # 打开pdf文件
61
- max_font_size = 0 # 初始化最大字体大小为0
62
- max_string = "" # 初始化最大字体大小对应的字符串为空
63
- max_font_sizes = [0]
64
- for page_index, page in enumerate(doc): # 遍历每一页
65
- text = page.get_text("dict") # 获取页面上的文本信息
66
- blocks = text["blocks"] # 获取文本块列表
67
- for block in blocks: # 遍历每个文本块
68
- if block["type"] == 0 and len(block['lines']): # 如果是文字类型
69
- if len(block["lines"][0]["spans"]):
70
- font_size = block["lines"][0]["spans"][0]["size"] # 获取第一行第一段文字的字体大小
71
- max_font_sizes.append(font_size)
72
- if font_size > max_font_size: # 如果字体大小大于当前最大值
73
- max_font_size = font_size # 更新最大值
74
- max_string = block["lines"][0]["spans"][0]["text"] # 更新最大值对应的字符串
75
- max_font_sizes.sort()
76
- # print("max_font_sizes", max_font_sizes[-10:])
77
- cur_title = ''
78
- for page_index, page in enumerate(doc): # 遍历每一页
79
- text = page.get_text("dict") # 获取页面上的文本信息
80
- blocks = text["blocks"] # 获取文本块列表
81
- for block in blocks: # 遍历每个文本块
82
- if block["type"] == 0 and len(block['lines']): # 如果是文字类型
83
- if len(block["lines"][0]["spans"]):
84
- cur_string = block["lines"][0]["spans"][0]["text"] # 更新最大值对应的字符串
85
- font_flags = block["lines"][0]["spans"][0]["flags"] # 获取第一行第一段文字的字体特征
86
- font_size = block["lines"][0]["spans"][0]["size"] # 获取第一行第一段文字的字体大小
87
- # print(font_size)
88
- if abs(font_size - max_font_sizes[-1]) < 0.3 or abs(font_size - max_font_sizes[-2]) < 0.3:
89
- # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags)
90
- if len(cur_string) > 4 and "arXiv" not in cur_string:
91
- # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags)
92
- if cur_title == '':
93
- cur_title += cur_string
94
- else:
95
- cur_title += ' ' + cur_string
96
- self.title_page = page_index
97
- # break
98
- title = cur_title.replace('\n', ' ')
99
- return title
100
-
101
- def extract_section_infomation(self):
102
- doc = fitz.open(self.path)
103
-
104
- # 获取文档中所有字体大小
105
- font_sizes = []
106
- for page in doc:
107
- blocks = page.get_text("dict")["blocks"]
108
- for block in blocks:
109
- if 'lines' not in block:
110
- continue
111
- lines = block["lines"]
112
- for line in lines:
113
- for span in line["spans"]:
114
- font_sizes.append(span["size"])
115
- most_common_size, _ = Counter(font_sizes).most_common(1)[0]
116
-
117
- # 按照最频繁的字体大小确定标题字体大小的阈值
118
- threshold = most_common_size * 1
119
-
120
- section_dict = {}
121
- section_dict["Abstract"] = ""
122
- last_heading = None
123
- subheadings = []
124
- heading_font = -1
125
- # 遍历每一页并查找子标题
126
- found_abstract = False
127
- upper_heading = False
128
- font_heading = False
129
- for page in doc:
130
- blocks = page.get_text("dict")["blocks"]
131
- for block in blocks:
132
- if not found_abstract:
133
- try:
134
- text = json.dumps(block)
135
- except:
136
- continue
137
- if re.search(r"\bAbstract\b", text, re.IGNORECASE):
138
- found_abstract = True
139
- last_heading = "Abstract"
140
- if found_abstract:
141
- if 'lines' not in block:
142
- continue
143
- lines = block["lines"]
144
- for line in lines:
145
- for span in line["spans"]:
146
- # 如果当前文本是子标题
147
- if not font_heading and span["text"].isupper() and sum(1 for c in span["text"] if c.isupper() and ('A' <= c <='Z')) > 4: # 针对一些标题大小一样,但是全大写的论文
148
- upper_heading = True
149
- heading = span["text"].strip()
150
- if "References" in heading: # reference 以后的内容不考虑
151
- self.section_names = subheadings
152
- self.section_texts = section_dict
153
- return
154
- subheadings.append(heading)
155
- if last_heading is not None:
156
- section_dict[last_heading] = section_dict[last_heading].strip()
157
- section_dict[heading] = ""
158
- last_heading = heading
159
- if not upper_heading and span["size"] > threshold and re.match( # 正常情况下,通过字体大小判断
160
- r"[A-Z][a-z]+(?:\s[A-Z][a-z]+)*",
161
- span["text"].strip()):
162
- font_heading = True
163
- if heading_font == -1:
164
- heading_font = span["size"]
165
- elif heading_font != span["size"]:
166
- continue
167
- heading = span["text"].strip()
168
- if "References" in heading: # reference 以后的内容不考虑
169
- self.section_names = subheadings
170
- self.section_texts = section_dict
171
- return
172
- subheadings.append(heading)
173
- if last_heading is not None:
174
- section_dict[last_heading] = section_dict[last_heading].strip()
175
- section_dict[heading] = ""
176
- last_heading = heading
177
- # 否则将当前文本添加到上一个子标题的文本中
178
- elif last_heading is not None:
179
- section_dict[last_heading] += " " + span["text"].strip()
180
- self.section_names = subheadings
181
- self.section_texts = section_dict
182
-
183
-
184
- def main():
185
- path = r'demo.pdf'
186
- paper = Paper(path=path)
187
- paper.parse_pdf()
188
- # for key, value in paper.section_text_dict.items():
189
- # print(key, value)
190
- # print("*"*40)
191
-
192
-
193
- if __name__ == '__main__':
194
- main()