Files changed (1) hide show
  1. app.py +0 -131
app.py DELETED
@@ -1,131 +0,0 @@
1
- import os
2
- import warnings
3
- from typing import *
4
- from dotenv import load_dotenv
5
- from transformers import logging
6
-
7
- from langgraph.checkpoint.memory import MemorySaver
8
- from langchain_openai import ChatOpenAI
9
- from langgraph.checkpoint.memory import MemorySaver
10
- from langchain_openai import ChatOpenAI
11
-
12
- from interface import create_demo
13
- from medrax.agent import *
14
- from medrax.tools import *
15
- from medrax.utils import *
16
-
17
- warnings.filterwarnings("ignore")
18
- logging.set_verbosity_error()
19
- _ = load_dotenv()
20
-
21
- def initialize_agent(
22
- prompt_file,
23
- tools_to_use=None,
24
- model_dir="./model-weights",
25
- temp_dir="temp",
26
- device="cuda",
27
- model="gpt-4o-mini",
28
- temperature=0.7,
29
- top_p=0.95,
30
- openai_kwargs={}
31
- ):
32
- """Initialize the MedRAX agent with specified tools and configuration.
33
- Args:
34
- prompt_file (str): Path to file containing system prompts
35
- tools_to_use (List[str], optional): List of tool names to initialize. If None, all tools are initialized.
36
- model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
37
- temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
38
- device (str, optional): Device to run models on. Defaults to "cuda".
39
- model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
40
- temperature (float, optional): Temperature for the model. Defaults to 0.7.
41
- top_p (float, optional): Top P for the model. Defaults to 0.95.
42
- openai_kwargs (dict, optional): Additional keyword arguments for OpenAI API, such as API key and base URL.
43
- Returns:
44
- Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
45
- """
46
- prompts = load_prompts_from_file(prompt_file)
47
- prompt = prompts["MEDICAL_ASSISTANT"]
48
-
49
- all_tools = {
50
- "ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
51
- "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
52
- "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
53
- "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
54
- "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
55
- cache_dir=model_dir, device=device
56
- ),
57
- "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
58
- cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
59
- ),
60
- "ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
61
- model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
62
- ),
63
- "ImageVisualizerTool": lambda: ImageVisualizerTool(),
64
- "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
65
- }
66
-
67
- # Initialize only selected tools or all if none specified
68
- tools_dict = {}
69
- tools_to_use = tools_to_use or all_tools.keys()
70
- for tool_name in tools_to_use:
71
- if tool_name in all_tools:
72
- tools_dict[tool_name] = all_tools[tool_name]()
73
-
74
- checkpointer = MemorySaver()
75
- model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
76
- agent = Agent(
77
- model,
78
- tools=list(tools_dict.values()),
79
- log_tools=True,
80
- log_dir="logs",
81
- system_prompt=prompt,
82
- checkpointer=checkpointer,
83
- )
84
-
85
- print("Agent initialized")
86
- return agent, tools_dict
87
-
88
- if __name__ == "__main__":
89
- """
90
- This is the main entry point for the MedRAX application.
91
- It initializes the agent with the selected tools and creates the demo.
92
- """
93
- print("Starting server...")
94
-
95
- # Example: initialize with only specific tools
96
- # Here three tools are commented out, you can uncomment them to use them
97
- selected_tools = [
98
- "ImageVisualizerTool",
99
- "DicomProcessorTool",
100
- "ChestXRayClassifierTool",
101
- "ChestXRaySegmentationTool",
102
- "ChestXRayReportGeneratorTool",
103
- "XRayVQATool",
104
- # "LlavaMedTool",
105
- # "XRayPhraseGroundingTool",
106
- # "ChestXRayGeneratorTool",
107
- ]
108
-
109
- # Collect the ENV variables
110
- openai_kwargs = {}
111
- if api_key := os.getenv("OPENAI_API_KEY"):
112
- openai_kwargs["api_key"] = api_key
113
-
114
- if base_url := os.getenv("OPENAI_BASE_URL"):
115
- openai_kwargs["base_url"] = base_url
116
-
117
- agent, tools_dict = initialize_agent(
118
- "medrax/docs/system_prompts.txt",
119
- tools_to_use=selected_tools,
120
- model_dir="./model-weights", # Change this to the path of the model weights
121
- temp_dir="temp", # Change this to the path of the temporary directory
122
- device="cuda", # Change this to the device you want to use
123
- model="gpt-4o-mini", # Change this to the model you want to use, e.g. gpt-4o-mini
124
- temperature=0.7,
125
- top_p=0.95,
126
- openai_kwargs=openai_kwargs
127
- )
128
- demo = create_demo(agent, tools_dict)
129
-
130
- # demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
131
- demo.launch(debug=True, ssr_mode=False)