Libroru commited on
Commit
612b7f5
1 Parent(s): b2cb9a7

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/icon-dark.png filter=lfs diff=lfs merge=lfs -text
37
+ img/icon-light.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,25 +1,35 @@
1
- import os
2
- import gradio
3
- import time, asyncio
4
  from theme import CustomTheme
5
- from llama_index.llms import OpenAI
6
  from llama_index import (
7
- ServiceContext,
8
  SimpleDirectoryReader,
9
- VectorStoreIndex,
10
- load_index_from_storage,
11
  StorageContext,
12
- set_global_service_context,
13
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  bot_examples = [
16
  "Wie kannst du mir helfen?",
17
  "Welche Sprachen sprichst du?",
18
  "Wie trainiere ich meinen Bizeps?",
19
  "Erstelle mir einen Trainingsplan, wenn ich nur 3 mal pro Woche trainieren kann.",
20
- "Berechne meinen BMI, wenn ich 75kg bei 175cm Körpergröße wiege.",
21
- "Berechne mir meinen Kaloriendefizit, wenn ich in der Woche 0,2kg abnehmen möchte.",
22
- "Berechne mir nochmal das Kaloriendefizit, wenn ich Männlich 19 bin.",
23
  "Wie wechsle ich meine Reifen?"
24
  ]
25
 
@@ -50,6 +60,61 @@ context_str = (
50
 
51
  chat_engine = None
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def setup_ai():
54
  """
55
  Setup the AI for use with querying questions to OpenAI.
@@ -59,31 +124,29 @@ def setup_ai():
59
  assigns the context_template and system_prompt used for manipulating
60
  the AI responses.
61
  """
62
- global chat_engine, context_str, system_prompt
63
-
64
- # Check if storage index exists
65
- if not os.path.isdir("storage"):
66
- print("Directory does not exist")
67
- print("Building Index")
68
- documents = SimpleDirectoryReader("data").load_data()
69
- index = VectorStoreIndex.from_documents(documents)
70
- index.storage_context.persist(persist_dir="storage")
71
- else:
72
- print("Directory does already exist")
73
- print("Reusing index")
74
- storage_context = StorageContext.from_defaults(persist_dir="storage")
75
- index = load_index_from_storage(storage_context)
76
-
77
  api_key = os.environ["OPENAI_API_KEY"]
78
 
79
- llm = OpenAI(temperature=0.1, model="gpt-4")
 
 
 
80
 
81
- chat_engine = index.as_chat_engine(chat_mode="context", system_prompt=system_prompt, context_template=context_str)
 
 
 
 
82
 
83
- service_context = ServiceContext.from_defaults(
84
- llm=llm
 
 
85
  )
86
- set_global_service_context(service_context)
87
 
88
  def response(message, history):
89
  """
@@ -99,18 +162,21 @@ def response(message, history):
99
  # If we don't assign an empty list if nothing is present,
100
  # then the program will-in the worst case-crash.
101
  chat_history = chat_engine.chat_history if chat_engine.chat_history is not None else []
102
- print("Sending request to ChatGPT")
103
- response = chat_engine.stream_chat(message, chat_history)
104
 
105
- output_text = ""
 
106
 
107
- for token in response.response_gen:
108
- time.sleep(0.05)
 
 
109
  output_text += token
110
  yield output_text
111
 
 
 
112
  # For debugging, just to check if the UI looks good.
113
- def response_no_api(message, history):
114
  """
115
  Returns a default message.
116
  """
@@ -131,20 +197,30 @@ def main():
131
  elem_classes=["ask-button"],
132
  )
133
 
134
- chat_interface = gradio.ChatInterface(
135
- fn=response,
136
- title="A.R.N.O.L.D.",
137
- theme=CustomTheme(),
138
- submit_btn=submit_button,
139
- chatbot=chatbot,
140
- examples=bot_examples,
141
- css="style.css",
142
- )
 
 
 
 
 
 
 
 
143
 
144
  chat_interface.queue()
145
  chat_interface.launch(
146
- inbrowser=True
 
147
  )
148
 
 
149
  if __name__ == "__main__":
150
  main()
 
1
+ import os, gradio, torch, openai, os, fitz, asyncio, qdrant_client, time, math
 
 
2
  from theme import CustomTheme
 
3
  from llama_index import (
 
4
  SimpleDirectoryReader,
 
 
5
  StorageContext,
 
6
  )
7
+ from llama_index.multi_modal_llms import OpenAIMultiModal
8
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
9
+ from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
10
+ from PIL import Image
11
+ from microsofttt import detect_and_crop_save_table
12
+
13
+ from torchvision import transforms
14
+
15
+ from transformers import AutoModelForObjectDetection
16
+
17
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ openai.api_key = os.environ["OPENAI_API_KEY"]
21
+
22
+ image_documents: None
23
+ openai_mm_llm: None
24
 
25
  bot_examples = [
26
  "Wie kannst du mir helfen?",
27
  "Welche Sprachen sprichst du?",
28
  "Wie trainiere ich meinen Bizeps?",
29
  "Erstelle mir einen Trainingsplan, wenn ich nur 3 mal pro Woche trainieren kann.",
30
+ "Berechne meinen BMI, wenn ich männlich bin und 75kg bei 175cm Körpergröße wiege.",
31
+ "Berechne mir meinen Kaloriendefizit, wenn ich in der Woche 0,1kg abnehmen möchte.",
32
+ "Berechne mir nochmal das Kaloriendefizit, wenn ich Männlich 18 bin.",
33
  "Wie wechsle ich meine Reifen?"
34
  ]
35
 
 
60
 
61
  chat_engine = None
62
 
63
+ def setup_db():
64
+ """
65
+ Setup the qdrant store as well as convert PDFs with tables into images
66
+ to then use with the Microsoft Table Transformer and extract table information.
67
+ """
68
+ if not os.path.exists("./qdrant_db"):
69
+
70
+ if not os.path.exists("./table_images"):
71
+ os.mkdir("./table_images/")
72
+ # Convert PDFs to images
73
+ for file in os.listdir("./pdf_with_tables"):
74
+ pdf_document = fitz.open("./pdf_with_tables/"+file)
75
+
76
+ for page_number in range(pdf_document.page_count):
77
+ # Get the page
78
+ page = pdf_document[page_number]
79
+
80
+ # Convert the page to an image
81
+ pix = page.get_pixmap()
82
+
83
+ # Create a Pillow Image object from the pixmap
84
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
85
+
86
+ # Save the image
87
+ image.save(f"./table_images/page_{page_number + 1}_{math.floor(time.time())}.png")
88
+
89
+ pdf_document.close()
90
+
91
+ # Crop images to tables
92
+ for image in os.listdir("./table_images"):
93
+ detect_and_crop_save_table("./table_images/"+image)
94
+ # Delete old uncropped image
95
+ os.remove("./table_images/"+image)
96
+
97
+ # Read text documents and images
98
+ text_documents = SimpleDirectoryReader("./data/").load_data()
99
+ image_documents = SimpleDirectoryReader("./table_images/").load_data()
100
+
101
+ # Create the text and image databases
102
+ client = qdrant_client.QdrantClient(path="qdrant_db")
103
+
104
+ text_store = QdrantVectorStore(
105
+ client=client, collection_name="text_collection"
106
+ )
107
+ image_store = QdrantVectorStore(
108
+ client=client, collection_name="image_collection"
109
+ )
110
+
111
+ # Create a storage_context for the chatbot from the databases
112
+ storage_context = StorageContext.from_defaults(
113
+ vector_store=text_store, image_store=image_store
114
+ )
115
+
116
+ return (text_documents, image_documents, storage_context)
117
+
118
  def setup_ai():
119
  """
120
  Setup the AI for use with querying questions to OpenAI.
 
124
  assigns the context_template and system_prompt used for manipulating
125
  the AI responses.
126
  """
127
+ global openai_mm_llm, context_str, system_prompt, chat_engine
128
+
129
+ # Setup database
130
+ text_documents, image_documents, storage_context = setup_db()
131
+
 
 
 
 
 
 
 
 
 
 
132
  api_key = os.environ["OPENAI_API_KEY"]
133
 
134
+ # Define the model used
135
+ openai_mm_llm = OpenAIMultiModal(
136
+ model="gpt-4-vision-preview", api_key=api_key, max_new_tokens=1500
137
+ )
138
 
139
+ # Give the model the storage_context
140
+ index = MultiModalVectorStoreIndex.from_documents(
141
+ documents=text_documents + image_documents,
142
+ storage_context=storage_context
143
+ )
144
 
145
+ # Create a chat engine from the index
146
+ chat_engine = index.as_chat_engine(
147
+ system_prompt=system_prompt,
148
+ context_str=context_str
149
  )
 
150
 
151
  def response(message, history):
152
  """
 
162
  # If we don't assign an empty list if nothing is present,
163
  # then the program will-in the worst case-crash.
164
  chat_history = chat_engine.chat_history if chat_engine.chat_history is not None else []
 
 
165
 
166
+ # Send query
167
+ _response = chat_engine.stream_chat(message, chat_history)
168
 
169
+ # Stream chat answer
170
+ output_text: str = ""
171
+ for token in _response.response_gen:
172
+ time.sleep(0.02)
173
  output_text += token
174
  yield output_text
175
 
176
+
177
+
178
  # For debugging, just to check if the UI looks good.
179
+ def response_no_api(message, history) -> str:
180
  """
181
  Returns a default message.
182
  """
 
197
  elem_classes=["ask-button"],
198
  )
199
 
200
+ with gradio.Blocks(theme=CustomTheme(), css="style.css") as chat_interface:
201
+ gradio.Markdown(
202
+ """<div style='display: flex; justify-content: center; align-items: center; margin-right: 12px;'>
203
+ <img width='48px' style='margin-right: 12px;' src='/file/img/icon-light.png'/>
204
+ ARNOLD
205
+ </div>""",
206
+ elem_classes=["arnold-title"]
207
+ )
208
+ gradio.ChatInterface(
209
+ fn=response,
210
+ theme=CustomTheme(),
211
+ submit_btn=submit_button,
212
+ chatbot=chatbot,
213
+ examples=bot_examples,
214
+ css="style.css",
215
+ )
216
+
217
 
218
  chat_interface.queue()
219
  chat_interface.launch(
220
+ inbrowser=True,
221
+ allowed_paths=["./img/"]
222
  )
223
 
224
+
225
  if __name__ == "__main__":
226
  main()
img/.DS_Store ADDED
Binary file (6.15 kB). View file
 
img/icon-dark.png ADDED

Git LFS Details

  • SHA256: cfaded4ad39679788929c6d2d532415d8f4593ccbc222802a9e8bfec1f5ae7fd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
img/icon-light.png ADDED

Git LFS Details

  • SHA256: 49ce2a7a23376ce538931edb8209cf1553e6322523f6019a89c4e9bc7cc094ee
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
img/test_img.png ADDED
microsofttt.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Microsoft Table Transformer Extension
3
+ By Neils:
4
+ https://docs.llamaindex.ai/en/stable/examples/multi_modal/multi_modal_pdf_tables.html#experiment-3-let-s-use-microsoft-table-transformer-to-crop-tables-from-the-images-and-see-if-it-gives-the-correct-answer
5
+ """
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.patches as patches
8
+ from matplotlib.patches import Patch
9
+ import io
10
+ from PIL import Image, ImageDraw
11
+ import numpy as np
12
+ import csv
13
+ import pandas as pd
14
+
15
+ from torchvision import transforms
16
+
17
+ from transformers import AutoModelForObjectDetection
18
+ import torch
19
+ import openai
20
+ import os
21
+ import fitz
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ class MaxResize(object):
26
+ def __init__(self, max_size=800):
27
+ self.max_size = max_size
28
+
29
+ def __call__(self, image):
30
+ width, height = image.size
31
+ current_max_size = max(width, height)
32
+ scale = self.max_size / current_max_size
33
+ resized_image = image.resize(
34
+ (int(round(scale * width)), int(round(scale * height)))
35
+ )
36
+
37
+ return resized_image
38
+
39
+
40
+ detection_transform = transforms.Compose(
41
+ [
42
+ MaxResize(800),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
45
+ ]
46
+ )
47
+
48
+ structure_transform = transforms.Compose(
49
+ [
50
+ MaxResize(1000),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
53
+ ]
54
+ )
55
+
56
+ # load table detection model
57
+ # processor = TableTransformerImageProcessor(max_size=800)
58
+ model = AutoModelForObjectDetection.from_pretrained(
59
+ "microsoft/table-transformer-detection", revision="no_timm"
60
+ ).to(device)
61
+
62
+ # load table structure recognition model
63
+ # structure_processor = TableTransformerImageProcessor(max_size=1000)
64
+ structure_model = AutoModelForObjectDetection.from_pretrained(
65
+ "microsoft/table-transformer-structure-recognition-v1.1-all"
66
+ ).to(device)
67
+
68
+
69
+ # for output bounding box post-processing
70
+ def box_cxcywh_to_xyxy(x):
71
+ x_c, y_c, w, h = x.unbind(-1)
72
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
73
+ return torch.stack(b, dim=1)
74
+
75
+
76
+ def rescale_bboxes(out_bbox, size):
77
+ width, height = size
78
+ boxes = box_cxcywh_to_xyxy(out_bbox)
79
+ boxes = boxes * torch.tensor(
80
+ [width, height, width, height], dtype=torch.float32
81
+ )
82
+ return boxes
83
+
84
+
85
+ def outputs_to_objects(outputs, img_size, id2label):
86
+ m = outputs.logits.softmax(-1).max(-1)
87
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
88
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
89
+ pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
90
+ pred_bboxes = [
91
+ elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)
92
+ ]
93
+
94
+ objects = []
95
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
96
+ class_label = id2label[int(label)]
97
+ if not class_label == "no object":
98
+ objects.append(
99
+ {
100
+ "label": class_label,
101
+ "score": float(score),
102
+ "bbox": [float(elem) for elem in bbox],
103
+ }
104
+ )
105
+
106
+ return objects
107
+
108
+
109
+ def detect_and_crop_save_table(
110
+ file_path, cropped_table_directory="./table_images/"
111
+ ):
112
+ image = Image.open(file_path)
113
+
114
+ filename, _ = os.path.splitext(file_path.split("/")[-1])
115
+
116
+ if not os.path.exists(cropped_table_directory):
117
+ os.makedirs(cropped_table_directory)
118
+
119
+ # prepare image for the model
120
+ # pixel_values = processor(image, return_tensors="pt").pixel_values
121
+ pixel_values = detection_transform(image).unsqueeze(0).to(device)
122
+
123
+ # forward pass
124
+ with torch.no_grad():
125
+ outputs = model(pixel_values)
126
+
127
+ # postprocess to get detected tables
128
+ id2label = model.config.id2label
129
+ id2label[len(model.config.id2label)] = "no object"
130
+ detected_tables = outputs_to_objects(outputs, image.size, id2label)
131
+
132
+ print(f"number of tables detected {len(detected_tables)}")
133
+
134
+ for idx in range(len(detected_tables)):
135
+ # # crop detected table out of image
136
+ cropped_table = image.crop(detected_tables[idx]["bbox"])
137
+ cropped_table.save(f"./{cropped_table_directory}/{filename}_{idx}.png")
138
+
139
+
140
+ def plot_images(image_paths):
141
+ images_shown = 0
142
+ plt.figure(figsize=(16, 9))
143
+ for img_path in image_paths:
144
+ if os.path.isfile(img_path):
145
+ image = Image.open(img_path)
146
+
147
+ plt.subplot(2, 3, images_shown + 1)
148
+ plt.imshow(image)
149
+ plt.xticks([])
150
+ plt.yticks([])
151
+
152
+ images_shown += 1
153
+ if images_shown >= 9:
154
+ break
pdf_with_tables/test.pdf ADDED
Binary file (39.7 kB). View file
 
requirements.txt CHANGED
@@ -1,4 +1,10 @@
 
1
  gradio==4.4.1
2
  openai==1.1.1
3
  llama-index==0.9.15
4
- pypdf==3.17.1
 
 
 
 
 
 
1
+ clip @ git+https://github.com/openai/CLIP.git
2
  gradio==4.4.1
3
  openai==1.1.1
4
  llama-index==0.9.15
5
+ pypdf==3.17.1
6
+ qdrant_client
7
+ pyMuPDF
8
+ tools
9
+ frontend
10
+ easyocr
style.css CHANGED
@@ -12,19 +12,15 @@
12
  padding: 0 !important;
13
  }
14
 
15
- div.gap>.stretch {
16
  display: none !important;
17
- }
18
 
19
  div.gap.panel>div.gr-group {
20
  position: absolute;
21
  bottom: 0;
22
  }
23
 
24
- h1 {
25
- font-size: 48px !important;
26
- }
27
-
28
  .ask-button {
29
  background-color: var(--color-accent);
30
  font-weight: bold;
@@ -39,6 +35,8 @@ div.message-wrap {
39
  margin-bottom: 32px !important;
40
  }
41
 
42
- .message, .gallery-item, .ask-button, textarea {
43
- font-family: "Arial" !important;
 
 
44
  }
 
12
  padding: 0 !important;
13
  }
14
 
15
+ /*div.gap>.stretch {
16
  display: none !important;
17
+ }*/
18
 
19
  div.gap.panel>div.gr-group {
20
  position: absolute;
21
  bottom: 0;
22
  }
23
 
 
 
 
 
24
  .ask-button {
25
  background-color: var(--color-accent);
26
  font-weight: bold;
 
35
  margin-bottom: 32px !important;
36
  }
37
 
38
+ .arnold-title {
39
+ font-family: "Saira";
40
+ font-size: 48px !important;
41
+ text-align: center;
42
  }
theme.py CHANGED
@@ -7,7 +7,7 @@ class CustomTheme(Base):
7
 
8
  def __init__(self):
9
  super().__init__(
10
- font=fonts.GoogleFont("Bruno Ace SC")
11
  )
12
 
13
  off_white = "#F0F0F0"
@@ -58,6 +58,6 @@ class CustomTheme(Base):
58
  color_accent_soft_dark=accent_soft_dark,
59
  border_color_accent_subdued_dark=accent_soft_dark,
60
 
61
- block_radius="15px",
62
  container_radius="32px",
63
  )
 
7
 
8
  def __init__(self):
9
  super().__init__(
10
+ font=(fonts.GoogleFont("Inter"), fonts.GoogleFont("Saira"))
11
  )
12
 
13
  off_white = "#F0F0F0"
 
58
  color_accent_soft_dark=accent_soft_dark,
59
  border_color_accent_subdued_dark=accent_soft_dark,
60
 
61
+ block_radius="16px",
62
  container_radius="32px",
63
  )