sheikhDipta003's picture
add api key input option
49655cc
import gradio as gr
import json
import asyncio
import time
from typing import Any, Dict
from src.enrichment_agent import graph
# from dotenv import load_dotenv
# load_dotenv()
import os
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if TAVILY_API_KEY:
print("TAVILY_API_KEY found!")
else:
print("Please provide your TAVILY_API_KEY.")
if OPENAI_API_KEY:
print("OPENAI_API_KEY found!")
else:
print("Please provide your OPENAI_API_KEY not found")
def extract_leaf_nodes(data, parent_key=''):
leaf_nodes = {}
for key, value in data.items():
new_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
leaf_nodes.update(extract_leaf_nodes(value, new_key))
elif isinstance(value, list) and all(isinstance(item, dict) for item in value):
for idx, item in enumerate(value):
leaf_nodes.update(extract_leaf_nodes(item, f"{new_key}[{idx}]"))
else:
leaf_nodes[new_key] = value
return leaf_nodes
def agent_response(schema_json: str, topic: str):
try:
# parse the schema JSON string
schema = json.loads(schema_json)
except json.JSONDecodeError:
return "Invalid JSON schema.", 0.0
async def fetch_data(schema: Dict[str, Any], topic: str) -> Dict[str, Any]:
return await graph.ainvoke({
"topic": topic,
"extraction_schema": schema,
})
# calc processing time
start_time = time.time()
result = asyncio.run(fetch_data(schema, topic))
processing_time = time.time() - start_time
# get 'info' dictionary from the result
info = result.get('info', {})
leaf_nodes = extract_leaf_nodes(info)
# format the key-value pairs as Markdown with newlines
display_data = "\n\n".join(f"**{key}**: {value}" for key, value in leaf_nodes.items())
return display_data, processing_time
with gr.Blocks() as demo:
gr.Markdown(
"""
<div style="text-align: center;">
<h1 style="color: #4CAF50;">🌟 Enrichment Agent Interface 🌟</h1>
<p style="font-size: 1.2em; color: #555;">
Gathers information about a topic and shows them in structured format.
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ›  Input")
tavily_key_input = gr.Textbox(label="Tavily API Key", type="password",
placeholder="Enter your Tavily API Key")
openai_key_input = gr.Textbox(label="OpenAI API Key", type="password",
placeholder="Enter your OpenAI API Key")
schema_input = gr.Textbox(
label="Extraction Schema (JSON)",
value=json.dumps({
"type": "object",
"properties": {
"founder": {"type": "string", "description": "Name of the founder"},
"websiteUrl": {"type": "string", "description": "Website URL"},
"products_sold": {"type": "array", "items": {"type": "string"}}
},
"required": ["founder", "websiteUrl", "products_sold"]
}, indent=2),
lines=10,
placeholder="Enter the extraction schema in JSON format."
)
topic_input = gr.Textbox(label="Topic", placeholder="Enter the research topic, e.g., 'Google'")
submit_button = gr.Button("Submit πŸš€")
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Output")
output_display = gr.Markdown(label="Extracted Information")
time_display = gr.Textbox(label="Processing Time (seconds)", interactive=False)
def on_submit(schema, topic, tavily_key, openai_key):
try:
# Set API keys as environment variables
os.environ["TAVILY_API_KEY"] = tavily_key
os.environ["OPENAI_API_KEY"] = openai_key
# Call agent_response
data, time_taken = agent_response(schema, topic)
return data, f"{time_taken:.2f}"
except Exception as e:
return f"❌ An error occurred: {str(e)}", "0.00"
submit_button.click(on_submit,
inputs=[schema_input, topic_input, tavily_key_input, openai_key_input],
outputs=[output_display, time_display])
demo.launch(share=True)