88hours commited on
Commit
ad022d3
·
verified ·
1 Parent(s): f7c72f7

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,36 +1,36 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- *.mp4 filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,12 +1,16 @@
1
- myenv
2
- __pycache__
3
- .gradio
4
- .venv
5
- .env
6
- .env.*
7
- !.env.example
8
- .github
9
- # LanceDB files
10
- shared_data/.lancedb/
11
- shared_data/.lancedb/**/*
12
- shared_data/videos/yt_video/blackholes101nationalgeographic/audio.mp3
 
 
 
 
 
1
+ myenv
2
+ __pycache__
3
+ .gradio
4
+ .venv
5
+ .env
6
+ .env.*
7
+ !.env.example
8
+ .github
9
+ # LanceDB files
10
+ shared_data/.lancedb/
11
+ shared_data/.lancedb/**/*
12
+ shared_data/videos/yt_video/blackholes101nationalgeographic/audio.mp3
13
+ mm_rag/embeddings/__pycache__/
14
+ mm_rag/embeddings/__pycache__/**
15
+ .DS_Store
16
+ .devcontainer/
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM mcr.microsoft.com/devcontainers/python:3.11
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
5
+ && apt-get -y install --no-install-recommends \
6
+ ffmpeg \
7
+ && apt-get clean -y && rm -rf /var/lib/apt/lists/*
8
+
9
+
10
+ # Install the required system libraries for OpenCV
11
+ RUN apt-get update && apt-get install -y libgl1-mesa-glx
12
+
13
+ # Install PyTorch and other dependencies
14
+ RUN pip install -r requirements.txt
15
+
16
+ # Run the application
17
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,182 +1,182 @@
1
- ---
2
- title: multimodel-rag-chat-with-videos
3
- app_file: app.py
4
- sdk: gradio
5
- sdk_version: 5.17.1
6
- ---
7
-
8
- # Demo
9
- ## Sample Video
10
- - https://www.youtube.com/watch?v=kOEDG3j1bjs
11
- - https://www.youtube.com/watch?v=7Hcg-rLYwdM
12
- ## Questions
13
- - Event Horizon
14
- - show me a group of astronauts, AStronaut name
15
- # ReArchitecture Multimodal RAG System Pipeline Journey
16
- I ported it locally and isolated each concept into a step as Python runnable
17
- It is simplified, refactored and bug-fixed now.
18
- I migrated from Prediction Guard to HuggingFace.
19
-
20
- [**Interactive Video Chat Demo and Multimodal RAG System Architecture**](https://learn.deeplearning.ai/courses/multimodal-rag-chat-with-videos/lesson/2/interactive-demo-and-multimodal-rag-system-architecture)
21
-
22
- ### A multimodal AI system should be able to understand both text and video content.
23
-
24
- ## Setup
25
- ```bash
26
- python -m venv venv
27
- source venv/bin/activate
28
- ```
29
- For Fish
30
- ```bash
31
- source venv/bin/activate.fish
32
- ```
33
-
34
- ## Step 1 - Learn Gradio (UI) (30 mins)
35
-
36
- Gradio is a powerful Python library for quickly building browser-based UIs. It supports hot reloading for fast development.
37
-
38
- ### Key Concepts:
39
- - **fn**: The function wrapped by the UI.
40
- - **inputs**: The Gradio components used for input (should match function arguments).
41
- - **outputs**: The Gradio components used for output (should match return values).
42
-
43
- 📖 [**Gradio Documentation**](https://www.gradio.app/docs/gradio/introduction)
44
-
45
- Gradio includes **30+ built-in components**.
46
-
47
- 💡 **Tip**: For `inputs` and `outputs`, you can pass either:
48
- - The **component name** as a string (e.g., `"textbox"`)
49
- - An **instance of the component class** (e.g., `gr.Textbox()`)
50
-
51
- ### Sharing Your Demo
52
- ```python
53
- demo.launch(share=True) # Share your demo with just one extra parameter.
54
- ```
55
-
56
- ## Gradio Advanced Features
57
-
58
- ### **Gradio.Blocks**
59
- Gradio provides `gr.Blocks`, a flexible way to design web apps with **custom layouts and complex interactions**:
60
- - Arrange components freely on the page.
61
- - Handle multiple data flows.
62
- - Use outputs as inputs for other components.
63
- - Dynamically update components based on user interaction.
64
-
65
- ### **Gradio.ChatInterface**
66
- - Always set `type="messages"` in `gr.ChatInterface`.
67
- - The default (`type="tuples"`) is **deprecated** and will be removed in future versions.
68
- - For more UI flexibility, use `gr.ChatBot`.
69
- - `gr.ChatInterface` supports **Markdown** (not tested yet).
70
-
71
- ---
72
-
73
- ## Step 2 - Learn Bridge Tower Embedding Model (Multimodal Learning) (15 mins)
74
-
75
- Developed in collaboration with Intel, this model maps image-caption pairs into **512-dimensional vectors**.
76
-
77
- ### Measuring Similarity
78
- - **Cosine Similarity** → Measures how close images are in vector space (**efficient & commonly used**).
79
- - **Euclidean Distance** → Uses `cv2.NORM_L2` to compute similarity between two images.
80
-
81
- ### Converting to 2D for Visualization
82
- - **UMAP** reduces 512D embeddings to **2D for display purposes**.
83
-
84
- ## Preprocessing Videos for Multimodal RAG
85
-
86
- ### **Case 1: WEBVTT → Extracting Text Segments from Video**
87
- - Converts video + text into structured metadata.
88
- - Splits content inhttps://www.youtube.com/watch?v=kOEDG3j1bjsto multiple segments.
89
-
90
- ### **Case 2: Whisper (Small) → Video Only**
91
- - Extracts **audio** → `model.transcribe()`.
92
- - Applies `getSubs()` helper function to retrieve **WEBVTT** subtitles.
93
- - Uses **Case 1** processing.
94
-
95
- ### **Case 3: LvLM → Video + Silent/Music Extraction**
96
- - Uses **Llava (LvLM model)** for **frame-based captioning**.
97
- - Encodes each frame as a **Base64 image**.
98
- - Extracts context and captions from video frames.
99
- - Uses **Case 1** processing.
100
-
101
- # Step 4 - What is LLaVA?
102
- LLaVA (Large Language-and-Vision Assistant), a large multimodal model that connects a vision encoder that doesn't just see images but understands them, reads the text embedded in them, and reasons about their context—all.
103
-
104
- # Step 5 - what is a vector Store?
105
- A vector store is a specialized database designed to:
106
-
107
- - Store and manage high-dimensional vector data efficiently
108
- - Perform similarity-based searches where K=1 returns the most similar result
109
-
110
- - In LanceDB specifically, store multiple data types:
111
- . Text content (captions)
112
- . Image file paths
113
- . Metadata
114
- . Vector embeddings
115
-
116
- ```python
117
- _ = MultimodalLanceDB.from_text_image_pairs(
118
- texts=updated_vid1_trans+vid2_trans,
119
- image_paths=vid1_img_path+vid2_img_path,
120
- embedding=BridgeTowerEmbeddings(),
121
- metadatas=vid1_metadata+vid2_metadata,
122
- connection=db,
123
- table_name=TBL_NAME,
124
- mode="overwrite",
125
- )
126
- ```
127
- # Gotchas and Solutions
128
- Image Processing: When working with base64 encoded images, convert them to PIL.Image format before processing with BridgeTower
129
- Model Selection: Using BridgeTowerForContrastiveLearning instead of PredictionGuard due to API access limitations
130
- Model Size: BridgeTower model requires ~3.5GB download
131
- Image Downloads: Some Flickr images may be unavailable; implement robust error handling
132
- Token Decoding: BridgeTower contrastive learning model works with embeddings, not token predictions
133
- Install from git+https://github.com/openai/whisper.git
134
-
135
- # Install ffmepg using brew
136
- ```bash
137
- brew install ffmpeg
138
- brew link ffmpeg
139
- ```
140
-
141
-
142
- # Learning and Skills
143
-
144
- ## Technical Skills:
145
-
146
- Basic Machine learning and deep learning
147
- Vector embeddings and similarity search
148
- Multimodal data processing
149
-
150
- ## Framework & Library Expertise:
151
-
152
- Hugging Face Transformers
153
- Gradio UI development
154
- LangChain integration (Basic)
155
- PyTorch basics
156
- LanceDB vector storage
157
-
158
- ## AI/ML Concepts:
159
-
160
- Multimodal RAG system architecture
161
- Vector embeddings and similarity search
162
- Large Language Models (LLaVA)
163
- Image-text pair processing
164
- Dimensionality reduction techniques
165
-
166
-
167
- ## Multimedia Processing:
168
-
169
- Video frame extraction
170
- Audio transcription (Whisper)
171
- Image processing (PIL)
172
- Base64 encoding/decoding
173
- WebVTT handling
174
-
175
- ## System Design:
176
-
177
- Client-server architecture
178
- API endpoint design
179
- Data pipeline construction
180
- Vector store implementation
181
- Multimodal system integration
182
-
 
1
+ ---
2
+ title: multimodel-rag-chat-with-videos
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.17.1
6
+ ---
7
+
8
+ # Demo
9
+ ## Sample Video
10
+ - https://www.youtube.com/watch?v=kOEDG3j1bjs
11
+ - https://www.youtube.com/watch?v=7Hcg-rLYwdM
12
+ ## Questions
13
+ - Event Horizon
14
+ - show me a group of astronauts, AStronaut name
15
+ # ReArchitecture Multimodal RAG System Pipeline Journey
16
+ I ported it locally and isolated each concept into a step as Python runnable
17
+ It is simplified, refactored and bug-fixed now.
18
+ I migrated from Prediction Guard to HuggingFace.
19
+
20
+ [**Interactive Video Chat Demo and Multimodal RAG System Architecture**](https://learn.deeplearning.ai/courses/multimodal-rag-chat-with-videos/lesson/2/interactive-demo-and-multimodal-rag-system-architecture)
21
+
22
+ ### A multimodal AI system should be able to understand both text and video content.
23
+
24
+ ## Setup
25
+ ```bash
26
+ python -m venv venv
27
+ source venv/bin/activate
28
+ ```
29
+ For Fish
30
+ ```bash
31
+ source venv/bin/activate.fish
32
+ ```
33
+
34
+ ## Step 1 - Learn Gradio (UI) (30 mins)
35
+
36
+ Gradio is a powerful Python library for quickly building browser-based UIs. It supports hot reloading for fast development.
37
+
38
+ ### Key Concepts:
39
+ - **fn**: The function wrapped by the UI.
40
+ - **inputs**: The Gradio components used for input (should match function arguments).
41
+ - **outputs**: The Gradio components used for output (should match return values).
42
+
43
+ 📖 [**Gradio Documentation**](https://www.gradio.app/docs/gradio/introduction)
44
+
45
+ Gradio includes **30+ built-in components**.
46
+
47
+ 💡 **Tip**: For `inputs` and `outputs`, you can pass either:
48
+ - The **component name** as a string (e.g., `"textbox"`)
49
+ - An **instance of the component class** (e.g., `gr.Textbox()`)
50
+
51
+ ### Sharing Your Demo
52
+ ```python
53
+ demo.launch(share=True) # Share your demo with just one extra parameter.
54
+ ```
55
+
56
+ ## Gradio Advanced Features
57
+
58
+ ### **Gradio.Blocks**
59
+ Gradio provides `gr.Blocks`, a flexible way to design web apps with **custom layouts and complex interactions**:
60
+ - Arrange components freely on the page.
61
+ - Handle multiple data flows.
62
+ - Use outputs as inputs for other components.
63
+ - Dynamically update components based on user interaction.
64
+
65
+ ### **Gradio.ChatInterface**
66
+ - Always set `type="messages"` in `gr.ChatInterface`.
67
+ - The default (`type="tuples"`) is **deprecated** and will be removed in future versions.
68
+ - For more UI flexibility, use `gr.ChatBot`.
69
+ - `gr.ChatInterface` supports **Markdown** (not tested yet).
70
+
71
+ ---
72
+
73
+ ## Step 2 - Learn Bridge Tower Embedding Model (Multimodal Learning) (15 mins)
74
+
75
+ Developed in collaboration with Intel, this model maps image-caption pairs into **512-dimensional vectors**.
76
+
77
+ ### Measuring Similarity
78
+ - **Cosine Similarity** → Measures how close images are in vector space (**efficient & commonly used**).
79
+ - **Euclidean Distance** → Uses `cv2.NORM_L2` to compute similarity between two images.
80
+
81
+ ### Converting to 2D for Visualization
82
+ - **UMAP** reduces 512D embeddings to **2D for display purposes**.
83
+
84
+ ## Preprocessing Videos for Multimodal RAG
85
+
86
+ ### **Case 1: WEBVTT → Extracting Text Segments from Video**
87
+ - Converts video + text into structured metadata.
88
+ - Splits content inhttps://www.youtube.com/watch?v=kOEDG3j1bjsto multiple segments.
89
+
90
+ ### **Case 2: Whisper (Small) → Video Only**
91
+ - Extracts **audio** → `model.transcribe()`.
92
+ - Applies `getSubs()` helper function to retrieve **WEBVTT** subtitles.
93
+ - Uses **Case 1** processing.
94
+
95
+ ### **Case 3: LvLM → Video + Silent/Music Extraction**
96
+ - Uses **Llava (LvLM model)** for **frame-based captioning**.
97
+ - Encodes each frame as a **Base64 image**.
98
+ - Extracts context and captions from video frames.
99
+ - Uses **Case 1** processing.
100
+
101
+ # Step 4 - What is LLaVA?
102
+ LLaVA (Large Language-and-Vision Assistant), a large multimodal model that connects a vision encoder that doesn't just see images but understands them, reads the text embedded in them, and reasons about their context—all.
103
+
104
+ # Step 5 - what is a vector Store?
105
+ A vector store is a specialized database designed to:
106
+
107
+ - Store and manage high-dimensional vector data efficiently
108
+ - Perform similarity-based searches where K=1 returns the most similar result
109
+
110
+ - In LanceDB specifically, store multiple data types:
111
+ . Text content (captions)
112
+ . Image file paths
113
+ . Metadata
114
+ . Vector embeddings
115
+
116
+ ```python
117
+ _ = MultimodalLanceDB.from_text_image_pairs(
118
+ texts=updated_vid1_trans+vid2_trans,
119
+ image_paths=vid1_img_path+vid2_img_path,
120
+ embedding=BridgeTowerEmbeddings(),
121
+ metadatas=vid1_metadata+vid2_metadata,
122
+ connection=db,
123
+ table_name=TBL_NAME,
124
+ mode="overwrite",
125
+ )
126
+ ```
127
+ # Gotchas and Solutions
128
+ Image Processing: When working with base64 encoded images, convert them to PIL.Image format before processing with BridgeTower
129
+ Model Selection: Using BridgeTowerForContrastiveLearning instead of PredictionGuard due to API access limitations
130
+ Model Size: BridgeTower model requires ~3.5GB download
131
+ Image Downloads: Some Flickr images may be unavailable; implement robust error handling
132
+ Token Decoding: BridgeTower contrastive learning model works with embeddings, not token predictions
133
+ Install from git+https://github.com/openai/whisper.git
134
+
135
+ # Install ffmepg using brew
136
+ ```bash
137
+ brew install ffmpeg
138
+ brew link ffmpeg
139
+ ```
140
+
141
+
142
+ # Learning and Skills
143
+
144
+ ## Technical Skills:
145
+
146
+ Basic Machine learning and deep learning
147
+ Vector embeddings and similarity search
148
+ Multimodal data processing
149
+
150
+ ## Framework & Library Expertise:
151
+
152
+ Hugging Face Transformers
153
+ Gradio UI development
154
+ LangChain integration (Basic)
155
+ PyTorch basics
156
+ LanceDB vector storage
157
+
158
+ ## AI/ML Concepts:
159
+
160
+ Multimodal RAG system architecture
161
+ Vector embeddings and similarity search
162
+ Large Language Models (LLaVA)
163
+ Image-text pair processing
164
+ Dimensionality reduction techniques
165
+
166
+
167
+ ## Multimedia Processing:
168
+
169
+ Video frame extraction
170
+ Audio transcription (Whisper)
171
+ Image processing (PIL)
172
+ Base64 encoding/decoding
173
+ WebVTT handling
174
+
175
+ ## System Design:
176
+
177
+ Client-server architecture
178
+ API endpoint design
179
+ Data pipeline construction
180
+ Vector store implementation
181
+ Multimodal system integration
182
+
app.py CHANGED
@@ -1,376 +1,385 @@
1
- from pathlib import Path
2
- import gradio as gr
3
- import os
4
- from PIL import Image
5
- import ollama
6
- from utility import download_video, get_transcript_vtt, extract_meta_data, lvlm_inference_with_phi, lvlm_inference_with_tiny_model, lvlm_inference_with_tiny_model
7
- from mm_rag.embeddings.bridgetower_embeddings import (
8
- BridgeTowerEmbeddings
9
- )
10
- from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
11
- import lancedb
12
- import json
13
- import os
14
- from PIL import Image
15
- from utility import load_json_file, display_retrieved_results
16
- import pyarrow as pa
17
-
18
- # declare host file
19
- LANCEDB_HOST_FILE = "./shared_data/.lancedb"
20
- # declare table name
21
- # initialize vectorstore
22
- db = lancedb.connect(LANCEDB_HOST_FILE)
23
- # initialize an BridgeTower embedder
24
- embedder = BridgeTowerEmbeddings()
25
-
26
- base_dir = "./shared_data/videos/yt_video"
27
- Path(base_dir).mkdir(parents=True, exist_ok=True)
28
-
29
-
30
- def open_table(table_name):
31
- # open a connection to table TBL_NAME
32
- tbl = db.open_table(table_name)
33
-
34
- print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
35
- # display the first 3 rows of the table
36
- tbl.to_pandas()[['text', 'image_path']].head(3)
37
-
38
- def check_if_table_exists(table_name):
39
- return table_name in db.table_names()
40
-
41
- def store_in_rag(vid_table_name, vid_metadata_path):
42
-
43
- # load metadata files
44
-
45
- vid_metadata = load_json_file(vid_metadata_path)
46
-
47
-
48
- vid_subs = [vid['transcript'] for vid in vid_metadata]
49
- vid_img_path = [vid['extracted_frame_path'] for vid in vid_metadata]
50
-
51
-
52
- # for video1, we pick n = 7
53
- n = 7
54
- updated_vid_subs = [
55
- ' '.join(vid_subs[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
56
- ' '.join(vid_subs[0 : i + int(n/2)]) for i in range(len(vid_subs))
57
- ]
58
-
59
- # also need to update the updated transcripts in metadata
60
- for i in range(len(updated_vid_subs)):
61
- vid_metadata[i]['transcript'] = updated_vid_subs[i]
62
-
63
- # you can pass in mode="append"
64
- # to add more entries to the vector store
65
- # in case you want to start with a fresh vector store,
66
- # you can pass in mode="overwrite" instead
67
-
68
-
69
- print("Creating vid_table_name ", vid_table_name)
70
- _ = MultimodalLanceDB.from_text_image_pairs(
71
- texts=updated_vid_subs,
72
- image_paths=vid_img_path,
73
- embedding=embedder,
74
- metadatas=vid_metadata,
75
- connection=db,
76
- table_name=vid_table_name,
77
- mode="overwrite",
78
- )
79
- open_table(vid_table_name)
80
-
81
- return vid_table_name
82
-
83
- def get_metadata_of_yt_video_with_captions(vid_url, from_gen=False):
84
- vid_filepath, vid_folder_path, is_downloaded = download_video(vid_url, base_dir)
85
- if is_downloaded:
86
- print("Video downloaded at ", vid_filepath)
87
- if from_gen:
88
- # Delete existing caption and metadata files if they exist
89
- caption_file = f"{vid_folder_path}/captions.vtt"
90
- metadata_file = f"{vid_folder_path}/metadatas.json"
91
- if os.path.exists(caption_file):
92
- os.remove(caption_file)
93
- print(f"Deleted existing caption file: {caption_file}")
94
- if os.path.exists(metadata_file):
95
- os.remove(metadata_file)
96
- print(f"Deleted existing metadata file: {metadata_file}")
97
-
98
- print("checking transcript")
99
- vid_transcript_filepath = get_transcript_vtt(vid_folder_path, vid_url, vid_filepath, from_gen)
100
- vid_metadata_path = f"{vid_folder_path}/metadatas.json"
101
- print("checking metadatas at", vid_metadata_path)
102
- if os.path.exists(vid_metadata_path):
103
- print('Metadatas already exists')
104
- else:
105
- print("Downloading metadatas for the video ", vid_filepath)
106
- extract_meta_data(vid_folder_path, vid_filepath, vid_transcript_filepath) #should return lowercase file name without spaces
107
-
108
- parent_dir_name = os.path.basename(os.path.dirname(vid_metadata_path))
109
- vid_table_name = f"{parent_dir_name}_table"
110
- print("Checking db and Table name ", vid_table_name)
111
- if not check_if_table_exists(vid_table_name):
112
- print("Table does not exists Storing in RAG")
113
- else:
114
- print("Table exists")
115
- def delete_table(table_name):
116
- db.drop_table(table_name)
117
- print(f"Deleted table {table_name}")
118
- delete_table(vid_table_name)
119
-
120
- store_in_rag(vid_table_name, vid_metadata_path)
121
- return vid_filepath, vid_table_name
122
-
123
- """
124
- def chat_response_llvm(instruction):
125
- #file_path = the_metadatas[0]
126
- file_path = 'shared_data/videos/yt_video/extracted_frame/'
127
- result = ollama.generate(
128
- model='llava',
129
- prompt=instruction,
130
- images=[file_path],
131
- stream=True
132
- )['response']
133
- return result
134
- """
135
-
136
- def return_top_k_most_similar_docs(vid_table_name, query, use_llm=False):
137
- # Initialize results variable outside the if condition
138
- max_docs = 2
139
- print("Querying ", vid_table_name)
140
- vectorstore = MultimodalLanceDB(
141
- uri=LANCEDB_HOST_FILE,
142
- embedding=embedder,
143
- table_name=vid_table_name
144
- )
145
-
146
- retriever = vectorstore.as_retriever(
147
- search_type='similarity',
148
- search_kwargs={"k": max_docs}
149
- )
150
-
151
- # Get results first
152
- results = retriever.invoke(query)
153
-
154
- if use_llm:
155
- # Read captions.vtt file
156
- def read_vtt_file(file_path):
157
- with open(file_path, 'r', encoding='utf-8') as f:
158
- return f.read()
159
-
160
- vid_table_name = vid_table_name.split('_table')[0]
161
- caption_file = 'shared_data/videos/yt_video/' + vid_table_name + '/captions.vtt'
162
- print("Caption file path ", caption_file)
163
- captions = read_vtt_file(caption_file)
164
- prompt = "Answer this query : " + query + " from the content " + captions
165
- print("Prompt ", prompt)
166
- all_page_content = lvlm_inference_with_phi(prompt)
167
- else:
168
- all_page_content = "\n\n".join([result.page_content for result in results])
169
-
170
- page_content = gr.Textbox(all_page_content, label="Response", elem_id='chat-response', visible=True, interactive=False)
171
- image1 = Image.open(results[0].metadata['extracted_frame_path'])
172
- image2_path = results[1].metadata['extracted_frame_path']
173
-
174
- if results[0].metadata['extracted_frame_path'] == image2_path:
175
- image2 = gr.update(visible=False)
176
- else:
177
- image2 = Image.open(image2_path)
178
- image2 = gr.update(value=image2, visible=True)
179
-
180
- return page_content, image1, image2
181
-
182
-
183
- def process_url_and_init(youtube_url, from_gen=False):
184
- url_input = gr.update(visible=False)
185
- submit_btn = gr.update(visible=True)
186
- chatbox = gr.update(visible=True)
187
- submit_btn2 = gr.update(visible=True)
188
- frame1 = gr.update(visible=True)
189
- frame2 = gr.update(visible=False)
190
- chatbox_llm, submit_btn_chat = gr.update(visible=True), gr.update(visible=True)
191
- vid_filepath, vid_table_name = get_metadata_of_yt_video_with_captions(youtube_url, from_gen)
192
- video = gr.Video(vid_filepath,render=True)
193
- return url_input, submit_btn, video, vid_table_name, chatbox,submit_btn2, frame1, frame2, chatbox_llm, submit_btn_chat
194
-
195
- def test_btn():
196
- text = "hi"
197
- res = lvlm_inference_with_phi(text)
198
- response = gr.Textbox(res, visible=True,interactive=False)
199
- return response
200
-
201
- def init_ui():
202
- with gr.Blocks() as demo:
203
-
204
- gr.Markdown("Welcome to video chat demo - Initial processing can take up to 2 minutes, and responses may be slow. Please be patient and avoid clicking repeatedly.")
205
- url_input = gr.Textbox(label="Enter YouTube URL", visible=False, elem_id='url-inp',value="https://www.youtube.com/watch?v=kOEDG3j1bjs", interactive=True)
206
- vid_table_name = gr.Textbox(label="Enter Table Name", visible=False, interactive=False)
207
- video = gr.Video()
208
- with gr.Row():
209
- submit_btn = gr.Button("Process Video By Download Subtitles")
210
- submit_btn_gen = gr.Button("Process Video By Generating Subtitles")
211
-
212
- with gr.Row():
213
- chatbox = gr.Textbox(label="Enter the keyword/s and AI will get related captions and images", visible=False, value="event horizan", scale=4)
214
- submit_btn_whisper = gr.Button("Submit", elem_id='chat-submit', visible=False, scale=1)
215
- with gr.Row():
216
- chatbox_llm = gr.Textbox(label="Ask a Question", visible=False, value="what this video is about?", scale=4)
217
- submit_btn_chat = gr.Button("Ask", visible=False, scale=1)
218
-
219
- response = gr.Textbox(label="Response", elem_id='chat-response', visible=False,interactive=False)
220
-
221
- with gr.Row():
222
- frame1 = gr.Image(visible=False, interactive=False, scale=2)
223
- frame2 = gr.Image(visible=False, interactive=False, scale=2)
224
- submit_btn.click(fn=process_url_and_init, inputs=[url_input], outputs=[url_input, submit_btn, video, vid_table_name, chatbox,submit_btn_whisper, frame1, frame2, chatbox_llm, submit_btn_chat])
225
- submit_btn_gen.click(fn=lambda x: process_url_and_init(x, from_gen=True), inputs=[url_input], outputs=[url_input, submit_btn, video, vid_table_name, chatbox,submit_btn_whisper, frame1, frame2,chatbox_llm, submit_btn_chat])
226
- submit_btn_whisper.click(fn=return_top_k_most_similar_docs, inputs=[vid_table_name, chatbox], outputs=[response, frame1, frame2])
227
-
228
- submit_btn_chat.click(
229
- fn=lambda table_name, query: return_top_k_most_similar_docs(
230
- vid_table_name=table_name,
231
- query=query,
232
- use_llm=True
233
- ),
234
- inputs=[vid_table_name, chatbox_llm],
235
- outputs=[response, frame1, frame2]
236
- )
237
- reset_btn = gr.Button("Reload Page")
238
- reset_btn.click(None, js="() => { location.reload(); }")
239
-
240
- test_llama = gr.Button("Test Llama")
241
- test_llama.click(test_btn, None, outputs=[response])
242
- return demo
243
-
244
- def init_improved_ui():
245
-
246
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
247
- # Header Section with Introduction
248
- with gr.Group():
249
- gr.Markdown("""
250
- # 🎬 Video Analysis Assistant
251
-
252
- ## How it Works:
253
- 1. 📥 Provide a YouTube URL.
254
- 2. 🔄 Choose a processing method:
255
- - Download the video and its captions/subtitles from YouTube.
256
- - Download the video and generate captions using Whisper AI.
257
- The system will load the video in video player for preview and process the video and extract frames from it.
258
- It will then pass the captions and images to the RAG model to store them in the database.
259
- The RAG (Lance DB) uses a pre-trained BridgeTower model to generate embeddings that provide pairs of captions and related images.
260
- 3. 🤖 Analyze video content through:
261
- - Keyword Search - Use this functionality to search for keywords in the video. Our RAG model will return the most relevant captions and images.
262
- - AI-powered Q&A - Use this functionality to ask questions about the video content. Our system will use the Meta/LLaMA model to analyze the captions and images and provide detailed answers.
263
- 4. 📊 Results will be displayed in the response section with related images.
264
-
265
- > **Note**: Initial processing takes several minutes. Please be patient and monitor the logs for progress updates.
266
- """)
267
-
268
- # Video Input Section
269
- with gr.Group():
270
- url_input = gr.Textbox(
271
- label="YouTube URL",
272
- value="https://www.youtube.com/watch?v=kOEDG3j1bjs",
273
- visible=True,
274
- elem_id='url-inp',
275
- interactive=False
276
- )
277
- vid_table_name = gr.Textbox(label="Table Name", visible=False)
278
- video = gr.Video(label="Video Preview")
279
-
280
- with gr.Row():
281
- submit_btn = gr.Button("📥 Process with Existing Subtitles", variant="primary")
282
- submit_btn_gen = gr.Button("🎯 Generate New Subtitles", variant="secondary")
283
-
284
- # Analysis Tools Section
285
- with gr.Group():
286
- gr.Markdown("### 🔍 Analysis Tools")
287
-
288
- with gr.Tab("Keyword Search"):
289
- with gr.Row():
290
- chatbox = gr.Textbox(
291
- label="Search Keywords",
292
- value="event horizon",
293
- visible=False,
294
- scale=4
295
- )
296
- submit_btn_whisper = gr.Button(
297
- "🔎 Search",
298
- elem_id='chat-submit',
299
- visible=False,
300
- scale=1
301
- )
302
-
303
- with gr.Tab("AI Q&A"):
304
- with gr.Row():
305
- chatbox_llm = gr.Textbox(
306
- label="Ask AI about the video",
307
- value="What is this video about?",
308
- visible=False,
309
- scale=4
310
- )
311
- submit_btn_chat = gr.Button(
312
- "🤖 Ask",
313
- visible=False,
314
- scale=1
315
- )
316
-
317
- # Results Display Section
318
- with gr.Group():
319
- gr.Markdown("### 📊 Results")
320
- response = gr.Textbox(
321
- label="AI Response",
322
- elem_id='chat-response',
323
- visible=False,
324
- interactive=False
325
- )
326
-
327
- with gr.Row():
328
- frame1 = gr.Image(visible=False, label="Related Frame 1", scale=2)
329
- frame2 = gr.Image(visible=False, label="Related Frame 2", scale=2)
330
-
331
- # Control Buttons
332
- with gr.Row():
333
- reset_btn = gr.Button("🔄 Start Over", variant="secondary")
334
- test_llama = gr.Button("🧪 Say Hi to Llama", variant="secondary")
335
-
336
- # Event Handlers
337
- submit_btn.click(
338
- fn=process_url_and_init,
339
- inputs=[url_input],
340
- outputs=[url_input, submit_btn, video, vid_table_name,
341
- chatbox, submit_btn_whisper, frame1, frame2,
342
- chatbox_llm, submit_btn_chat]
343
- )
344
-
345
- submit_btn_gen.click(
346
- fn=lambda x: process_url_and_init(x, from_gen=True),
347
- inputs=[url_input],
348
- outputs=[url_input, submit_btn, video, vid_table_name,
349
- chatbox, submit_btn_whisper, frame1, frame2,
350
- chatbox_llm, submit_btn_chat]
351
- )
352
-
353
- submit_btn_whisper.click(
354
- fn=return_top_k_most_similar_docs,
355
- inputs=[vid_table_name, chatbox],
356
- outputs=[response, frame1, frame2]
357
- )
358
-
359
- submit_btn_chat.click(
360
- fn=lambda table_name, query: return_top_k_most_similar_docs(
361
- vid_table_name=table_name,
362
- query=query,
363
- use_llm=True
364
- ),
365
- inputs=[vid_table_name, chatbox_llm],
366
- outputs=[response, frame1, frame2]
367
- )
368
-
369
- reset_btn.click(None, js="() => { location.reload(); }")
370
- test_llama.click(test_btn, None, outputs=[response])
371
-
372
- return demo
373
-
374
- if __name__ == '__main__':
375
- demo = init_improved_ui() # Updated function name here
376
- demo.launch(share=True, debug=True)
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import gradio as gr
3
+ import os
4
+ from PIL import Image
5
+ import ollama
6
+ from utility import download_video, get_transcript_vtt, extract_meta_data, lvlm_inference_with_phi, lvlm_inference_with_tiny_model, lvlm_inference_with_tiny_model
7
+ from mm_rag.embeddings.bridgetower_embeddings import (
8
+ BridgeTowerEmbeddings
9
+ )
10
+ from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
11
+ import lancedb
12
+ import json
13
+ import os
14
+ from PIL import Image
15
+ from utility import load_json_file, display_retrieved_results
16
+ import pyarrow as pa
17
+
18
+ # declare host file
19
+ LANCEDB_HOST_FILE = "./shared_data/.lancedb"
20
+ # declare table name
21
+ # initialize vectorstore
22
+ db = lancedb.connect(LANCEDB_HOST_FILE)
23
+ # initialize an BridgeTower embedder
24
+ embedder = BridgeTowerEmbeddings()
25
+ video_processed = False
26
+ base_dir = "./shared_data/videos/yt_video"
27
+ Path(base_dir).mkdir(parents=True, exist_ok=True)
28
+
29
+
30
+ def open_table(table_name):
31
+ # open a connection to table TBL_NAME
32
+ tbl = db.open_table(table_name)
33
+
34
+ print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
35
+ # display the first 3 rows of the table
36
+ tbl.to_pandas()[['text', 'image_path']].head(3)
37
+
38
+
39
+ def check_if_table_exists(table_name):
40
+ return table_name in db.table_names()
41
+
42
+
43
+ def store_in_rag(vid_table_name, vid_metadata_path):
44
+
45
+ # load metadata files
46
+
47
+ vid_metadata = load_json_file(vid_metadata_path)
48
+
49
+ vid_subs = [vid['transcript'] for vid in vid_metadata]
50
+ vid_img_path = [vid['extracted_frame_path'] for vid in vid_metadata]
51
+
52
+ # for video1, we pick n = 7
53
+ n = 7
54
+ updated_vid_subs = [
55
+ ' '.join(vid_subs[i-int(n/2): i+int(n/2)]) if i-int(n/2) >= 0 else
56
+ ' '.join(vid_subs[0: i + int(n/2)]) for i in range(len(vid_subs))
57
+ ]
58
+
59
+ # also need to update the updated transcripts in metadata
60
+ for i in range(len(updated_vid_subs)):
61
+ vid_metadata[i]['transcript'] = updated_vid_subs[i]
62
+
63
+ # you can pass in mode="append"
64
+ # to add more entries to the vector store
65
+ # in case you want to start with a fresh vector store,
66
+ # you can pass in mode="overwrite" instead
67
+
68
+ print("Creating vid_table_name ", vid_table_name)
69
+ _ = MultimodalLanceDB.from_text_image_pairs(
70
+ texts=updated_vid_subs,
71
+ image_paths=vid_img_path,
72
+ embedding=embedder,
73
+ metadatas=vid_metadata,
74
+ connection=db,
75
+ table_name=vid_table_name,
76
+ mode="overwrite",
77
+ )
78
+ open_table(vid_table_name)
79
+
80
+ return vid_table_name
81
+
82
+
83
+ def get_metadata_of_yt_video_with_captions(vid_url, from_gen=False):
84
+ vid_filepath, vid_folder_path, is_downloaded = download_video(
85
+ vid_url, base_dir)
86
+ if is_downloaded:
87
+ print("Video downloaded at ", vid_filepath)
88
+ if from_gen:
89
+ # Delete existing caption and metadata files if they exist
90
+ caption_file = f"{vid_folder_path}/captions.vtt"
91
+ metadata_file = f"{vid_folder_path}/metadatas.json"
92
+ if os.path.exists(caption_file):
93
+ os.remove(caption_file)
94
+ print(f"Deleted existing caption file: {caption_file}")
95
+ if os.path.exists(metadata_file):
96
+ os.remove(metadata_file)
97
+ print(f"Deleted existing metadata file: {metadata_file}")
98
+
99
+ print("checking transcript")
100
+ vid_transcript_filepath = get_transcript_vtt(
101
+ vid_folder_path, vid_url, vid_filepath, from_gen)
102
+ vid_metadata_path = f"{vid_folder_path}/metadatas.json"
103
+ print("checking metadatas at", vid_metadata_path)
104
+ if os.path.exists(vid_metadata_path):
105
+ print('Metadatas already exists')
106
+ else:
107
+ print("Downloading metadatas for the video ", vid_filepath)
108
+ # should return lowercase file name without spaces
109
+ extract_meta_data(vid_folder_path, vid_filepath,
110
+ vid_transcript_filepath)
111
+
112
+ parent_dir_name = os.path.basename(os.path.dirname(vid_metadata_path))
113
+ vid_table_name = f"{parent_dir_name}_table"
114
+ print("Checking db and Table name ", vid_table_name)
115
+ if not check_if_table_exists(vid_table_name):
116
+ print("Table does not exists Storing in RAG")
117
+ else:
118
+ print("Table exists")
119
+
120
+ def delete_table(table_name):
121
+ db.drop_table(table_name)
122
+ print(f"Deleted table {table_name}")
123
+ delete_table(vid_table_name)
124
+
125
+ store_in_rag(vid_table_name, vid_metadata_path)
126
+ return vid_filepath, vid_table_name
127
+
128
+
129
+ def return_top_k_most_similar_docs(vid_table_name, query, use_llm=False):
130
+ if not video_processed:
131
+ gr.Error("Please process the video first in Step 1")
132
+ # Initialize results variable outside the if condition
133
+ max_docs = 2
134
+ print("Querying ", vid_table_name)
135
+ vectorstore = MultimodalLanceDB(
136
+ uri=LANCEDB_HOST_FILE,
137
+ embedding=embedder,
138
+ table_name=vid_table_name
139
+ )
140
+
141
+ retriever = vectorstore.as_retriever(
142
+ search_type='similarity',
143
+ search_kwargs={"k": max_docs}
144
+ )
145
+
146
+ # Get results first
147
+ results = retriever.invoke(query)
148
+
149
+ if use_llm:
150
+ # Read captions.vtt file
151
+ def read_vtt_file(file_path):
152
+ with open(file_path, 'r', encoding='utf-8') as f:
153
+ return f.read()
154
+
155
+ vid_table_name = vid_table_name.split('_table')[0]
156
+ caption_file = 'shared_data/videos/yt_video/' + vid_table_name + '/captions.vtt'
157
+ print("Caption file path ", caption_file)
158
+ captions = read_vtt_file(caption_file)
159
+ prompt = "Answer this query : " + query + " from the content " + captions
160
+ print("Prompt ", prompt)
161
+ all_page_content = lvlm_inference_with_phi(prompt)
162
+ else:
163
+ all_page_content = "\n\n".join(
164
+ [result.page_content for result in results])
165
+
166
+ page_content = gr.Textbox(all_page_content, label="Response",
167
+ elem_id='chat-response', visible=True, interactive=False)
168
+ image1 = Image.open(results[0].metadata['extracted_frame_path'])
169
+ image2_path = results[1].metadata['extracted_frame_path']
170
+
171
+ if results[0].metadata['extracted_frame_path'] == image2_path:
172
+ image2 = gr.update(visible=False)
173
+ else:
174
+ image2 = Image.open(image2_path)
175
+ image2 = gr.update(value=image2, visible=True)
176
+
177
+ return page_content, image1, image2
178
+
179
+
180
+ def process_url_and_init(youtube_url, from_gen=False):
181
+ video_processed = True
182
+ url_input = gr.update(visible=False)
183
+ submit_btn = gr.update(visible=True)
184
+ chatbox = gr.update(visible=True)
185
+ submit_btn2 = gr.update(visible=True)
186
+ frame1 = gr.update(visible=True)
187
+ frame2 = gr.update(visible=False)
188
+ chatbox_llm, submit_btn_chat = gr.update(
189
+ visible=True), gr.update(visible=True)
190
+ vid_filepath, vid_table_name = get_metadata_of_yt_video_with_captions(
191
+ youtube_url, from_gen)
192
+ video = gr.Video(vid_filepath, render=True)
193
+ return url_input, submit_btn, video, vid_table_name, chatbox, submit_btn2, frame1, frame2, chatbox_llm, submit_btn_chat
194
+
195
+
196
+ def test_btn():
197
+ text = "hi"
198
+ res = lvlm_inference_with_phi(text)
199
+ response = gr.Textbox(res, visible=True, interactive=False)
200
+ return response
201
+
202
+
203
+ def init_ui():
204
+ with gr.Blocks() as demo:
205
+
206
+ gr.Markdown("Welcome to video chat demo - Initial processing can take up to 2 minutes, and responses may be slow. Please be patient and avoid clicking repeatedly.")
207
+ url_input = gr.Textbox(label="Enter YouTube URL", visible=False, elem_id='url-inp',
208
+ value="https://www.youtube.com/watch?v=kOEDG3j1bjs", interactive=True)
209
+ vid_table_name = gr.Textbox(
210
+ label="Enter Table Name", visible=False, interactive=False)
211
+ video = gr.Video()
212
+ with gr.Row():
213
+ submit_btn = gr.Button("Process Video By Download Subtitles")
214
+ submit_btn_gen = gr.Button("Process Video By Generating Subtitles")
215
+
216
+ with gr.Row():
217
+ chatbox = gr.Textbox(label="Enter the keyword/s and AI will get related captions and images",
218
+ visible=False, value="event horizan", scale=4)
219
+ submit_btn_whisper = gr.Button(
220
+ "Submit", elem_id='chat-submit', visible=False, scale=1)
221
+ with gr.Row():
222
+ chatbox_llm = gr.Textbox(
223
+ label="Ask a Question", visible=False, value="what this video is about?", scale=4)
224
+ submit_btn_chat = gr.Button("Ask", visible=False, scale=1)
225
+
226
+ response = gr.Textbox(
227
+ label="Response", elem_id='chat-response', visible=False, interactive=False)
228
+
229
+ with gr.Row():
230
+ frame1 = gr.Image(visible=False, interactive=False, scale=2)
231
+ frame2 = gr.Image(visible=False, interactive=False, scale=2)
232
+ submit_btn.click(fn=process_url_and_init, inputs=[url_input], outputs=[
233
+ url_input, submit_btn, video, vid_table_name, chatbox, submit_btn_whisper, frame1, frame2, chatbox_llm, submit_btn_chat])
234
+ submit_btn_gen.click(fn=lambda x: process_url_and_init(x, from_gen=True), inputs=[url_input], outputs=[
235
+ url_input, submit_btn, video, vid_table_name, chatbox, submit_btn_whisper, frame1, frame2, chatbox_llm, submit_btn_chat])
236
+ submit_btn_whisper.click(fn=return_top_k_most_similar_docs, inputs=[
237
+ vid_table_name, chatbox], outputs=[response, frame1, frame2])
238
+
239
+ submit_btn_chat.click(
240
+ fn=lambda table_name, query: return_top_k_most_similar_docs(
241
+ vid_table_name=table_name,
242
+ query=query,
243
+ use_llm=True
244
+ ),
245
+ inputs=[vid_table_name, chatbox_llm],
246
+ outputs=[response, frame1, frame2]
247
+ )
248
+ reset_btn = gr.Button("Reload Page")
249
+ reset_btn.click(None, js="() => { location.reload(); }")
250
+
251
+ test_llama = gr.Button("Test Llama")
252
+ test_llama.click(test_btn, None, outputs=[response])
253
+ return demo
254
+
255
+
256
+ def init_improved_ui():
257
+
258
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
259
+ # Header Section with Introduction
260
+ with gr.Accordion(label=" # 🎬 Video Analysis Assistant", open=True):
261
+ gr.Markdown("""
262
+ ## How it Works:
263
+ 1. 📥 Provide a YouTube URL.
264
+ 2. 🔄 Choose a processing method:
265
+ - Download the video and its captions/subtitles from YouTube.
266
+ - Download the video and generate captions using Whisper AI.
267
+ The system will load the video in video player for preview and process the video and extract frames from it.
268
+ It will then pass the captions and images to the RAG model to store them in the database.
269
+ The RAG (Lance DB) uses a pre-trained BridgeTower model to generate embeddings that provide pairs of captions and related images.
270
+ 3. 🤖 Analyze video content through:
271
+ - Keyword Search - Use this functionality to search for keywords in the video. Our RAG model will return the most relevant captions and images.
272
+ - AI-powered Q&A - Use this functionality to ask questions about the video content. Our system will use the Meta/LLaMA model to analyze the captions and images and provide detailed answers.
273
+ 4. 📊 Results will be displayed in the response section with related images.
274
+
275
+ > **Note**: Initial processing takes several minutes. Please be patient and monitor the logs for progress updates.
276
+ """)
277
+
278
+ # Video Input Section
279
+ with gr.Group():
280
+ url_input = gr.Textbox(
281
+ label="YouTube URL",
282
+ value="https://www.youtube.com/watch?v=kOEDG3j1bjs",
283
+ visible=True,
284
+ interactive=False
285
+ )
286
+ vid_table_name = gr.Textbox(label="Table Name", visible=False)
287
+ video = gr.Video(label="Video Preview")
288
+
289
+ with gr.Row():
290
+ submit_btn = gr.Button(
291
+ "📥 Step 1: Process with Existing Subtitles", variant="primary", size='md')
292
+ submit_btn_gen = gr.Button(
293
+ "🎯 Generate New Subtitles", variant="secondary", visible=False)
294
+
295
+ # Analysis Tools Section
296
+ with gr.Group():
297
+ gr.Markdown("### 🔍 Step 2: Chat AI about the video")
298
+
299
+ with gr.Row():
300
+ chatbox = gr.Textbox(
301
+ label="Step 2: Search Keywords",
302
+ value="event horizon, black holes, space",
303
+ visible=False
304
+ )
305
+ submit_btn_whisper = gr.Button(
306
+ "🔎 Search",
307
+ visible=False,
308
+ variant="primary"
309
+ )
310
+
311
+ with gr.Row():
312
+ chatbox_llm = gr.Textbox(
313
+ label="",
314
+ value="What is this video about?",
315
+ visible=True
316
+ )
317
+ submit_btn_chat = gr.Button(
318
+ "🤖 Ask",
319
+ visible=True,
320
+ scale=1
321
+ )
322
+
323
+ # Results Display Section
324
+ with gr.Group():
325
+ gr.Markdown("### 📊 AI Response")
326
+ response = gr.Textbox(
327
+ label="AI Response",
328
+ visible=True,
329
+ interactive=False
330
+ )
331
+
332
+ with gr.Row():
333
+ frame1 = gr.Image(
334
+ visible=False, label="Related Frame 1", scale=1)
335
+ frame2 = gr.Image(
336
+ visible=False, label="Related Frame 2", scale=2)
337
+
338
+ # Control Buttons
339
+ with gr.Row():
340
+ reset_btn = gr.Button("🔄 Start Over", variant="secondary")
341
+ test_llama = gr.Button("🧪 Say Hi to Llama",
342
+ visible=False, variant="secondary")
343
+
344
+ # Event Handlers
345
+ submit_btn.click(
346
+ fn=process_url_and_init,
347
+ inputs=[url_input],
348
+ outputs=[url_input, submit_btn, video, vid_table_name,
349
+ chatbox, submit_btn_whisper, frame1, frame2,
350
+ chatbox_llm, submit_btn_chat]
351
+ )
352
+
353
+ submit_btn_gen.click(
354
+ fn=lambda x: process_url_and_init(x, from_gen=True),
355
+ inputs=[url_input],
356
+ outputs=[url_input, submit_btn, video, vid_table_name,
357
+ chatbox, submit_btn_whisper, frame1, frame2,
358
+ chatbox_llm, submit_btn_chat]
359
+ )
360
+
361
+ submit_btn_whisper.click(
362
+ fn=return_top_k_most_similar_docs,
363
+ inputs=[vid_table_name, chatbox],
364
+ outputs=[response, frame1, frame2]
365
+ )
366
+
367
+ submit_btn_chat.click(
368
+ fn=lambda table_name, query: return_top_k_most_similar_docs(
369
+ vid_table_name=table_name,
370
+ query=query,
371
+ use_llm=True
372
+ ),
373
+ inputs=[vid_table_name, chatbox_llm],
374
+ outputs=[response, frame1, frame2]
375
+ )
376
+
377
+ reset_btn.click(None, js="() => { location.reload(); }")
378
+ test_llama.click(test_btn, None, outputs=[response])
379
+
380
+ return demo
381
+
382
+
383
+ if __name__ == '__main__':
384
+ demo = init_improved_ui() # Updated function name here
385
+ demo.launch(share=True, debug=True)
gradio_utils.py CHANGED
@@ -1,483 +1,483 @@
1
- import gradio as gr
2
- import io
3
- import sys
4
- import time
5
- import dataclasses
6
- from pathlib import Path
7
- import os
8
- from enum import auto, Enum
9
- from typing import List, Tuple, Any
10
- from utility import prediction_guard_llava_conv
11
- import lancedb
12
- from utility import load_json_file
13
- from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings
14
- from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
15
- from mm_rag.MLM.client import PredictionGuardClient
16
- from mm_rag.MLM.lvlm import LVLM
17
- from PIL import Image
18
- from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
19
- from moviepy.video.io.VideoFileClip import VideoFileClip
20
- from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation
21
-
22
- server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
23
-
24
- # function to split video at a timestamp
25
- def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3):
26
- timestamp_in_sec = int(timestamp_in_ms / 1000)
27
- # create output_video_name folder if not exist:
28
- Path(output_video_path).mkdir(parents=True, exist_ok=True)
29
- output_video = os.path.join(output_video_path, output_video_name)
30
- with VideoFileClip(video_path) as video:
31
- duration = video.duration
32
- start_time = max(timestamp_in_sec - play_before_sec, 0)
33
- end_time = min(timestamp_in_sec + play_after_sec, duration)
34
- new = video.subclip(start_time, end_time)
35
- new.write_videofile(output_video, audio_codec='aac')
36
- return output_video
37
-
38
-
39
- prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
40
-
41
- # define default rag_chain
42
- def get_default_rag_chain():
43
- # declare host file
44
- LANCEDB_HOST_FILE = "./shared_data/.lancedb"
45
- # declare table name
46
- TBL_NAME = "demo_tbl"
47
-
48
- # initialize vectorstore
49
- db = lancedb.connect(LANCEDB_HOST_FILE)
50
-
51
- # initialize an BridgeTower embedder
52
- embedder = BridgeTowerEmbeddings()
53
-
54
- ## Creating a LanceDB vector store
55
- vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME)
56
- ### creating a retriever for the vector store
57
- retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1})
58
-
59
- # initialize a client as PredictionGuardClien
60
- client = PredictionGuardClient()
61
- # initialize LVLM with the given client
62
- lvlm_inference_module = LVLM(client=client)
63
-
64
- def prompt_processing(input):
65
- # get the retrieved results and user's query
66
- retrieved_results, user_query = input['retrieved_results'], input['user_query']
67
- # get the first retrieved result by default
68
- retrieved_result = retrieved_results[0]
69
- # prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
70
-
71
- # get all metadata of the retrieved video segment
72
- metadata_retrieved_video_segment = retrieved_result.metadata['metadata']
73
-
74
- # get the frame and the corresponding transcript, path to extracted frame, path to whole video, and time stamp of the retrieved video segment.
75
- transcript = metadata_retrieved_video_segment['transcript']
76
- frame_path = metadata_retrieved_video_segment['extracted_frame_path']
77
- return {
78
- 'prompt': prompt_template.format(transcript=transcript, user_query=user_query),
79
- 'image' : frame_path,
80
- 'metadata' : metadata_retrieved_video_segment,
81
- }
82
- # initialize prompt processing module as a Langchain RunnableLambda of function prompt_processing
83
- prompt_processing_module = RunnableLambda(prompt_processing)
84
-
85
- # the output of this new chain will be a dictionary
86
- mm_rag_chain_with_retrieved_image = (
87
- RunnableParallel({"retrieved_results": retriever_module ,
88
- "user_query": RunnablePassthrough()})
89
- | prompt_processing_module
90
- | RunnableParallel({'final_text_output': lvlm_inference_module,
91
- 'input_to_lvlm' : RunnablePassthrough()})
92
- )
93
- return mm_rag_chain_with_retrieved_image
94
-
95
- class SeparatorStyle(Enum):
96
- """Different separator style."""
97
- SINGLE = auto()
98
-
99
- @dataclasses.dataclass
100
- class GradioInstance:
101
- """A class that keeps all conversation history."""
102
- system: str
103
- roles: List[str]
104
- messages: List[List[str]]
105
- offset: int
106
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
107
- sep: str = "\n"
108
- sep2: str = None
109
- version: str = "Unknown"
110
- path_to_img: str = None
111
- video_title: str = None
112
- path_to_video: str = None
113
- caption: str = None
114
- mm_rag_chain: Any = None
115
-
116
- skip_next: bool = False
117
-
118
- def _template_caption(self):
119
- out = ""
120
- if self.caption is not None:
121
- out = f"The caption associated with the image is '{self.caption}'. "
122
- return out
123
-
124
- def get_prompt_for_rag(self):
125
- messages = self.messages
126
- assert len(messages) == 2, "length of current conversation should be 2"
127
- assert messages[1][1] is None, "the first response message of current conversation should be None"
128
- ret = messages[0][1]
129
- return ret
130
-
131
- def get_conversation_for_lvlm(self):
132
- pg_conv = prediction_guard_llava_conv.copy()
133
- image_path = self.path_to_img
134
- b64_img = encode_image(image_path)
135
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
136
- if msg is None:
137
- break
138
- if i == 0:
139
- pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img])
140
- elif i == len(self.messages[self.offset:]) - 2:
141
- pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)])
142
- else:
143
- pg_conv.append_message(role, [msg])
144
- return pg_conv
145
-
146
- def append_message(self, role, message):
147
- self.messages.append([role, message])
148
-
149
- def get_images(self, return_pil=False):
150
- images = []
151
- if self.path_to_img is not None:
152
- path_to_image = self.path_to_img
153
- images.append(path_to_image)
154
- return images
155
-
156
- def to_gradio_chatbot(self):
157
- ret = []
158
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
159
- if i % 2 == 0:
160
- if type(msg) is tuple:
161
- import base64
162
- from io import BytesIO
163
- msg, image, image_process_mode = msg
164
- max_hw, min_hw = max(image.size), min(image.size)
165
- aspect_ratio = max_hw / min_hw
166
- max_len, min_len = 800, 400
167
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
168
- longest_edge = int(shortest_edge * aspect_ratio)
169
- W, H = image.size
170
- if H > W:
171
- H, W = longest_edge, shortest_edge
172
- else:
173
- H, W = shortest_edge, longest_edge
174
- image = image.resize((W, H))
175
- buffered = BytesIO()
176
- image.save(buffered, format="JPEG")
177
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
178
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
179
- msg = img_str + msg.replace('<image>', '').strip()
180
- ret.append([msg, None])
181
- else:
182
- ret.append([msg, None])
183
- else:
184
- ret[-1][-1] = msg
185
- return ret
186
-
187
- def copy(self):
188
- return GradioInstance(
189
- system=self.system,
190
- roles=self.roles,
191
- messages=[[x, y] for x, y in self.messages],
192
- offset=self.offset,
193
- sep_style=self.sep_style,
194
- sep=self.sep,
195
- sep2=self.sep2,
196
- version=self.version,
197
- mm_rag_chain=self.mm_rag_chain,
198
- )
199
-
200
- def dict(self):
201
- return {
202
- "system": self.system,
203
- "roles": self.roles,
204
- "messages": self.messages,
205
- "offset": self.offset,
206
- "sep": self.sep,
207
- "sep2": self.sep2,
208
- "path_to_img": self.path_to_img,
209
- "video_title" : self.video_title,
210
- "path_to_video": self.path_to_video,
211
- "caption" : self.caption,
212
- }
213
- def get_path_to_subvideos(self):
214
- if self.video_title is not None and self.path_to_img is not None:
215
- info = video_helper_map[self.video_title]
216
- path = info['path']
217
- prefix = info['prefix']
218
- vid_index = self.path_to_img.split('/')[-1]
219
- vid_index = vid_index.split('_')[-1]
220
- vid_index = vid_index.replace('.jpg', '')
221
- ret = f"{prefix}{vid_index}.mp4"
222
- ret = os.path.join(path, ret)
223
- return ret
224
- elif self.path_to_video is not None:
225
- return self.path_to_video
226
- return None
227
-
228
- def get_gradio_instance(mm_rag_chain=None):
229
- if mm_rag_chain is None:
230
- mm_rag_chain = get_default_rag_chain()
231
-
232
- instance = GradioInstance(
233
- system="",
234
- roles=prediction_guard_llava_conv.roles,
235
- messages=[],
236
- offset=0,
237
- sep_style=SeparatorStyle.SINGLE,
238
- sep="\n",
239
- path_to_img=None,
240
- video_title=None,
241
- caption=None,
242
- mm_rag_chain=mm_rag_chain,
243
- )
244
- return instance
245
-
246
- gr.set_static_paths(paths=["./assets/"])
247
- theme = gr.themes.Base(
248
- primary_hue=gr.themes.Color(
249
- c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
250
- secondary_hue=gr.themes.Color(
251
- c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
252
- ).set(
253
- body_background_fill_dark='*primary_950',
254
- body_text_color_dark='*neutral_300',
255
- border_color_accent='*primary_700',
256
- border_color_accent_dark='*neutral_800',
257
- block_background_fill_dark='*primary_950',
258
- block_border_width='2px',
259
- block_border_width_dark='2px',
260
- button_primary_background_fill_dark='*primary_500',
261
- button_primary_border_color_dark='*primary_500'
262
- )
263
-
264
- css='''
265
- @font-face {
266
- font-family: IntelOne;
267
- src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf");
268
- }
269
- .gradio-container {background-color: #0a0c2b}
270
- table {
271
- border-collapse: collapse;
272
- border: none;
273
- }
274
- '''
275
-
276
- ## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
277
-
278
- # html_title = '''
279
- # <table style="bordercolor=#0a0c2b; border=0">
280
- # <tr style="height:150px; border:0">
281
- # <td style="border:0"><img src="/file=../assets/intel-labs.png" height="100" width="100"></td>
282
- # <td style="vertical-align:bottom; border:0">
283
- # <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
284
- # Multimodal RAG:
285
- # <br>
286
- # Chat with Videos
287
- # </p>
288
- # </td>
289
- # <td style="border:0"><img src="/file=../assets/gaudi.png" width="100" height="100"></td>
290
-
291
- # <td style="border:0"><img src="/file=../assets/IDC7.png" width="300" height="350"></td>
292
- # <td style="border:0"><img src="/file=../assets/prediction_guard3.png" width="120" height="120"></td>
293
- # </tr>
294
- # </table>
295
-
296
- # '''
297
-
298
- html_title = '''
299
- <table style="bordercolor=#0a0c2b; border=0">
300
- <tr style="height:150px; border:0">
301
- <td style="border:0"><img src="/file=./assets/header.png"></td>
302
- </tr>
303
- </table>
304
-
305
- '''
306
-
307
- #<td style="border:0"><img src="/file=../assets/xeon.png" width="100" height="100"></td>
308
- dropdown_list = [
309
- "What is the name of one of the astronauts?",
310
- "An astronaut's spacewalk",
311
- "What does the astronaut say?",
312
-
313
- ]
314
-
315
- no_change_btn = gr.Button()
316
- enable_btn = gr.Button(interactive=True)
317
- disable_btn = gr.Button(interactive=False)
318
-
319
- def clear_history(state, request: gr.Request):
320
- state = get_gradio_instance(state.mm_rag_chain)
321
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
322
-
323
- def add_text(state, text, request: gr.Request):
324
- if len(text) <= 0 :
325
- state.skip_next = True
326
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
327
-
328
- text = text[:1536] # Hard cut-off
329
-
330
- state.append_message(state.roles[0], text)
331
- state.append_message(state.roles[1], None)
332
- state.skip_next = False
333
- return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
334
-
335
- def http_bot(
336
- state, request: gr.Request
337
- ):
338
- start_tstamp = time.time()
339
-
340
- if state.skip_next:
341
- # This generate call is skipped due to invalid inputs
342
- path_to_sub_videos = state.get_path_to_subvideos()
343
- yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
344
- return
345
-
346
- if len(state.messages) == state.offset + 2:
347
- # First round of conversation
348
- new_state = get_gradio_instance(state.mm_rag_chain)
349
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
350
- new_state.append_message(new_state.roles[1], None)
351
- state = new_state
352
-
353
- all_images = state.get_images(return_pil=False)
354
-
355
- # Make requests
356
- is_very_first_query = True
357
- if len(all_images) == 0:
358
- # first query need to do RAG
359
- # Construct prompt
360
- prompt_or_conversation = state.get_prompt_for_rag()
361
- else:
362
- # subsequence queries, no need to do Retrieval
363
- is_very_first_query = False
364
- prompt_or_conversation = state.get_conversation_for_lvlm()
365
-
366
- if is_very_first_query:
367
- executor = state.mm_rag_chain
368
- else:
369
- executor = lvlm_inference_with_conversation
370
-
371
- state.messages[-1][-1] = "▌"
372
- path_to_sub_videos = state.get_path_to_subvideos()
373
- yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
374
-
375
- try:
376
- if is_very_first_query:
377
- # get response by invoke executor chain
378
- response = executor.invoke(prompt_or_conversation)
379
- message = response['final_text_output']
380
- if 'metadata' in response['input_to_lvlm']:
381
- metadata = response['input_to_lvlm']['metadata']
382
- if (state.path_to_img is None
383
- and 'input_to_lvlm' in response
384
- and 'image' in response['input_to_lvlm']
385
- ):
386
- state.path_to_img = response['input_to_lvlm']['image']
387
-
388
- if state.path_to_video is None and 'video_path' in metadata:
389
- video_path = metadata['video_path']
390
- mid_time_ms = metadata['mid_time_ms']
391
- splited_video_path = split_video(video_path, mid_time_ms)
392
- state.path_to_video = splited_video_path
393
-
394
- if state.caption is None and 'transcript' in metadata:
395
- state.caption = metadata['transcript']
396
- else:
397
- raise ValueError("Response's format is changed")
398
- else:
399
- # get the response message by directly call PredictionGuardAPI
400
- message = executor(prompt_or_conversation)
401
-
402
- except Exception as e:
403
- print(e)
404
- state.messages[-1][-1] = server_error_msg
405
- yield (state, state.to_gradio_chatbot(), None) + (
406
- enable_btn,
407
- )
408
- return
409
-
410
- state.messages[-1][-1] = message
411
- path_to_sub_videos = state.get_path_to_subvideos()
412
- # path_to_image = state.path_to_img
413
- # caption = state.caption
414
- # # print(path_to_sub_videos)
415
- # # print(path_to_image)
416
- # # print('caption: ', caption)
417
- yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
418
-
419
- finish_tstamp = time.time()
420
- return
421
-
422
- def get_demo(rag_chain=None):
423
- if rag_chain is None:
424
- rag_chain = get_default_rag_chain()
425
-
426
- with gr.Blocks(theme=theme, css=css) as demo:
427
- # gr.Markdown(description)
428
- instance = get_gradio_instance(rag_chain)
429
- state = gr.State(instance)
430
- demo.load(
431
- None,
432
- None,
433
- js="""
434
- () => {
435
- const params = new URLSearchParams(window.location.search);
436
- if (!params.has('__theme')) {
437
- params.set('__theme', 'dark');
438
- window.location.search = params.toString();
439
- }
440
- }""",
441
- )
442
- gr.HTML(value=html_title)
443
- with gr.Row():
444
- with gr.Column(scale=4):
445
- video = gr.Video(height=512, width=512, elem_id="video", interactive=False )
446
- with gr.Column(scale=7):
447
- chatbot = gr.Chatbot(
448
- elem_id="chatbot", label="Multimodal RAG Chatbot", height=512,
449
- )
450
- with gr.Row():
451
- with gr.Column(scale=8):
452
- # textbox.render()
453
- textbox = gr.Dropdown(
454
- dropdown_list,
455
- allow_custom_value=True,
456
- # show_label=False,
457
- # container=False,
458
- label="Query",
459
- info="Enter your query here or choose a sample from the dropdown list!"
460
- )
461
- with gr.Column(scale=1, min_width=50):
462
- submit_btn = gr.Button(
463
- value="Send", variant="primary", interactive=True
464
- )
465
- with gr.Row(elem_id="buttons") as button_row:
466
- clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
467
-
468
- btn_list = [clear_btn]
469
-
470
- clear_btn.click(
471
- clear_history, [state], [state, chatbot, textbox, video] + btn_list
472
- )
473
- submit_btn.click(
474
- add_text,
475
- [state, textbox],
476
- [state, chatbot, textbox,] + btn_list,
477
- ).then(
478
- http_bot,
479
- [state],
480
- [state, chatbot, video] + btn_list,
481
- )
482
- return demo
483
-
 
1
+ import gradio as gr
2
+ import io
3
+ import sys
4
+ import time
5
+ import dataclasses
6
+ from pathlib import Path
7
+ import os
8
+ from enum import auto, Enum
9
+ from typing import List, Tuple, Any
10
+ from utility import prediction_guard_llava_conv
11
+ import lancedb
12
+ from utility import load_json_file
13
+ from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings
14
+ from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
15
+ from mm_rag.MLM.client import PredictionGuardClient
16
+ from mm_rag.MLM.lvlm import LVLM
17
+ from PIL import Image
18
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
19
+ from moviepy.video.io.VideoFileClip import VideoFileClip
20
+ from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation
21
+
22
+ server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
23
+
24
+ # function to split video at a timestamp
25
+ def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3):
26
+ timestamp_in_sec = int(timestamp_in_ms / 1000)
27
+ # create output_video_name folder if not exist:
28
+ Path(output_video_path).mkdir(parents=True, exist_ok=True)
29
+ output_video = os.path.join(output_video_path, output_video_name)
30
+ with VideoFileClip(video_path) as video:
31
+ duration = video.duration
32
+ start_time = max(timestamp_in_sec - play_before_sec, 0)
33
+ end_time = min(timestamp_in_sec + play_after_sec, duration)
34
+ new = video.subclip(start_time, end_time)
35
+ new.write_videofile(output_video, audio_codec='aac')
36
+ return output_video
37
+
38
+
39
+ prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
40
+
41
+ # define default rag_chain
42
+ def get_default_rag_chain():
43
+ # declare host file
44
+ LANCEDB_HOST_FILE = "./shared_data/.lancedb"
45
+ # declare table name
46
+ TBL_NAME = "demo_tbl"
47
+
48
+ # initialize vectorstore
49
+ db = lancedb.connect(LANCEDB_HOST_FILE)
50
+
51
+ # initialize an BridgeTower embedder
52
+ embedder = BridgeTowerEmbeddings()
53
+
54
+ ## Creating a LanceDB vector store
55
+ vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME)
56
+ ### creating a retriever for the vector store
57
+ retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1})
58
+
59
+ # initialize a client as PredictionGuardClien
60
+ client = PredictionGuardClient()
61
+ # initialize LVLM with the given client
62
+ lvlm_inference_module = LVLM(client=client)
63
+
64
+ def prompt_processing(input):
65
+ # get the retrieved results and user's query
66
+ retrieved_results, user_query = input['retrieved_results'], input['user_query']
67
+ # get the first retrieved result by default
68
+ retrieved_result = retrieved_results[0]
69
+ # prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
70
+
71
+ # get all metadata of the retrieved video segment
72
+ metadata_retrieved_video_segment = retrieved_result.metadata['metadata']
73
+
74
+ # get the frame and the corresponding transcript, path to extracted frame, path to whole video, and time stamp of the retrieved video segment.
75
+ transcript = metadata_retrieved_video_segment['transcript']
76
+ frame_path = metadata_retrieved_video_segment['extracted_frame_path']
77
+ return {
78
+ 'prompt': prompt_template.format(transcript=transcript, user_query=user_query),
79
+ 'image' : frame_path,
80
+ 'metadata' : metadata_retrieved_video_segment,
81
+ }
82
+ # initialize prompt processing module as a Langchain RunnableLambda of function prompt_processing
83
+ prompt_processing_module = RunnableLambda(prompt_processing)
84
+
85
+ # the output of this new chain will be a dictionary
86
+ mm_rag_chain_with_retrieved_image = (
87
+ RunnableParallel({"retrieved_results": retriever_module ,
88
+ "user_query": RunnablePassthrough()})
89
+ | prompt_processing_module
90
+ | RunnableParallel({'final_text_output': lvlm_inference_module,
91
+ 'input_to_lvlm' : RunnablePassthrough()})
92
+ )
93
+ return mm_rag_chain_with_retrieved_image
94
+
95
+ class SeparatorStyle(Enum):
96
+ """Different separator style."""
97
+ SINGLE = auto()
98
+
99
+ @dataclasses.dataclass
100
+ class GradioInstance:
101
+ """A class that keeps all conversation history."""
102
+ system: str
103
+ roles: List[str]
104
+ messages: List[List[str]]
105
+ offset: int
106
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
107
+ sep: str = "\n"
108
+ sep2: str = None
109
+ version: str = "Unknown"
110
+ path_to_img: str = None
111
+ video_title: str = None
112
+ path_to_video: str = None
113
+ caption: str = None
114
+ mm_rag_chain: Any = None
115
+
116
+ skip_next: bool = False
117
+
118
+ def _template_caption(self):
119
+ out = ""
120
+ if self.caption is not None:
121
+ out = f"The caption associated with the image is '{self.caption}'. "
122
+ return out
123
+
124
+ def get_prompt_for_rag(self):
125
+ messages = self.messages
126
+ assert len(messages) == 2, "length of current conversation should be 2"
127
+ assert messages[1][1] is None, "the first response message of current conversation should be None"
128
+ ret = messages[0][1]
129
+ return ret
130
+
131
+ def get_conversation_for_lvlm(self):
132
+ pg_conv = prediction_guard_llava_conv.copy()
133
+ image_path = self.path_to_img
134
+ b64_img = encode_image(image_path)
135
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
136
+ if msg is None:
137
+ break
138
+ if i == 0:
139
+ pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img])
140
+ elif i == len(self.messages[self.offset:]) - 2:
141
+ pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)])
142
+ else:
143
+ pg_conv.append_message(role, [msg])
144
+ return pg_conv
145
+
146
+ def append_message(self, role, message):
147
+ self.messages.append([role, message])
148
+
149
+ def get_images(self, return_pil=False):
150
+ images = []
151
+ if self.path_to_img is not None:
152
+ path_to_image = self.path_to_img
153
+ images.append(path_to_image)
154
+ return images
155
+
156
+ def to_gradio_chatbot(self):
157
+ ret = []
158
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
159
+ if i % 2 == 0:
160
+ if type(msg) is tuple:
161
+ import base64
162
+ from io import BytesIO
163
+ msg, image, image_process_mode = msg
164
+ max_hw, min_hw = max(image.size), min(image.size)
165
+ aspect_ratio = max_hw / min_hw
166
+ max_len, min_len = 800, 400
167
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
168
+ longest_edge = int(shortest_edge * aspect_ratio)
169
+ W, H = image.size
170
+ if H > W:
171
+ H, W = longest_edge, shortest_edge
172
+ else:
173
+ H, W = shortest_edge, longest_edge
174
+ image = image.resize((W, H))
175
+ buffered = BytesIO()
176
+ image.save(buffered, format="JPEG")
177
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
178
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
179
+ msg = img_str + msg.replace('<image>', '').strip()
180
+ ret.append([msg, None])
181
+ else:
182
+ ret.append([msg, None])
183
+ else:
184
+ ret[-1][-1] = msg
185
+ return ret
186
+
187
+ def copy(self):
188
+ return GradioInstance(
189
+ system=self.system,
190
+ roles=self.roles,
191
+ messages=[[x, y] for x, y in self.messages],
192
+ offset=self.offset,
193
+ sep_style=self.sep_style,
194
+ sep=self.sep,
195
+ sep2=self.sep2,
196
+ version=self.version,
197
+ mm_rag_chain=self.mm_rag_chain,
198
+ )
199
+
200
+ def dict(self):
201
+ return {
202
+ "system": self.system,
203
+ "roles": self.roles,
204
+ "messages": self.messages,
205
+ "offset": self.offset,
206
+ "sep": self.sep,
207
+ "sep2": self.sep2,
208
+ "path_to_img": self.path_to_img,
209
+ "video_title" : self.video_title,
210
+ "path_to_video": self.path_to_video,
211
+ "caption" : self.caption,
212
+ }
213
+ def get_path_to_subvideos(self):
214
+ if self.video_title is not None and self.path_to_img is not None:
215
+ info = video_helper_map[self.video_title]
216
+ path = info['path']
217
+ prefix = info['prefix']
218
+ vid_index = self.path_to_img.split('/')[-1]
219
+ vid_index = vid_index.split('_')[-1]
220
+ vid_index = vid_index.replace('.jpg', '')
221
+ ret = f"{prefix}{vid_index}.mp4"
222
+ ret = os.path.join(path, ret)
223
+ return ret
224
+ elif self.path_to_video is not None:
225
+ return self.path_to_video
226
+ return None
227
+
228
+ def get_gradio_instance(mm_rag_chain=None):
229
+ if mm_rag_chain is None:
230
+ mm_rag_chain = get_default_rag_chain()
231
+
232
+ instance = GradioInstance(
233
+ system="",
234
+ roles=prediction_guard_llava_conv.roles,
235
+ messages=[],
236
+ offset=0,
237
+ sep_style=SeparatorStyle.SINGLE,
238
+ sep="\n",
239
+ path_to_img=None,
240
+ video_title=None,
241
+ caption=None,
242
+ mm_rag_chain=mm_rag_chain,
243
+ )
244
+ return instance
245
+
246
+ gr.set_static_paths(paths=["./assets/"])
247
+ theme = gr.themes.Base(
248
+ primary_hue=gr.themes.Color(
249
+ c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
250
+ secondary_hue=gr.themes.Color(
251
+ c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
252
+ ).set(
253
+ body_background_fill_dark='*primary_950',
254
+ body_text_color_dark='*neutral_300',
255
+ border_color_accent='*primary_700',
256
+ border_color_accent_dark='*neutral_800',
257
+ block_background_fill_dark='*primary_950',
258
+ block_border_width='2px',
259
+ block_border_width_dark='2px',
260
+ button_primary_background_fill_dark='*primary_500',
261
+ button_primary_border_color_dark='*primary_500'
262
+ )
263
+
264
+ css='''
265
+ @font-face {
266
+ font-family: IntelOne;
267
+ src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf");
268
+ }
269
+ .gradio-container {background-color: #0a0c2b}
270
+ table {
271
+ border-collapse: collapse;
272
+ border: none;
273
+ }
274
+ '''
275
+
276
+ ## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
277
+
278
+ # html_title = '''
279
+ # <table style="bordercolor=#0a0c2b; border=0">
280
+ # <tr style="height:150px; border:0">
281
+ # <td style="border:0"><img src="/file=../assets/intel-labs.png" height="100" width="100"></td>
282
+ # <td style="vertical-align:bottom; border:0">
283
+ # <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
284
+ # Multimodal RAG:
285
+ # <br>
286
+ # Chat with Videos
287
+ # </p>
288
+ # </td>
289
+ # <td style="border:0"><img src="/file=../assets/gaudi.png" width="100" height="100"></td>
290
+
291
+ # <td style="border:0"><img src="/file=../assets/IDC7.png" width="300" height="350"></td>
292
+ # <td style="border:0"><img src="/file=../assets/prediction_guard3.png" width="120" height="120"></td>
293
+ # </tr>
294
+ # </table>
295
+
296
+ # '''
297
+
298
+ html_title = '''
299
+ <table style="bordercolor=#0a0c2b; border=0">
300
+ <tr style="height:150px; border:0">
301
+ <td style="border:0"><img src="/file=./assets/header.png"></td>
302
+ </tr>
303
+ </table>
304
+
305
+ '''
306
+
307
+ #<td style="border:0"><img src="/file=../assets/xeon.png" width="100" height="100"></td>
308
+ dropdown_list = [
309
+ "What is the name of one of the astronauts?",
310
+ "An astronaut's spacewalk",
311
+ "What does the astronaut say?",
312
+
313
+ ]
314
+
315
+ no_change_btn = gr.Button()
316
+ enable_btn = gr.Button(interactive=True)
317
+ disable_btn = gr.Button(interactive=False)
318
+
319
+ def clear_history(state, request: gr.Request):
320
+ state = get_gradio_instance(state.mm_rag_chain)
321
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
322
+
323
+ def add_text(state, text, request: gr.Request):
324
+ if len(text) <= 0 :
325
+ state.skip_next = True
326
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
327
+
328
+ text = text[:1536] # Hard cut-off
329
+
330
+ state.append_message(state.roles[0], text)
331
+ state.append_message(state.roles[1], None)
332
+ state.skip_next = False
333
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
334
+
335
+ def http_bot(
336
+ state, request: gr.Request
337
+ ):
338
+ start_tstamp = time.time()
339
+
340
+ if state.skip_next:
341
+ # This generate call is skipped due to invalid inputs
342
+ path_to_sub_videos = state.get_path_to_subvideos()
343
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
344
+ return
345
+
346
+ if len(state.messages) == state.offset + 2:
347
+ # First round of conversation
348
+ new_state = get_gradio_instance(state.mm_rag_chain)
349
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
350
+ new_state.append_message(new_state.roles[1], None)
351
+ state = new_state
352
+
353
+ all_images = state.get_images(return_pil=False)
354
+
355
+ # Make requests
356
+ is_very_first_query = True
357
+ if len(all_images) == 0:
358
+ # first query need to do RAG
359
+ # Construct prompt
360
+ prompt_or_conversation = state.get_prompt_for_rag()
361
+ else:
362
+ # subsequence queries, no need to do Retrieval
363
+ is_very_first_query = False
364
+ prompt_or_conversation = state.get_conversation_for_lvlm()
365
+
366
+ if is_very_first_query:
367
+ executor = state.mm_rag_chain
368
+ else:
369
+ executor = lvlm_inference_with_conversation
370
+
371
+ state.messages[-1][-1] = "▌"
372
+ path_to_sub_videos = state.get_path_to_subvideos()
373
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
374
+
375
+ try:
376
+ if is_very_first_query:
377
+ # get response by invoke executor chain
378
+ response = executor.invoke(prompt_or_conversation)
379
+ message = response['final_text_output']
380
+ if 'metadata' in response['input_to_lvlm']:
381
+ metadata = response['input_to_lvlm']['metadata']
382
+ if (state.path_to_img is None
383
+ and 'input_to_lvlm' in response
384
+ and 'image' in response['input_to_lvlm']
385
+ ):
386
+ state.path_to_img = response['input_to_lvlm']['image']
387
+
388
+ if state.path_to_video is None and 'video_path' in metadata:
389
+ video_path = metadata['video_path']
390
+ mid_time_ms = metadata['mid_time_ms']
391
+ splited_video_path = split_video(video_path, mid_time_ms)
392
+ state.path_to_video = splited_video_path
393
+
394
+ if state.caption is None and 'transcript' in metadata:
395
+ state.caption = metadata['transcript']
396
+ else:
397
+ raise ValueError("Response's format is changed")
398
+ else:
399
+ # get the response message by directly call PredictionGuardAPI
400
+ message = executor(prompt_or_conversation)
401
+
402
+ except Exception as e:
403
+ print(e)
404
+ state.messages[-1][-1] = server_error_msg
405
+ yield (state, state.to_gradio_chatbot(), None) + (
406
+ enable_btn,
407
+ )
408
+ return
409
+
410
+ state.messages[-1][-1] = message
411
+ path_to_sub_videos = state.get_path_to_subvideos()
412
+ # path_to_image = state.path_to_img
413
+ # caption = state.caption
414
+ # # print(path_to_sub_videos)
415
+ # # print(path_to_image)
416
+ # # print('caption: ', caption)
417
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
418
+
419
+ finish_tstamp = time.time()
420
+ return
421
+
422
+ def get_demo(rag_chain=None):
423
+ if rag_chain is None:
424
+ rag_chain = get_default_rag_chain()
425
+
426
+ with gr.Blocks(theme=theme, css=css) as demo:
427
+ # gr.Markdown(description)
428
+ instance = get_gradio_instance(rag_chain)
429
+ state = gr.State(instance)
430
+ demo.load(
431
+ None,
432
+ None,
433
+ js="""
434
+ () => {
435
+ const params = new URLSearchParams(window.location.search);
436
+ if (!params.has('__theme')) {
437
+ params.set('__theme', 'dark');
438
+ window.location.search = params.toString();
439
+ }
440
+ }""",
441
+ )
442
+ gr.HTML(value=html_title)
443
+ with gr.Row():
444
+ with gr.Column(scale=4):
445
+ video = gr.Video(height=512, width=512, elem_id="video", interactive=False )
446
+ with gr.Column(scale=7):
447
+ chatbot = gr.Chatbot(
448
+ elem_id="chatbot", label="Multimodal RAG Chatbot", height=512,
449
+ )
450
+ with gr.Row():
451
+ with gr.Column(scale=8):
452
+ # textbox.render()
453
+ textbox = gr.Dropdown(
454
+ dropdown_list,
455
+ allow_custom_value=True,
456
+ # show_label=False,
457
+ # container=False,
458
+ label="Query",
459
+ info="Enter your query here or choose a sample from the dropdown list!"
460
+ )
461
+ with gr.Column(scale=1, min_width=50):
462
+ submit_btn = gr.Button(
463
+ value="Send", variant="primary", interactive=True
464
+ )
465
+ with gr.Row(elem_id="buttons") as button_row:
466
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
467
+
468
+ btn_list = [clear_btn]
469
+
470
+ clear_btn.click(
471
+ clear_history, [state], [state, chatbot, textbox, video] + btn_list
472
+ )
473
+ submit_btn.click(
474
+ add_text,
475
+ [state, textbox],
476
+ [state, chatbot, textbox,] + btn_list,
477
+ ).then(
478
+ http_bot,
479
+ [state],
480
+ [state, chatbot, video] + btn_list,
481
+ )
482
+ return demo
483
+
mm_rag/MLM/client.py CHANGED
@@ -1,135 +1,135 @@
1
- """Base interface for client making requests/call to visual language model provider API"""
2
-
3
- from abc import ABC, abstractmethod
4
- from typing import List, Optional, Dict, Union, Iterator
5
- import requests
6
- import json
7
- from utility import isBase64, encode_image, encode_image_from_path_or_url, lvlm_inference
8
-
9
- class BaseClient(ABC):
10
- def __init__(self,
11
- hostname: str = "127.0.0.1",
12
- port: int = 8090,
13
- timeout: int = 60,
14
- url: Optional[str] = None):
15
- self.connection_url = f"http://{hostname}:{port}" if url is None else url
16
- self.timeout = timeout
17
- # self.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
18
- self.headers = {'Content-Type': 'application/json'}
19
-
20
- def root(self):
21
- """Request for showing welcome message"""
22
- connection_route = f"{self.connection_url}/"
23
- return requests.get(connection_route)
24
-
25
- @abstractmethod
26
- def generate(self,
27
- prompt: str,
28
- image: str,
29
- **kwargs
30
- ) -> str:
31
- """Send request to visual language model API
32
- and return generated text that was returned by the visual language model API
33
-
34
- Use this method when you want to call visual language model API to generate text without streaming
35
-
36
- Args:
37
- prompt: A prompt.
38
- image: A string that can be either path to image or base64 of an image.
39
- **kwargs: Arbitrary additional keyword arguments.
40
- These are usually passed to the model provider API call as hyperparameter for generation.
41
-
42
- Returns:
43
- Text returned from visual language model provider API call
44
- """
45
-
46
-
47
- def generate_stream(
48
- self,
49
- prompt: str,
50
- image: str,
51
- **kwargs
52
- ) -> Iterator[str]:
53
- """Send request to visual language model API
54
- and return an iterator of streaming text that were returned from the visual language model API call
55
-
56
- Use this method when you want to call visual language model API to stream generated text.
57
-
58
- Args:
59
- prompt: A prompt.
60
- image: A string that can be either path to image or base64 of an image.
61
- **kwargs: Arbitrary additional keyword arguments.
62
- These are usually passed to the model provider API call as hyperparameter for generation.
63
-
64
- Returns:
65
- Iterator of text streamed from visual language model provider API call
66
- """
67
- raise NotImplementedError()
68
-
69
- def generate_batch(
70
- self,
71
- prompt: List[str],
72
- image: List[str],
73
- **kwargs
74
- ) -> List[str]:
75
- """Send a request to visual language model API for multi-batch generation
76
- and return a list of generated text that was returned by the visual language model API
77
-
78
- Use this method when you want to call visual language model API to multi-batch generate text.
79
- Multi-batch generation does not support streaming.
80
-
81
- Args:
82
- prompt: List of prompts.
83
- image: List of strings; each of which can be either path to image or base64 of an image.
84
- **kwargs: Arbitrary additional keyword arguments.
85
- These are usually passed to the model provider API call as hyperparameter for generation.
86
-
87
- Returns:
88
- List of texts returned from visual language model provider API call
89
- """
90
- raise NotImplementedError()
91
-
92
- class PredictionGuardClient(BaseClient):
93
-
94
- generate_kwargs = ['max_tokens',
95
- 'temperature',
96
- 'top_p',
97
- 'top_k']
98
-
99
- def filter_accepted_genkwargs(self, kwargs):
100
- gen_args = {}
101
- if "generate_kwargs" in kwargs and isinstance(kwargs["generate_kwargs"], dict):
102
- gen_args = {k:kwargs["generate_kwargs"][k]
103
- for k in self.generate_kwargs
104
- if k in kwargs["generate_kwargs"]}
105
- return gen_args
106
-
107
- def generate(self,
108
- prompt: str,
109
- image: str,
110
- **kwargs
111
- ) -> str:
112
- """Send request to PredictionGuard's API
113
- and return generated text that was returned by LLAVA model
114
-
115
- Use this method when you want to call LLAVA model API to generate text without streaming
116
-
117
- Args:
118
- prompt: A prompt.
119
- image: A string that can be either path/URL to image or base64 of an image.
120
- **kwargs: Arbitrary additional keyword arguments.
121
- These are usually passed to the model provider API call as hyperparameter for generation.
122
-
123
- Returns:
124
- Text returned from visual language model provider API call
125
- """
126
-
127
- assert image is not None and len(image) != "", "the input image cannot be None, it must be either base64-encoded image or path/URL to image"
128
- if isBase64(image):
129
- base64_image = image
130
- else: # this is path to image or URL to image
131
- base64_image = encode_image_from_path_or_url(image)
132
-
133
- args = self.filter_accepted_genkwargs(kwargs)
134
- return lvlm_inference(prompt=prompt, image=base64_image, **args)
135
 
 
1
+ """Base interface for client making requests/call to visual language model provider API"""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, Dict, Union, Iterator
5
+ import requests
6
+ import json
7
+ from utility import isBase64, encode_image, encode_image_from_path_or_url, lvlm_inference
8
+
9
+ class BaseClient(ABC):
10
+ def __init__(self,
11
+ hostname: str = "127.0.0.1",
12
+ port: int = 8090,
13
+ timeout: int = 60,
14
+ url: Optional[str] = None):
15
+ self.connection_url = f"http://{hostname}:{port}" if url is None else url
16
+ self.timeout = timeout
17
+ # self.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
18
+ self.headers = {'Content-Type': 'application/json'}
19
+
20
+ def root(self):
21
+ """Request for showing welcome message"""
22
+ connection_route = f"{self.connection_url}/"
23
+ return requests.get(connection_route)
24
+
25
+ @abstractmethod
26
+ def generate(self,
27
+ prompt: str,
28
+ image: str,
29
+ **kwargs
30
+ ) -> str:
31
+ """Send request to visual language model API
32
+ and return generated text that was returned by the visual language model API
33
+
34
+ Use this method when you want to call visual language model API to generate text without streaming
35
+
36
+ Args:
37
+ prompt: A prompt.
38
+ image: A string that can be either path to image or base64 of an image.
39
+ **kwargs: Arbitrary additional keyword arguments.
40
+ These are usually passed to the model provider API call as hyperparameter for generation.
41
+
42
+ Returns:
43
+ Text returned from visual language model provider API call
44
+ """
45
+
46
+
47
+ def generate_stream(
48
+ self,
49
+ prompt: str,
50
+ image: str,
51
+ **kwargs
52
+ ) -> Iterator[str]:
53
+ """Send request to visual language model API
54
+ and return an iterator of streaming text that were returned from the visual language model API call
55
+
56
+ Use this method when you want to call visual language model API to stream generated text.
57
+
58
+ Args:
59
+ prompt: A prompt.
60
+ image: A string that can be either path to image or base64 of an image.
61
+ **kwargs: Arbitrary additional keyword arguments.
62
+ These are usually passed to the model provider API call as hyperparameter for generation.
63
+
64
+ Returns:
65
+ Iterator of text streamed from visual language model provider API call
66
+ """
67
+ raise NotImplementedError()
68
+
69
+ def generate_batch(
70
+ self,
71
+ prompt: List[str],
72
+ image: List[str],
73
+ **kwargs
74
+ ) -> List[str]:
75
+ """Send a request to visual language model API for multi-batch generation
76
+ and return a list of generated text that was returned by the visual language model API
77
+
78
+ Use this method when you want to call visual language model API to multi-batch generate text.
79
+ Multi-batch generation does not support streaming.
80
+
81
+ Args:
82
+ prompt: List of prompts.
83
+ image: List of strings; each of which can be either path to image or base64 of an image.
84
+ **kwargs: Arbitrary additional keyword arguments.
85
+ These are usually passed to the model provider API call as hyperparameter for generation.
86
+
87
+ Returns:
88
+ List of texts returned from visual language model provider API call
89
+ """
90
+ raise NotImplementedError()
91
+
92
+ class PredictionGuardClient(BaseClient):
93
+
94
+ generate_kwargs = ['max_tokens',
95
+ 'temperature',
96
+ 'top_p',
97
+ 'top_k']
98
+
99
+ def filter_accepted_genkwargs(self, kwargs):
100
+ gen_args = {}
101
+ if "generate_kwargs" in kwargs and isinstance(kwargs["generate_kwargs"], dict):
102
+ gen_args = {k:kwargs["generate_kwargs"][k]
103
+ for k in self.generate_kwargs
104
+ if k in kwargs["generate_kwargs"]}
105
+ return gen_args
106
+
107
+ def generate(self,
108
+ prompt: str,
109
+ image: str,
110
+ **kwargs
111
+ ) -> str:
112
+ """Send request to PredictionGuard's API
113
+ and return generated text that was returned by LLAVA model
114
+
115
+ Use this method when you want to call LLAVA model API to generate text without streaming
116
+
117
+ Args:
118
+ prompt: A prompt.
119
+ image: A string that can be either path/URL to image or base64 of an image.
120
+ **kwargs: Arbitrary additional keyword arguments.
121
+ These are usually passed to the model provider API call as hyperparameter for generation.
122
+
123
+ Returns:
124
+ Text returned from visual language model provider API call
125
+ """
126
+
127
+ assert image is not None and len(image) != "", "the input image cannot be None, it must be either base64-encoded image or path/URL to image"
128
+ if isBase64(image):
129
+ base64_image = image
130
+ else: # this is path to image or URL to image
131
+ base64_image = encode_image_from_path_or_url(image)
132
+
133
+ args = self.filter_accepted_genkwargs(kwargs)
134
+ return lvlm_inference(prompt=prompt, image=base64_image, **args)
135
 
mm_rag/MLM/lvlm.py CHANGED
@@ -1,301 +1,301 @@
1
- from .client import PredictionGuardClient
2
- from langchain_core.language_models.llms import LLM
3
- from langchain_core.pydantic_v1 import Extra, root_validator
4
- from typing import Any, Optional, List, Dict, Iterator, AsyncIterator
5
- from langchain_core.callbacks import CallbackManagerForLLMRun
6
- from utility import get_from_dict_or_env, MultimodalModelInput
7
-
8
- from langchain_core.runnables import RunnableConfig, ensure_config
9
- from langchain_core.language_models.base import LanguageModelInput
10
- from langchain_core.prompt_values import StringPromptValue
11
- # from langchain_core.outputs import GenerationChunk, LLMResult
12
- from langchain_core.language_models.llms import BaseLLM
13
- from langchain_core.callbacks import (
14
- # CallbackManager,
15
- CallbackManagerForLLMRun,
16
- )
17
- # from langchain_core.load import dumpd
18
- from langchain_core.runnables.config import run_in_executor
19
-
20
- class LVLM(LLM):
21
- """This class extends LLM class for implementing a custom request to LVLM provider API"""
22
-
23
-
24
- client: Any = None #: :meta private:
25
- hostname: Optional[str] = None
26
- port: Optional[int] = None
27
- url: Optional[str] = None
28
- max_new_tokens: Optional[int] = 200
29
- temperature: Optional[float] = 0.6
30
- top_k: Optional[float] = 0
31
- stop: Optional[List[str]] = None
32
- ignore_eos: Optional[bool] = False
33
- do_sample: Optional[bool] = True
34
- lazy_mode: Optional[bool] = True
35
- hpu_graphs: Optional[bool] = True
36
-
37
- @root_validator()
38
- def validate_environment(cls, values: Dict) -> Dict:
39
- """Validate that the access token and python package exists in environment if needed"""
40
- if values['client'] is None:
41
- # check if url of API is provided
42
- url = get_from_dict_or_env(values, 'url', "VLM_URL", None)
43
- if url is None:
44
- hostname = get_from_dict_or_env(values, 'hostname', 'VLM_HOSTNAME', None)
45
- port = get_from_dict_or_env(values, 'port', 'VLM_PORT', None)
46
- if hostname is not None and port is not None:
47
- values['client'] = PredictionGuardClient(hostname=hostname, port=port)
48
- else:
49
- # using default hostname and port to create Client
50
- values['client'] = PredictionGuardClient()
51
- else:
52
- values['client'] = PredictionGuardClient(url=url)
53
- return values
54
-
55
- @property
56
- def _llm_type(self) -> str:
57
- """Return type of llm"""
58
- return "Large Vision Language Model"
59
-
60
- @property
61
- def _default_params(self) -> Dict[str, Any]:
62
- """Get the default parameters for calling the Prediction Guard API."""
63
- return {
64
- "max_tokens": self.max_new_tokens,
65
- "temperature": self.temperature,
66
- "top_k": self.top_k,
67
- "ignore_eos": self.ignore_eos,
68
- "do_sample": self.do_sample,
69
- "stop" : self.stop,
70
- }
71
-
72
- def get_params(self, **kwargs):
73
- params = self._default_params
74
- params.update(kwargs)
75
- return params
76
-
77
-
78
- def _call(
79
- self,
80
- prompt: str,
81
- image: str,
82
- stop: Optional[List[str]] = None,
83
- run_manager: Optional[CallbackManagerForLLMRun] = None,
84
- **kwargs: Any,
85
- ) -> str:
86
- """Run the VLM on the given input.
87
-
88
- Args:
89
- prompt: The prompt to generate from.
90
- image: This can be either path to image or base64 encode of the image.
91
- stop: Stop words to use when generating. Model output is cut off at the
92
- first occurrence of any of the stop substrings.
93
- If stop tokens are not supported consider raising NotImplementedError.
94
- Returns:
95
- The model output as a string. Actual completions DOES NOT include the prompt
96
- Example: TBD
97
- """
98
- params = {}
99
- if stop is not None:
100
- raise ValueError("stop kwargs are not permitted.")
101
- params['generate_kwargs'] = self.get_params(**kwargs)
102
- response = self.client.generate(prompt=prompt, image=image, **params)
103
- return response
104
-
105
- def _stream(
106
- self,
107
- prompt: str,
108
- image: str,
109
- stop: Optional[List[str]] = None,
110
- run_manager: Optional[CallbackManagerForLLMRun] = None,
111
- **kwargs: Any,
112
- ) -> Iterator[str]:
113
- """Stream the VLM on the given prompt and image.
114
-
115
- Args:
116
- prompt: The prompt to generate from.
117
- image: This can be either path to image or base64 encode of the image.
118
- stop: Stop words to use when generating. Model output is cut off at the
119
- first occurrence of any of the stop substrings.
120
- If stop tokens are not supported consider raising NotImplementedError.
121
- Returns:
122
- The model outputs an iterator of string. Actual completions DOES NOT include the prompt
123
- Example: TBD
124
- """
125
- params = {}
126
- params['generate_kwargs'] = self.get_params(**kwargs)
127
- for chunk in self.client.generate_stream(prompt=prompt, image=image, **params):
128
- yield chunk
129
-
130
- async def _astream(
131
- self,
132
- prompt: str,
133
- image: str,
134
- stop: Optional[List[str]] = None,
135
- run_manager: Optional[CallbackManagerForLLMRun] = None,
136
- **kwargs: Any,
137
- ) -> AsyncIterator[str]:
138
- """An async version of _stream method that stream the VLM on the given prompt and image.
139
-
140
- Args:
141
- prompt: The prompt to generate from.
142
- image: This can be either path to image or base64 encode of the image.
143
- stop: Stop words to use when generating. Model output is cut off at the
144
- first occurrence of any of the stop substrings.
145
- If stop tokens are not supported consider raising NotImplementedError.
146
- Returns:
147
- The model outputs an async iterator of string. Actual completions DOES NOT include the prompt
148
- Example: TBD
149
- """
150
- iterator = await run_in_executor(
151
- None,
152
- self._stream,
153
- prompt,
154
- image,
155
- stop,
156
- run_manager.get_sync() if run_manager else None,
157
- **kwargs,
158
- )
159
- done = object()
160
- while True:
161
- item = await run_in_executor(
162
- None,
163
- next,
164
- iterator,
165
- done, # type: ignore[call-arg, arg-type]
166
- )
167
- if item is done:
168
- break
169
- yield item # type: ignore[misc]
170
-
171
- def invoke(
172
- self,
173
- input: MultimodalModelInput,
174
- config: Optional[RunnableConfig] = None,
175
- *,
176
- stop: Optional[List[str]] = None,
177
- **kwargs: Any,
178
- ) -> str:
179
- config = ensure_config(config)
180
- if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
181
- return (
182
- self.generate_prompt(
183
- [self._convert_input(StringPromptValue(text=input['prompt']))],
184
- stop=stop,
185
- callbacks=config.get("callbacks"),
186
- tags=config.get("tags"),
187
- metadata=config.get("metadata"),
188
- run_name=config.get("run_name"),
189
- run_id=config.pop("run_id", None),
190
- image= input['image'],
191
- **kwargs,
192
- )
193
- .generations[0][0]
194
- .text
195
- )
196
- return (
197
- self.generate_prompt(
198
- [self._convert_input(input)],
199
- stop=stop,
200
- callbacks=config.get("callbacks"),
201
- tags=config.get("tags"),
202
- metadata=config.get("metadata"),
203
- run_name=config.get("run_name"),
204
- run_id=config.pop("run_id", None),
205
- **kwargs,
206
- )
207
- .generations[0][0]
208
- .text
209
- )
210
-
211
- async def ainvoke(
212
- self,
213
- input: MultimodalModelInput,
214
- config: Optional[RunnableConfig] = None,
215
- *,
216
- stop: Optional[List[str]] = None,
217
- **kwargs: Any,
218
- ) -> str:
219
- config = ensure_config(config)
220
- if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
221
- llm_result = await self.agenerate_prompt(
222
- [self._convert_input(StringPromptValue(text=input['prompt']))],
223
- stop=stop,
224
- callbacks=config.get("callbacks"),
225
- tags=config.get("tags"),
226
- metadata=config.get("metadata"),
227
- run_name=config.get("run_name"),
228
- run_id=config.pop("run_id", None),
229
- image=input['image'],
230
- **kwargs,
231
- )
232
- else:
233
- llm_result = await self.agenerate_prompt(
234
- [self._convert_input(input)],
235
- stop=stop,
236
- callbacks=config.get("callbacks"),
237
- tags=config.get("tags"),
238
- metadata=config.get("metadata"),
239
- run_name=config.get("run_name"),
240
- run_id=config.pop("run_id", None),
241
- **kwargs,
242
- )
243
- return llm_result.generations[0][0].text
244
-
245
- def stream(
246
- self,
247
- input: MultimodalModelInput,
248
- config: Optional[RunnableConfig] = None,
249
- *,
250
- stop: Optional[List[str]] = None,
251
- **kwargs: Any,
252
- ) -> Iterator[str]:
253
- if type(self)._stream == BaseLLM._stream:
254
- # model doesn't implement streaming, so use default implementation
255
- yield self.invoke(input, config=config, stop=stop, **kwargs)
256
- else:
257
- if stop is not None:
258
- raise ValueError("stop kwargs are not permitted.")
259
- image = None
260
- prompt = None
261
- if isinstance(input, dict) and 'prompt' in input.keys():
262
- prompt = self._convert_input(input['prompt']).to_string()
263
- else:
264
- raise ValueError("prompt must be provided")
265
- if isinstance(input, dict) and 'image' in input.keys():
266
- image = input['image']
267
-
268
- for chunk in self._stream(
269
- prompt=prompt, image=image, **kwargs
270
- ):
271
- yield chunk
272
-
273
- async def astream(
274
- self,
275
- input: LanguageModelInput,
276
- config: Optional[RunnableConfig] = None,
277
- *,
278
- stop: Optional[List[str]] = None,
279
- **kwargs: Any,
280
- ) -> AsyncIterator[str]:
281
- if (
282
- type(self)._astream is BaseLLM._astream
283
- and type(self)._stream is BaseLLM._stream
284
- ):
285
- yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
286
- return
287
- else:
288
- if stop is not None:
289
- raise ValueError("stop kwargs are not permitted.")
290
- image = None
291
- if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
292
- prompt = self._convert_input(input['prompt']).to_string()
293
- image = input['image']
294
- else:
295
- raise ValueError("missing image is not permitted")
296
- prompt = self._convert_input(input).to_string()
297
-
298
- async for chunk in self._astream(
299
- prompt=prompt, image=image, **kwargs
300
- ):
301
  yield chunk
 
1
+ from .client import PredictionGuardClient
2
+ from langchain_core.language_models.llms import LLM
3
+ from langchain_core.pydantic_v1 import Extra, root_validator
4
+ from typing import Any, Optional, List, Dict, Iterator, AsyncIterator
5
+ from langchain_core.callbacks import CallbackManagerForLLMRun
6
+ from utility import get_from_dict_or_env, MultimodalModelInput
7
+
8
+ from langchain_core.runnables import RunnableConfig, ensure_config
9
+ from langchain_core.language_models.base import LanguageModelInput
10
+ from langchain_core.prompt_values import StringPromptValue
11
+ # from langchain_core.outputs import GenerationChunk, LLMResult
12
+ from langchain_core.language_models.llms import BaseLLM
13
+ from langchain_core.callbacks import (
14
+ # CallbackManager,
15
+ CallbackManagerForLLMRun,
16
+ )
17
+ # from langchain_core.load import dumpd
18
+ from langchain_core.runnables.config import run_in_executor
19
+
20
+ class LVLM(LLM):
21
+ """This class extends LLM class for implementing a custom request to LVLM provider API"""
22
+
23
+
24
+ client: Any = None #: :meta private:
25
+ hostname: Optional[str] = None
26
+ port: Optional[int] = None
27
+ url: Optional[str] = None
28
+ max_new_tokens: Optional[int] = 200
29
+ temperature: Optional[float] = 0.6
30
+ top_k: Optional[float] = 0
31
+ stop: Optional[List[str]] = None
32
+ ignore_eos: Optional[bool] = False
33
+ do_sample: Optional[bool] = True
34
+ lazy_mode: Optional[bool] = True
35
+ hpu_graphs: Optional[bool] = True
36
+
37
+ @root_validator()
38
+ def validate_environment(cls, values: Dict) -> Dict:
39
+ """Validate that the access token and python package exists in environment if needed"""
40
+ if values['client'] is None:
41
+ # check if url of API is provided
42
+ url = get_from_dict_or_env(values, 'url', "VLM_URL", None)
43
+ if url is None:
44
+ hostname = get_from_dict_or_env(values, 'hostname', 'VLM_HOSTNAME', None)
45
+ port = get_from_dict_or_env(values, 'port', 'VLM_PORT', None)
46
+ if hostname is not None and port is not None:
47
+ values['client'] = PredictionGuardClient(hostname=hostname, port=port)
48
+ else:
49
+ # using default hostname and port to create Client
50
+ values['client'] = PredictionGuardClient()
51
+ else:
52
+ values['client'] = PredictionGuardClient(url=url)
53
+ return values
54
+
55
+ @property
56
+ def _llm_type(self) -> str:
57
+ """Return type of llm"""
58
+ return "Large Vision Language Model"
59
+
60
+ @property
61
+ def _default_params(self) -> Dict[str, Any]:
62
+ """Get the default parameters for calling the Prediction Guard API."""
63
+ return {
64
+ "max_tokens": self.max_new_tokens,
65
+ "temperature": self.temperature,
66
+ "top_k": self.top_k,
67
+ "ignore_eos": self.ignore_eos,
68
+ "do_sample": self.do_sample,
69
+ "stop" : self.stop,
70
+ }
71
+
72
+ def get_params(self, **kwargs):
73
+ params = self._default_params
74
+ params.update(kwargs)
75
+ return params
76
+
77
+
78
+ def _call(
79
+ self,
80
+ prompt: str,
81
+ image: str,
82
+ stop: Optional[List[str]] = None,
83
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
84
+ **kwargs: Any,
85
+ ) -> str:
86
+ """Run the VLM on the given input.
87
+
88
+ Args:
89
+ prompt: The prompt to generate from.
90
+ image: This can be either path to image or base64 encode of the image.
91
+ stop: Stop words to use when generating. Model output is cut off at the
92
+ first occurrence of any of the stop substrings.
93
+ If stop tokens are not supported consider raising NotImplementedError.
94
+ Returns:
95
+ The model output as a string. Actual completions DOES NOT include the prompt
96
+ Example: TBD
97
+ """
98
+ params = {}
99
+ if stop is not None:
100
+ raise ValueError("stop kwargs are not permitted.")
101
+ params['generate_kwargs'] = self.get_params(**kwargs)
102
+ response = self.client.generate(prompt=prompt, image=image, **params)
103
+ return response
104
+
105
+ def _stream(
106
+ self,
107
+ prompt: str,
108
+ image: str,
109
+ stop: Optional[List[str]] = None,
110
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
111
+ **kwargs: Any,
112
+ ) -> Iterator[str]:
113
+ """Stream the VLM on the given prompt and image.
114
+
115
+ Args:
116
+ prompt: The prompt to generate from.
117
+ image: This can be either path to image or base64 encode of the image.
118
+ stop: Stop words to use when generating. Model output is cut off at the
119
+ first occurrence of any of the stop substrings.
120
+ If stop tokens are not supported consider raising NotImplementedError.
121
+ Returns:
122
+ The model outputs an iterator of string. Actual completions DOES NOT include the prompt
123
+ Example: TBD
124
+ """
125
+ params = {}
126
+ params['generate_kwargs'] = self.get_params(**kwargs)
127
+ for chunk in self.client.generate_stream(prompt=prompt, image=image, **params):
128
+ yield chunk
129
+
130
+ async def _astream(
131
+ self,
132
+ prompt: str,
133
+ image: str,
134
+ stop: Optional[List[str]] = None,
135
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
136
+ **kwargs: Any,
137
+ ) -> AsyncIterator[str]:
138
+ """An async version of _stream method that stream the VLM on the given prompt and image.
139
+
140
+ Args:
141
+ prompt: The prompt to generate from.
142
+ image: This can be either path to image or base64 encode of the image.
143
+ stop: Stop words to use when generating. Model output is cut off at the
144
+ first occurrence of any of the stop substrings.
145
+ If stop tokens are not supported consider raising NotImplementedError.
146
+ Returns:
147
+ The model outputs an async iterator of string. Actual completions DOES NOT include the prompt
148
+ Example: TBD
149
+ """
150
+ iterator = await run_in_executor(
151
+ None,
152
+ self._stream,
153
+ prompt,
154
+ image,
155
+ stop,
156
+ run_manager.get_sync() if run_manager else None,
157
+ **kwargs,
158
+ )
159
+ done = object()
160
+ while True:
161
+ item = await run_in_executor(
162
+ None,
163
+ next,
164
+ iterator,
165
+ done, # type: ignore[call-arg, arg-type]
166
+ )
167
+ if item is done:
168
+ break
169
+ yield item # type: ignore[misc]
170
+
171
+ def invoke(
172
+ self,
173
+ input: MultimodalModelInput,
174
+ config: Optional[RunnableConfig] = None,
175
+ *,
176
+ stop: Optional[List[str]] = None,
177
+ **kwargs: Any,
178
+ ) -> str:
179
+ config = ensure_config(config)
180
+ if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
181
+ return (
182
+ self.generate_prompt(
183
+ [self._convert_input(StringPromptValue(text=input['prompt']))],
184
+ stop=stop,
185
+ callbacks=config.get("callbacks"),
186
+ tags=config.get("tags"),
187
+ metadata=config.get("metadata"),
188
+ run_name=config.get("run_name"),
189
+ run_id=config.pop("run_id", None),
190
+ image= input['image'],
191
+ **kwargs,
192
+ )
193
+ .generations[0][0]
194
+ .text
195
+ )
196
+ return (
197
+ self.generate_prompt(
198
+ [self._convert_input(input)],
199
+ stop=stop,
200
+ callbacks=config.get("callbacks"),
201
+ tags=config.get("tags"),
202
+ metadata=config.get("metadata"),
203
+ run_name=config.get("run_name"),
204
+ run_id=config.pop("run_id", None),
205
+ **kwargs,
206
+ )
207
+ .generations[0][0]
208
+ .text
209
+ )
210
+
211
+ async def ainvoke(
212
+ self,
213
+ input: MultimodalModelInput,
214
+ config: Optional[RunnableConfig] = None,
215
+ *,
216
+ stop: Optional[List[str]] = None,
217
+ **kwargs: Any,
218
+ ) -> str:
219
+ config = ensure_config(config)
220
+ if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
221
+ llm_result = await self.agenerate_prompt(
222
+ [self._convert_input(StringPromptValue(text=input['prompt']))],
223
+ stop=stop,
224
+ callbacks=config.get("callbacks"),
225
+ tags=config.get("tags"),
226
+ metadata=config.get("metadata"),
227
+ run_name=config.get("run_name"),
228
+ run_id=config.pop("run_id", None),
229
+ image=input['image'],
230
+ **kwargs,
231
+ )
232
+ else:
233
+ llm_result = await self.agenerate_prompt(
234
+ [self._convert_input(input)],
235
+ stop=stop,
236
+ callbacks=config.get("callbacks"),
237
+ tags=config.get("tags"),
238
+ metadata=config.get("metadata"),
239
+ run_name=config.get("run_name"),
240
+ run_id=config.pop("run_id", None),
241
+ **kwargs,
242
+ )
243
+ return llm_result.generations[0][0].text
244
+
245
+ def stream(
246
+ self,
247
+ input: MultimodalModelInput,
248
+ config: Optional[RunnableConfig] = None,
249
+ *,
250
+ stop: Optional[List[str]] = None,
251
+ **kwargs: Any,
252
+ ) -> Iterator[str]:
253
+ if type(self)._stream == BaseLLM._stream:
254
+ # model doesn't implement streaming, so use default implementation
255
+ yield self.invoke(input, config=config, stop=stop, **kwargs)
256
+ else:
257
+ if stop is not None:
258
+ raise ValueError("stop kwargs are not permitted.")
259
+ image = None
260
+ prompt = None
261
+ if isinstance(input, dict) and 'prompt' in input.keys():
262
+ prompt = self._convert_input(input['prompt']).to_string()
263
+ else:
264
+ raise ValueError("prompt must be provided")
265
+ if isinstance(input, dict) and 'image' in input.keys():
266
+ image = input['image']
267
+
268
+ for chunk in self._stream(
269
+ prompt=prompt, image=image, **kwargs
270
+ ):
271
+ yield chunk
272
+
273
+ async def astream(
274
+ self,
275
+ input: LanguageModelInput,
276
+ config: Optional[RunnableConfig] = None,
277
+ *,
278
+ stop: Optional[List[str]] = None,
279
+ **kwargs: Any,
280
+ ) -> AsyncIterator[str]:
281
+ if (
282
+ type(self)._astream is BaseLLM._astream
283
+ and type(self)._stream is BaseLLM._stream
284
+ ):
285
+ yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
286
+ return
287
+ else:
288
+ if stop is not None:
289
+ raise ValueError("stop kwargs are not permitted.")
290
+ image = None
291
+ if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
292
+ prompt = self._convert_input(input['prompt']).to_string()
293
+ image = input['image']
294
+ else:
295
+ raise ValueError("missing image is not permitted")
296
+ prompt = self._convert_input(input).to_string()
297
+
298
+ async for chunk in self._astream(
299
+ prompt=prompt, image=image, **kwargs
300
+ ):
301
  yield chunk
mm_rag/embeddings/bridgetower_embeddings.py CHANGED
@@ -1,89 +1,89 @@
1
- from typing import List
2
- from langchain_core.embeddings import Embeddings
3
- import torch
4
- from transformers import (
5
- BridgeTowerProcessor,
6
- BridgeTowerForContrastiveLearning
7
- )
8
- from langchain_core.pydantic_v1 import (
9
- BaseModel,
10
- )
11
- from lrn_vector_embeddings import bt_embeddings_from_local
12
- from utility import encode_image, bt_embedding_from_prediction_guard
13
- from tqdm import tqdm
14
- from PIL import Image
15
- class BridgeTowerEmbeddings(BaseModel, Embeddings):
16
- """ BridgeTower embedding model """
17
-
18
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
19
- """Embed a list of documents using BridgeTower.
20
-
21
- Args:
22
- texts: The list of texts to embed.
23
-
24
- Returns:
25
- List of embeddings, one for each text.
26
- """
27
-
28
- embeddings = []
29
- img = Image.new('RGB', (100, 100))
30
- for text in texts:
31
- embedding = bt_embeddings_from_local(text, img)
32
- embeddings.append(embedding)
33
- return embeddings
34
-
35
- def embed_query(self, text: str) -> List[float]:
36
- """Embed a query using BridgeTower.
37
-
38
- Args:
39
- text: The text to embed.
40
-
41
- Returns:
42
- Embeddings for the text as a flat list of floats.
43
- """
44
- # Get embeddings
45
- embeddings = self.embed_documents([text])[0]
46
-
47
- # If embeddings is a dict, extract the text embeddings
48
- if isinstance(embeddings, dict):
49
- embeddings = embeddings["text_embeddings"]
50
-
51
- # If embeddings is a nested list or tensor, flatten it
52
- if isinstance(embeddings, (list, torch.Tensor)) and len(embeddings) == 1:
53
- embeddings = embeddings[0]
54
-
55
- # Convert tensor to list if needed
56
- if torch.is_tensor(embeddings):
57
- embeddings = embeddings.detach().tolist()
58
-
59
- return embeddings
60
-
61
-
62
- def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
63
- """Embed a list of image-text pairs using BridgeTower.
64
-
65
- Args:
66
- texts: The list of texts to embed.
67
- images: The list of path-to-images to embed
68
- batch_size: the batch size to process, default to 2
69
- Returns:
70
- List of embeddings, one for each image-text pairs.
71
- """
72
-
73
- # the length of texts must be equal to the length of images
74
- assert len(texts)==len(images), "the len of captions should be equal to the len of images"
75
-
76
- processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
77
- model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
78
-
79
-
80
-
81
- embeddings = []
82
- for path_to_img, text in tqdm(zip(images, texts), total=len(texts)):
83
- inputs = processor(text=[text], images=[Image.open(path_to_img)], return_tensors="pt")
84
- outputs = model(**inputs)
85
- # Get embeddings and convert to list
86
- embedding = outputs.text_embeds.detach().numpy().tolist()[0]
87
- embeddings.append(embedding)
88
-
89
  return embeddings
 
1
+ from typing import List
2
+ from langchain_core.embeddings import Embeddings
3
+ import torch
4
+ from transformers import (
5
+ BridgeTowerProcessor,
6
+ BridgeTowerForContrastiveLearning
7
+ )
8
+ from langchain_core.pydantic_v1 import (
9
+ BaseModel,
10
+ )
11
+ from lrn_vector_embeddings import bt_embeddings_from_local
12
+ from utility import encode_image, bt_embedding_from_prediction_guard
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+ class BridgeTowerEmbeddings(BaseModel, Embeddings):
16
+ """ BridgeTower embedding model """
17
+
18
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
19
+ """Embed a list of documents using BridgeTower.
20
+
21
+ Args:
22
+ texts: The list of texts to embed.
23
+
24
+ Returns:
25
+ List of embeddings, one for each text.
26
+ """
27
+
28
+ embeddings = []
29
+ img = Image.new('RGB', (100, 100))
30
+ for text in texts:
31
+ embedding = bt_embeddings_from_local(text, img)
32
+ embeddings.append(embedding)
33
+ return embeddings
34
+
35
+ def embed_query(self, text: str) -> List[float]:
36
+ """Embed a query using BridgeTower.
37
+
38
+ Args:
39
+ text: The text to embed.
40
+
41
+ Returns:
42
+ Embeddings for the text as a flat list of floats.
43
+ """
44
+ # Get embeddings
45
+ embeddings = self.embed_documents([text])[0]
46
+
47
+ # If embeddings is a dict, extract the text embeddings
48
+ if isinstance(embeddings, dict):
49
+ embeddings = embeddings["text_embeddings"]
50
+
51
+ # If embeddings is a nested list or tensor, flatten it
52
+ if isinstance(embeddings, (list, torch.Tensor)) and len(embeddings) == 1:
53
+ embeddings = embeddings[0]
54
+
55
+ # Convert tensor to list if needed
56
+ if torch.is_tensor(embeddings):
57
+ embeddings = embeddings.detach().tolist()
58
+
59
+ return embeddings
60
+
61
+
62
+ def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
63
+ """Embed a list of image-text pairs using BridgeTower.
64
+
65
+ Args:
66
+ texts: The list of texts to embed.
67
+ images: The list of path-to-images to embed
68
+ batch_size: the batch size to process, default to 2
69
+ Returns:
70
+ List of embeddings, one for each image-text pairs.
71
+ """
72
+
73
+ # the length of texts must be equal to the length of images
74
+ assert len(texts)==len(images), "the len of captions should be equal to the len of images"
75
+
76
+ processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
77
+ model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
78
+
79
+
80
+
81
+ embeddings = []
82
+ for path_to_img, text in tqdm(zip(images, texts), total=len(texts)):
83
+ inputs = processor(text=[text], images=[Image.open(path_to_img)], return_tensors="pt")
84
+ outputs = model(**inputs)
85
+ # Get embeddings and convert to list
86
+ embedding = outputs.text_embeds.detach().numpy().tolist()[0]
87
+ embeddings.append(embedding)
88
+
89
  return embeddings
mm_rag/vectorstores/multimodal_lancedb.py CHANGED
@@ -1,131 +1,131 @@
1
- from typing import Any, Iterable, List, Optional
2
- from langchain_core.embeddings import Embeddings
3
- import uuid
4
- from langchain_community.vectorstores.lancedb import LanceDB
5
-
6
- class MultimodalLanceDB(LanceDB):
7
- """`LanceDB` vector store to process multimodal data
8
-
9
- To use, you should have ``lancedb`` python package installed.
10
- You can install it with ``pip install lancedb``.
11
-
12
- Args:
13
- connection: LanceDB connection to use. If not provided, a new connection
14
- will be created.
15
- embedding: Embedding to use for the vectorstore.
16
- vector_key: Key to use for the vector in the database. Defaults to ``vector``.
17
- id_key: Key to use for the id in the database. Defaults to ``id``.
18
- text_key: Key to use for the text in the database. Defaults to ``text``.
19
- image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.
20
- table_name: Name of the table to use. Defaults to ``vectorstore``.
21
- api_key: API key to use for LanceDB cloud database.
22
- region: Region to use for LanceDB cloud database.
23
- mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
24
-
25
-
26
-
27
- Example:
28
- .. code-block:: python
29
- vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)
30
- vectorstore.add_texts(['text1', 'text2'])
31
- result = vectorstore.similarity_search('text1')
32
- """
33
-
34
- def __init__(
35
- self,
36
- connection: Optional[Any] = None,
37
- embedding: Optional[Embeddings] = None,
38
- uri: Optional[str] = "/tmp/lancedb",
39
- vector_key: Optional[str] = "vector",
40
- id_key: Optional[str] = "id",
41
- text_key: Optional[str] = "text",
42
- image_path_key: Optional[str] = "image_path",
43
- table_name: Optional[str] = "vectorstore",
44
- api_key: Optional[str] = None,
45
- region: Optional[str] = None,
46
- mode: Optional[str] = "append",
47
- ):
48
- super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
49
- self._image_path_key = image_path_key
50
-
51
- def add_text_image_pairs(
52
- self,
53
- texts: Iterable[str],
54
- image_paths: Iterable[str],
55
- metadatas: Optional[List[dict]] = None,
56
- ids: Optional[List[str]] = None,
57
- **kwargs: Any,
58
- ) -> List[str]:
59
- """Turn text-image pairs into embedding and add it to the database
60
-
61
- Args:
62
- texts: Iterable of strings to combine with corresponding images to add to the vectorstore.
63
- images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.
64
- metadatas: Optional list of metadatas associated with the texts.
65
- ids: Optional list of ids to associate w ith the texts.
66
-
67
- Returns:
68
- List of ids of the added text-image pairs.
69
- """
70
- # the length of texts must be equal to the length of images
71
- assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"
72
-
73
- # Embed texts and create documents
74
- docs = []
75
- ids = ids or [str(uuid.uuid4()) for _ in texts]
76
- embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) # type: ignore
77
- for idx, text in enumerate(texts):
78
- embedding = embeddings[idx]
79
- metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
80
- docs.append(
81
- {
82
- self._vector_key: embedding,
83
- self._id_key: ids[idx],
84
- self._text_key: text,
85
- self._image_path_key : image_paths[idx],
86
- "metadata": metadata,
87
- }
88
- )
89
-
90
- if 'mode' in kwargs:
91
- mode = kwargs['mode']
92
- else:
93
- mode = self.mode
94
- if self._table_name in self._connection.table_names():
95
- tbl = self._connection.open_table(self._table_name)
96
- if self.api_key is None:
97
- tbl.add(docs, mode=mode)
98
- else:
99
- tbl.add(docs)
100
- else:
101
- self._connection.create_table(self._table_name, data=docs)
102
- return ids
103
-
104
- @classmethod
105
- def from_text_image_pairs(
106
- cls,
107
- texts: List[str],
108
- image_paths: List[str],
109
- embedding: Embeddings,
110
- metadatas: Optional[List[dict]] = None,
111
- connection: Any = None,
112
- vector_key: Optional[str] = "vector",
113
- id_key: Optional[str] = "id",
114
- text_key: Optional[str] = "text",
115
- image_path_key: Optional[str] = "image_path",
116
- table_name: Optional[str] = "vectorstore",
117
- **kwargs: Any,
118
- ):
119
-
120
- instance = MultimodalLanceDB(
121
- connection=connection,
122
- embedding=embedding,
123
- vector_key=vector_key,
124
- id_key=id_key,
125
- text_key=text_key,
126
- image_path_key=image_path_key,
127
- table_name=table_name,
128
- )
129
- instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)
130
-
131
  return instance
 
1
+ from typing import Any, Iterable, List, Optional
2
+ from langchain_core.embeddings import Embeddings
3
+ import uuid
4
+ from langchain_community.vectorstores.lancedb import LanceDB
5
+
6
+ class MultimodalLanceDB(LanceDB):
7
+ """`LanceDB` vector store to process multimodal data
8
+
9
+ To use, you should have ``lancedb`` python package installed.
10
+ You can install it with ``pip install lancedb``.
11
+
12
+ Args:
13
+ connection: LanceDB connection to use. If not provided, a new connection
14
+ will be created.
15
+ embedding: Embedding to use for the vectorstore.
16
+ vector_key: Key to use for the vector in the database. Defaults to ``vector``.
17
+ id_key: Key to use for the id in the database. Defaults to ``id``.
18
+ text_key: Key to use for the text in the database. Defaults to ``text``.
19
+ image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.
20
+ table_name: Name of the table to use. Defaults to ``vectorstore``.
21
+ api_key: API key to use for LanceDB cloud database.
22
+ region: Region to use for LanceDB cloud database.
23
+ mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
24
+
25
+
26
+
27
+ Example:
28
+ .. code-block:: python
29
+ vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)
30
+ vectorstore.add_texts(['text1', 'text2'])
31
+ result = vectorstore.similarity_search('text1')
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ connection: Optional[Any] = None,
37
+ embedding: Optional[Embeddings] = None,
38
+ uri: Optional[str] = "/tmp/lancedb",
39
+ vector_key: Optional[str] = "vector",
40
+ id_key: Optional[str] = "id",
41
+ text_key: Optional[str] = "text",
42
+ image_path_key: Optional[str] = "image_path",
43
+ table_name: Optional[str] = "vectorstore",
44
+ api_key: Optional[str] = None,
45
+ region: Optional[str] = None,
46
+ mode: Optional[str] = "append",
47
+ ):
48
+ super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
49
+ self._image_path_key = image_path_key
50
+
51
+ def add_text_image_pairs(
52
+ self,
53
+ texts: Iterable[str],
54
+ image_paths: Iterable[str],
55
+ metadatas: Optional[List[dict]] = None,
56
+ ids: Optional[List[str]] = None,
57
+ **kwargs: Any,
58
+ ) -> List[str]:
59
+ """Turn text-image pairs into embedding and add it to the database
60
+
61
+ Args:
62
+ texts: Iterable of strings to combine with corresponding images to add to the vectorstore.
63
+ images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.
64
+ metadatas: Optional list of metadatas associated with the texts.
65
+ ids: Optional list of ids to associate w ith the texts.
66
+
67
+ Returns:
68
+ List of ids of the added text-image pairs.
69
+ """
70
+ # the length of texts must be equal to the length of images
71
+ assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"
72
+
73
+ # Embed texts and create documents
74
+ docs = []
75
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
76
+ embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) # type: ignore
77
+ for idx, text in enumerate(texts):
78
+ embedding = embeddings[idx]
79
+ metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
80
+ docs.append(
81
+ {
82
+ self._vector_key: embedding,
83
+ self._id_key: ids[idx],
84
+ self._text_key: text,
85
+ self._image_path_key : image_paths[idx],
86
+ "metadata": metadata,
87
+ }
88
+ )
89
+
90
+ if 'mode' in kwargs:
91
+ mode = kwargs['mode']
92
+ else:
93
+ mode = self.mode
94
+ if self._table_name in self._connection.table_names():
95
+ tbl = self._connection.open_table(self._table_name)
96
+ if self.api_key is None:
97
+ tbl.add(docs, mode=mode)
98
+ else:
99
+ tbl.add(docs)
100
+ else:
101
+ self._connection.create_table(self._table_name, data=docs)
102
+ return ids
103
+
104
+ @classmethod
105
+ def from_text_image_pairs(
106
+ cls,
107
+ texts: List[str],
108
+ image_paths: List[str],
109
+ embedding: Embeddings,
110
+ metadatas: Optional[List[dict]] = None,
111
+ connection: Any = None,
112
+ vector_key: Optional[str] = "vector",
113
+ id_key: Optional[str] = "id",
114
+ text_key: Optional[str] = "text",
115
+ image_path_key: Optional[str] = "image_path",
116
+ table_name: Optional[str] = "vectorstore",
117
+ **kwargs: Any,
118
+ ):
119
+
120
+ instance = MultimodalLanceDB(
121
+ connection=connection,
122
+ embedding=embedding,
123
+ vector_key=vector_key,
124
+ id_key=id_key,
125
+ text_key=text_key,
126
+ image_path_key=image_path_key,
127
+ table_name=table_name,
128
+ )
129
+ instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)
130
+
131
  return instance
requirements.txt CHANGED
@@ -1,25 +1,25 @@
1
- gradio
2
- langchain-predictionguard
3
- IPython
4
- umap-learn
5
- pytubefix
6
- youtube_transcript_api
7
- torch
8
- transformers
9
- matplotlib
10
- seaborn
11
- datasets
12
- moviepy
13
- whisper
14
- webvtt-py
15
- tqdm
16
- lancedb
17
- langchain-core
18
- langchain-community
19
- ollama
20
- opencv-python
21
- openai-whisper
22
- huggingface_hub[cli]
23
- huggingface_hub
24
- pillow
25
- accelerate>=0.26.0
 
1
+ gradio
2
+ langchain-predictionguard
3
+ IPython
4
+ umap-learn
5
+ pytubefix
6
+ youtube_transcript_api
7
+ torch
8
+ transformers
9
+ matplotlib
10
+ seaborn
11
+ datasets
12
+ moviepy
13
+ whisper
14
+ webvtt-py
15
+ tqdm
16
+ lancedb
17
+ langchain-core
18
+ langchain-community
19
+ ollama
20
+ opencv-python
21
+ openai-whisper
22
+ huggingface_hub[cli]
23
+ huggingface_hub
24
+ pillow
25
+ accelerate>=0.26.0
s6_prepare_video_input.py CHANGED
@@ -1,90 +1,90 @@
1
- from pathlib import Path
2
- import os
3
- from os import path as osp
4
- import whisper
5
- from moviepy import VideoFileClip
6
- from PIL import Image
7
- from utility import download_video, extract_meta_data, get_transcript_vtt, getSubs
8
- from urllib.request import urlretrieve
9
- from IPython.display import display
10
- import ollama
11
-
12
- def demp_video_input_that_has_transcript():
13
- # first video's url
14
- vid_url = "https://www.youtube.com/watch?v=7Hcg-rLYwdM"
15
-
16
- # download Youtube video to ./shared_data/videos/video1
17
- vid_dir = "./shared_data/videos/video1"
18
- vid_filepath = download_video(vid_url, vid_dir)
19
-
20
- # download Youtube video's subtitle to ./shared_data/videos/video1
21
- vid_transcript_filepath = get_transcript_vtt(vid_url, vid_dir)
22
-
23
- return extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath)
24
-
25
- def demp_video_input_that_has_no_transcript():
26
- # second video's url
27
- vid_url=(
28
- "https://multimedia-commons.s3-us-west-2.amazonaws.com/"
29
- "data/videos/mp4/010/a07/010a074acb1975c4d6d6e43c1faeb8.mp4"
30
- )
31
- vid_dir = "./shared_data/videos/video2"
32
- vid_name = "toddler_in_playground.mp4"
33
-
34
- # create folder to which video2 will be downloaded
35
- Path(vid_dir).mkdir(parents=True, exist_ok=True)
36
- vid_filepath = urlretrieve(
37
- vid_url,
38
- osp.join(vid_dir, vid_name)
39
- )[0]
40
-
41
- path_to_video_no_transcript = vid_filepath
42
-
43
- # declare where to save .mp3 audio
44
- path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
45
-
46
- # extract mp3 audio file from mp4 video video file
47
- clip = VideoFileClip(path_to_video_no_transcript)
48
- clip.audio.write_audiofile(path_to_extracted_audio_file)
49
-
50
- model = whisper.load_model("small")
51
- options = dict(task="translate", best_of=1, language='en')
52
- results = model.transcribe(path_to_extracted_audio_file, **options)
53
-
54
- vtt = getSubs(results["segments"], "vtt")
55
-
56
- # path to save generated transcript of video1
57
- path_to_generated_trans = osp.join(vid_dir, 'generated_video1.vtt')
58
- # write transcription to file
59
- with open(path_to_generated_trans, 'w') as f:
60
- f.write(vtt)
61
-
62
- return extract_meta_data(vid_dir, vid_filepath, path_to_generated_trans)
63
-
64
-
65
-
66
- def ask_llvm(instruction, file_path):
67
- result = ollama.generate(
68
- model='llava',
69
- prompt=instruction,
70
- images=[file_path],
71
- stream=False
72
- )['response']
73
- img=Image.open(file_path, mode='r')
74
- img = img.resize([int(i/1.2) for i in img.size])
75
- display(img)
76
- for i in result.split('.'):
77
- print(i, end='', flush=True)
78
- if __name__ == "__main__":
79
- meta_data = demp_video_input_that_has_transcript()
80
-
81
- meta_data1 = demp_video_input_that_has_no_transcript()
82
- data = meta_data1[1]
83
- caption = data['transcript']
84
- print(f'Generated caption is: "{caption}"')
85
- frame = Image.open(data['extracted_frame_path'])
86
- display(frame)
87
- instruction = "Can you describe the image?"
88
- ask_llvm(instruction, data['extracted_frame_path'])
89
- #print(meta_data)
90
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from os import path as osp
4
+ import whisper
5
+ from moviepy import VideoFileClip
6
+ from PIL import Image
7
+ from utility import download_video, extract_meta_data, get_transcript_vtt, getSubs
8
+ from urllib.request import urlretrieve
9
+ from IPython.display import display
10
+ import ollama
11
+
12
+ def demp_video_input_that_has_transcript():
13
+ # first video's url
14
+ vid_url = "https://www.youtube.com/watch?v=7Hcg-rLYwdM"
15
+
16
+ # download Youtube video to ./shared_data/videos/video1
17
+ vid_dir = "./shared_data/videos/video1"
18
+ vid_filepath = download_video(vid_url, vid_dir)
19
+
20
+ # download Youtube video's subtitle to ./shared_data/videos/video1
21
+ vid_transcript_filepath = get_transcript_vtt(vid_url, vid_dir)
22
+
23
+ return extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath)
24
+
25
+ def demp_video_input_that_has_no_transcript():
26
+ # second video's url
27
+ vid_url=(
28
+ "https://multimedia-commons.s3-us-west-2.amazonaws.com/"
29
+ "data/videos/mp4/010/a07/010a074acb1975c4d6d6e43c1faeb8.mp4"
30
+ )
31
+ vid_dir = "./shared_data/videos/video2"
32
+ vid_name = "toddler_in_playground.mp4"
33
+
34
+ # create folder to which video2 will be downloaded
35
+ Path(vid_dir).mkdir(parents=True, exist_ok=True)
36
+ vid_filepath = urlretrieve(
37
+ vid_url,
38
+ osp.join(vid_dir, vid_name)
39
+ )[0]
40
+
41
+ path_to_video_no_transcript = vid_filepath
42
+
43
+ # declare where to save .mp3 audio
44
+ path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
45
+
46
+ # extract mp3 audio file from mp4 video video file
47
+ clip = VideoFileClip(path_to_video_no_transcript)
48
+ clip.audio.write_audiofile(path_to_extracted_audio_file)
49
+
50
+ model = whisper.load_model("small")
51
+ options = dict(task="translate", best_of=1, language='en')
52
+ results = model.transcribe(path_to_extracted_audio_file, **options)
53
+
54
+ vtt = getSubs(results["segments"], "vtt")
55
+
56
+ # path to save generated transcript of video1
57
+ path_to_generated_trans = osp.join(vid_dir, 'generated_video1.vtt')
58
+ # write transcription to file
59
+ with open(path_to_generated_trans, 'w') as f:
60
+ f.write(vtt)
61
+
62
+ return extract_meta_data(vid_dir, vid_filepath, path_to_generated_trans)
63
+
64
+
65
+
66
+ def ask_llvm(instruction, file_path):
67
+ result = ollama.generate(
68
+ model='llava',
69
+ prompt=instruction,
70
+ images=[file_path],
71
+ stream=False
72
+ )['response']
73
+ img=Image.open(file_path, mode='r')
74
+ img = img.resize([int(i/1.2) for i in img.size])
75
+ display(img)
76
+ for i in result.split('.'):
77
+ print(i, end='', flush=True)
78
+ if __name__ == "__main__":
79
+ meta_data = demp_video_input_that_has_transcript()
80
+
81
+ meta_data1 = demp_video_input_that_has_no_transcript()
82
+ data = meta_data1[1]
83
+ caption = data['transcript']
84
+ print(f'Generated caption is: "{caption}"')
85
+ frame = Image.open(data['extracted_frame_path'])
86
+ display(frame)
87
+ instruction = "Can you describe the image?"
88
+ ask_llvm(instruction, data['extracted_frame_path'])
89
+ #print(meta_data)
90
 
s7_store_in_rag.py CHANGED
@@ -1,105 +1,105 @@
1
- from mm_rag.embeddings.bridgetower_embeddings import (
2
- BridgeTowerEmbeddings
3
- )
4
- from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
5
- import lancedb
6
- import json
7
- import os
8
- from PIL import Image
9
- from utility import load_json_file, display_retrieved_results
10
- import pyarrow as pa
11
-
12
- # declare host file
13
- LANCEDB_HOST_FILE = "./shared_data/.lancedb"
14
- # declare table name
15
- TBL_NAME = "test_tbl"
16
- # initialize vectorstore
17
- db = lancedb.connect(LANCEDB_HOST_FILE)
18
- # initialize an BridgeTower embedder
19
- embedder = BridgeTowerEmbeddings()
20
-
21
-
22
- def return_top_k_most_similar_docs(max_docs=3):
23
- # ask to return top 3 most similar documents
24
- # Creating a LanceDB vector store
25
- vectorstore = MultimodalLanceDB(
26
- uri=LANCEDB_HOST_FILE,
27
- embedding=embedder,
28
- table_name=TBL_NAME)
29
-
30
- # creating a retriever for the vector store
31
- # search_type="similarity"
32
- # declares that the type of search that the Retriever should perform
33
- # is similarity search
34
- # search_kwargs={"k": 1} means returning top-1 most similar document
35
-
36
-
37
- retriever = vectorstore.as_retriever(
38
- search_type='similarity',
39
- search_kwargs={"k": max_docs})
40
- query2 = (
41
- "an astronaut's spacewalk "
42
- "with an amazing view of the earth from space behind"
43
- )
44
- results2 = retriever.invoke(query2)
45
- display_retrieved_results(results2)
46
- query3 = "a group of astronauts"
47
- results3 = retriever.invoke(query3)
48
- display_retrieved_results(results3)
49
-
50
-
51
- def open_table(TBL_NAME):
52
- # open a connection to table TBL_NAME
53
- tbl = db.open_table()
54
-
55
- print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
56
- # display the first 3 rows of the table
57
- tbl.to_pandas()[['text', 'image_path']].head(3)
58
-
59
- def store_in_rag():
60
-
61
- # load metadata files
62
- vid1_metadata_path = './shared_data/videos/video1/metadatas.json'
63
- vid2_metadata_path = './shared_data/videos/video2/metadatas.json'
64
- vid1_metadata = load_json_file(vid1_metadata_path)
65
- vid2_metadata = load_json_file(vid2_metadata_path)
66
-
67
- # collect transcripts and image paths
68
- vid1_trans = [vid['transcript'] for vid in vid1_metadata]
69
- vid1_img_path = [vid['extracted_frame_path'] for vid in vid1_metadata]
70
-
71
- vid2_trans = [vid['transcript'] for vid in vid2_metadata]
72
- vid2_img_path = [vid['extracted_frame_path'] for vid in vid2_metadata]
73
-
74
-
75
- # for video1, we pick n = 7
76
- n = 7
77
- updated_vid1_trans = [
78
- ' '.join(vid1_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
79
- ' '.join(vid1_trans[0 : i + int(n/2)]) for i in range(len(vid1_trans))
80
- ]
81
-
82
- # also need to update the updated transcripts in metadata
83
- for i in range(len(updated_vid1_trans)):
84
- vid1_metadata[i]['transcript'] = updated_vid1_trans[i]
85
-
86
- # you can pass in mode="append"
87
- # to add more entries to the vector store
88
- # in case you want to start with a fresh vector store,
89
- # you can pass in mode="overwrite" instead
90
-
91
- _ = MultimodalLanceDB.from_text_image_pairs(
92
- texts=updated_vid1_trans+vid2_trans,
93
- image_paths=vid1_img_path+vid2_img_path,
94
- embedding=embedder,
95
- metadatas=vid1_metadata+vid2_metadata,
96
- connection=db,
97
- table_name=TBL_NAME,
98
- mode="overwrite",
99
- )
100
-
101
- if __name__ == "__main__":
102
- tbl = db.open_table(TBL_NAME)
103
- print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
104
- #display the first 3 rows of the table
105
  return_top_k_most_similar_docs()
 
1
+ from mm_rag.embeddings.bridgetower_embeddings import (
2
+ BridgeTowerEmbeddings
3
+ )
4
+ from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
5
+ import lancedb
6
+ import json
7
+ import os
8
+ from PIL import Image
9
+ from utility import load_json_file, display_retrieved_results
10
+ import pyarrow as pa
11
+
12
+ # declare host file
13
+ LANCEDB_HOST_FILE = "./shared_data/.lancedb"
14
+ # declare table name
15
+ TBL_NAME = "test_tbl"
16
+ # initialize vectorstore
17
+ db = lancedb.connect(LANCEDB_HOST_FILE)
18
+ # initialize an BridgeTower embedder
19
+ embedder = BridgeTowerEmbeddings()
20
+
21
+
22
+ def return_top_k_most_similar_docs(max_docs=3):
23
+ # ask to return top 3 most similar documents
24
+ # Creating a LanceDB vector store
25
+ vectorstore = MultimodalLanceDB(
26
+ uri=LANCEDB_HOST_FILE,
27
+ embedding=embedder,
28
+ table_name=TBL_NAME)
29
+
30
+ # creating a retriever for the vector store
31
+ # search_type="similarity"
32
+ # declares that the type of search that the Retriever should perform
33
+ # is similarity search
34
+ # search_kwargs={"k": 1} means returning top-1 most similar document
35
+
36
+
37
+ retriever = vectorstore.as_retriever(
38
+ search_type='similarity',
39
+ search_kwargs={"k": max_docs})
40
+ query2 = (
41
+ "an astronaut's spacewalk "
42
+ "with an amazing view of the earth from space behind"
43
+ )
44
+ results2 = retriever.invoke(query2)
45
+ display_retrieved_results(results2)
46
+ query3 = "a group of astronauts"
47
+ results3 = retriever.invoke(query3)
48
+ display_retrieved_results(results3)
49
+
50
+
51
+ def open_table(TBL_NAME):
52
+ # open a connection to table TBL_NAME
53
+ tbl = db.open_table()
54
+
55
+ print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
56
+ # display the first 3 rows of the table
57
+ tbl.to_pandas()[['text', 'image_path']].head(3)
58
+
59
+ def store_in_rag():
60
+
61
+ # load metadata files
62
+ vid1_metadata_path = './shared_data/videos/video1/metadatas.json'
63
+ vid2_metadata_path = './shared_data/videos/video2/metadatas.json'
64
+ vid1_metadata = load_json_file(vid1_metadata_path)
65
+ vid2_metadata = load_json_file(vid2_metadata_path)
66
+
67
+ # collect transcripts and image paths
68
+ vid1_trans = [vid['transcript'] for vid in vid1_metadata]
69
+ vid1_img_path = [vid['extracted_frame_path'] for vid in vid1_metadata]
70
+
71
+ vid2_trans = [vid['transcript'] for vid in vid2_metadata]
72
+ vid2_img_path = [vid['extracted_frame_path'] for vid in vid2_metadata]
73
+
74
+
75
+ # for video1, we pick n = 7
76
+ n = 7
77
+ updated_vid1_trans = [
78
+ ' '.join(vid1_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
79
+ ' '.join(vid1_trans[0 : i + int(n/2)]) for i in range(len(vid1_trans))
80
+ ]
81
+
82
+ # also need to update the updated transcripts in metadata
83
+ for i in range(len(updated_vid1_trans)):
84
+ vid1_metadata[i]['transcript'] = updated_vid1_trans[i]
85
+
86
+ # you can pass in mode="append"
87
+ # to add more entries to the vector store
88
+ # in case you want to start with a fresh vector store,
89
+ # you can pass in mode="overwrite" instead
90
+
91
+ _ = MultimodalLanceDB.from_text_image_pairs(
92
+ texts=updated_vid1_trans+vid2_trans,
93
+ image_paths=vid1_img_path+vid2_img_path,
94
+ embedding=embedder,
95
+ metadatas=vid1_metadata+vid2_metadata,
96
+ connection=db,
97
+ table_name=TBL_NAME,
98
+ mode="overwrite",
99
+ )
100
+
101
+ if __name__ == "__main__":
102
+ tbl = db.open_table(TBL_NAME)
103
+ print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
104
+ #display the first 3 rows of the table
105
  return_top_k_most_similar_docs()
shared_data/videos/yt_video/blackholes101nationalgeographic/blackholes101nationalgeographic.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9f0d499a1b09e47d6f1e382e3be6666b6c268276f16abd84a680a7eb512b1a0
3
- size 8783737
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e61dc16ee65a8a0ef08316bfb4d8f7110c7f5186d298a7e28e1cafb3bb25c338
3
+ size 132
shared_data/videos/yt_video/blackholes101nationalgeographic/captions.vtt CHANGED
@@ -1,104 +1,104 @@
1
- WEBVTT
2
-
3
- 00:00.000 --> 00:08.760
4
- Black holes are among the most fascinating objects in our universe, and also the most
5
-
6
- 00:08.760 --> 00:13.520
7
- mysterious.
8
-
9
- 00:13.520 --> 00:19.040
10
- A black hole is a region in space where the force of gravity is so strong, not even light,
11
-
12
- 00:19.040 --> 00:23.200
13
- the fastest known entity in our universe can escape.
14
-
15
- 00:23.200 --> 00:28.680
16
- The boundary of a black hole is called the event horizon, a point of no return beyond
17
-
18
- 00:28.680 --> 00:31.840
19
- which we truly cannot see.
20
-
21
- 00:31.840 --> 00:37.040
22
- When something crosses the event horizon, it collapses into the black hole's singularity,
23
-
24
- 00:37.040 --> 00:42.400
25
- an infinitely small, infinitely dense point where space, time, and the laws of physics
26
-
27
- 00:42.400 --> 00:46.200
28
- no longer apply.
29
-
30
- 00:46.200 --> 00:51.400
31
- Scientists have theorized several different types of black holes, with stellar and supermassive
32
-
33
- 00:51.400 --> 00:54.280
34
- black holes being the most common.
35
-
36
- 00:54.280 --> 00:58.640
37
- Stolar black holes form when massive stars die and collapse.
38
-
39
- 00:58.640 --> 01:05.080
40
- They're roughly 10 to 20 times the mass of our sun, and scattered throughout the universe.
41
-
42
- 01:05.080 --> 01:11.040
43
- There could be millions of these stellar black holes in the Milky Way alone.
44
-
45
- 01:11.040 --> 01:16.440
46
- Supermassive black holes are giants by comparison, measuring millions, even billions of times
47
-
48
- 01:16.440 --> 01:19.440
49
- more massive than our sun.
50
-
51
- 01:19.440 --> 01:23.800
52
- Scientists can only guess how they form, but we do know they exist at the center of just
53
-
54
- 01:23.800 --> 01:28.920
55
- about every large galaxy, including our own.
56
-
57
- 01:28.920 --> 01:33.760
58
- Sagittarius A, the supermassive black hole at the center of the Milky Way, has a mass
59
-
60
- 01:33.760 --> 01:39.360
61
- of roughly four million suns, and has a diameter about the distance between the Earth and our
62
-
63
- 01:39.360 --> 01:41.960
64
- sun.
65
-
66
- 01:41.960 --> 01:46.680
67
- Because black holes are invisible, the only way for scientists to detect and study them
68
-
69
- 01:46.680 --> 01:50.040
70
- is to observe their effect on nearby matter.
71
-
72
- 01:50.040 --> 01:55.360
73
- This includes accretion disks, a disk of particles that form when gases and dust fall toward a
74
-
75
- 01:55.360 --> 02:03.920
76
- black hole, and quasars, jets of particles that blast out of supermassive black holes.
77
-
78
- 02:03.920 --> 02:08.720
79
- Black holes remained largely unknown until the 20th century.
80
-
81
- 02:08.720 --> 02:14.840
82
- In 1916, using Einstein's General Theory of Relativity, a German physicist named Karl
83
-
84
- 02:14.840 --> 02:20.280
85
- Schwartzschild calculated that any mass could become a black hole if it were compressed tightly
86
-
87
- 02:20.280 --> 02:22.640
88
- enough.
89
-
90
- 02:22.640 --> 02:27.480
91
- But it wasn't until 1971 when theory became reality.
92
-
93
- 02:27.480 --> 02:34.000
94
- Astronomers, studying the constellation Cygnus, discovered the first black hole.
95
-
96
- 02:34.000 --> 02:39.440
97
- An untold number of black holes are scattered throughout the universe, constantly warping
98
-
99
- 02:39.440 --> 02:45.600
100
- space and time, altering entire galaxies, and endlessly inspiring both scientists and
101
-
102
- 02:45.600 --> 02:47.120
103
- our collective imagination.
104
-
 
1
+ WEBVTT
2
+
3
+ 00:00.000 --> 00:08.760
4
+ Black holes are among the most fascinating objects in our universe, and also the most
5
+
6
+ 00:08.760 --> 00:13.520
7
+ mysterious.
8
+
9
+ 00:13.520 --> 00:19.040
10
+ A black hole is a region in space where the force of gravity is so strong, not even light,
11
+
12
+ 00:19.040 --> 00:23.200
13
+ the fastest known entity in our universe can escape.
14
+
15
+ 00:23.200 --> 00:28.680
16
+ The boundary of a black hole is called the event horizon, a point of no return beyond
17
+
18
+ 00:28.680 --> 00:31.840
19
+ which we truly cannot see.
20
+
21
+ 00:31.840 --> 00:37.040
22
+ When something crosses the event horizon, it collapses into the black hole's singularity,
23
+
24
+ 00:37.040 --> 00:42.400
25
+ an infinitely small, infinitely dense point where space, time, and the laws of physics
26
+
27
+ 00:42.400 --> 00:46.200
28
+ no longer apply.
29
+
30
+ 00:46.200 --> 00:51.400
31
+ Scientists have theorized several different types of black holes, with stellar and supermassive
32
+
33
+ 00:51.400 --> 00:54.280
34
+ black holes being the most common.
35
+
36
+ 00:54.280 --> 00:58.640
37
+ Stolar black holes form when massive stars die and collapse.
38
+
39
+ 00:58.640 --> 01:05.080
40
+ They're roughly 10 to 20 times the mass of our sun, and scattered throughout the universe.
41
+
42
+ 01:05.080 --> 01:11.040
43
+ There could be millions of these stellar black holes in the Milky Way alone.
44
+
45
+ 01:11.040 --> 01:16.440
46
+ Supermassive black holes are giants by comparison, measuring millions, even billions of times
47
+
48
+ 01:16.440 --> 01:19.440
49
+ more massive than our sun.
50
+
51
+ 01:19.440 --> 01:23.800
52
+ Scientists can only guess how they form, but we do know they exist at the center of just
53
+
54
+ 01:23.800 --> 01:28.920
55
+ about every large galaxy, including our own.
56
+
57
+ 01:28.920 --> 01:33.760
58
+ Sagittarius A, the supermassive black hole at the center of the Milky Way, has a mass
59
+
60
+ 01:33.760 --> 01:39.360
61
+ of roughly four million suns, and has a diameter about the distance between the Earth and our
62
+
63
+ 01:39.360 --> 01:41.960
64
+ sun.
65
+
66
+ 01:41.960 --> 01:46.680
67
+ Because black holes are invisible, the only way for scientists to detect and study them
68
+
69
+ 01:46.680 --> 01:50.040
70
+ is to observe their effect on nearby matter.
71
+
72
+ 01:50.040 --> 01:55.360
73
+ This includes accretion disks, a disk of particles that form when gases and dust fall toward a
74
+
75
+ 01:55.360 --> 02:03.920
76
+ black hole, and quasars, jets of particles that blast out of supermassive black holes.
77
+
78
+ 02:03.920 --> 02:08.720
79
+ Black holes remained largely unknown until the 20th century.
80
+
81
+ 02:08.720 --> 02:14.840
82
+ In 1916, using Einstein's General Theory of Relativity, a German physicist named Karl
83
+
84
+ 02:14.840 --> 02:20.280
85
+ Schwartzschild calculated that any mass could become a black hole if it were compressed tightly
86
+
87
+ 02:20.280 --> 02:22.640
88
+ enough.
89
+
90
+ 02:22.640 --> 02:27.480
91
+ But it wasn't until 1971 when theory became reality.
92
+
93
+ 02:27.480 --> 02:34.000
94
+ Astronomers, studying the constellation Cygnus, discovered the first black hole.
95
+
96
+ 02:34.000 --> 02:39.440
97
+ An untold number of black holes are scattered throughout the universe, constantly warping
98
+
99
+ 02:39.440 --> 02:45.600
100
+ space and time, altering entire galaxies, and endlessly inspiring both scientists and
101
+
102
+ 02:45.600 --> 02:47.120
103
+ our collective imagination.
104
+
utility.py CHANGED
@@ -1,764 +1,764 @@
1
- # Add your utilities or helper functions to this file.
2
-
3
- import os
4
- from pathlib import Path
5
- from dotenv import load_dotenv, find_dotenv
6
- from io import StringIO, BytesIO
7
- import textwrap
8
- from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
9
- from enum import auto, Enum
10
- import base64
11
- import glob
12
- from moviepy import VideoFileClip
13
- import requests
14
- from tqdm import tqdm
15
- from pytubefix import YouTube, Stream
16
- import webvtt
17
- import whisper
18
- from youtube_transcript_api import YouTubeTranscriptApi
19
- from youtube_transcript_api.formatters import WebVTTFormatter
20
- from predictionguard import PredictionGuard
21
- import cv2
22
- import re
23
- import json
24
- import PIL
25
- from ollama import chat
26
- from ollama import ChatResponse
27
- from PIL import Image
28
- import dataclasses
29
- import random
30
- from datasets import load_dataset
31
- from os import path as osp
32
- from IPython.display import display
33
- from langchain_core.prompt_values import PromptValue
34
- from langchain_core.messages import (
35
- MessageLikeRepresentation,
36
- )
37
- from transformers import pipeline
38
- from huggingface_hub import InferenceClient
39
-
40
- MultimodalModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation], Dict[str, Any]]
41
-
42
- def get_from_dict_or_env(
43
- data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
44
- ) -> str:
45
- """Get a value from a dictionary or an environment variable."""
46
- if key in data and data[key]:
47
- return data[key]
48
- else:
49
- return get_from_env(key, env_key, default=default)
50
-
51
- def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
52
- """Get a value from a dictionary or an environment variable."""
53
- if env_key in os.environ and os.environ[env_key]:
54
- return os.environ[env_key]
55
- else:
56
- return default
57
-
58
- def load_env():
59
- _ = load_dotenv(find_dotenv())
60
-
61
- def get_openai_api_key():
62
- load_env()
63
- openai_api_key = os.getenv("OPENAI_API_KEY")
64
- return openai_api_key
65
-
66
- def get_prediction_guard_api_key():
67
- load_env()
68
- PREDICTION_GUARD_API_KEY = os.getenv("PREDICTION_GUARD_API_KEY", None)
69
- if PREDICTION_GUARD_API_KEY is None:
70
- PREDICTION_GUARD_API_KEY = input("Please enter your Prediction Guard API Key: ")
71
- return PREDICTION_GUARD_API_KEY
72
-
73
- PREDICTION_GUARD_URL_ENDPOINT = os.getenv("DLAI_PREDICTION_GUARD_URL_ENDPOINT", "https://dl-itdc.predictionguard.com") ###"https://proxy-dl-itdc.predictionguard.com"
74
-
75
- # prompt templates
76
- templates = [
77
- 'a picture of {}',
78
- 'an image of {}',
79
- 'a nice {}',
80
- 'a beautiful {}',
81
- ]
82
-
83
- # function helps to prepare list image-text pairs from the first [test_size] data of a Huggingface dataset
84
- def prepare_dataset_for_umap_visualization(hf_dataset, class_name, templates=templates, test_size=1000):
85
- # load Huggingface dataset (download if needed)
86
- dataset = load_dataset(hf_dataset, trust_remote_code=True)
87
- # split dataset with specific test_size
88
- train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
89
- # get the test dataset
90
- test_dataset = train_test_dataset['test']
91
- img_txt_pairs = []
92
- for i in range(len(test_dataset)):
93
- img_txt_pairs.append({
94
- 'caption' : templates[random.randint(0, len(templates)-1)].format(class_name),
95
- 'pil_img' : test_dataset[i]['image']
96
- })
97
- return img_txt_pairs
98
-
99
-
100
- def download_video(video_url, path):
101
- print(f'Getting video information for {video_url}')
102
-
103
- def progress_callback(stream: Stream, data_chunk: bytes, bytes_remaining: int) -> None:
104
- pbar.update(len(data_chunk))
105
- stream = None
106
- try:
107
- yt = YouTube(video_url, on_progress_callback=progress_callback)
108
- stream = yt.streams.filter(progressive=True, file_extension='mp4', res='480p').desc().first()
109
- if stream is None:
110
- stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
111
- except Exception as e:
112
- print(f"Youtube Exception Occured.Loading from local resource: {e}")
113
-
114
- uncleaned_filename = stream.default_filename.replace(' ', '').lower() if stream else "blackholes101nationalgeographic.mp4"
115
- print(f'Uncleaned filename: {uncleaned_filename}')
116
- filename= re.sub(r'[^a-zA-Z0-9]', '', uncleaned_filename).replace('mp4', '')
117
- filename_without_extension = os.path.splitext(filename)[0]
118
- filename_with_extension = filename+'.mp4'
119
- folder_path = os.path.join(path, filename_without_extension)
120
- print(f'Checking the folder path {folder_path}')
121
- full_file_path = os.path.join(folder_path, filename_with_extension)
122
-
123
- if not os.path.exists(folder_path):
124
- os.makedirs(folder_path, exist_ok=True)
125
-
126
- if os.path.exists(full_file_path):
127
- print('Video already downloaded at the folder path', full_file_path)
128
- is_downloaded = False
129
- return full_file_path, folder_path, is_downloaded
130
-
131
-
132
- is_downloaded = True
133
-
134
- print('Downloading video from YouTube...')
135
- pbar = tqdm(desc='Downloading video from YouTube', total=stream.filesize, unit="bytes")
136
- stream.download(folder_path, filename=filename_with_extension)
137
- pbar.close()
138
- return full_file_path, folder_path, is_downloaded
139
-
140
- def get_video_id_from_url(video_url):
141
- """
142
- Examples:
143
- - http://youtu.be/SA2iWivDJiE
144
- - http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
145
- - http://www.youtube.com/embed/SA2iWivDJiE
146
- - http://www.youtube.com/v/SA2iWivDJiE?version=3&amp;hl=en_US
147
- """
148
- import urllib.parse
149
- url = urllib.parse.urlparse(video_url)
150
- if url.hostname == 'youtu.be':
151
- return url.path[1:]
152
- if url.hostname in ('www.youtube.com', 'youtube.com'):
153
- if url.path == '/watch':
154
- p = urllib.parse.parse_qs(url.query)
155
- return p['v'][0]
156
- if url.path[:7] == '/embed/':
157
- return url.path.split('/')[2]
158
- if url.path[:3] == '/v/':
159
- return url.path.split('/')[2]
160
-
161
- return video_url
162
-
163
- def generate_transcript_vtt(vid_dir, vid_filepath):
164
- print("Generating transcript for video ", vid_filepath)
165
- # declare where to save .mp3 audio
166
- path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
167
-
168
- # extract mp3 audio file from mp4 video video file
169
- path_to_video_no_transcript = vid_filepath
170
- clip = VideoFileClip(path_to_video_no_transcript)
171
- clip.audio.write_audiofile(path_to_extracted_audio_file)
172
-
173
- model = whisper.load_model("small")
174
- options = dict(task="translate", best_of=1, language='en')
175
- results = model.transcribe(path_to_extracted_audio_file, **options)
176
-
177
- vtt = getSubs(results["segments"], "vtt")
178
-
179
- # path to save generated transcript of video1
180
- path_to_generated_trans = osp.join(vid_dir, 'captions.vtt')
181
- # write transcription to file
182
- with open(path_to_generated_trans, 'w') as f:
183
- f.write(vtt)
184
- return path_to_generated_trans
185
-
186
-
187
- # if this has transcript then download
188
- def get_transcript_vtt(path, video_url, vid_file_path, from_gen=False):
189
- if from_gen:
190
- return generate_transcript_vtt(path,vid_file_path)
191
- video_id = get_video_id_from_url(video_url)
192
- filepath = os.path.join(path,'captions.vtt')
193
- if os.path.exists(filepath):
194
- print('Transcript already exists')
195
- return filepath
196
-
197
- print('Downloading Transcript...')
198
-
199
- transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en-GB', 'en'])
200
- formatter = WebVTTFormatter()
201
- webvtt_formatted = formatter.format_transcript(transcript)
202
-
203
- with open(filepath, 'w', encoding='utf-8') as webvtt_file:
204
- webvtt_file.write(webvtt_formatted)
205
- webvtt_file.close()
206
-
207
- return filepath
208
-
209
-
210
- # helper function for convert time in second to time format for .vtt or .srt file
211
- def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
212
- assert seconds >= 0, "non-negative timestamp expected"
213
- milliseconds = round(seconds * 1000.0)
214
-
215
- hours = milliseconds // 3_600_000
216
- milliseconds -= hours * 3_600_000
217
-
218
- minutes = milliseconds // 60_000
219
- milliseconds -= minutes * 60_000
220
-
221
- seconds = milliseconds // 1_000
222
- milliseconds -= seconds * 1_000
223
-
224
- hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
225
- return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
226
-
227
- # a help function that helps to convert a specific time written as a string in format `webvtt` into a time in miliseconds
228
- def str2time(strtime):
229
- # strip character " if exists
230
- strtime = strtime.strip('"')
231
- # get hour, minute, second from time string
232
- hrs, mins, seconds = [float(c) for c in strtime.split(':')]
233
- # get the corresponding time as total seconds
234
- total_seconds = hrs * 60**2 + mins * 60 + seconds
235
- total_miliseconds = total_seconds * 1000
236
- return total_miliseconds
237
-
238
- def _processText(text: str, maxLineWidth=None):
239
- if (maxLineWidth is None or maxLineWidth < 0):
240
- return text
241
-
242
- lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
243
- return '\n'.join(lines)
244
-
245
- # Resizes a image and maintains aspect ratio
246
- def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
247
- # Grab the image size and initialize dimensions
248
- dim = None
249
- (h, w) = image.shape[:2]
250
-
251
- # Return original image if no need to resize
252
- if width is None and height is None:
253
- return image
254
-
255
- # We are resizing height if width is none
256
- if width is None:
257
- # Calculate the ratio of the height and construct the dimensions
258
- r = height / float(h)
259
- dim = (int(w * r), height)
260
- # We are resizing width if height is none
261
- else:
262
- # Calculate the ratio of the width and construct the dimensions
263
- r = width / float(w)
264
- dim = (width, int(h * r))
265
-
266
- # Return the resized image
267
- return cv2.resize(image, dim, interpolation=inter)
268
-
269
- # helper function to convert transcripts generated by whisper to .vtt file
270
- def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
271
- print("WEBVTT\n", file=file)
272
- for segment in transcript:
273
- text = _processText(segment['text'], maxLineWidth).replace('-->', '->')
274
-
275
- print(
276
- f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
277
- f"{text}\n",
278
- file=file,
279
- flush=True,
280
- )
281
-
282
- # helper function to convert transcripts generated by whisper to .srt file
283
- def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
284
- """
285
- Write a transcript to a file in SRT format.
286
- Example usage:
287
- from pathlib import Path
288
- from whisper.utils import write_srt
289
- import requests
290
- result = transcribe(model, audio_path, temperature=temperature, **args)
291
- # save SRT
292
- audio_basename = Path(audio_path).stem
293
- with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
294
- write_srt(result["segments"], file=srt)
295
- """
296
- for i, segment in enumerate(transcript, start=1):
297
- text = _processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
298
-
299
- # write srt lines
300
- print(
301
- f"{i}\n"
302
- f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
303
- f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
304
- f"{text}\n",
305
- file=file,
306
- flush=True,
307
- )
308
-
309
- def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int=-1) -> str:
310
- segmentStream = StringIO()
311
-
312
- if format == 'vtt':
313
- write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
314
- elif format == 'srt':
315
- write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
316
- else:
317
- raise Exception("Unknown format " + format)
318
-
319
- segmentStream.seek(0)
320
- return segmentStream.read()
321
-
322
- # encoding image at given path or PIL Image using base64
323
- def encode_image(image_path_or_PIL_img):
324
- if isinstance(image_path_or_PIL_img, PIL.Image.Image):
325
- # this is a PIL image
326
- buffered = BytesIO()
327
- image_path_or_PIL_img.save(buffered, format="JPEG")
328
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
329
- else:
330
- # this is a image_path
331
- with open(image_path_or_PIL_img, "rb") as image_file:
332
- return base64.b64encode(image_file.read()).decode('utf-8')
333
-
334
- # checking whether the given string is base64 or not
335
- def isBase64(sb):
336
- try:
337
- if isinstance(sb, str):
338
- # If there's any unicode here, an exception will be thrown and the function will return false
339
- sb_bytes = bytes(sb, 'ascii')
340
- elif isinstance(sb, bytes):
341
- sb_bytes = sb
342
- else:
343
- raise ValueError("Argument must be string or bytes")
344
- return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes
345
- except Exception:
346
- return False
347
-
348
- def encode_image_from_path_or_url(image_path_or_url):
349
- try:
350
- # try to open the url to check valid url
351
- f = urlopen(image_path_or_url)
352
- # if this is an url
353
- return base64.b64encode(requests.get(image_path_or_url).content).decode('utf-8')
354
- except:
355
- # this is a path to image
356
- with open(image_path_or_url, "rb") as image_file:
357
- return base64.b64encode(image_file.read()).decode('utf-8')
358
-
359
- # helper function to compute the joint embedding of a prompt and a base64-encoded image through PredictionGuard
360
- def bt_embedding_from_prediction_guard(prompt, base64_image):
361
- # get PredictionGuard client
362
- client = _getPredictionGuardClient()
363
- message = {"text": prompt,}
364
- if base64_image is not None and base64_image != "":
365
- if not isBase64(base64_image):
366
- raise TypeError("image input must be in base64 encoding!")
367
- message['image'] = base64_image
368
- response = client.embeddings.create(
369
- model="bridgetower-large-itm-mlm-itc",
370
- input=[message]
371
- )
372
- return response['data'][0]['embedding']
373
-
374
-
375
- def load_json_file(file_path):
376
- # Open the JSON file in read mode
377
- with open(file_path, 'r') as file:
378
- data = json.load(file)
379
- return data
380
-
381
- def display_retrieved_results(results):
382
- print(f'There is/are {len(results)} retrieved result(s)')
383
- print()
384
- for i, res in enumerate(results):
385
- print(f'The caption of the {str(i+1)}-th retrieved result is:\n"{results[i].page_content}"')
386
- print()
387
- print(results[i])
388
- #display(Image.open(results[i].metadata['metadata']['extracted_frame_path']))
389
- print("------------------------------------------------------------")
390
-
391
- class SeparatorStyle(Enum):
392
- """Different separator style."""
393
- SINGLE = auto()
394
-
395
- @dataclasses.dataclass
396
- class Conversation:
397
- """A class that keeps all conversation history"""
398
- system: str
399
- roles: List[str]
400
- messages: List[List[str]]
401
- map_roles: Dict[str, str]
402
- version: str = "Unknown"
403
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
404
- sep: str = "\n"
405
-
406
- def _get_prompt_role(self, role):
407
- if self.map_roles is not None and role in self.map_roles.keys():
408
- return self.map_roles[role]
409
- else:
410
- return role
411
-
412
- def _build_content_for_first_message_in_conversation(self, first_message: List[str]):
413
- content = []
414
- if len(first_message) != 2:
415
- raise TypeError("First message in Conversation needs to include a prompt and a base64-enconded image!")
416
-
417
- prompt, b64_image = first_message[0], first_message[1]
418
-
419
- # handling prompt
420
- if prompt is None:
421
- raise TypeError("API does not support None prompt yet")
422
- content.append({
423
- "type": "text",
424
- "text": prompt
425
- })
426
- if b64_image is None:
427
- raise TypeError("API does not support text only conversation yet")
428
-
429
- # handling image
430
- if not isBase64(b64_image):
431
- raise TypeError("Image in Conversation's first message must be stored under base64 encoding!")
432
-
433
- content.append({
434
- "type": "image_url",
435
- "image_url": {
436
- "url": b64_image,
437
- }
438
- })
439
- return content
440
-
441
- def _build_content_for_follow_up_messages_in_conversation(self, follow_up_message: List[str]):
442
-
443
- if follow_up_message is not None and len(follow_up_message) > 1:
444
- raise TypeError("Follow-up message in Conversation must not include an image!")
445
-
446
- # handling text prompt
447
- if follow_up_message is None or follow_up_message[0] is None:
448
- raise TypeError("Follow-up message in Conversation must include exactly one text message")
449
-
450
- text = follow_up_message[0]
451
- return text
452
-
453
- def get_message(self):
454
- messages = self.messages
455
- api_messages = []
456
- for i, msg in enumerate(messages):
457
- role, message_content = msg
458
- if i == 0:
459
- # get content for very first message in conversation
460
- content = self._build_content_for_first_message_in_conversation(message_content)
461
- else:
462
- # get content for follow-up message in conversation
463
- content = self._build_content_for_follow_up_messages_in_conversation(message_content)
464
-
465
- api_messages.append({
466
- "role": role,
467
- "content": content,
468
- })
469
- return api_messages
470
-
471
- # this method helps represent a multi-turn chat into as a single turn chat format
472
- def serialize_messages(self):
473
- messages = self.messages
474
- ret = ""
475
- if self.sep_style == SeparatorStyle.SINGLE:
476
- if self.system is not None and self.system != "":
477
- ret = self.system + self.sep
478
- for i, (role, message) in enumerate(messages):
479
- role = self._get_prompt_role(role)
480
- if message:
481
- if isinstance(message, List):
482
- # get prompt only
483
- message = message[0]
484
- if i == 0:
485
- # do not include role at the beginning
486
- ret += message
487
- else:
488
- ret += role + ": " + message
489
- if i < len(messages) - 1:
490
- # avoid including sep at the end of serialized message
491
- ret += self.sep
492
- else:
493
- ret += role + ":"
494
- else:
495
- raise ValueError(f"Invalid style: {self.sep_style}")
496
-
497
- return ret
498
-
499
- def append_message(self, role, message):
500
- if len(self.messages) == 0:
501
- # data verification for the very first message
502
- assert role == self.roles[0], f"the very first message in conversation must be from role {self.roles[0]}"
503
- assert len(message) == 2, f"the very first message in conversation must include both prompt and an image"
504
- prompt, image = message[0], message[1]
505
- assert prompt is not None, f"prompt must be not None"
506
- assert isBase64(image), f"image must be under base64 encoding"
507
- else:
508
- # data verification for follow-up message
509
- assert role in self.roles, f"the follow-up message must be from one of the roles {self.roles}"
510
- assert len(message) == 1, f"the follow-up message must consist of one text message only, no image"
511
-
512
- self.messages.append([role, message])
513
-
514
- def copy(self):
515
- return Conversation(
516
- system=self.system,
517
- roles=self.roles,
518
- messages=[[x,y] for x, y in self.messages],
519
- version=self.version,
520
- map_roles=self.map_roles,
521
- )
522
-
523
- def dict(self):
524
- return {
525
- "system": self.system,
526
- "roles": self.roles,
527
- "messages": [[x, y[0] if len(y) == 1 else y] for x, y in self.messages],
528
- "version": self.version,
529
- }
530
-
531
- prediction_guard_llava_conv = Conversation(
532
- system="",
533
- roles=("user", "assistant"),
534
- messages=[],
535
- version="Prediction Guard LLaVA enpoint Conversation v0",
536
- sep_style=SeparatorStyle.SINGLE,
537
- map_roles={
538
- "user": "USER",
539
- "assistant": "ASSISTANT"
540
- }
541
- )
542
-
543
- # get PredictionGuard Client
544
- def _getPredictionGuardClient():
545
- PREDICTION_GUARD_API_KEY = get_prediction_guard_api_key()
546
- client = PredictionGuard(
547
- api_key=PREDICTION_GUARD_API_KEY,
548
- url=PREDICTION_GUARD_URL_ENDPOINT,
549
- )
550
- return client
551
-
552
- # helper function to call chat completion endpoint of PredictionGuard given a prompt and an image
553
- def lvlm_inference(prompt, image, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
554
- # prepare conversation
555
- conversation = prediction_guard_llava_conv.copy()
556
- conversation.append_message(conversation.roles[0], [prompt, image])
557
- return lvlm_inference_with_conversation(conversation, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
558
-
559
-
560
-
561
- def lvlm_inference_with_conversation(conversation, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
562
- # get PredictionGuard client
563
- client = _getPredictionGuardClient()
564
- # get message from conversation
565
- messages = conversation.get_message()
566
- # call chat completion endpoint at Grediction Guard
567
- response = client.chat.completions.create(
568
- model="llava-1.5-7b-hf",
569
- messages=messages,
570
- max_tokens=max_tokens,
571
- temperature=temperature,
572
- top_p=top_p,
573
- top_k=top_k,
574
- )
575
- return response['choices'][-1]['message']['content']
576
-
577
- def get_token():
578
- load_env()
579
- token = os.getenv("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
580
- if token is None:
581
- raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
582
- return token
583
-
584
-
585
- def lvlm_inference_with_phi(prompt):
586
-
587
-
588
- messages = [{"role": "user", "content": prompt}]
589
- client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=get_token())
590
- response = ''
591
- token = client.chat_completion(messages, max_tokens=256)
592
- response = token['choices'][0]['message']['content']
593
- return response
594
-
595
- def lvlm_inference_with_tiny_model(prompt):
596
- classifier = pipeline(
597
- "text-generation",
598
- model="microsoft/phi-2", # Only ~2.7GB
599
- device_map="auto",
600
- torch_dtype="auto",
601
- )
602
-
603
- response = classifier(
604
- prompt,
605
- max_new_tokens=512, # Remove max_length and use only max_new_tokens
606
- temperature=0.7,
607
- do_sample=True,
608
- num_return_sequences=1,
609
- truncation=True, # Add explicit truncation
610
- pad_token_id=classifier.tokenizer.eos_token_id,
611
- eos_token_id=classifier.tokenizer.eos_token_id,
612
- )[0]['generated_text']
613
-
614
- # Remove the input prompt from the response and clean up
615
- return response.replace(prompt, "").strip()
616
-
617
- # function `extract_and_save_frames_and_metadata``:
618
- # receives as input a video and its transcript
619
- # does extracting and saving frames and their metadatas
620
- # returns the extracted metadatas
621
- def extract_and_save_frames_and_metadata(
622
- path_to_video,
623
- path_to_transcript,
624
- path_to_save_extracted_frames,
625
- path_to_save_metadatas):
626
-
627
- # metadatas will store the metadata of all extracted frames
628
- metadatas = []
629
-
630
- # load video using cv2
631
- print(f"Loading video from {path_to_video}")
632
- video = cv2.VideoCapture(path_to_video)
633
- # load transcript using webvtt
634
- print(f"Loading transcript from {path_to_transcript}")
635
- trans = webvtt.read(path_to_transcript)
636
-
637
- # iterate transcript file
638
- # for each video segment specified in the transcript file
639
- for idx, transcript in enumerate(trans):
640
- # get the start time and end time in seconds
641
- start_time_ms = str2time(transcript.start)
642
- end_time_ms = str2time(transcript.end)
643
- # get the time in ms exactly
644
- # in the middle of start time and end time
645
- mid_time_ms = (end_time_ms + start_time_ms) / 2
646
- # get the transcript, remove the next-line symbol
647
- text = transcript.text.replace("\n", ' ')
648
- # get frame at the middle time
649
- video.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
650
- print(f"Extracting frame at {mid_time_ms} ms")
651
- success, frame = video.read()
652
- if success:
653
- # if the frame is extracted successfully, resize it
654
- image = maintain_aspect_ratio_resize(frame, height=350)
655
- # save frame as JPEG file
656
- img_fname = f'frame_{idx}.jpg'
657
- img_fpath = osp.join(
658
- path_to_save_extracted_frames, img_fname
659
- )
660
- cv2.imwrite(img_fpath, image)
661
-
662
- # prepare the metadata
663
- metadata = {
664
- 'extracted_frame_path': img_fpath,
665
- 'transcript': text,
666
- 'video_segment_id': idx,
667
- 'video_path': path_to_video,
668
- 'mid_time_ms': mid_time_ms,
669
- }
670
- metadatas.append(metadata)
671
-
672
- else:
673
- print(f"ERROR! Cannot extract frame: idx = {idx}")
674
-
675
- # save metadata of all extracted frames
676
- fn = osp.join(path_to_save_metadatas, 'metadatas.json')
677
- with open(fn, 'w') as outfile:
678
- json.dump(metadatas, outfile)
679
- return metadatas
680
-
681
- def extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath):
682
- # output paths to save extracted frames and their metadata
683
- extracted_frames_path = osp.join(vid_dir, 'extracted_frame')
684
- metadatas_path = vid_dir
685
-
686
- # create these output folders if not existing
687
- print(f"Creating folders {extracted_frames_path} and {metadatas_path}")
688
- Path(extracted_frames_path).mkdir(parents=True, exist_ok=True)
689
- Path(metadatas_path).mkdir(parents=True, exist_ok=True)
690
- print("Extracting frames the video path ", vid_filepath)
691
-
692
- # call the function to extract frames and metadatas
693
- metadatas = extract_and_save_frames_and_metadata(
694
- vid_filepath,
695
- vid_transcript_filepath,
696
- extracted_frames_path,
697
- metadatas_path,
698
- )
699
- return metadatas
700
-
701
- # function extract_and_save_frames_and_metadata_with_fps
702
- # receives as input a video
703
- # does extracting and saving frames and their metadatas
704
- # returns the extracted metadatas
705
- def extract_and_save_frames_and_metadata_with_fps(
706
- lvlm_prompt,
707
- path_to_video,
708
- path_to_save_extracted_frames,
709
- path_to_save_metadatas,
710
- num_of_extracted_frames_per_second=1):
711
-
712
- # metadatas will store the metadata of all extracted frames
713
- metadatas = []
714
-
715
- # load video using cv2
716
- video = cv2.VideoCapture(path_to_video)
717
-
718
- # Get the frames per second
719
- fps = video.get(cv2.CAP_PROP_FPS)
720
- # Get hop = the number of frames pass before a frame is extracted
721
- hop = round(fps / num_of_extracted_frames_per_second)
722
- curr_frame = 0
723
- idx = -1
724
- while(True):
725
- # iterate all frames
726
- ret, frame = video.read()
727
- if not ret:
728
- break
729
- if curr_frame % hop == 0:
730
- idx = idx + 1
731
-
732
- # if the frame is extracted successfully, resize it
733
- image = maintain_aspect_ratio_resize(frame, height=350)
734
- # save frame as JPEG file
735
- img_fname = f'frame_{idx}.jpg'
736
- img_fpath = osp.join(
737
- path_to_save_extracted_frames,
738
- img_fname
739
- )
740
- cv2.imwrite(img_fpath, image)
741
-
742
- # generate caption using lvlm_inference
743
- b64_image = encode_image(img_fpath)
744
- caption = lvlm_inference(lvlm_prompt, b64_image)
745
-
746
- # prepare the metadata
747
- metadata = {
748
- 'extracted_frame_path': img_fpath,
749
- 'transcript': caption,
750
- 'video_segment_id': idx,
751
- 'video_path': path_to_video,
752
- }
753
- metadatas.append(metadata)
754
- curr_frame += 1
755
-
756
- # save metadata of all extracted frames
757
- metadatas_path = osp.join(path_to_save_metadatas,'metadatas.json')
758
- with open(metadatas_path, 'w') as outfile:
759
- json.dump(metadatas, outfile)
760
- return metadatas
761
-
762
- if __name__ == "__main__":
763
- res = lvlm_inference_with_phi("Tell me a story")
764
  print(res)
 
1
+ # Add your utilities or helper functions to this file.
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from dotenv import load_dotenv, find_dotenv
6
+ from io import StringIO, BytesIO
7
+ import textwrap
8
+ from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
9
+ from enum import auto, Enum
10
+ import base64
11
+ import glob
12
+ from moviepy import VideoFileClip
13
+ import requests
14
+ from tqdm import tqdm
15
+ from pytubefix import YouTube, Stream
16
+ import webvtt
17
+ import whisper
18
+ from youtube_transcript_api import YouTubeTranscriptApi
19
+ from youtube_transcript_api.formatters import WebVTTFormatter
20
+ from predictionguard import PredictionGuard
21
+ import cv2
22
+ import re
23
+ import json
24
+ import PIL
25
+ from ollama import chat
26
+ from ollama import ChatResponse
27
+ from PIL import Image
28
+ import dataclasses
29
+ import random
30
+ from datasets import load_dataset
31
+ from os import path as osp
32
+ from IPython.display import display
33
+ from langchain_core.prompt_values import PromptValue
34
+ from langchain_core.messages import (
35
+ MessageLikeRepresentation,
36
+ )
37
+ from transformers import pipeline
38
+ from huggingface_hub import InferenceClient
39
+
40
+ MultimodalModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation], Dict[str, Any]]
41
+
42
+ def get_from_dict_or_env(
43
+ data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
44
+ ) -> str:
45
+ """Get a value from a dictionary or an environment variable."""
46
+ if key in data and data[key]:
47
+ return data[key]
48
+ else:
49
+ return get_from_env(key, env_key, default=default)
50
+
51
+ def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
52
+ """Get a value from a dictionary or an environment variable."""
53
+ if env_key in os.environ and os.environ[env_key]:
54
+ return os.environ[env_key]
55
+ else:
56
+ return default
57
+
58
+ def load_env():
59
+ _ = load_dotenv(find_dotenv())
60
+
61
+ def get_openai_api_key():
62
+ load_env()
63
+ openai_api_key = os.getenv("OPENAI_API_KEY")
64
+ return openai_api_key
65
+
66
+ def get_prediction_guard_api_key():
67
+ load_env()
68
+ PREDICTION_GUARD_API_KEY = os.getenv("PREDICTION_GUARD_API_KEY", None)
69
+ if PREDICTION_GUARD_API_KEY is None:
70
+ PREDICTION_GUARD_API_KEY = input("Please enter your Prediction Guard API Key: ")
71
+ return PREDICTION_GUARD_API_KEY
72
+
73
+ PREDICTION_GUARD_URL_ENDPOINT = os.getenv("DLAI_PREDICTION_GUARD_URL_ENDPOINT", "https://dl-itdc.predictionguard.com") ###"https://proxy-dl-itdc.predictionguard.com"
74
+
75
+ # prompt templates
76
+ templates = [
77
+ 'a picture of {}',
78
+ 'an image of {}',
79
+ 'a nice {}',
80
+ 'a beautiful {}',
81
+ ]
82
+
83
+ # function helps to prepare list image-text pairs from the first [test_size] data of a Huggingface dataset
84
+ def prepare_dataset_for_umap_visualization(hf_dataset, class_name, templates=templates, test_size=1000):
85
+ # load Huggingface dataset (download if needed)
86
+ dataset = load_dataset(hf_dataset, trust_remote_code=True)
87
+ # split dataset with specific test_size
88
+ train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
89
+ # get the test dataset
90
+ test_dataset = train_test_dataset['test']
91
+ img_txt_pairs = []
92
+ for i in range(len(test_dataset)):
93
+ img_txt_pairs.append({
94
+ 'caption' : templates[random.randint(0, len(templates)-1)].format(class_name),
95
+ 'pil_img' : test_dataset[i]['image']
96
+ })
97
+ return img_txt_pairs
98
+
99
+
100
+ def download_video(video_url, path):
101
+ print(f'Getting video information for {video_url}')
102
+
103
+ def progress_callback(stream: Stream, data_chunk: bytes, bytes_remaining: int) -> None:
104
+ pbar.update(len(data_chunk))
105
+ stream = None
106
+ try:
107
+ yt = YouTube(video_url, on_progress_callback=progress_callback)
108
+ stream = yt.streams.filter(progressive=True, file_extension='mp4', res='480p').desc().first()
109
+ if stream is None:
110
+ stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
111
+ except Exception as e:
112
+ print(f"Youtube Exception Occured.Loading from local resource: {e}")
113
+
114
+ uncleaned_filename = stream.default_filename.replace(' ', '').lower() if stream else "blackholes101nationalgeographic.mp4"
115
+ print(f'Uncleaned filename: {uncleaned_filename}')
116
+ filename= re.sub(r'[^a-zA-Z0-9]', '', uncleaned_filename).replace('mp4', '')
117
+ filename_without_extension = os.path.splitext(filename)[0]
118
+ filename_with_extension = filename+'.mp4'
119
+ folder_path = os.path.join(path, filename_without_extension)
120
+ print(f'Checking the folder path {folder_path}')
121
+ full_file_path = os.path.join(folder_path, filename_with_extension)
122
+
123
+ if not os.path.exists(folder_path):
124
+ os.makedirs(folder_path, exist_ok=True)
125
+
126
+ if os.path.exists(full_file_path):
127
+ print('Video already downloaded at the folder path', full_file_path)
128
+ is_downloaded = False
129
+ return full_file_path, folder_path, is_downloaded
130
+
131
+
132
+ is_downloaded = True
133
+
134
+ print('Downloading video from YouTube...')
135
+ pbar = tqdm(desc='Downloading video from YouTube', total=stream.filesize, unit="bytes")
136
+ stream.download(folder_path, filename=filename_with_extension)
137
+ pbar.close()
138
+ return full_file_path, folder_path, is_downloaded
139
+
140
+ def get_video_id_from_url(video_url):
141
+ """
142
+ Examples:
143
+ - http://youtu.be/SA2iWivDJiE
144
+ - http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
145
+ - http://www.youtube.com/embed/SA2iWivDJiE
146
+ - http://www.youtube.com/v/SA2iWivDJiE?version=3&amp;hl=en_US
147
+ """
148
+ import urllib.parse
149
+ url = urllib.parse.urlparse(video_url)
150
+ if url.hostname == 'youtu.be':
151
+ return url.path[1:]
152
+ if url.hostname in ('www.youtube.com', 'youtube.com'):
153
+ if url.path == '/watch':
154
+ p = urllib.parse.parse_qs(url.query)
155
+ return p['v'][0]
156
+ if url.path[:7] == '/embed/':
157
+ return url.path.split('/')[2]
158
+ if url.path[:3] == '/v/':
159
+ return url.path.split('/')[2]
160
+
161
+ return video_url
162
+
163
+ def generate_transcript_vtt(vid_dir, vid_filepath):
164
+ print("Generating transcript for video ", vid_filepath)
165
+ # declare where to save .mp3 audio
166
+ path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
167
+
168
+ # extract mp3 audio file from mp4 video video file
169
+ path_to_video_no_transcript = vid_filepath
170
+ clip = VideoFileClip(path_to_video_no_transcript)
171
+ clip.audio.write_audiofile(path_to_extracted_audio_file)
172
+
173
+ model = whisper.load_model("small")
174
+ options = dict(task="translate", best_of=1, language='en')
175
+ results = model.transcribe(path_to_extracted_audio_file, **options)
176
+
177
+ vtt = getSubs(results["segments"], "vtt")
178
+
179
+ # path to save generated transcript of video1
180
+ path_to_generated_trans = osp.join(vid_dir, 'captions.vtt')
181
+ # write transcription to file
182
+ with open(path_to_generated_trans, 'w') as f:
183
+ f.write(vtt)
184
+ return path_to_generated_trans
185
+
186
+
187
+ # if this has transcript then download
188
+ def get_transcript_vtt(path, video_url, vid_file_path, from_gen=False):
189
+ if from_gen:
190
+ return generate_transcript_vtt(path,vid_file_path)
191
+ video_id = get_video_id_from_url(video_url)
192
+ filepath = os.path.join(path,'captions.vtt')
193
+ if os.path.exists(filepath):
194
+ print('Transcript already exists')
195
+ return filepath
196
+
197
+ print('Downloading Transcript...')
198
+
199
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en-GB', 'en'])
200
+ formatter = WebVTTFormatter()
201
+ webvtt_formatted = formatter.format_transcript(transcript)
202
+
203
+ with open(filepath, 'w', encoding='utf-8') as webvtt_file:
204
+ webvtt_file.write(webvtt_formatted)
205
+ webvtt_file.close()
206
+
207
+ return filepath
208
+
209
+
210
+ # helper function for convert time in second to time format for .vtt or .srt file
211
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
212
+ assert seconds >= 0, "non-negative timestamp expected"
213
+ milliseconds = round(seconds * 1000.0)
214
+
215
+ hours = milliseconds // 3_600_000
216
+ milliseconds -= hours * 3_600_000
217
+
218
+ minutes = milliseconds // 60_000
219
+ milliseconds -= minutes * 60_000
220
+
221
+ seconds = milliseconds // 1_000
222
+ milliseconds -= seconds * 1_000
223
+
224
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
225
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
226
+
227
+ # a help function that helps to convert a specific time written as a string in format `webvtt` into a time in miliseconds
228
+ def str2time(strtime):
229
+ # strip character " if exists
230
+ strtime = strtime.strip('"')
231
+ # get hour, minute, second from time string
232
+ hrs, mins, seconds = [float(c) for c in strtime.split(':')]
233
+ # get the corresponding time as total seconds
234
+ total_seconds = hrs * 60**2 + mins * 60 + seconds
235
+ total_miliseconds = total_seconds * 1000
236
+ return total_miliseconds
237
+
238
+ def _processText(text: str, maxLineWidth=None):
239
+ if (maxLineWidth is None or maxLineWidth < 0):
240
+ return text
241
+
242
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
243
+ return '\n'.join(lines)
244
+
245
+ # Resizes a image and maintains aspect ratio
246
+ def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
247
+ # Grab the image size and initialize dimensions
248
+ dim = None
249
+ (h, w) = image.shape[:2]
250
+
251
+ # Return original image if no need to resize
252
+ if width is None and height is None:
253
+ return image
254
+
255
+ # We are resizing height if width is none
256
+ if width is None:
257
+ # Calculate the ratio of the height and construct the dimensions
258
+ r = height / float(h)
259
+ dim = (int(w * r), height)
260
+ # We are resizing width if height is none
261
+ else:
262
+ # Calculate the ratio of the width and construct the dimensions
263
+ r = width / float(w)
264
+ dim = (width, int(h * r))
265
+
266
+ # Return the resized image
267
+ return cv2.resize(image, dim, interpolation=inter)
268
+
269
+ # helper function to convert transcripts generated by whisper to .vtt file
270
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
271
+ print("WEBVTT\n", file=file)
272
+ for segment in transcript:
273
+ text = _processText(segment['text'], maxLineWidth).replace('-->', '->')
274
+
275
+ print(
276
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
277
+ f"{text}\n",
278
+ file=file,
279
+ flush=True,
280
+ )
281
+
282
+ # helper function to convert transcripts generated by whisper to .srt file
283
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
284
+ """
285
+ Write a transcript to a file in SRT format.
286
+ Example usage:
287
+ from pathlib import Path
288
+ from whisper.utils import write_srt
289
+ import requests
290
+ result = transcribe(model, audio_path, temperature=temperature, **args)
291
+ # save SRT
292
+ audio_basename = Path(audio_path).stem
293
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
294
+ write_srt(result["segments"], file=srt)
295
+ """
296
+ for i, segment in enumerate(transcript, start=1):
297
+ text = _processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
298
+
299
+ # write srt lines
300
+ print(
301
+ f"{i}\n"
302
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
303
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
304
+ f"{text}\n",
305
+ file=file,
306
+ flush=True,
307
+ )
308
+
309
+ def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int=-1) -> str:
310
+ segmentStream = StringIO()
311
+
312
+ if format == 'vtt':
313
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
314
+ elif format == 'srt':
315
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
316
+ else:
317
+ raise Exception("Unknown format " + format)
318
+
319
+ segmentStream.seek(0)
320
+ return segmentStream.read()
321
+
322
+ # encoding image at given path or PIL Image using base64
323
+ def encode_image(image_path_or_PIL_img):
324
+ if isinstance(image_path_or_PIL_img, PIL.Image.Image):
325
+ # this is a PIL image
326
+ buffered = BytesIO()
327
+ image_path_or_PIL_img.save(buffered, format="JPEG")
328
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
329
+ else:
330
+ # this is a image_path
331
+ with open(image_path_or_PIL_img, "rb") as image_file:
332
+ return base64.b64encode(image_file.read()).decode('utf-8')
333
+
334
+ # checking whether the given string is base64 or not
335
+ def isBase64(sb):
336
+ try:
337
+ if isinstance(sb, str):
338
+ # If there's any unicode here, an exception will be thrown and the function will return false
339
+ sb_bytes = bytes(sb, 'ascii')
340
+ elif isinstance(sb, bytes):
341
+ sb_bytes = sb
342
+ else:
343
+ raise ValueError("Argument must be string or bytes")
344
+ return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes
345
+ except Exception:
346
+ return False
347
+
348
+ def encode_image_from_path_or_url(image_path_or_url):
349
+ try:
350
+ # try to open the url to check valid url
351
+ f = urlopen(image_path_or_url)
352
+ # if this is an url
353
+ return base64.b64encode(requests.get(image_path_or_url).content).decode('utf-8')
354
+ except:
355
+ # this is a path to image
356
+ with open(image_path_or_url, "rb") as image_file:
357
+ return base64.b64encode(image_file.read()).decode('utf-8')
358
+
359
+ # helper function to compute the joint embedding of a prompt and a base64-encoded image through PredictionGuard
360
+ def bt_embedding_from_prediction_guard(prompt, base64_image):
361
+ # get PredictionGuard client
362
+ client = _getPredictionGuardClient()
363
+ message = {"text": prompt,}
364
+ if base64_image is not None and base64_image != "":
365
+ if not isBase64(base64_image):
366
+ raise TypeError("image input must be in base64 encoding!")
367
+ message['image'] = base64_image
368
+ response = client.embeddings.create(
369
+ model="bridgetower-large-itm-mlm-itc",
370
+ input=[message]
371
+ )
372
+ return response['data'][0]['embedding']
373
+
374
+
375
+ def load_json_file(file_path):
376
+ # Open the JSON file in read mode
377
+ with open(file_path, 'r') as file:
378
+ data = json.load(file)
379
+ return data
380
+
381
+ def display_retrieved_results(results):
382
+ print(f'There is/are {len(results)} retrieved result(s)')
383
+ print()
384
+ for i, res in enumerate(results):
385
+ print(f'The caption of the {str(i+1)}-th retrieved result is:\n"{results[i].page_content}"')
386
+ print()
387
+ print(results[i])
388
+ #display(Image.open(results[i].metadata['metadata']['extracted_frame_path']))
389
+ print("------------------------------------------------------------")
390
+
391
+ class SeparatorStyle(Enum):
392
+ """Different separator style."""
393
+ SINGLE = auto()
394
+
395
+ @dataclasses.dataclass
396
+ class Conversation:
397
+ """A class that keeps all conversation history"""
398
+ system: str
399
+ roles: List[str]
400
+ messages: List[List[str]]
401
+ map_roles: Dict[str, str]
402
+ version: str = "Unknown"
403
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
404
+ sep: str = "\n"
405
+
406
+ def _get_prompt_role(self, role):
407
+ if self.map_roles is not None and role in self.map_roles.keys():
408
+ return self.map_roles[role]
409
+ else:
410
+ return role
411
+
412
+ def _build_content_for_first_message_in_conversation(self, first_message: List[str]):
413
+ content = []
414
+ if len(first_message) != 2:
415
+ raise TypeError("First message in Conversation needs to include a prompt and a base64-enconded image!")
416
+
417
+ prompt, b64_image = first_message[0], first_message[1]
418
+
419
+ # handling prompt
420
+ if prompt is None:
421
+ raise TypeError("API does not support None prompt yet")
422
+ content.append({
423
+ "type": "text",
424
+ "text": prompt
425
+ })
426
+ if b64_image is None:
427
+ raise TypeError("API does not support text only conversation yet")
428
+
429
+ # handling image
430
+ if not isBase64(b64_image):
431
+ raise TypeError("Image in Conversation's first message must be stored under base64 encoding!")
432
+
433
+ content.append({
434
+ "type": "image_url",
435
+ "image_url": {
436
+ "url": b64_image,
437
+ }
438
+ })
439
+ return content
440
+
441
+ def _build_content_for_follow_up_messages_in_conversation(self, follow_up_message: List[str]):
442
+
443
+ if follow_up_message is not None and len(follow_up_message) > 1:
444
+ raise TypeError("Follow-up message in Conversation must not include an image!")
445
+
446
+ # handling text prompt
447
+ if follow_up_message is None or follow_up_message[0] is None:
448
+ raise TypeError("Follow-up message in Conversation must include exactly one text message")
449
+
450
+ text = follow_up_message[0]
451
+ return text
452
+
453
+ def get_message(self):
454
+ messages = self.messages
455
+ api_messages = []
456
+ for i, msg in enumerate(messages):
457
+ role, message_content = msg
458
+ if i == 0:
459
+ # get content for very first message in conversation
460
+ content = self._build_content_for_first_message_in_conversation(message_content)
461
+ else:
462
+ # get content for follow-up message in conversation
463
+ content = self._build_content_for_follow_up_messages_in_conversation(message_content)
464
+
465
+ api_messages.append({
466
+ "role": role,
467
+ "content": content,
468
+ })
469
+ return api_messages
470
+
471
+ # this method helps represent a multi-turn chat into as a single turn chat format
472
+ def serialize_messages(self):
473
+ messages = self.messages
474
+ ret = ""
475
+ if self.sep_style == SeparatorStyle.SINGLE:
476
+ if self.system is not None and self.system != "":
477
+ ret = self.system + self.sep
478
+ for i, (role, message) in enumerate(messages):
479
+ role = self._get_prompt_role(role)
480
+ if message:
481
+ if isinstance(message, List):
482
+ # get prompt only
483
+ message = message[0]
484
+ if i == 0:
485
+ # do not include role at the beginning
486
+ ret += message
487
+ else:
488
+ ret += role + ": " + message
489
+ if i < len(messages) - 1:
490
+ # avoid including sep at the end of serialized message
491
+ ret += self.sep
492
+ else:
493
+ ret += role + ":"
494
+ else:
495
+ raise ValueError(f"Invalid style: {self.sep_style}")
496
+
497
+ return ret
498
+
499
+ def append_message(self, role, message):
500
+ if len(self.messages) == 0:
501
+ # data verification for the very first message
502
+ assert role == self.roles[0], f"the very first message in conversation must be from role {self.roles[0]}"
503
+ assert len(message) == 2, f"the very first message in conversation must include both prompt and an image"
504
+ prompt, image = message[0], message[1]
505
+ assert prompt is not None, f"prompt must be not None"
506
+ assert isBase64(image), f"image must be under base64 encoding"
507
+ else:
508
+ # data verification for follow-up message
509
+ assert role in self.roles, f"the follow-up message must be from one of the roles {self.roles}"
510
+ assert len(message) == 1, f"the follow-up message must consist of one text message only, no image"
511
+
512
+ self.messages.append([role, message])
513
+
514
+ def copy(self):
515
+ return Conversation(
516
+ system=self.system,
517
+ roles=self.roles,
518
+ messages=[[x,y] for x, y in self.messages],
519
+ version=self.version,
520
+ map_roles=self.map_roles,
521
+ )
522
+
523
+ def dict(self):
524
+ return {
525
+ "system": self.system,
526
+ "roles": self.roles,
527
+ "messages": [[x, y[0] if len(y) == 1 else y] for x, y in self.messages],
528
+ "version": self.version,
529
+ }
530
+
531
+ prediction_guard_llava_conv = Conversation(
532
+ system="",
533
+ roles=("user", "assistant"),
534
+ messages=[],
535
+ version="Prediction Guard LLaVA enpoint Conversation v0",
536
+ sep_style=SeparatorStyle.SINGLE,
537
+ map_roles={
538
+ "user": "USER",
539
+ "assistant": "ASSISTANT"
540
+ }
541
+ )
542
+
543
+ # get PredictionGuard Client
544
+ def _getPredictionGuardClient():
545
+ PREDICTION_GUARD_API_KEY = get_prediction_guard_api_key()
546
+ client = PredictionGuard(
547
+ api_key=PREDICTION_GUARD_API_KEY,
548
+ url=PREDICTION_GUARD_URL_ENDPOINT,
549
+ )
550
+ return client
551
+
552
+ # helper function to call chat completion endpoint of PredictionGuard given a prompt and an image
553
+ def lvlm_inference(prompt, image, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
554
+ # prepare conversation
555
+ conversation = prediction_guard_llava_conv.copy()
556
+ conversation.append_message(conversation.roles[0], [prompt, image])
557
+ return lvlm_inference_with_conversation(conversation, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
558
+
559
+
560
+
561
+ def lvlm_inference_with_conversation(conversation, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
562
+ # get PredictionGuard client
563
+ client = _getPredictionGuardClient()
564
+ # get message from conversation
565
+ messages = conversation.get_message()
566
+ # call chat completion endpoint at Grediction Guard
567
+ response = client.chat.completions.create(
568
+ model="llava-1.5-7b-hf",
569
+ messages=messages,
570
+ max_tokens=max_tokens,
571
+ temperature=temperature,
572
+ top_p=top_p,
573
+ top_k=top_k,
574
+ )
575
+ return response['choices'][-1]['message']['content']
576
+
577
+ def get_token():
578
+ load_env()
579
+ token = os.getenv("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
580
+ if token is None:
581
+ raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
582
+ return token
583
+
584
+
585
+ def lvlm_inference_with_phi(prompt):
586
+
587
+
588
+ messages = [{"role": "user", "content": prompt}]
589
+ client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=get_token())
590
+ response = ''
591
+ token = client.chat_completion(messages, max_tokens=256)
592
+ response = token['choices'][0]['message']['content']
593
+ return response
594
+
595
+ def lvlm_inference_with_tiny_model(prompt):
596
+ classifier = pipeline(
597
+ "text-generation",
598
+ model="microsoft/phi-2", # Only ~2.7GB
599
+ device_map="auto",
600
+ torch_dtype="auto",
601
+ )
602
+
603
+ response = classifier(
604
+ prompt,
605
+ max_new_tokens=512, # Remove max_length and use only max_new_tokens
606
+ temperature=0.7,
607
+ do_sample=True,
608
+ num_return_sequences=1,
609
+ truncation=True, # Add explicit truncation
610
+ pad_token_id=classifier.tokenizer.eos_token_id,
611
+ eos_token_id=classifier.tokenizer.eos_token_id,
612
+ )[0]['generated_text']
613
+
614
+ # Remove the input prompt from the response and clean up
615
+ return response.replace(prompt, "").strip()
616
+
617
+ # function `extract_and_save_frames_and_metadata``:
618
+ # receives as input a video and its transcript
619
+ # does extracting and saving frames and their metadatas
620
+ # returns the extracted metadatas
621
+ def extract_and_save_frames_and_metadata(
622
+ path_to_video,
623
+ path_to_transcript,
624
+ path_to_save_extracted_frames,
625
+ path_to_save_metadatas):
626
+
627
+ # metadatas will store the metadata of all extracted frames
628
+ metadatas = []
629
+
630
+ # load video using cv2
631
+ print(f"Loading video from {path_to_video}")
632
+ video = cv2.VideoCapture(path_to_video)
633
+ # load transcript using webvtt
634
+ print(f"Loading transcript from {path_to_transcript}")
635
+ trans = webvtt.read(path_to_transcript)
636
+
637
+ # iterate transcript file
638
+ # for each video segment specified in the transcript file
639
+ for idx, transcript in enumerate(trans):
640
+ # get the start time and end time in seconds
641
+ start_time_ms = str2time(transcript.start)
642
+ end_time_ms = str2time(transcript.end)
643
+ # get the time in ms exactly
644
+ # in the middle of start time and end time
645
+ mid_time_ms = (end_time_ms + start_time_ms) / 2
646
+ # get the transcript, remove the next-line symbol
647
+ text = transcript.text.replace("\n", ' ')
648
+ # get frame at the middle time
649
+ video.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
650
+ print(f"Extracting frame at {mid_time_ms} ms")
651
+ success, frame = video.read()
652
+ if success:
653
+ # if the frame is extracted successfully, resize it
654
+ image = maintain_aspect_ratio_resize(frame, height=350)
655
+ # save frame as JPEG file
656
+ img_fname = f'frame_{idx}.jpg'
657
+ img_fpath = osp.join(
658
+ path_to_save_extracted_frames, img_fname
659
+ )
660
+ cv2.imwrite(img_fpath, image)
661
+
662
+ # prepare the metadata
663
+ metadata = {
664
+ 'extracted_frame_path': img_fpath,
665
+ 'transcript': text,
666
+ 'video_segment_id': idx,
667
+ 'video_path': path_to_video,
668
+ 'mid_time_ms': mid_time_ms,
669
+ }
670
+ metadatas.append(metadata)
671
+
672
+ else:
673
+ print(f"ERROR! Cannot extract frame: idx = {idx}")
674
+
675
+ # save metadata of all extracted frames
676
+ fn = osp.join(path_to_save_metadatas, 'metadatas.json')
677
+ with open(fn, 'w') as outfile:
678
+ json.dump(metadatas, outfile)
679
+ return metadatas
680
+
681
+ def extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath):
682
+ # output paths to save extracted frames and their metadata
683
+ extracted_frames_path = osp.join(vid_dir, 'extracted_frame')
684
+ metadatas_path = vid_dir
685
+
686
+ # create these output folders if not existing
687
+ print(f"Creating folders {extracted_frames_path} and {metadatas_path}")
688
+ Path(extracted_frames_path).mkdir(parents=True, exist_ok=True)
689
+ Path(metadatas_path).mkdir(parents=True, exist_ok=True)
690
+ print("Extracting frames the video path ", vid_filepath)
691
+
692
+ # call the function to extract frames and metadatas
693
+ metadatas = extract_and_save_frames_and_metadata(
694
+ vid_filepath,
695
+ vid_transcript_filepath,
696
+ extracted_frames_path,
697
+ metadatas_path,
698
+ )
699
+ return metadatas
700
+
701
+ # function extract_and_save_frames_and_metadata_with_fps
702
+ # receives as input a video
703
+ # does extracting and saving frames and their metadatas
704
+ # returns the extracted metadatas
705
+ def extract_and_save_frames_and_metadata_with_fps(
706
+ lvlm_prompt,
707
+ path_to_video,
708
+ path_to_save_extracted_frames,
709
+ path_to_save_metadatas,
710
+ num_of_extracted_frames_per_second=1):
711
+
712
+ # metadatas will store the metadata of all extracted frames
713
+ metadatas = []
714
+
715
+ # load video using cv2
716
+ video = cv2.VideoCapture(path_to_video)
717
+
718
+ # Get the frames per second
719
+ fps = video.get(cv2.CAP_PROP_FPS)
720
+ # Get hop = the number of frames pass before a frame is extracted
721
+ hop = round(fps / num_of_extracted_frames_per_second)
722
+ curr_frame = 0
723
+ idx = -1
724
+ while(True):
725
+ # iterate all frames
726
+ ret, frame = video.read()
727
+ if not ret:
728
+ break
729
+ if curr_frame % hop == 0:
730
+ idx = idx + 1
731
+
732
+ # if the frame is extracted successfully, resize it
733
+ image = maintain_aspect_ratio_resize(frame, height=350)
734
+ # save frame as JPEG file
735
+ img_fname = f'frame_{idx}.jpg'
736
+ img_fpath = osp.join(
737
+ path_to_save_extracted_frames,
738
+ img_fname
739
+ )
740
+ cv2.imwrite(img_fpath, image)
741
+
742
+ # generate caption using lvlm_inference
743
+ b64_image = encode_image(img_fpath)
744
+ caption = lvlm_inference(lvlm_prompt, b64_image)
745
+
746
+ # prepare the metadata
747
+ metadata = {
748
+ 'extracted_frame_path': img_fpath,
749
+ 'transcript': caption,
750
+ 'video_segment_id': idx,
751
+ 'video_path': path_to_video,
752
+ }
753
+ metadatas.append(metadata)
754
+ curr_frame += 1
755
+
756
+ # save metadata of all extracted frames
757
+ metadatas_path = osp.join(path_to_save_metadatas,'metadatas.json')
758
+ with open(metadatas_path, 'w') as outfile:
759
+ json.dump(metadatas, outfile)
760
+ return metadatas
761
+
762
+ if __name__ == "__main__":
763
+ res = lvlm_inference_with_phi("Tell me a story")
764
  print(res)