Spaces:
Running
Running
ts startup
Browse files- app/main.py +106 -25
app/main.py
CHANGED
|
@@ -20,9 +20,9 @@ import tempfile
|
|
| 20 |
from utils import getconfig
|
| 21 |
|
| 22 |
config = getconfig("params.cfg")
|
| 23 |
-
RETRIEVER = config.get("retriever", "RETRIEVER")
|
| 24 |
-
GENERATOR = config.get("generator", "GENERATOR")
|
| 25 |
-
INGESTOR = config.get("ingestor", "INGESTOR")
|
| 26 |
MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
|
| 27 |
|
| 28 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -84,7 +84,7 @@ def ingest_node(state: GraphState) -> GraphState:
|
|
| 84 |
try:
|
| 85 |
# Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
|
| 86 |
ingestor_context = client.predict(
|
| 87 |
-
file
|
| 88 |
api_name="/ingest"
|
| 89 |
)
|
| 90 |
|
|
@@ -122,6 +122,52 @@ def ingest_node(state: GraphState) -> GraphState:
|
|
| 122 |
"ingestion_error": str(e)
|
| 123 |
})
|
| 124 |
return {"ingestor_context": "", "metadata": metadata}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
def retrieve_node(state: GraphState) -> GraphState:
|
| 127 |
start_time = datetime.now()
|
|
@@ -170,8 +216,8 @@ def generate_node(state: GraphState) -> GraphState:
|
|
| 170 |
ingestor_context = state.get("ingestor_context", "")
|
| 171 |
|
| 172 |
# Limit context size to prevent token overflow
|
| 173 |
-
MAX_CONTEXT_CHARS = int(MAX_CONTEXT_CHARS)
|
| 174 |
-
|
| 175 |
combined_context = ""
|
| 176 |
if ingestor_context and retrieved_context:
|
| 177 |
# Prioritize ingestor context, truncate if needed
|
|
@@ -355,7 +401,6 @@ def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
|
|
| 355 |
)
|
| 356 |
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
| 357 |
|
| 358 |
-
# This is not working currently... Problematic because HF doesn't allow > 1 port open at the same time
|
| 359 |
def create_gradio_interface():
|
| 360 |
with gr.Blocks(title="ChatFed Orchestrator") as demo:
|
| 361 |
gr.Markdown("# ChatFed Orchestrator")
|
|
@@ -416,25 +461,42 @@ async def root():
|
|
| 416 |
}
|
| 417 |
}
|
| 418 |
|
| 419 |
-
#
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
)
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
|
|
|
| 438 |
@app.post("/chatfed-with-file")
|
| 439 |
async def chatfed_with_file(
|
| 440 |
query: str = Form(...),
|
|
@@ -469,6 +531,25 @@ async def chatfed_with_file(
|
|
| 469 |
|
| 470 |
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
def run_gradio_server():
|
| 473 |
demo = create_gradio_interface()
|
| 474 |
demo.launch(
|
|
|
|
| 20 |
from utils import getconfig
|
| 21 |
|
| 22 |
config = getconfig("params.cfg")
|
| 23 |
+
RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
|
| 24 |
+
GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
|
| 25 |
+
INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
|
| 26 |
MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
|
| 27 |
|
| 28 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 84 |
try:
|
| 85 |
# Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
|
| 86 |
ingestor_context = client.predict(
|
| 87 |
+
file(tmp_file_path), # Use gradio_client.file() to properly format
|
| 88 |
api_name="/ingest"
|
| 89 |
)
|
| 90 |
|
|
|
|
| 122 |
"ingestion_error": str(e)
|
| 123 |
})
|
| 124 |
return {"ingestor_context": "", "metadata": metadata}
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
client = Client(INGESTOR)
|
| 128 |
+
|
| 129 |
+
# Create a temporary file to upload
|
| 130 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
|
| 131 |
+
tmp_file.write(state["file_content"])
|
| 132 |
+
tmp_file_path = tmp_file.name
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
# Call the ingestor's ingest endpoint - returns context directly
|
| 136 |
+
ingestor_context = client.predict(
|
| 137 |
+
file=tmp_file_path,
|
| 138 |
+
api_name="/ingest"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
|
| 142 |
+
|
| 143 |
+
finally:
|
| 144 |
+
# Clean up temporary file
|
| 145 |
+
os.unlink(tmp_file_path)
|
| 146 |
+
|
| 147 |
+
duration = (datetime.now() - start_time).total_seconds()
|
| 148 |
+
metadata = state.get("metadata", {})
|
| 149 |
+
metadata.update({
|
| 150 |
+
"ingestion_duration": duration,
|
| 151 |
+
"ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
|
| 152 |
+
"ingestion_success": True
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
return {
|
| 156 |
+
"ingestor_context": ingestor_context,
|
| 157 |
+
"metadata": metadata
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
duration = (datetime.now() - start_time).total_seconds()
|
| 162 |
+
logger.error(f"Ingestion failed: {str(e)}")
|
| 163 |
+
|
| 164 |
+
metadata = state.get("metadata", {})
|
| 165 |
+
metadata.update({
|
| 166 |
+
"ingestion_duration": duration,
|
| 167 |
+
"ingestion_success": False,
|
| 168 |
+
"ingestion_error": str(e)
|
| 169 |
+
})
|
| 170 |
+
return {"ingestor_context": "", "metadata": metadata}
|
| 171 |
|
| 172 |
def retrieve_node(state: GraphState) -> GraphState:
|
| 173 |
start_time = datetime.now()
|
|
|
|
| 216 |
ingestor_context = state.get("ingestor_context", "")
|
| 217 |
|
| 218 |
# Limit context size to prevent token overflow
|
| 219 |
+
MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
|
| 220 |
+
|
| 221 |
combined_context = ""
|
| 222 |
if ingestor_context and retrieved_context:
|
| 223 |
# Prioritize ingestor context, truncate if needed
|
|
|
|
| 401 |
)
|
| 402 |
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
| 403 |
|
|
|
|
| 404 |
def create_gradio_interface():
|
| 405 |
with gr.Blocks(title="ChatFed Orchestrator") as demo:
|
| 406 |
gr.Markdown("# ChatFed Orchestrator")
|
|
|
|
| 461 |
}
|
| 462 |
}
|
| 463 |
|
| 464 |
+
# Additional endpoint for file uploads via API
|
| 465 |
+
@app.post("/chatfed-with-file")
|
| 466 |
+
async def chatfed_with_file(
|
| 467 |
+
query: str = Form(...),
|
| 468 |
+
file: Optional[UploadFile] = File(None),
|
| 469 |
+
reports_filter: Optional[str] = Form(""),
|
| 470 |
+
sources_filter: Optional[str] = Form(""),
|
| 471 |
+
subtype_filter: Optional[str] = Form(""),
|
| 472 |
+
year_filter: Optional[str] = Form(""),
|
| 473 |
+
session_id: Optional[str] = Form(None),
|
| 474 |
+
user_id: Optional[str] = Form(None)
|
| 475 |
+
):
|
| 476 |
+
"""Endpoint for queries with optional file attachments"""
|
| 477 |
+
file_content = None
|
| 478 |
+
filename = None
|
| 479 |
+
|
| 480 |
+
if file:
|
| 481 |
+
file_content = await file.read()
|
| 482 |
+
filename = file.filename
|
| 483 |
+
|
| 484 |
+
result = process_query_core(
|
| 485 |
+
query=query,
|
| 486 |
+
reports_filter=reports_filter,
|
| 487 |
+
sources_filter=sources_filter,
|
| 488 |
+
subtype_filter=subtype_filter,
|
| 489 |
+
year_filter=year_filter,
|
| 490 |
+
file_content=file_content,
|
| 491 |
+
filename=filename,
|
| 492 |
+
session_id=session_id,
|
| 493 |
+
user_id=user_id,
|
| 494 |
+
return_metadata=True
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
| 498 |
|
| 499 |
+
# Additional endpoint for file uploads via API
|
| 500 |
@app.post("/chatfed-with-file")
|
| 501 |
async def chatfed_with_file(
|
| 502 |
query: str = Form(...),
|
|
|
|
| 531 |
|
| 532 |
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
| 533 |
|
| 534 |
+
# LangServe routes (these are the main endpoints)
|
| 535 |
+
add_routes(
|
| 536 |
+
app,
|
| 537 |
+
RunnableLambda(process_query_langserve),
|
| 538 |
+
path="/chatfed",
|
| 539 |
+
input_type=ChatFedInput,
|
| 540 |
+
output_type=ChatFedOutput
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
add_routes(
|
| 544 |
+
app,
|
| 545 |
+
RunnableLambda(chatui_adapter),
|
| 546 |
+
path="/chatfed-ui-stream",
|
| 547 |
+
input_type=ChatUIInput,
|
| 548 |
+
output_type=str,
|
| 549 |
+
enable_feedback_endpoint=True,
|
| 550 |
+
enable_public_trace_link_endpoint=True,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
def run_gradio_server():
|
| 554 |
demo = create_gradio_interface()
|
| 555 |
demo.launch(
|