Suchinthana commited on
Commit
8152b02
·
1 Parent(s): 467d7be

added pydantic, slim prompt

Browse files
Files changed (2) hide show
  1. app.py +52 -101
  2. requirements.txt +1 -1
app.py CHANGED
@@ -10,6 +10,8 @@ from staticmap import StaticMap, CircleMarker, Polygon
10
  from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
11
  import spaces
12
  import logging
 
 
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -19,6 +21,20 @@ logger = logging.getLogger(__name__)
19
  openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
20
  geolocator = Nominatim(user_agent="geoapi")
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Function to fetch coordinates
23
  @spaces.GPU
24
  def get_geo_coordinates(location_name):
@@ -37,43 +53,15 @@ def process_openai_response(query):
37
  response = openai_client.chat.completions.create(
38
  model="gpt-4o-mini",
39
  messages=[
40
- {
41
- "role": "system",
42
- "content": [
43
- {
44
- "type": "text",
45
- "text": "\"input\": \"\"\"You are a skilled assistant answering geographical and historical questions. For each question, generate a structured output in JSON format, based on city names without coordinates. The response should include:\
46
- Answer: A concise response to the question.\
47
- Feature Representation: A feature type based on city names (Point, LineString, Polygon, MultiPoint, MultiLineString, MultiPolygon, GeometryCollection).\
48
- Description: A prompt for a diffusion model describing the what should we draw regarding that.\
49
- \
50
- Handle the following cases:\
51
- \
52
- 1. **Single (Point) or Multiple Points (MultiPoint)**: Create a point or a list of points for multiple cities.\
53
- 2. **LineString**: Create a line between two cities.\
54
- 3. **Polygon**: Represent an area formed by three or more cities (closed). Example: Cities forming a triangle (A, B, C).\
55
- 4. **MultiPoint, MultiLineString, MultiPolygon, GeometryCollection**: Use as needed based on the question.\
56
- \
57
- For example, if asked about cities forming a polygon, create a feature like this:\
58
- \
59
- Input: Mark an area with three cities.\
60
- Output: {\"input\": \"Mark an area with three cities.\", \"output\": {\"answer\": \"The cities A, B, and C form a triangle.\", \"feature_representation\": {\"type\": \"Polygon\", \"cities\": [\"A\", \"B\", \"C\"], \"properties\": {\"description\": \"satelite image of a plantation, green fill, 4k, map, detailed, greenary, plants, vegitation, high contrast\"}}}}\
61
- \
62
- Ensure all responses are descriptive and relevant to city names only, without coordinates. **Adhere to given example format always**\
63
- \"}\"}"
64
- }
65
- ]
66
- },
67
- {
68
- "role": "user",
69
- "content": [
70
- {
71
- "type": "text",
72
- "text": query
73
- }
74
- ]
75
- }
76
- ],
77
  temperature=1,
78
  max_tokens=2048,
79
  top_p=1,
@@ -99,48 +87,55 @@ def generate_geojson(response):
99
  if feature_type == "Polygon":
100
  coordinates.append(coordinates[0]) # Close the polygon
101
 
102
- return {
103
  "type": "FeatureCollection",
104
- "features": [{
105
- "type": "Feature",
106
- "properties": properties,
107
- "geometry": {
108
- "type": feature_type,
109
- "coordinates": [coordinates] if feature_type == "Polygon" else coordinates
 
 
110
  }
111
- }]
112
  }
113
 
 
 
 
 
 
 
 
 
 
114
  # Generate static map image
115
  @spaces.GPU
116
  def generate_static_map(geojson_data, invisible=False):
117
- # Create a static map object with specified dimensions
118
  m = StaticMap(600, 600)
119
- #log the geojson data
120
  logger.info(f"GeoJSON data: {geojson_data}")
121
- # Process each feature in the GeoJSON
122
  for feature in geojson_data["features"]:
123
  geom_type = feature["geometry"]["type"]
124
  coords = feature["geometry"]["coordinates"]
125
 
126
  if geom_type == "Point":
127
- m.add_marker(CircleMarker((coords[0][0], coords[0][1]), '#1C00ff00' if invisible == True else 'blue', 1000))
128
  elif geom_type in ["MultiPoint", "LineString"]:
129
  for coord in coords:
130
- m.add_marker(CircleMarker((coord[0], coord[1]), '#1C00ff00' if invisible == True else 'blue', 1000))
131
  elif geom_type in ["Polygon", "MultiPolygon"]:
132
  for polygon in coords:
133
- m.add_polygon(Polygon([(c[0], c[1]) for c in polygon], '#1C00ff00' if invisible == True else 'blue', 3))
134
-
135
- return m.render() #zoom=10
136
 
 
137
 
138
  # ControlNet pipeline setup
139
  controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16)
140
  pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
141
  "stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
142
  )
