abdo-Mansour commited on
Commit
ab74ea1
·
1 Parent(s): 13f15f6
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .env
2
+ *.ipynb
3
+ venv
4
+ *.csv
5
+ *.json
6
+ *.jsonl
7
+ vllm*
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.txt ADDED
File without changes
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: MCP STRUCTRA
3
- emoji: 🐢
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.35.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
 
 
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MCP Server Web2JSON
3
+ emoji: 🖇️
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.33.0
8
  app_file: app.py
9
+ pinned: True
10
+ tags: [mcp-server-track]
11
  ---
12
 
13
+ [Video overview of the agent demo](https://youtu.be/wd0kjOVoGn8)
14
+
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ import gradio as gr
4
+ from typing import Dict, Any, Type
5
+ from web2json.preprocessor import BasicPreprocessor
6
+ from web2json.ai_extractor import AIExtractor,LLMClassifierExtractor,NvidiaLLMClient, NvidiaRerankerClient , ModalRerankerClient
7
+ from web2json.postprocessor import PostProcessor
8
+ from web2json.pipeline import Pipeline
9
+ from pydantic import BaseModel, Field, create_model
10
+ import os
11
+ import dotenv
12
+ import random
13
+ import numpy as np
14
+ import torch
15
+
16
+ dotenv.load_dotenv()
17
+
18
+ def seed_everything(seed=42):
19
+ random.seed(seed)
20
+ np.random.seed(seed)
21
+ torch.manual_seed(seed)
22
+
23
+ if torch.cuda.is_available():
24
+ torch.cuda.manual_seed(seed)
25
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
26
+
27
+ torch.backends.cudnn.deterministic = True
28
+ torch.backends.cudnn.benchmark = False
29
+
30
+ seed_everything(22)
31
+
32
+ def parse_schema_input(schema_input: str) -> Type[BaseModel]:
33
+ """
34
+ Convert user schema input to a Pydantic BaseModel.
35
+ Supports multiple input formats:
36
+ 1. JSON schema format
37
+ 2. Python class definition
38
+ 3. Simple field definitions
39
+ """
40
+ schema_input = schema_input.strip()
41
+
42
+ if not schema_input:
43
+ # Default schema if none provided
44
+ return create_model('DefaultSchema',
45
+ title=(str, Field(description="Title of the content")),
46
+ content=(str, Field(description="Main content")))
47
+
48
+ try:
49
+ # Try parsing as JSON schema
50
+ if schema_input.startswith('{'):
51
+ schema_dict = json.loads(schema_input)
52
+ return json_schema_to_basemodel(schema_dict)
53
+
54
+ # Try parsing as Python class definition
55
+ elif 'class ' in schema_input and 'BaseModel' in schema_input:
56
+ return python_class_to_basemodel(schema_input)
57
+
58
+ # Try parsing as simple field definitions
59
+ else:
60
+ return simple_fields_to_basemodel(schema_input)
61
+
62
+ except Exception as e:
63
+ raise ValueError(f"Could not parse schema: {str(e)}. Please check your schema format.")
64
+
65
+ def json_schema_to_basemodel(schema_dict: Dict) -> Type[BaseModel]:
66
+ """Convert JSON schema to BaseModel"""
67
+ fields = {}
68
+ properties = schema_dict.get('properties', {})
69
+ required = schema_dict.get('required', [])
70
+
71
+ for field_name, field_info in properties.items():
72
+ field_type = get_python_type(field_info.get('type', 'string'))
73
+ field_description = field_info.get('description', '')
74
+
75
+ if field_name in required:
76
+ fields[field_name] = (field_type, Field(description=field_description))
77
+ else:
78
+ fields[field_name] = (field_type, Field(default=None, description=field_description))
79
+
80
+ return create_model('DynamicSchema', **fields)
81
+
82
+ def python_class_to_basemodel(class_definition: str) -> Type[BaseModel]:
83
+ """Convert Python class definition to BaseModel"""
84
+ try:
85
+ # Execute the class definition in a safe namespace
86
+ namespace = {'BaseModel': BaseModel, 'Field': Field, 'str': str, 'int': int,
87
+ 'float': float, 'bool': bool, 'list': list, 'dict': dict}
88
+ exec(class_definition, namespace)
89
+
90
+ # Find the class that inherits from BaseModel
91
+ for name, obj in namespace.items():
92
+ if (isinstance(obj, type) and
93
+ issubclass(obj, BaseModel) and
94
+ obj != BaseModel):
95
+ return obj
96
+
97
+ raise ValueError("No BaseModel class found in definition")
98
+ except Exception as e:
99
+ raise ValueError(f"Invalid Python class definition: {str(e)}")
100
+
101
+ def simple_fields_to_basemodel(fields_text: str) -> Type[BaseModel]:
102
+ """Convert simple field definitions to BaseModel"""
103
+ fields = {}
104
+
105
+ for line in fields_text.strip().split('\n'):
106
+ line = line.strip()
107
+ if not line or line.startswith('#'):
108
+ continue
109
+
110
+ # Parse field definition (e.g., "name: str = description")
111
+ if ':' in line:
112
+ parts = line.split(':', 1)
113
+ field_name = parts[0].strip()
114
+
115
+ type_and_desc = parts[1].strip()
116
+ if '=' in type_and_desc:
117
+ type_part, desc_part = type_and_desc.split('=', 1)
118
+ field_type = get_python_type(type_part.strip())
119
+ description = desc_part.strip().strip('"\'')
120
+ else:
121
+ field_type = get_python_type(type_and_desc.strip())
122
+ description = ""
123
+
124
+ fields[field_name] = (field_type, Field(description=description))
125
+ else:
126
+ # Simple field name only
127
+ field_name = line.strip()
128
+ fields[field_name] = (str, Field(description=""))
129
+
130
+ if not fields:
131
+ raise ValueError("No valid fields found in schema definition")
132
+
133
+ return create_model('DynamicSchema', **fields)
134
+
135
+ def get_python_type(type_str: str):
136
+ """Convert type string to Python type"""
137
+ type_str = type_str.lower().strip()
138
+ type_mapping = {
139
+ 'string': str, 'str': str,
140
+ 'integer': int, 'int': int,
141
+ 'number': float, 'float': float,
142
+ 'boolean': bool, 'bool': bool,
143
+ 'array': list, 'list': list,
144
+ 'object': dict, 'dict': dict
145
+ }
146
+ return type_mapping.get(type_str, str)
147
+
148
+ def webpage_to_json_wrapper(content: str, is_url: bool, schema_input: str) -> Dict[str, Any]:
149
+ """Wrapper function that converts schema input to BaseModel"""
150
+ try:
151
+ # Parse the schema input into a BaseModel
152
+ schema_model = parse_schema_input(schema_input)
153
+
154
+ # Call the original function
155
+ return webpage_to_json(content, is_url, schema_model)
156
+
157
+ except Exception as e:
158
+ return {"error": f"Schema parsing error: {str(e)}"}
159
+
160
+ def webpage_to_json(content: str, is_url: bool, schema: BaseModel) -> Dict[str, Any]:
161
+ """
162
+ Extracts structured JSON information from a given content based on a specified schema.
163
+ This function sets up a processing pipeline that includes:
164
+ - Preprocessing the input content.
165
+ - Utilizing an AI language model to extract information according to the provided schema.
166
+ - Postprocessing the extracted output to match the exact schema requirements.
167
+ Parameters:
168
+ content (str): The input content to be analyzed. This can be direct text or a URL content.
169
+ is_url (bool): A flag indicating whether the provided content is a URL (True) or raw text (False).
170
+ schema (BaseModel): A Pydantic BaseModel defining the expected structure and data types for the output.
171
+ Returns:
172
+ Dict[str, Any]: A dictionary containing the extracted data matching the schema. In case of errors during initialization
173
+ or processing, the dictionary will include an "error" key with a descriptive message.
174
+ """
175
+ prompt_template = """Extract the following information from the provided content according to the specified schema.
176
+
177
+ Content to analyze:
178
+ {content}
179
+
180
+ Schema requirements:
181
+ {schema}
182
+
183
+ Instructions:
184
+ - Extract only information that is explicitly present in the content
185
+ - Follow the exact structure and data types specified in the schema
186
+ - If a required field cannot be found, indicate this clearly
187
+ - Preserve the original formatting and context where relevant
188
+ - Return the extracted data in the format specified by the schema
189
+ - STICK TO THE SCHEMA DON'T EVEN THINK OF DOING SOMETHING ELSE
190
+ - IF THE SCHEMA ASKS FOR AN ARRAY THEN YOU MAY TRY TO EXTRACT ONE IF THERE IS
191
+ - OR I WILL KILL AND KIDNAP YOUR FAMILY AND TORTURE THEM """
192
+
193
+ classification_prompt_template = schema.model_json_schema()
194
+ # Initialize pipeline components
195
+ # TODO: improve the RAG system and optimize (don't instantiate every time)
196
+ preprocessor = BasicPreprocessor(config={'keep_tags': True})
197
+ try:
198
+ # llm = GeminiLLMClient(config={'api_key': os.getenv('GEMINI_API_KEY')})
199
+ llm = NvidiaLLMClient(config={'api_key': os.getenv('NVIDIA_API_KEY'),'model_name': 'google/gemma-3n-e2b-it'})
200
+ # reranker = NvidiaRerankerClient(config={'api_key': os.getenv('NVIDIA_API_KEY'),'model_name': 'nv-rerank-qa-mistral-4b:1'})\
201
+ reranker = ModalRerankerClient("https://abdulrahmanmfam2003--qwen3-reranker-rerank.modal.run")
202
+ except Exception as e:
203
+ return {"error": f"Failed to initialize LLM client: {str(e)}"}
204
+
205
+ # ai_extractor = RAGExtractor(llm_client=llm, prompt_template=prompt_template)
206
+ ai_extractor = LLMClassifierExtractor(reranker=reranker, llm_client=llm, prompt_template=prompt_template, classifier_prompt=classification_prompt_template)
207
+ postprocessor = PostProcessor()
208
+ pipeline = Pipeline(preprocessor, ai_extractor, postprocessor)
209
+
210
+ try:
211
+ result = pipeline.run(content, is_url, schema)
212
+ print("-"*80)
213
+ print(f"Processed result: {result}")
214
+ return result
215
+ except Exception as e:
216
+ return {"error": f"Processing error: {str(e)}"}
217
+
218
+ # Example schemas for the user
219
+ example_schemas = """
220
+ **Example Schema Formats:**
221
+
222
+ 1. **Simple field definitions:**
223
+ ```
224
+ title: str = Page title
225
+ price: float = Product price
226
+ description: str = Product description
227
+ available: bool = Is available
228
+ ```
229
+
230
+ 2. **JSON Schema:**
231
+ ```json
232
+ {
233
+ "properties": {
234
+ "title": {"type": "string", "description": "Page title"},
235
+ "price": {"type": "number", "description": "Product price"},
236
+ "description": {"type": "string", "description": "Product description"}
237
+ },
238
+ "required": ["title"]
239
+ }
240
+ ```
241
+
242
+ 3. **Python Class Definition:**
243
+ ```python
244
+ class ProductSchema(BaseModel):
245
+ title: str = Field(description="Product title")
246
+ price: float = Field(description="Product price")
247
+ description: str = Field(description="Product description")
248
+ available: bool = Field(default=False, description="Availability status")
249
+ ```
250
+ """
251
+
252
+ # Build Gradio Interface
253
+ demo = gr.Interface(
254
+ fn=webpage_to_json_wrapper,
255
+ inputs=[
256
+ gr.Textbox(
257
+ label="Content (URL or Raw Text)",
258
+ lines=10,
259
+ placeholder="Enter URL or paste raw HTML/text here."
260
+ ),
261
+ gr.Checkbox(label="Content is URL?", value=False),
262
+ gr.Textbox(
263
+ label="Schema Definition",
264
+ lines=15,
265
+ placeholder="Define your extraction schema (see examples below)",
266
+ info=example_schemas
267
+ )
268
+ ],
269
+ outputs=gr.JSON(label="Output JSON"),
270
+ title="Webpage to JSON Converter",
271
+ description="Convert web pages or raw text into structured JSON using customizable schemas. Define your schema using simple field definitions, JSON schema, or Python class syntax.",
272
+ examples=[
273
+ [
274
+ "https://example.com",
275
+ True,
276
+ "title: str = Page title\nprice: float = Product price\ndescription: str = Description"
277
+ ],
278
+ [
279
+ "<h1>Sample Product</h1><p>Price: $29.99</p><p>Great quality item</p>",
280
+ False,
281
+ '''{
282
+ "type": "object",
283
+ "properties": {
284
+ "title": {
285
+ "type": "string",
286
+ "description": "Name of the product"
287
+ },
288
+ "price": {
289
+ "type": "number",
290
+ "description": "Price of the product"
291
+ },
292
+ "description": {
293
+ "type": "string",
294
+ "description": "Detailed description of the product"
295
+ },
296
+ "availability": {
297
+ "type": "boolean",
298
+ "description": "Whether the product is in stock (true) or not (false)"
299
+ }
300
+ },
301
+ "required": ["title", "price"]
302
+ }'''
303
+ ]
304
+ ]
305
+ )
306
+
307
+ if __name__ == "__main__":
308
+ demo.launch(mcp_server=True)
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ gradio
3
+ gradio[mcp]
4
+ pydantic
5
+ python-dotenv
6
+ beautifulsoup4
7
+ requests
8
+ google-genai
9
+ json_repair
10
+ numpy
11
+ langchain
12
+ langchain-text-splitters
13
+ sentence-transformers
14
+ openai
15
+ html_chunking
16
+ langchain_nvidia_ai_endpoints
17
+ langchain_core
18
+ lxml
19
+ pdfkit
20
+ html2text
21
+ inscriptis
22
+ trafilatura
23
+ markdownify
24
+ beautifulsoup4
25
+ readabilipy
26
+ docling
27
+ htmlrag
web2json/__pycache__/ai_extractor.cpython-311.pyc ADDED
Binary file (41.3 kB). View file
 
web2json/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
web2json/__pycache__/postprocessor.cpython-311.pyc ADDED
Binary file (1.64 kB). View file
 
web2json/__pycache__/preprocessor.cpython-311.pyc ADDED
Binary file (9.89 kB). View file
 
web2json/ai_extractor.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from google import genai
5
+ from openai import OpenAI
6
+ import time
7
+ import random
8
+ from openai import RateLimitError
9
+ from functools import wraps
10
+ from google.genai import types
11
+ from pydantic import BaseModel
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from html_chunking import get_html_chunks
14
+ from langchain_nvidia_ai_endpoints import NVIDIARerank
15
+ from langchain_core.documents import Document
16
+ from abc import ABC, abstractmethod
17
+ from typing import List, Any, Dict, Tuple, Optional
18
+ import re
19
+ import json
20
+ from langchain_text_splitters import HTMLHeaderTextSplitter
21
+ from sentence_transformers import SentenceTransformer
22
+ import requests
23
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
24
+ import torch
25
+ from typing import List, Dict
26
+ from tenacity import retry, wait_exponential, stop_after_attempt
27
+ import trafilatura
28
+
29
+
30
+ class LLMClient(ABC):
31
+ """
32
+ Abstract base class for calling LLM APIs.
33
+ """
34
+ def __init__(self, config: dict = None):
35
+ """
36
+ Initializes the LLMClient with a configuration dictionary.
37
+
38
+ Args:
39
+ config (dict): Configuration settings for the LLM client.
40
+ """
41
+ self.config = config or {}
42
+
43
+ @abstractmethod
44
+ def call_api(self, prompt: str) -> str:
45
+ """
46
+ Call the underlying LLM API with the given prompt.
47
+
48
+ Args:
49
+ prompt (str): The prompt or input text for the LLM.
50
+
51
+ Returns:
52
+ str: The response from the LLM.
53
+ """
54
+ pass
55
+
56
+ class RerankerClient(ABC):
57
+ """
58
+ Abstract base class for reranker APIs.
59
+ """
60
+ def __init__(self, config: dict = None):
61
+ """
62
+ Initializes the RerankerClient with a configuration dictionary.
63
+
64
+ Args:
65
+ config (dict): Configuration settings for the reranker client.
66
+ """
67
+ self.config = config or {}
68
+
69
+ @abstractmethod
70
+ def rerank(self, query: str, passages: List[str], top_k: int = 3) -> List[str]:
71
+ """
72
+ Rerank passages based on relevance to query.
73
+
74
+ Args:
75
+ query (str): Query string.
76
+ passages (List[str]): List of passages.
77
+ top_k (int): Number of top passages to return.
78
+
79
+ Returns:
80
+ List[str]: Top-k most relevant passages.
81
+ """
82
+ pass
83
+
84
+
85
+ class GeminiLLMClient(LLMClient):
86
+ """
87
+ Concrete implementation of LLMClient for the Gemini API.
88
+ """
89
+
90
+ def __init__(self, config: dict):
91
+ """
92
+ Initializes the GeminiLLMClient with an API key, model name, and optional generation settings.
93
+
94
+ Args:
95
+ config (dict): Configuration containing:
96
+ - 'api_key': (optional) API key for Gemini (falls back to GEMINI_API_KEY env var)
97
+ - 'model_name': (optional) the model to use (default 'gemini-2.0-flash')
98
+ - 'generation_config': (optional) dict of GenerateContentConfig parameters
99
+ """
100
+ api_key = config.get("api_key") or os.environ.get("GEMINI_API_KEY")
101
+ if not api_key:
102
+ raise ValueError(
103
+ "API key for Gemini must be provided in config['api_key'] or GEMINI_API_KEY env var."
104
+ )
105
+ self.client = genai.Client(api_key=api_key)
106
+ self.model_name = config.get("model_name", "gemini-2.0-flash")
107
+ # allow custom generation settings, fallback to sensible defaults
108
+ gen_conf = config.get("generation_config", {})
109
+ self.generate_config = types.GenerateContentConfig(
110
+ response_mime_type=gen_conf.get("response_mime_type", "text/plain"),
111
+ temperature=gen_conf.get("temperature"),
112
+ max_output_tokens=gen_conf.get("max_output_tokens"),
113
+ top_p=gen_conf.get("top_p"),
114
+ top_k=gen_conf.get("top_k"),
115
+ # add any other fields you want to expose
116
+ )
117
+
118
+ def call_api(self, prompt: str) -> str:
119
+ """
120
+ Call the Gemini API with the given prompt (non-streaming).
121
+
122
+ Args:
123
+ prompt (str): The input text for the API.
124
+
125
+ Returns:
126
+ str: The generated text from the Gemini API.
127
+ """
128
+ contents = [
129
+ types.Content(
130
+ role="user",
131
+ parts=[types.Part.from_text(text=prompt)],
132
+ )
133
+ ]
134
+
135
+ # Non-streaming call returns a full response object
136
+ response = self.client.models.generate_content(
137
+ model=self.model_name,
138
+ contents=contents,
139
+ config=self.generate_config,
140
+ )
141
+
142
+ # Combine all output parts into a single string
143
+ return response.text
144
+
145
+ def extract_markdown_json(text: str) -> Optional[Dict[str, Any]]:
146
+ """
147
+ Find the first Markdown ```json ...``` block in `text`,
148
+ parse it as JSON, and return the resulting dict.
149
+ Returns None if no valid JSON block is found.
150
+ """
151
+ # 1) Look specifically for a ```json code fence
152
+ fence_match = re.search(
153
+ r"```json\s*(\{.*?\})\s*```",
154
+ text,
155
+ re.DOTALL | re.IGNORECASE
156
+ )
157
+ if not fence_match:
158
+ return None
159
+
160
+ json_str = fence_match.group(1)
161
+ try:
162
+ return json.loads(json_str)
163
+ except json.JSONDecodeError:
164
+ return None
165
+
166
+ def retry_on_ratelimit(max_retries=5, base_delay=1.0, max_delay=10.0):
167
+ def deco(fn):
168
+ @wraps(fn)
169
+ def wrapped(*args, **kwargs):
170
+ delay = base_delay
171
+ for attempt in range(max_retries):
172
+ try:
173
+ return fn(*args, **kwargs)
174
+ except RateLimitError:
175
+ if attempt == max_retries - 1:
176
+ # give up
177
+ raise
178
+ # back off + jitter
179
+ sleep = min(max_delay, delay) + random.uniform(0, delay)
180
+ time.sleep(sleep)
181
+ delay *= 2
182
+ # unreachable
183
+ return wrapped
184
+ return deco
185
+ class NvidiaLLMClient(LLMClient):
186
+ """
187
+ Concrete implementation of LLMClient for the NVIDIA API (non-streaming).
188
+ """
189
+
190
+ def __init__(self, config: dict):
191
+ """
192
+ Initializes the NvidiaLLMClient with an API key, model name, and optional generation settings.
193
+
194
+ Args:
195
+ config (dict): Configuration containing:
196
+ - 'api_key': (optional) API key for NVIDIA (falls back to NVIDIA_API_KEY env var)
197
+ - 'model_name': (optional) the model to use (default 'google/gemma-3-1b-it')
198
+ - 'generation_config': (optional) dict of generation parameters like temperature, top_p, etc.
199
+ """
200
+ api_key = config.get("api_key") or os.environ.get("NVIDIA_API_KEY")
201
+ if not api_key:
202
+ raise ValueError(
203
+ "API key for NVIDIA must be provided in config['api_key'] or NVIDIA_API_KEY env var."
204
+ )
205
+
206
+ self.client = OpenAI(
207
+ base_url="https://integrate.api.nvidia.com/v1",
208
+ api_key=api_key
209
+ )
210
+ self.model_name = config.get("model_name", "google/gemma-3-1b-it")
211
+
212
+ # Store generation settings with sensible defaults
213
+ gen_conf = config.get("generation_config", {})
214
+ self.temperature = gen_conf.get("temperature", 0)
215
+ self.top_p = gen_conf.get("top_p", 0.7)
216
+ self.max_tokens = gen_conf.get("max_tokens", 8192)
217
+
218
+ def set_model(self, model_name: str):
219
+ """
220
+ Set the model name for the NVIDIA API client.
221
+
222
+ Args:
223
+ model_name (str): The name of the model to use.
224
+ """
225
+ self.model_name = model_name
226
+
227
+ @retry_on_ratelimit(max_retries=20, base_delay=0.5, max_delay=5.0)
228
+ def call_api(self, prompt: str) -> str:
229
+ """
230
+ Call the NVIDIA API with the given prompt (non-streaming).
231
+
232
+ Args:
233
+ prompt (str): The input text for the API.
234
+
235
+ Returns:
236
+ str: The generated text from the NVIDIA API.
237
+ """
238
+ print("prompt: ", prompt)
239
+ response = self.client.chat.completions.create(
240
+ model=self.model_name,
241
+ messages=[{"role": "user", "content": prompt}],
242
+ temperature=self.temperature,
243
+ top_p=self.top_p,
244
+ max_tokens=self.max_tokens,
245
+ extra_body={"chat_template_kwargs": {"thinking":True}},
246
+ # stream is omitted (defaults to False)
247
+ )
248
+ # print("DONE")
249
+ # For the standard (non-streaming) response:
250
+ # choices[0].message.content holds the generated text
251
+ return response.choices[0].message.content
252
+
253
+ def call_batch(self, prompts, max_workers=8):
254
+ """
255
+ Parallel batch with isolated errors: each prompt that still
256
+ fails after retries will raise, but others succeed.
257
+ """
258
+ from concurrent.futures import ThreadPoolExecutor, as_completed
259
+ results = [None] * len(prompts)
260
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
261
+ futures = {ex.submit(self.call_api, p): i for i, p in enumerate(prompts)}
262
+ for fut in as_completed(futures):
263
+ idx = futures[fut]
264
+ try:
265
+ results[idx] = fut.result()
266
+ print("DONE")
267
+ except RateLimitError:
268
+ # You could set results[idx] = None or a default string
269
+ results[idx] = f"<failed after retries>"
270
+ return results
271
+
272
+
273
+ class NvidiaRerankerClient(RerankerClient):
274
+ """
275
+ Concrete implementation of LLMClient for the NVIDIA API (non-streaming).
276
+ """
277
+
278
+ def __init__(self, config: dict):
279
+ self.model_name = config.get("model_name", "nvidia/llama-3.2-nv-rerankqa-1b-v2")
280
+ self.client = NVIDIARerank(
281
+ model=self.model_name,
282
+ api_key=os.getenv("NVIDIA_API_KEY"),
283
+ )
284
+
285
+ def set_model(self, model_name: str):
286
+ """
287
+ Set the model name for the NVIDIA API client.
288
+
289
+ Args:
290
+ model_name (str): The name of the model to use.
291
+ """
292
+ self.model_name = model_name
293
+
294
+ @retry_on_ratelimit(max_retries=6, base_delay=0.5, max_delay=5.0)
295
+ def rerank(self, query: str, passages: List[str], top_k: int = 3, threshold: float = 0.5) -> List[Document]:
296
+ # 1. Prepare and send documents for scoring
297
+ docs = [Document(page_content=p) for p in passages]
298
+ scored_docs = self.client.compress_documents(
299
+ query=str(query),
300
+ documents=docs
301
+ )
302
+
303
+ # 2. Extract raw scores and compute sigmoid probabilities
304
+ raw_scores = np.array([doc.metadata['relevance_score'] for doc in scored_docs], dtype=float)
305
+ print(f"raw scores {raw_scores}")
306
+ p_scores = 1 / (1 + np.exp(-raw_scores))
307
+ print(f"Sigmoid scores: {p_scores}")
308
+
309
+ # 3. Max normalization
310
+ max_score = np.max(p_scores)
311
+ if max_score == 0:
312
+ norm_scores = np.zeros_like(p_scores)
313
+ else:
314
+ norm_scores = p_scores / max_score
315
+ print(f"Normalized scores: {norm_scores}")
316
+
317
+ # 4. Filter by threshold using normalized scores
318
+ scored_pairs = [(doc, norm) for doc, norm in zip(scored_docs, norm_scores) if norm > threshold]
319
+ print(f"Filtered pairs:\n{scored_pairs}")
320
+
321
+ # 5. Return top_k documents (already sorted by model, no need to re-sort)
322
+ top_docs = [doc.page_content for doc, _ in scored_pairs]
323
+ return top_docs
324
+
325
+
326
+
327
+
328
+ # TODO: will I need it ?
329
+ # def call_batch(self, prompts, max_workers=8):
330
+ # pass
331
+
332
+ def retry_on_error(fn):
333
+ """Simple retry decorator (exponential back-off, max 6 tries)."""
334
+ return retry(
335
+ wait=wait_exponential(multiplier=0.5, min=0.5, max=5),
336
+ stop=stop_after_attempt(6),
337
+ reraise=True,
338
+ )(fn)
339
+
340
+
341
+ class ModalRerankerClient(RerankerClient):
342
+ """Client for the Modal Qwen3-Reranker endpoint (non-streaming)."""
343
+
344
+ def __init__(self, endpoint_url: str):
345
+ self.endpoint_url = endpoint_url.rstrip("/") # ensure no trailing slash
346
+
347
+ def set_endpoint(self, url: str):
348
+ self.endpoint_url = url.rstrip("/")
349
+
350
+ @retry_on_error
351
+ def rerank(
352
+ self,
353
+ query: str,
354
+ passages: List[str],
355
+ threshold: float = 0.5,
356
+ ) -> List[Document]:
357
+ """Call the remote endpoint and return filtered passages."""
358
+ if not isinstance(query,str):
359
+ query = str(query)
360
+ payload = {"query": query, "passages": passages}
361
+ print(payload)
362
+ res = requests.post(self.endpoint_url, json=payload, timeout=60)
363
+ res.raise_for_status()
364
+ data = res.json()
365
+
366
+ # The endpoint already returns probabilities (0-1). Extract them.
367
+ ranked = data.get("ranked_passages", [])
368
+ # Extract scores
369
+ scores = np.array([p["score"] for p in ranked], dtype=float)
370
+ # Max normalization
371
+ max_score = scores.max() if len(scores) > 0 else 1.0
372
+ # max_score = 1
373
+ if max_score == 0:
374
+ norm_scores = np.zeros_like(scores)
375
+ else:
376
+ norm_scores = scores / max_score
377
+ # Filter by threshold using normalized scores
378
+ filtered = [
379
+ (p, norm) for p, norm in zip(ranked, norm_scores) if norm >= threshold
380
+ ]
381
+ # Convert to LangChain Documents
382
+ docs = [
383
+ Document(page_content=p["passage"], metadata={"score": p["score"], "norm_score": norm})
384
+ for p, norm in filtered
385
+ ]
386
+
387
+ # docs.reverse()
388
+
389
+ return docs
390
+
391
+ class HFRerankerClient(LLMClient):
392
+ """
393
+ Hugging Face Reranker client using Qwen/Qwen1.5-MoE-A14B-Chat reranking style (0.6B variant).
394
+ """
395
+
396
+ def __init__(self, model_name: str = "Qwen/Qwen3-Reranker-0.6B", device: str = None):
397
+ """
398
+ Initialize the Hugging Face reranker.
399
+ """
400
+ self.model_name = model_name
401
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
402
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
403
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
404
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
405
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
406
+
407
+ def rerank(self, query: str, passages: List[str], top_k: int = 3, threshold: float = 0.5) -> List[str]:
408
+ """
409
+ Rerank passages based on relevance to query using min-max normalized scores.
410
+
411
+ Args:
412
+ query (str): Query string.
413
+ passages (List[str]): List of passages.
414
+ top_k (int): Number of top passages to return.
415
+ threshold (float): Minimum normalized score to include passage.
416
+
417
+ Returns:
418
+ List[str]: Top-k most relevant passages above threshold.
419
+ """
420
+ inputs = [
421
+ self.tokenizer(f"{query} [SEP] {p}", return_tensors="pt", truncation=True, padding=True).to(self.device)
422
+ for p in passages
423
+ ]
424
+ scores = []
425
+
426
+ with torch.no_grad():
427
+ for inp in inputs:
428
+ logits = self.model(**inp).logits
429
+ # print("logits:", logits)
430
+ score = torch.softmax(logits, dim=1)[0, 1].item() # probability of relevance
431
+ scores.append(score)
432
+
433
+ print(f"Softmax Scores: {scores}")
434
+
435
+ # Min-max normalize the scores
436
+ scores_np = np.array(scores)
437
+ min_score = scores_np.min()
438
+ max_score = scores_np.max()
439
+ if max_score == min_score:
440
+ norm_scores = np.ones_like(scores_np)
441
+ else:
442
+ norm_scores = (scores_np - min_score) / (max_score - min_score)
443
+
444
+ print(f"Normalized Scores: {norm_scores}")
445
+ # Filter based on normalized threshold
446
+ filtered = [(i, s) for i, s in enumerate(norm_scores) if s > threshold]
447
+ print(f"Filtered: {filtered}")
448
+
449
+ # Sort by normalized score descending
450
+ filtered.sort(key=lambda x: x[1], reverse=True)
451
+
452
+ # Select top_k passages
453
+ top_passages = [passages[i] for i, _ in filtered]
454
+
455
+ return top_passages
456
+
457
+
458
+ @retry_on_ratelimit(max_retries=6, base_delay=0.5, max_delay=5.0)
459
+ def call_api(self, prompt: str) -> str:
460
+ pass
461
+
462
+ def call_batch(self, prompts, max_workers=8):
463
+ pass
464
+
465
+
466
+ class AIExtractor:
467
+ def __init__(self, llm_client: LLMClient, prompt_template: str):
468
+ """
469
+ Initializes the AIExtractor with a specific LLM client and configuration.
470
+
471
+ Args:
472
+ llm_client (LLMClient): An instance of a class that implements the LLMClient interface.
473
+ prompt_template (str): The template to use for generating prompts for the LLM.
474
+ should contain placeholders for dynamic content.
475
+ e.g., "Extract the following information: {content} based on schema: {schema}"
476
+ """
477
+ self.llm_client = llm_client
478
+ self.prompt_template = prompt_template
479
+
480
+ def extract(self, content: str, schema: BaseModel) -> str:
481
+ """
482
+ Extracts structured information from the given content based on the provided schema.
483
+
484
+ Args:
485
+ content (str): The raw content to extract information from.
486
+ schema (BaseModel): A Pydantic model defining the structure of the expected output.
487
+
488
+ Returns:
489
+ str: The structured JSON object as a string.
490
+ """
491
+ prompt = self.prompt_template.format(content=content, schema=schema.model_json_schema())
492
+ # print(f"Generated prompt: {prompt}")
493
+ response = self.llm_client.call_api(prompt)
494
+ return response
495
+
496
+ class LLMClassifierExtractor(AIExtractor):
497
+ """
498
+ Extractor that uses an LLM to classify and extract structured information from text content.
499
+ This class is designed to handle classification tasks where the LLM generates structured output based on a provided schema.
500
+ """
501
+ def __init__(self, reranker: RerankerClient, llm_client: LLMClient, prompt_template: str, classifier_prompt: str, ):
502
+ """
503
+ Initializes the LLMClassifierExtractor with an LLM client and a prompt template.
504
+
505
+ Args:
506
+ llm_client (LLMClient): An instance of a class that implements the LLMClient interface.
507
+ prompt_template (str): The template to use for generating prompts for the LLM.
508
+ """
509
+ super().__init__(llm_client, prompt_template)
510
+ self.reranker = reranker
511
+ self.classifier_prompt = classifier_prompt
512
+
513
+ def chunk_content(self, content: str , max_tokens: int = 500, is_clean: bool = True) -> List[str]:
514
+ """
515
+ Splits the content into manageable chunks for processing.
516
+
517
+ Args:
518
+ content (str): The raw content to be chunked.
519
+
520
+ Returns:
521
+ List[str]: A list of text chunks.
522
+ """
523
+ # Use the get_html_chunks function to split the content into chunks
524
+ return get_html_chunks(html=content, max_tokens=max_tokens, is_clean_html=is_clean, attr_cutoff_len=5)
525
+
526
+
527
+ def classify_chunks(self, passages, top_k=3, hf: bool = False): # reranker
528
+ # print("TIME TO CLASSIFY")
529
+ query = self.classifier_prompt
530
+
531
+ if hf:
532
+ # print("Using Hugging Face reranker for classification.")
533
+ return self.reranker.rerank(query, passages, top_k=top_k)
534
+ response = self.reranker.rerank(query,passages)
535
+ print(f"response: {response}")
536
+ # print("DONNNNE")
537
+ # NVIDIA reranker path
538
+ return response
539
+
540
+ def extract(self, content, schema, hf: bool = False):
541
+ """
542
+ Extracts structured information from the given content based on the provided schema.
543
+
544
+ Args:
545
+ content (str): The raw content to extract information from.
546
+ schema (BaseModel): A Pydantic model defining the structure of the expected output.
547
+ hf (bool): Whether to use the Hugging Face reranker or NVIDIA (default).
548
+ """
549
+ # print("TIME TO EXTRACT")
550
+ chunks = self.chunk_content(content, max_tokens=500)
551
+ print(f"Content successfully chunked into {len(chunks)}.")
552
+ # print(f"Content successfully chunked: {chunks}")
553
+ # chunks = [trafilatura.extract(chunk,favor_recall=True) for chunk in chunks]
554
+ # chunks = [chunk for chunk in chunks if chunk is not None]
555
+ classified_chunks = self.classify_chunks(chunks, hf=hf) # conditional reranker
556
+ # extracting the content
557
+
558
+ if isinstance(classified_chunks[0],Document):
559
+ classified_chunks = [chunk.page_content for chunk in classified_chunks]
560
+ print(f"Classified Chunks {len(classified_chunks)}")
561
+ # print(classified_chunks)
562
+ # print('='*80)
563
+ # NOTE: More preprocesing
564
+ # classified_chunks = [trafilatura.extract(chunk,favor_recall=True) for chunk in classified_chunks]
565
+ # classified_chunks = [chunk for chunk in classified_chunks if chunk is not None]
566
+ filtered_content = "\n\n".join(classified_chunks)
567
+
568
+ if not filtered_content:
569
+ print("Warning: No relevant chunks found. Returning empty response.")
570
+ return "{}"
571
+
572
+ prompt = self.prompt_template.format(content=filtered_content, schema=schema.model_json_schema())
573
+ # print(f"Generated prompt for extraction: {prompt[:500]}...")
574
+ llm_response = self.llm_client.call_api(prompt)
575
+ # print(f"LLM response: {llm_response[:500]}...")
576
+
577
+ return llm_response or "{}"
578
+
579
+
580
+ # TODO: RAGExtractor class
581
+ class RAGExtractor(AIExtractor):
582
+ """
583
+ RAG-enhanced extractor that uses similarity search to find relevant chunks
584
+ before performing extraction, utilizing HTML header-based chunking and SentenceTransformer embeddings.
585
+ """
586
+
587
+ def __init__(self,
588
+ llm_client: LLMClient,
589
+ prompt_template: str,
590
+ embedding_model_path: str = "sentence-transformers/all-mpnet-base-v2",
591
+ top_k: int = 3):
592
+ """
593
+ Initialize RAG extractor with embedding and chunking capabilities.
594
+
595
+ Args:
596
+ llm_client: LLM client for generation.
597
+ prompt_template: Template for prompts.
598
+ embedding_model_path: Path/name for the SentenceTransformer embedding model.
599
+ top_k: Number of top similar chunks to retrieve.
600
+ """
601
+ super().__init__(llm_client, prompt_template)
602
+ self.embedding_model_path = embedding_model_path
603
+ # Initialize the SentenceTransformer model for embeddings
604
+ self.embedding_model_instance = SentenceTransformer(self.embedding_model_path)
605
+ self.top_k = top_k
606
+
607
+ @staticmethod
608
+ def _langchain_HHTS(text: str) -> List[str]:
609
+ """
610
+ Chunks HTML text using Langchain's HTMLHeaderTextSplitter based on h1 and h2 headers.
611
+
612
+ Args:
613
+ text (str): The HTML content to chunk.
614
+
615
+ Returns:
616
+ List[str]: A list of chunked text strings (extracted from Document objects' page_content).
617
+ """
618
+ headers_to_split_on = [
619
+ ("h1", "Header 1"),
620
+ ("h2", "Header 2"),
621
+ # ("h3", "Header 3"), # This header was explicitly commented out in the request
622
+ ]
623
+ html_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
624
+ return [doc.page_content for doc in html_splitter.split_text(text)]
625
+
626
+ def embed_text(self, text: str) -> np.ndarray:
627
+ """
628
+ Generate embeddings for text using the initialized SentenceTransformer model.
629
+
630
+ Args:
631
+ text: The text string to embed.
632
+
633
+ Returns:
634
+ np.ndarray: The embedding vector for the input text as a NumPy array.
635
+ """
636
+ try:
637
+ return self.embedding_model_instance.encode(text)
638
+ except Exception as e:
639
+ print(f"Warning: Embedding failed for text: '{text[:50]}...', using random embedding: {e}")
640
+
641
+ return None
642
+
643
+ def search_similar_chunks(self,
644
+ query: str,
645
+ chunks: List[str],
646
+ embeddings: np.ndarray) -> List[str]:
647
+ """
648
+ Find the most similar chunks to the query within the given list of chunks
649
+ by calculating cosine similarity between their embeddings.
650
+
651
+ Args:
652
+ query (str): The query text whose embedding will be used for similarity comparison.
653
+ chunks (List[str]): A list of text chunks to search within.
654
+ embeddings (np.ndarray): Precomputed embeddings for the chunks, corresponding to the 'chunks' list.
655
+
656
+ Returns:
657
+ List[str]: A list of the 'top_k' most similar chunks to the query.
658
+ """
659
+ query_embedding = self.embed_text(query)
660
+
661
+ similarities = []
662
+
663
+ if query_embedding.ndim > 1:
664
+ query_embedding = query_embedding.flatten()
665
+
666
+ for i, chunk_embedding in enumerate(embeddings):
667
+ if chunk_embedding.ndim > 1:
668
+ chunk_embedding = chunk_embedding.flatten()
669
+
670
+ norm_query = np.linalg.norm(query_embedding)
671
+ norm_chunk = np.linalg.norm(chunk_embedding)
672
+
673
+ if norm_query == 0 or norm_chunk == 0:
674
+ similarity = 0.0
675
+ else:
676
+ similarity = np.dot(query_embedding, chunk_embedding) / (norm_query * norm_chunk)
677
+ similarities.append((similarity, i))
678
+
679
+ similarities.sort(key=lambda x: x[0], reverse=True)
680
+ top_indices = [idx for _, idx in similarities[:self.top_k]]
681
+
682
+ return [chunks[i] for i in top_indices]
683
+
684
+ def extract(self, content: str, schema: BaseModel, query: str = None) -> str:
685
+ """
686
+ Overrides the base AIExtractor's method to implement RAG-enhanced extraction.
687
+ This function first chunks the input HTML content, then uses a query to find
688
+ the most relevant chunks via embedding similarity, and finally sends these
689
+ relevant chunks as context to the LLM for structured information extraction.
690
+
691
+ Args:
692
+ content (str): The raw HTML content from which to extract information.
693
+ schema (BaseModel): A Pydantic model defining the desired output structure for the LLM.
694
+ query (str, optional): An optional query string to guide the retrieval of relevant chunks.
695
+ If not provided, a default query based on the schema will be used.
696
+
697
+ Returns:
698
+ str: The structured JSON object as a string, as generated by the LLM.
699
+ """
700
+ start_time = time.time()
701
+
702
+ if not query:
703
+ query = f"Extract information based on the following JSON schema: {schema.model_json_schema()}"
704
+ # print(f"No explicit query provided for retrieval. Using default: '{query[:100]}...'")
705
+
706
+ chunks = self._langchain_HHTS(content)
707
+ print(f"Content successfully chunked into {len(chunks)} pieces.")
708
+
709
+ combined_content_for_llm = ""
710
+ if not chunks:
711
+ print("Warning: No chunks were generated from the provided content. The entire original content will be sent to the LLM.")
712
+ combined_content_for_llm = content
713
+ else:
714
+ chunk_embeddings = np.array([self.embed_text(chunk) for chunk in chunks])
715
+ print(f"Generated embeddings for {len(chunks)} chunks.")
716
+
717
+ similar_chunks = self.search_similar_chunks(query, chunks, chunk_embeddings)
718
+ print(f"Retrieved {len(similar_chunks)} similar chunks based on the query.")
719
+
720
+ combined_content_for_llm = "\n\n".join(similar_chunks)
721
+ print(f"Combined content for LLM (truncated): '{combined_content_for_llm[:200]}...'")
722
+
723
+ prompt = self.prompt_template.format(content=combined_content_for_llm, schema=schema.model_json_schema())
724
+ print(f"Sending prompt to LLM (truncated): '{prompt[:500]}...'")
725
+ llm_response = self.llm_client.call_api(prompt)
726
+
727
+ execution_time = (time.time() - start_time) * 1000
728
+ print(f"Extraction process completed in {execution_time:.2f} milliseconds.")
729
+ print(f"LLM's final response: {llm_response}")
730
+ print("=" * 78)
731
+
732
+ return llm_response
web2json/contentextractors.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import pdfkit
5
+ import requests
6
+ import warnings
7
+ import tempfile
8
+ # import textract
9
+ import html2text
10
+ import inscriptis
11
+ import trafilatura
12
+ from pathlib import Path
13
+ from markdownify import markdownify
14
+ from json_repair import repair_json
15
+ from bs4 import BeautifulSoup, Comment
16
+ from html_chunking import get_html_chunks
17
+ from urllib.error import URLError, HTTPError
18
+ from html_to_markdown import convert_to_markdown
19
+ from readabilipy import simple_json_from_html_string
20
+ from docling.document_converter import DocumentConverter
21
+ from dateparser_scripts.update_supported_languages_and_locales import to_string
22
+
23
+
24
+ def clean_html(html_content: str) -> str:
25
+ """
26
+ Cleans up the given HTML content by:
27
+ - Removing <script> and <style> tags and their content.
28
+ - Removing HTML comments.
29
+ - Extracting and returning the visible text with normalized whitespace.
30
+
31
+ Args:
32
+ html_content (str): The HTML content to clean.
33
+
34
+ Returns:
35
+ str: The cleaned, visible text from the HTML.
36
+ """
37
+ # Parse the HTML content
38
+ soup = BeautifulSoup(html_content, "html.parser")
39
+
40
+ # Remove script and style elements
41
+ # Remove unwanted tags
42
+ for tag in soup(["script", "style", "img", "a", "table", "tr", "td", "th", "thead", "tbody",
43
+ "tfoot", "header", "footer", "link", "rel"]):
44
+ tag.decompose()
45
+
46
+ # Remove elements that do not contain any visible text
47
+ for element in soup.find_all():
48
+ # If the element has no text (after stripping whitespace), remove it
49
+ if not element.get_text(strip=True):
50
+ element.decompose()
51
+
52
+ # Remove HTML comments
53
+ for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
54
+ comment.extract()
55
+
56
+ # Extract text and normalize whitespace
57
+ # text = soup.get_text(separator=" ", strip=True)
58
+ # clean_text = re.sub(r'\s+', ' ', text)
59
+
60
+ # return clean_text
61
+ return str(soup)
62
+
63
+
64
+ def print_content_extractors():
65
+ print(
66
+ [
67
+ "Default: the plain text of the HTML page",
68
+ "Inscriptis",
69
+ "Trafilatura",
70
+ ]
71
+ )
72
+
73
+
74
+ class ContentExtractor:
75
+ def get_text(self, html):
76
+ return clean_html(html)
77
+
78
+ # TODO: Clean this mess
79
+ def url_to_html(self, url,clean=False):
80
+ # Define custom headers to mimic a browser request
81
+ headers = {
82
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
83
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8",
84
+ "Accept-Language": "en-US,en;q=0.6",
85
+ "Cache-Control": "max-age=0",
86
+ "Sec-Ch-Ua": "\"Not(A:Brand\";v=\"99\", \"Brave\";v=\"133\", \"Chromium\";v=\"133\"",
87
+ "Sec-Ch-Ua-Mobile": "?0",
88
+ "Sec-Ch-Ua-Platform": "\"Windows\"",
89
+ "Sec-Fetch-Dest": "document",
90
+ "Sec-Fetch-Mode": "navigate",
91
+ "Sec-Fetch-Site": "none",
92
+ "Sec-Fetch-User": "?1",
93
+ "Upgrade-Insecure-Requests": "1"
94
+ }
95
+
96
+ try:
97
+ # Create a Request object with custom headers
98
+ response = requests.get(url, headers=headers, timeout=10)
99
+
100
+ html = None
101
+
102
+ if response.status_code == 200:
103
+ html = response.text
104
+ else:
105
+ print(f"Failed to retrieve HTML. Status code: {response.status_code}")
106
+ return None
107
+
108
+ if clean:
109
+ return self.get_text(html)
110
+
111
+ return html
112
+
113
+ except HTTPError as e:
114
+ print(f"HTTP Error: {e.code} - {e.reason}")
115
+ return None
116
+ except URLError as e:
117
+ print(f"URL Error: {e.reason}")
118
+ return None
119
+ except Exception as e:
120
+ print(f"An unexpected error occurred: {e}")
121
+ return None
122
+
123
+
124
+ class Inscriptis(ContentExtractor):
125
+ def __init__(self):
126
+ super()
127
+ self.headers = {
128
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Brave/119.0.0.0",
129
+ "Accept-Language": "en-US,en;q=0.9,ar;q=0.8",
130
+ }
131
+
132
+ warnings.warn("\nBeware, put only clean links with no trackers, or it may produce unexpected results.")
133
+
134
+ def get_text(self, html):
135
+ """Extract text from HTML using inscriptis."""
136
+ return inscriptis.get_text(html)
137
+
138
+ def url_to_html(self, url):
139
+ response = requests.get(url, headers=self.headers)
140
+ return response.text
141
+
142
+
143
+ class Docling(ContentExtractor):
144
+ def __init__(self):
145
+ super().__init__()
146
+
147
+ # TODO: This is an unexpected behaviour but due to docling docs website being down, it's what works for now
148
+ def get_text(self, text_content):
149
+ result = None
150
+ with tempfile.NamedTemporaryFile(mode='w+', suffix='.html', delete=False, encoding='utf-8') as tmpfile:
151
+ tmpfile.write(text_content)
152
+ tmpfile.flush()
153
+ tmpfile_path = tmpfile.name.replace("\\", "/")
154
+ tmpfile_path = Path(tmpfile_path)
155
+ try:
156
+ converter = DocumentConverter()
157
+ document = converter.convert(tmpfile_path).document
158
+ tables = []
159
+ for table_ix, table in enumerate(document.tables):
160
+ table_text = table.export_to_markdown()
161
+ tables.append(table_text)
162
+
163
+ result = document.export_to_markdown()
164
+ for table in tables:
165
+ result += "\n\n" + table
166
+ finally:
167
+ os.remove(tmpfile_path)
168
+ return result
169
+
170
+
171
+ class ReadabiliPy(ContentExtractor):
172
+ def __init__(self):
173
+ super().__init__()
174
+
175
+ def get_text(self, html):
176
+ content = simple_json_from_html_string(html, use_readability=True)
177
+ json_object = json.dumps(content, indent=4)
178
+ repaired = repair_json(json_object)
179
+ return repaired
180
+
181
+
182
+ class Trafilatura(ContentExtractor):
183
+ def __init__(self):
184
+ super().__init__()
185
+ self.headers = {
186
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
187
+ "Accept-Language": "en-US,en;q=0.9",
188
+ }
189
+
190
+ warnings.warn("\nTrafilatura Content Extractor: Beware, put only clean links with no trackers, or it may produce unexpected results.")
191
+
192
+ from copy import deepcopy
193
+ from trafilatura.settings import DEFAULT_CONFIG
194
+ config = deepcopy(DEFAULT_CONFIG)
195
+ # config['DEFAULT']['MIN_EXTRACTED_SIZE'] = '5000' # Configurable but this value worked well for me
196
+ self.config = config
197
+
198
+ def url_to_html(self, url):
199
+ response = requests.get(url, headers=self.headers)
200
+ return response.text
201
+
202
+ def get_text(self, html, output_format="markdown", min_extracted_size_char=20_000):
203
+ # self.config['DEFAULT']['MIN_EXTRACTED_SIZE'] = f"{min_extracted_size_char}"
204
+ # self.config['DEFAULT']['MIN_OUTPUT_SIZE'] = f"{min_extracted_size_char}"
205
+ return trafilatura.extract(filecontent=html, favor_recall=True, config=self.config, output_format=output_format)
206
+
207
+
208
+ class Markdownify(ContentExtractor):
209
+ def get_text(self, html):
210
+ alt = re.sub(r"\n{3,}", "\n\n", html)
211
+ md = markdownify(alt, strip=['href', 'table', 'tr', 'td', 'header', 'footer'])
212
+
213
+ md = re.sub(r'!?\[[^\]]*\]\([^)]*\)', '', md)
214
+ # Remove extra newlines
215
+ md = re.sub(r"\n{3,}", "\n\n", md)
216
+ md = md.strip()
217
+
218
+ return md
219
+
220
+
221
+ class HTML2Text(ContentExtractor):
222
+ def get_text(self, html):
223
+ converter = html2text.HTML2Text()
224
+ converter.ignore_tables=True
225
+ converter.ignore_links=True
226
+ converter.ignore_images=True
227
+ converter.ignore_mailto_links=True
228
+ return converter.handle(html)
229
+
230
+
231
+ class HTML_TO_Markdown(ContentExtractor):
232
+ def get_text(self, html):
233
+ alt = re.sub(r"\n{3,}", "\n\n", html)
234
+ md = convert_to_markdown(alt, strip=['href', 'table', 'tr', 'td', 'header', 'footer'])
235
+
236
+ md = re.sub(r'!?\[[^\]]*\]\([^)]*\)', '', md)
237
+ # Remove extra newlines
238
+ md = re.sub(r"\n{3,}", "\n\n", md)
239
+ md = md.strip()
240
+
241
+ return md
242
+
243
+
244
+ class PDFkitDocling(ContentExtractor):
245
+ def get_text(self, html):
246
+ soup = BeautifulSoup(html, "html.parser")
247
+
248
+ # Remove <a>, <link>, <img>, and other unwanted tags
249
+ for tag in soup.find_all(['a', 'link', 'img', 'base', 'meta', 'style', 'script', 'noscript', 'head']):
250
+ tag.decompose()
251
+
252
+ # Remove HTML comments
253
+ for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
254
+ comment.extract()
255
+
256
+
257
+ content = str(soup)
258
+
259
+ # PDF path to save
260
+ pdf_path = 'test.pdf'
261
+
262
+ # Create PDF
263
+ pdfkit.from_string(content, pdf_path)
264
+
265
+ converter = DocumentConverter()
266
+
267
+ return converter.convert(pdf_path).document.export_to_markdown()
268
+
269
+
270
+ class TrafilatraCHUNKS(ContentExtractor):
271
+ def __init__(self):
272
+ super().__init__()
273
+ # self.trafi = Trafilatura()
274
+
275
+ def get_text(self, html, max_tokens=1000):
276
+ soup = BeautifulSoup(html, "html.parser")
277
+
278
+ # Remove <a>, <link>, <img>, and other unwanted tags
279
+ for tag in soup.find_all(['a', 'link', 'img', 'base', 'meta', 'style', 'script', 'noscript', 'head']):
280
+ tag.decompose()
281
+
282
+ # Remove HTML comments
283
+ for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
284
+ comment.extract()
285
+
286
+
287
+ content = str(soup)
288
+
289
+ chunks = get_html_chunks(content, max_tokens=max_tokens, is_clean_html=True, attr_cutoff_len=50)
290
+
291
+ cleaned = [trafilatura.extract(chunk) for chunk in chunks]
292
+ cleaned = [chunk for chunk in cleaned if chunk is not None]
293
+
294
+
295
+ combined_text = ""
296
+ for chunk in cleaned:
297
+ if chunk is None:
298
+ continue
299
+ combined_text += chunk + "\n"
300
+
301
+ return combined_text
302
+
303
+
304
+ class TrafilaCHUNKSRobust(ContentExtractor):
305
+ def __init__(self):
306
+ super().__init__()
307
+ # self.trafi = Trafilatura()
308
+
309
+ def get_text(self, html, max_tokens=1000):
310
+ soup = BeautifulSoup(html, "html.parser")
311
+
312
+ for tag in soup.find_all(['style', 'script', 'head', 'img', 'base', 'noscript']):
313
+ tag.decompose()
314
+
315
+ for tag in soup.find_all(lambda tag: tag.attrs and any("nav" in str(v) for v in tag.attrs.values())):
316
+ tag.decompose()
317
+
318
+ # Remove HTML comments
319
+ for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
320
+ comment.extract()
321
+
322
+ content = str(soup)
323
+
324
+ chunks = get_html_chunks(content, max_tokens=max_tokens, is_clean_html=True, attr_cutoff_len=50)
325
+
326
+ cleaned = [trafilatura.extract(chunk) for chunk in chunks]
327
+ cleaned = [chunk for chunk in cleaned if chunk is not None]
328
+
329
+ combined_text = ""
330
+ for chunk in cleaned:
331
+ if chunk is None:
332
+ continue
333
+ combined_text += chunk + "\n"
334
+
335
+ return combined_text
336
+
337
+ class TrafilaCHUNKSRobustV2(ContentExtractor):
338
+ def __init__(self):
339
+ super().__init__()
340
+ # self.trafi = Trafilatura()
341
+
342
+ def get_text(self, html, max_tokens=1000):
343
+ soup = BeautifulSoup(html, "html.parser")
344
+
345
+ for tag in soup.find_all(['style', 'script', 'head', 'img', 'base', 'noscript']):
346
+ tag.decompose()
347
+
348
+ # Remove HTML comments
349
+ for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
350
+ comment.extract()
351
+
352
+ content = str(soup)
353
+
354
+ chunks = get_html_chunks(content, max_tokens=max_tokens, is_clean_html=True, attr_cutoff_len=50)
355
+
356
+ cleaned = [trafilatura.extract(chunk) for chunk in chunks]
357
+ cleaned = [chunk for chunk in cleaned if chunk is not None]
358
+
359
+ combined_text = ""
360
+ for chunk in cleaned:
361
+ if chunk is None:
362
+ continue
363
+ combined_text += chunk + "\n"
364
+
365
+ return combined_text
366
+
367
+ # Very Bad lol
368
+ # class Textract(ContentExtractor):
369
+ # def get_text(self, html):
370
+ # with tempfile.NamedTemporaryFile(mode='w+', suffix='.html', delete=False, encoding='utf-8') as tmpfile:
371
+ # tmpfile.write(html)
372
+ # tmpfile.flush()
373
+ # tmpfile_path = tmpfile.name.replace("\\", "/")
374
+ # tmpfile_path = Path(tmpfile_path)
375
+ # try:
376
+ # result = textract.process(tmpfile_path)
377
+ # finally:
378
+ # os.remove(tmpfile_path)
379
+ # return result
web2json/pipeline.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from web2json.ai_extractor import *
2
+ from web2json.postprocessor import *
3
+ from web2json.preprocessor import *
4
+ from pydantic import BaseModel
5
+
6
+ class Pipeline:
7
+ # constructor
8
+ def __init__(self,
9
+ preprocessor: Preprocessor,
10
+ ai_extractor: AIExtractor,
11
+ postprocessor: PostProcessor):
12
+ self.preprocessor = preprocessor
13
+ self.ai_extractor = ai_extractor
14
+ self.postprocessor = postprocessor
15
+
16
+ def run(self, content: str, is_url: bool, schema:BaseModel, hf=False) -> dict:
17
+ """
18
+ Run the entire pipeline: preprocess, extract, and postprocess.
19
+
20
+ Args:
21
+ content (str): The raw content to process.
22
+ is_url (bool): Whether the content is a URL or raw text.
23
+ schema (BaseModel): The schema defining the structure of the expected output.
24
+
25
+ Returns:
26
+ dict: The final structured data after processing.
27
+ """
28
+ # Step 1: Preprocess the content
29
+ preprocessed_content = self.preprocessor.preprocess(content, is_url)
30
+ # print(f"Preprocessed content: {preprocessed_content}...")
31
+ print('+'*80)
32
+ # Step 2: Extract structured information using AI
33
+ extracted_data = self.ai_extractor.extract(preprocessed_content, schema, hf=hf)
34
+ # print(f"Extracted data: {extracted_data[:100]}...")
35
+ print('+'*80)
36
+ # Step 3: Post-process the extracted data
37
+ final_output = self.postprocessor.process(extracted_data)
38
+ print(f"Final output: {final_output}")
39
+ print('+'*80)
40
+
41
+ return final_output
42
+
43
+
web2json/postprocessor.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from json_repair import repair_json
2
+ import json
3
+
4
+ class PostProcessor:
5
+
6
+ def process(self, response: str) -> dict:
7
+ json_response = {}
8
+ try:
9
+ # Extract the JSON from the generated text. Handle variations in output format.
10
+ json_string = response
11
+ if "```json" in response:
12
+ json_string = response.split("```json")[1].split("```")[0]
13
+ elif "{" in response and "}" in response:
14
+ # try to grab the json
15
+ start_index = response.find("{")
16
+ end_index = response.rfind("}") + 1
17
+ json_string = response[start_index:end_index]
18
+
19
+ json_response = json.loads(repair_json(json_string)) # Added for robustness
20
+ except Exception as e:
21
+ print(f"Error parsing JSON: {e}")
22
+ print(f"Generated text: {response}")
23
+ json_response = {}
24
+
25
+
26
+ return json_response
27
+
web2json/preprocessor.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import requests
3
+ from bs4 import BeautifulSoup , Comment
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Dict, Optional
6
+ from htmlrag import clean_html
7
+
8
+ class HTMLCleaner:
9
+ DEFAULT_REMOVE_TAGS = [
10
+ "script", "style"
11
+ ]
12
+
13
+ def __init__(self, config: dict = None):
14
+ self.config = config or {}
15
+ # allow custom tags to remove
16
+ self.remove_tags = set(self.DEFAULT_REMOVE_TAGS) | set(self.config.get("extra_remove_tags", []))
17
+
18
+ def _clean_html(self, html_content: str) -> str:
19
+ """
20
+ Cleans up the given HTML content by:
21
+ - Removing specified tags and their content.
22
+ - Stripping HTML comments.
23
+ - Optionally stripping out all attributes.
24
+ - Optionally flattening hyperlinks.
25
+ - Removing empty tags.
26
+ - Extracting and returning cleaned HTML or visible text.
27
+
28
+ Args:
29
+ html_content (str): The HTML content to clean.
30
+
31
+ Returns:
32
+ str: The cleaned HTML (if keep_tags=True) or normalized text.
33
+ """
34
+ soup = BeautifulSoup(html_content, "html.parser")
35
+
36
+ # Remove unwanted tags entirely
37
+ for tag_name in self.remove_tags:
38
+ for tag in soup.find_all(tag_name):
39
+ tag.decompose()
40
+
41
+ # Remove HTML comments
42
+ for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
43
+ comment.extract()
44
+
45
+ # Strip attributes if requested
46
+ if self.config.get("strip_attrs", False):
47
+ for tag in soup.find_all(True):
48
+ tag.attrs = {}
49
+
50
+ # Flatten hyperlinks if requested
51
+ if self.config.get("strip_links", False):
52
+ for a in soup.find_all('a'):
53
+ a.replace_with(a.get_text())
54
+
55
+ # Remove empty tags (no text and no non-empty children)
56
+ for tag in soup.find_all(True):
57
+ if not tag.get_text(strip=True):
58
+ tag.decompose()
59
+
60
+ # Convert soup to HTML string if preserving tags
61
+ if self.config.get('keep_tags', False):
62
+ html_str = str(soup)
63
+ # Remove any empty lines
64
+ html_str = re.sub(r'(?m)^[ \t]*\n', '', html_str)
65
+ return html_str.strip()
66
+
67
+ # Extract visible text
68
+ text = soup.get_text(separator="\n", strip=True)
69
+ # Remove empty lines
70
+ lines = [line for line in text.splitlines() if line.strip()]
71
+ clean_text = "\n".join(lines)
72
+ # Normalize whitespace within lines
73
+ clean_text = re.sub(r'\s+', ' ', clean_text)
74
+
75
+ return clean_text.strip()
76
+
77
+ class Preprocessor(ABC):
78
+ """
79
+ Abstract base class for preprocessors.
80
+ Defines the interface for transforming raw inputs into structured data.
81
+ """
82
+
83
+ def __init__(self, config: Optional[Dict[str, Any]] = None) -> None:
84
+ """
85
+ Initialize the preprocessor with optional configuration.
86
+
87
+ Args:
88
+ config: A dictionary of configuration settings.
89
+ - keep_tags (bool): If True, keeps HTML tags in the output; otherwise, cleans them.
90
+ """
91
+ self.config = config if config is not None else {'keep_tags': False}
92
+
93
+ def _fetch_content(self, url: str) -> str:
94
+ """
95
+ Fetches and parses the text content from a URL.
96
+
97
+ Args:
98
+ url: The URL to fetch content from.
99
+
100
+ Returns:
101
+ The clean, extracted text content from the page.
102
+
103
+ Raises:
104
+ ValueError: If the URL cannot be fetched or processed.
105
+ """
106
+ try:
107
+ # Set a User-Agent header to mimic a browser, which can help avoid
108
+ # being blocked by some websites.
109
+ # Inside _fetch_content method
110
+ headers = headers = {
111
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
112
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8",
113
+ "Accept-Language": "en-US,en;q=0.6",
114
+ "Cache-Control": "max-age=0",
115
+ "Sec-Ch-Ua": "\"Not(A:Brand\";v=\"99\", \"Brave\";v=\"133\", \"Chromium\";v=\"133\"",
116
+ "Sec-Ch-Ua-Mobile": "?0",
117
+ "Sec-Ch-Ua-Platform": "\"Windows\"",
118
+ "Sec-Fetch-Dest": "document",
119
+ "Sec-Fetch-Mode": "navigate",
120
+ "Sec-Fetch-Site": "none",
121
+ "Sec-Fetch-User": "?1",
122
+ "Upgrade-Insecure-Requests": "1",
123
+ }
124
+
125
+ # Make the HTTP GET request with a timeout.
126
+ response = requests.get(url, headers=headers, timeout=15)
127
+
128
+
129
+ return response.text
130
+
131
+ except requests.exceptions.RequestException as e:
132
+ # Catch any network-related errors (DNS, connection, timeout, etc.)
133
+ # and re-raise them as a more user-friendly ValueError.
134
+ raise ValueError(f"Failed to fetch content from URL: {url}. Error: {e}")
135
+
136
+
137
+ @abstractmethod
138
+ def preprocess(self, content: str, is_url: bool) -> str:
139
+ """
140
+ Take raw content (HTML, text, etc.) and apply preprocessing steps.
141
+
142
+ Args:
143
+ content: The raw data to preprocess.
144
+
145
+ Returns:
146
+ A dictionary containing structured, cleaned data ready for downstream tasks.
147
+ """
148
+ pass
149
+
150
+ class BasicPreprocessor(Preprocessor):
151
+ """
152
+ Base preprocessor with common functionality.
153
+ Can be extended for specific preprocessing tasks.
154
+ """
155
+ # TODO: Might need to think of how to improve this later
156
+ def _clean_html(self, html_content: str) -> str:
157
+ """
158
+ Cleans up the given HTML content by:
159
+ - Removing <script> and <style> tags and their content.
160
+ - Removing HTML comments.
161
+ - Extracting and returning the visible text with normalized whitespace if keep_tags is False.
162
+
163
+ Args:
164
+ html_content (str): The HTML content to clean.
165
+
166
+ Returns:
167
+ str: The cleaned, visible text from the HTML.
168
+ """
169
+ # Parse the HTML content
170
+ soup = BeautifulSoup(html_content, "html.parser")
171
+
172
+ # Remove script and style elements
173
+ for tag in soup(["script", "style"]):
174
+ tag.decompose()
175
+
176
+ # Remove HTML comments
177
+ for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
178
+ comment.extract()
179
+
180
+ # Extract text and normalize whitespace
181
+ if self.config.get('keep_tags', False):
182
+ # If keep_tags is True, return the raw HTML
183
+ return str(soup)
184
+
185
+ text = soup.get_text(separator=" ", strip=True)
186
+ clean_text = re.sub(r'\s+', ' ', text)
187
+
188
+ return clean_text
189
+
190
+ def preprocess(self, content: str, is_url: bool) -> str:
191
+ """
192
+ Take raw content (HTML, text, etc.) and apply preprocessing steps.
193
+
194
+ Args:
195
+ content: The raw data to preprocess.
196
+
197
+ Returns:
198
+ A dictionary containing structured, cleaned data ready for downstream tasks.
199
+ """
200
+
201
+ html_content = content
202
+ if is_url:
203
+ # Fetch content from the URL
204
+ html_content = self._fetch_content(content)
205
+
206
+
207
+ # Clean the HTML content
208
+ # cleaned_content = self._clean_html(html_content)
209
+ cleaner = HTMLCleaner({
210
+ 'keep_tags': True if self.config.get('keep_tags', False) else False,
211
+ 'strip_attrs': True,
212
+ 'strip_links': True,
213
+ 'extra_remove_tags': ['header', 'footer']
214
+ })
215
+ clean = cleaner._clean_html(html_content=html_content)
216
+ clean = clean_html(clean)
217
+ # clean = clean_html(html_content)
218
+ return clean.strip() # Return the cleaned text content, stripped of leading/trailing whitespace
219
+
220
+
221
+
222
+