Spaces:
Sleeping
Sleeping
Commit
·
8630bef
1
Parent(s):
c9cae08
feat: add GCS
Browse files- src/cloud_storage/__init__.py +11 -0
- src/cloud_storage/base_storage.py +13 -0
- src/cloud_storage/google_cloud_storage.py +34 -0
- src/fetch_places.py +6 -6
- src/openai_predictions.py +7 -4
src/cloud_storage/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .google_cloud_storage import GoogleCloudStorage
|
2 |
+
|
3 |
+
storage_map = {
|
4 |
+
"gcp": GoogleCloudStorage,
|
5 |
+
}
|
6 |
+
|
7 |
+
|
8 |
+
def get_cloud_storage(ingestion_mode):
|
9 |
+
if ingestion_mode not in storage_map:
|
10 |
+
raise ValueError(f"Invalid storage mode: {ingestion_mode}")
|
11 |
+
return storage_map[ingestion_mode]()
|
src/cloud_storage/base_storage.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseStorage:
|
2 |
+
def __init__(self):
|
3 |
+
self.client = None
|
4 |
+
|
5 |
+
def upload_file_from_content(
|
6 |
+
self, bucket_name: str, bucket_file_path: str, file_content: str
|
7 |
+
) -> str:
|
8 |
+
raise NotImplementedError
|
9 |
+
|
10 |
+
def upload_file_from_path(
|
11 |
+
self, bucket_name: str, bucket_file_path: str, source_file_path: str
|
12 |
+
) -> str:
|
13 |
+
raise NotImplementedError
|
src/cloud_storage/google_cloud_storage.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from google.cloud import storage
|
2 |
+
from google.oauth2 import service_account
|
3 |
+
|
4 |
+
from .base_storage import BaseStorage
|
5 |
+
|
6 |
+
|
7 |
+
class GoogleCloudStorage(BaseStorage):
|
8 |
+
def __init__(self, service_account_path: str):
|
9 |
+
credentials = service_account.Credentials.from_service_account_file(
|
10 |
+
service_account_path
|
11 |
+
)
|
12 |
+
self.client = storage.Client(credentials=credentials)
|
13 |
+
|
14 |
+
def upload_file_from_content(
|
15 |
+
self, bucket_name: str, bucket_file_path: str, file_content: bytes
|
16 |
+
) -> str:
|
17 |
+
"""Uploads a file to a Google Cloud Storage bucket
|
18 |
+
|
19 |
+
Args:
|
20 |
+
bucket_name (str): Bucket name
|
21 |
+
bucket_file_path (str): Path to the file to be uploaded
|
22 |
+
file_content (bytes): Path to the file to be uploaded
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
str: URL of the uploaded file
|
26 |
+
"""
|
27 |
+
print(f"gs://{bucket_name}/{bucket_file_path}")
|
28 |
+
|
29 |
+
bucket = self.client.get_bucket(bucket_name)
|
30 |
+
blob = bucket.blob(f"{bucket_file_path}")
|
31 |
+
blob.upload_from_string(file_content)
|
32 |
+
|
33 |
+
print(f"File uploaded to gs://{bucket_name}/{bucket_file_path}")
|
34 |
+
return f"gs://{bucket_name}/{bucket_file_path}"
|
src/fetch_places.py
CHANGED
@@ -8,11 +8,12 @@ from data_models.park_manager import ParkManager
|
|
8 |
from data_models.image_manager import ImageManager
|
9 |
from data_models.openai_manager import OpenAIManager
|
10 |
from openai_predictions import process_agent_predictions
|
11 |
-
|
12 |
|
13 |
park_manager = ParkManager()
|
14 |
image_manager = ImageManager()
|
15 |
prediction_manager = OpenAIManager()
|
|
|
16 |
|
17 |
|
18 |
def nearby_search(coordinates: tuple) -> Tuple[str, str, dict]:
|
@@ -144,12 +145,11 @@ def fetch_place_photos(place_id: str, place_name: str, photos: dict) -> int:
|
|
144 |
max_height = photo["heightPx"]
|
145 |
|
146 |
photo_binary = fetch_photo(place_id, photo_id, max_width, max_height)
|
147 |
-
file_name = f"{folder_path}/{photo_id[:150]}.jpg"
|
148 |
-
with open(file_name, "wb") as f:
|
149 |
-
f.write(photo_binary)
|
150 |
|
151 |
-
|
152 |
-
|
|
|
|
|
153 |
prediction_manager.add_predictions(image_id, predictions)
|
154 |
|
155 |
print(f"{len(photos)} photos fetched for place: {place_name}")
|
|
|
8 |
from data_models.image_manager import ImageManager
|
9 |
from data_models.openai_manager import OpenAIManager
|
10 |
from openai_predictions import process_agent_predictions
|
11 |
+
from cloud_storage.google_cloud_storage import GoogleCloudStorage
|
12 |
|
13 |
park_manager = ParkManager()
|
14 |
image_manager = ImageManager()
|
15 |
prediction_manager = OpenAIManager()
|
16 |
+
gcs = GoogleCloudStorage("ialab-fr-7047e56bef0b.json")
|
17 |
|
18 |
|
19 |
def nearby_search(coordinates: tuple) -> Tuple[str, str, dict]:
|
|
|
145 |
max_height = photo["heightPx"]
|
146 |
|
147 |
photo_binary = fetch_photo(place_id, photo_id, max_width, max_height)
|
|
|
|
|
|
|
148 |
|
149 |
+
file_name = f"{place_name}/{photo_id[:150]}.jpg"
|
150 |
+
bucket_path = gcs.upload_file_from_content("suad_park", file_name, photo_binary)
|
151 |
+
image_id = image_manager.add_image(bucket_path, datetime.now(), park_id=park_id)
|
152 |
+
predictions = process_agent_predictions(file_name, photo_binary)
|
153 |
prediction_manager.add_predictions(image_id, predictions)
|
154 |
|
155 |
print(f"{len(photos)} photos fetched for place: {place_name}")
|
src/openai_predictions.py
CHANGED
@@ -6,9 +6,8 @@ from openai_agent import Agent
|
|
6 |
|
7 |
agent = Agent("./prompts")
|
8 |
|
9 |
-
|
10 |
# Function to encode the image
|
11 |
-
def
|
12 |
if not os.path.exists(image_path):
|
13 |
raise FileNotFoundError(f"Image not found at {image_path}")
|
14 |
|
@@ -16,8 +15,12 @@ def encode_image(image_path: str):
|
|
16 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
17 |
|
18 |
|
19 |
-
def
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
|
22 |
prompts = [
|
23 |
"built_elements",
|
|
|
6 |
|
7 |
agent = Agent("./prompts")
|
8 |
|
|
|
9 |
# Function to encode the image
|
10 |
+
def encode_image_from_path(image_path: str) -> str:
|
11 |
if not os.path.exists(image_path):
|
12 |
raise FileNotFoundError(f"Image not found at {image_path}")
|
13 |
|
|
|
15 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
16 |
|
17 |
|
18 |
+
def encode_image(photo_binary: bytes) -> str:
|
19 |
+
return base64.b64encode(photo_binary).decode("utf-8")
|
20 |
+
|
21 |
+
|
22 |
+
def process_agent_predictions(file_path: str, photo_binary: bytes):
|
23 |
+
base64_image = encode_image(photo_binary)
|
24 |
|
25 |
prompts = [
|
26 |
"built_elements",
|