143
- # ZeroGPU compatibility
144
  pipeline.to('cuda')
145
 
146
  @spaces.GPU
@@ -164,66 +159,23 @@ def generate_satellite_image(init_image, mask_image, prompt):
164
  control_image=control_image,
165
  strength=0.42,
166
  guidance_scale=62
167
- )
168
  return result.images[0]
169
 
170
  # Gradio UI
171
  @spaces.GPU
172
  def handle_query(query):
173
- # Process OpenAI response
174
  response = process_openai_response(query)
175
  geojson_data = generate_geojson(response)
176
 
177
- geojson_data = {
178
- "type": "FeatureCollection",
179
- "features": [
180
- {
181
- "type": "Feature",
182
- "properties": {
183
- "description": "satellite image of the Coconut Triangle region, green fill, 4k, map, detailed, coconut palms, lush vegetation, high contrast"
184
- },
185
- "geometry": {
186
- "type": "Polygon",
187
- "coordinates": [
188
- [
189
- [
190
- 80.364908,
191
- 7.4870464
192
- ],
193
- [
194
- 79.82933709234904,
195
- 7.981840249999999
196
- ],
197
- [
198
- 79.91598756451819,
199
- 7.1190247499999995
200
- ],
201
- [
202
- 80.364908,
203
- 7.4870464
204
- ]
205
- ]
206
- ]
207
- }
208
- }
209
- ]
210
- }
211
-
212
-
213
- # Generate the main map image
214
  map_image = generate_static_map(geojson_data)
 
215
 
216
- empty_map_image = generate_static_map(geojson_data, invisible=True) # Empty map with the same bounds
217
-
218
- # Create the mask
219
  difference = np.abs(np.array(map_image.convert("RGB")) - np.array(empty_map_image.convert("RGB")))
220
- threshold = 10 # Tolerance for difference
221
  mask = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255
222
 
223
- # Convert the mask to a PIL image
224
  mask_image = Image.fromarray(mask, mode="L")
225
-
226
- # Generate the satellite image
227
  satellite_image = generate_satellite_image(
228
  empty_map_image, mask_image, response['output']['feature_representation']['properties']['description']
229
  )
@@ -238,7 +190,6 @@ query_options = [
238
  "Due to considerable rainfall in the up- and mid- stream areas of Kala Oya, the Rajanganaya reservoir is now spilling at a rate of 17,000 cubic feet per second, the department said."
239
  ]
240
 
241
- # Gradio interface
242
  with gr.Blocks() as demo:
243
  with gr.Row():
244
  selected_query = gr.Dropdown(label="Select Query", choices=query_options, value=query_options[-1])
 
10
  from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
11
  import spaces
12
  import logging
13
+ from pydantic import BaseModel, ValidationError, Field
14
+ from typing import List, Union
15
 
16
  # Set up logging
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
21
  openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
22
  geolocator = Nominatim(user_agent="geoapi")
23
 
24
+ # Define Pydantic models for GeoJSON validation
25
+ class Geometry(BaseModel):
26
+ type: str
27
+ coordinates: Union[List[List[float]], List[List[List[float]]]]
28
+
29
+ class Feature(BaseModel):
30
+ type: str = "Feature"
31
+ properties: dict
32
+ geometry: Geometry
33
+
34
+ class FeatureCollection(BaseModel):
35
+ type: str = "FeatureCollection"
36
+ features: List[Feature]
37
+
38
  # Function to fetch coordinates
39
  @spaces.GPU
40
  def get_geo_coordinates(location_name):
 
53
  response = openai_client.chat.completions.create(
54
  model="gpt-4o-mini",
55
  messages=[
56
+ {
57
+ "role": "system",
58
+ "content": "You are a skilled assistant answering geographical and historical questions in JSON format."
59
+ },
60
+ {
61
+ "role": "user",
62
+ "content": query
63
+ }
64
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  temperature=1,
66
  max_tokens=2048,
67
  top_p=1,
 
87
  if feature_type == "Polygon":
88
  coordinates.append(coordinates[0]) # Close the polygon
89
 
90
+ geojson_data = {
91
  "type": "FeatureCollection",
92
+ "features": [
93
+ {
94
+ "type": "Feature",
95
+ "properties": properties,
96
+ "geometry": {
97
+ "type": feature_type,
98
+ "coordinates": [coordinates] if feature_type == "Polygon" else coordinates
99
+ }
100
  }
101
+ ]
102
  }
103
 
104
+ # Validate the generated GeoJSON using Pydantic
105
+ try:
106
+ validated_geojson = FeatureCollection(**geojson_data)
107
+ logger.info("GeoJSON validation successful.")
108
+ return validated_geojson.dict()
109
+ except ValidationError as e:
110
+ logger.error(f"GeoJSON validation failed: {e}")
111
+ raise
112
+
113
  # Generate static map image
114
  @spaces.GPU
115
  def generate_static_map(geojson_data, invisible=False):
 
116
  m = StaticMap(600, 600)
 
117
  logger.info(f"GeoJSON data: {geojson_data}")
118
+
119
  for feature in geojson_data["features"]:
120
  geom_type = feature["geometry"]["type"]
121
  coords = feature["geometry"]["coordinates"]
122
 
123
  if geom_type == "Point":
124
+ m.add_marker(CircleMarker((coords[0][0], coords[0][1]), '#1C00ff00' if invisible else 'blue', 1000))
125
  elif geom_type in ["MultiPoint", "LineString"]:
126
  for coord in coords:
127
+ m.add_marker(CircleMarker((coord[0], coord[1]), '#1C00ff00' if invisible else 'blue', 1000))
128
  elif geom_type in ["Polygon", "MultiPolygon"]:
129
  for polygon in coords:
130
+ m.add_polygon(Polygon([(c[0], c[1]) for c in polygon], '#1C00ff00' if invisible else 'blue', 3))
 
 
131
 
132
+ return m.render()
133
 
134
  # ControlNet pipeline setup
135
  controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16)
136
  pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
137
  "stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
138
  )
 
139
  pipeline.to('cuda')
140
 
141
  @spaces.GPU
 
159
  control_image=control_image,
160
  strength=0.42,
161
  guidance_scale=62
162
+ )
163
  return result.images[0]
164
 
165
  # Gradio UI
166
  @spaces.GPU
167
  def handle_query(query):
 
168
  response = process_openai_response(query)
169
  geojson_data = generate_geojson(response)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  map_image = generate_static_map(geojson_data)
172
+ empty_map_image = generate_static_map(geojson_data, invisible=True)
173
 
 
 
 
174
  difference = np.abs(np.array(map_image.convert("RGB")) - np.array(empty_map_image.convert("RGB")))
175
+ threshold = 10
176
  mask = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255
177
 
 
178
  mask_image = Image.fromarray(mask, mode="L")
 
 
179
  satellite_image = generate_satellite_image(
180
  empty_map_image, mask_image, response['output']['feature_representation']['properties']['description']
181
  )
 
190
  "Due to considerable rainfall in the up- and mid- stream areas of Kala Oya, the Rajanganaya reservoir is now spilling at a rate of 17,000 cubic feet per second, the department said."
191
  ]
192
 
 
193
  with gr.Blocks() as demo:
194
  with gr.Row():
195
  selected_query = gr.Dropdown(label="Select Query", choices=query_options, value=query_options[-1])
requirements.txt CHANGED
@@ -11,4 +11,4 @@ torchvision
11
  opencv-python
12
  torch
13
  staticmap
14
- selenium
 
11
  opencv-python
12
  torch
13
  staticmap
14
+ pydantic