Upload folder using huggingface_hub
Browse files- .gitattributes +36 -36
- .gitignore +16 -12
- Dockerfile +17 -0
- README.md +182 -182
- app.py +385 -376
- gradio_utils.py +483 -483
- mm_rag/MLM/client.py +134 -134
- mm_rag/MLM/lvlm.py +300 -300
- mm_rag/embeddings/bridgetower_embeddings.py +88 -88
- mm_rag/vectorstores/multimodal_lancedb.py +130 -130
- requirements.txt +25 -25
- s6_prepare_video_input.py +89 -89
- s7_store_in_rag.py +104 -104
- shared_data/videos/yt_video/blackholes101nationalgeographic/blackholes101nationalgeographic.mp4 +2 -2
- shared_data/videos/yt_video/blackholes101nationalgeographic/captions.vtt +104 -104
- utility.py +763 -763
.gitattributes
CHANGED
@@ -1,36 +1,36 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
-
myenv
|
2 |
-
__pycache__
|
3 |
-
.gradio
|
4 |
-
.venv
|
5 |
-
.env
|
6 |
-
.env.*
|
7 |
-
!.env.example
|
8 |
-
.github
|
9 |
-
# LanceDB files
|
10 |
-
shared_data/.lancedb/
|
11 |
-
shared_data/.lancedb/**/*
|
12 |
-
shared_data/videos/yt_video/blackholes101nationalgeographic/audio.mp3
|
|
|
|
|
|
|
|
|
|
1 |
+
myenv
|
2 |
+
__pycache__
|
3 |
+
.gradio
|
4 |
+
.venv
|
5 |
+
.env
|
6 |
+
.env.*
|
7 |
+
!.env.example
|
8 |
+
.github
|
9 |
+
# LanceDB files
|
10 |
+
shared_data/.lancedb/
|
11 |
+
shared_data/.lancedb/**/*
|
12 |
+
shared_data/videos/yt_video/blackholes101nationalgeographic/audio.mp3
|
13 |
+
mm_rag/embeddings/__pycache__/
|
14 |
+
mm_rag/embeddings/__pycache__/**
|
15 |
+
.DS_Store
|
16 |
+
.devcontainer/
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM mcr.microsoft.com/devcontainers/python:3.11
|
2 |
+
|
3 |
+
# Install system dependencies
|
4 |
+
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
5 |
+
&& apt-get -y install --no-install-recommends \
|
6 |
+
ffmpeg \
|
7 |
+
&& apt-get clean -y && rm -rf /var/lib/apt/lists/*
|
8 |
+
|
9 |
+
|
10 |
+
# Install the required system libraries for OpenCV
|
11 |
+
RUN apt-get update && apt-get install -y libgl1-mesa-glx
|
12 |
+
|
13 |
+
# Install PyTorch and other dependencies
|
14 |
+
RUN pip install -r requirements.txt
|
15 |
+
|
16 |
+
# Run the application
|
17 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -1,182 +1,182 @@
|
|
1 |
-
---
|
2 |
-
title: multimodel-rag-chat-with-videos
|
3 |
-
app_file: app.py
|
4 |
-
sdk: gradio
|
5 |
-
sdk_version: 5.17.1
|
6 |
-
---
|
7 |
-
|
8 |
-
# Demo
|
9 |
-
## Sample Video
|
10 |
-
- https://www.youtube.com/watch?v=kOEDG3j1bjs
|
11 |
-
- https://www.youtube.com/watch?v=7Hcg-rLYwdM
|
12 |
-
## Questions
|
13 |
-
- Event Horizon
|
14 |
-
- show me a group of astronauts, AStronaut name
|
15 |
-
# ReArchitecture Multimodal RAG System Pipeline Journey
|
16 |
-
I ported it locally and isolated each concept into a step as Python runnable
|
17 |
-
It is simplified, refactored and bug-fixed now.
|
18 |
-
I migrated from Prediction Guard to HuggingFace.
|
19 |
-
|
20 |
-
[**Interactive Video Chat Demo and Multimodal RAG System Architecture**](https://learn.deeplearning.ai/courses/multimodal-rag-chat-with-videos/lesson/2/interactive-demo-and-multimodal-rag-system-architecture)
|
21 |
-
|
22 |
-
### A multimodal AI system should be able to understand both text and video content.
|
23 |
-
|
24 |
-
## Setup
|
25 |
-
```bash
|
26 |
-
python -m venv venv
|
27 |
-
source venv/bin/activate
|
28 |
-
```
|
29 |
-
For Fish
|
30 |
-
```bash
|
31 |
-
source venv/bin/activate.fish
|
32 |
-
```
|
33 |
-
|
34 |
-
## Step 1 - Learn Gradio (UI) (30 mins)
|
35 |
-
|
36 |
-
Gradio is a powerful Python library for quickly building browser-based UIs. It supports hot reloading for fast development.
|
37 |
-
|
38 |
-
### Key Concepts:
|
39 |
-
- **fn**: The function wrapped by the UI.
|
40 |
-
- **inputs**: The Gradio components used for input (should match function arguments).
|
41 |
-
- **outputs**: The Gradio components used for output (should match return values).
|
42 |
-
|
43 |
-
📖 [**Gradio Documentation**](https://www.gradio.app/docs/gradio/introduction)
|
44 |
-
|
45 |
-
Gradio includes **30+ built-in components**.
|
46 |
-
|
47 |
-
💡 **Tip**: For `inputs` and `outputs`, you can pass either:
|
48 |
-
- The **component name** as a string (e.g., `"textbox"`)
|
49 |
-
- An **instance of the component class** (e.g., `gr.Textbox()`)
|
50 |
-
|
51 |
-
### Sharing Your Demo
|
52 |
-
```python
|
53 |
-
demo.launch(share=True) # Share your demo with just one extra parameter.
|
54 |
-
```
|
55 |
-
|
56 |
-
## Gradio Advanced Features
|
57 |
-
|
58 |
-
### **Gradio.Blocks**
|
59 |
-
Gradio provides `gr.Blocks`, a flexible way to design web apps with **custom layouts and complex interactions**:
|
60 |
-
- Arrange components freely on the page.
|
61 |
-
- Handle multiple data flows.
|
62 |
-
- Use outputs as inputs for other components.
|
63 |
-
- Dynamically update components based on user interaction.
|
64 |
-
|
65 |
-
### **Gradio.ChatInterface**
|
66 |
-
- Always set `type="messages"` in `gr.ChatInterface`.
|
67 |
-
- The default (`type="tuples"`) is **deprecated** and will be removed in future versions.
|
68 |
-
- For more UI flexibility, use `gr.ChatBot`.
|
69 |
-
- `gr.ChatInterface` supports **Markdown** (not tested yet).
|
70 |
-
|
71 |
-
---
|
72 |
-
|
73 |
-
## Step 2 - Learn Bridge Tower Embedding Model (Multimodal Learning) (15 mins)
|
74 |
-
|
75 |
-
Developed in collaboration with Intel, this model maps image-caption pairs into **512-dimensional vectors**.
|
76 |
-
|
77 |
-
### Measuring Similarity
|
78 |
-
- **Cosine Similarity** → Measures how close images are in vector space (**efficient & commonly used**).
|
79 |
-
- **Euclidean Distance** → Uses `cv2.NORM_L2` to compute similarity between two images.
|
80 |
-
|
81 |
-
### Converting to 2D for Visualization
|
82 |
-
- **UMAP** reduces 512D embeddings to **2D for display purposes**.
|
83 |
-
|
84 |
-
## Preprocessing Videos for Multimodal RAG
|
85 |
-
|
86 |
-
### **Case 1: WEBVTT → Extracting Text Segments from Video**
|
87 |
-
- Converts video + text into structured metadata.
|
88 |
-
- Splits content inhttps://www.youtube.com/watch?v=kOEDG3j1bjsto multiple segments.
|
89 |
-
|
90 |
-
### **Case 2: Whisper (Small) → Video Only**
|
91 |
-
- Extracts **audio** → `model.transcribe()`.
|
92 |
-
- Applies `getSubs()` helper function to retrieve **WEBVTT** subtitles.
|
93 |
-
- Uses **Case 1** processing.
|
94 |
-
|
95 |
-
### **Case 3: LvLM → Video + Silent/Music Extraction**
|
96 |
-
- Uses **Llava (LvLM model)** for **frame-based captioning**.
|
97 |
-
- Encodes each frame as a **Base64 image**.
|
98 |
-
- Extracts context and captions from video frames.
|
99 |
-
- Uses **Case 1** processing.
|
100 |
-
|
101 |
-
# Step 4 - What is LLaVA?
|
102 |
-
LLaVA (Large Language-and-Vision Assistant), a large multimodal model that connects a vision encoder that doesn't just see images but understands them, reads the text embedded in them, and reasons about their context—all.
|
103 |
-
|
104 |
-
# Step 5 - what is a vector Store?
|
105 |
-
A vector store is a specialized database designed to:
|
106 |
-
|
107 |
-
- Store and manage high-dimensional vector data efficiently
|
108 |
-
- Perform similarity-based searches where K=1 returns the most similar result
|
109 |
-
|
110 |
-
- In LanceDB specifically, store multiple data types:
|
111 |
-
. Text content (captions)
|
112 |
-
. Image file paths
|
113 |
-
. Metadata
|
114 |
-
. Vector embeddings
|
115 |
-
|
116 |
-
```python
|
117 |
-
_ = MultimodalLanceDB.from_text_image_pairs(
|
118 |
-
texts=updated_vid1_trans+vid2_trans,
|
119 |
-
image_paths=vid1_img_path+vid2_img_path,
|
120 |
-
embedding=BridgeTowerEmbeddings(),
|
121 |
-
metadatas=vid1_metadata+vid2_metadata,
|
122 |
-
connection=db,
|
123 |
-
table_name=TBL_NAME,
|
124 |
-
mode="overwrite",
|
125 |
-
)
|
126 |
-
```
|
127 |
-
# Gotchas and Solutions
|
128 |
-
Image Processing: When working with base64 encoded images, convert them to PIL.Image format before processing with BridgeTower
|
129 |
-
Model Selection: Using BridgeTowerForContrastiveLearning instead of PredictionGuard due to API access limitations
|
130 |
-
Model Size: BridgeTower model requires ~3.5GB download
|
131 |
-
Image Downloads: Some Flickr images may be unavailable; implement robust error handling
|
132 |
-
Token Decoding: BridgeTower contrastive learning model works with embeddings, not token predictions
|
133 |
-
Install from git+https://github.com/openai/whisper.git
|
134 |
-
|
135 |
-
# Install ffmepg using brew
|
136 |
-
```bash
|
137 |
-
brew install ffmpeg
|
138 |
-
brew link ffmpeg
|
139 |
-
```
|
140 |
-
|
141 |
-
|
142 |
-
# Learning and Skills
|
143 |
-
|
144 |
-
## Technical Skills:
|
145 |
-
|
146 |
-
Basic Machine learning and deep learning
|
147 |
-
Vector embeddings and similarity search
|
148 |
-
Multimodal data processing
|
149 |
-
|
150 |
-
## Framework & Library Expertise:
|
151 |
-
|
152 |
-
Hugging Face Transformers
|
153 |
-
Gradio UI development
|
154 |
-
LangChain integration (Basic)
|
155 |
-
PyTorch basics
|
156 |
-
LanceDB vector storage
|
157 |
-
|
158 |
-
## AI/ML Concepts:
|
159 |
-
|
160 |
-
Multimodal RAG system architecture
|
161 |
-
Vector embeddings and similarity search
|
162 |
-
Large Language Models (LLaVA)
|
163 |
-
Image-text pair processing
|
164 |
-
Dimensionality reduction techniques
|
165 |
-
|
166 |
-
|
167 |
-
## Multimedia Processing:
|
168 |
-
|
169 |
-
Video frame extraction
|
170 |
-
Audio transcription (Whisper)
|
171 |
-
Image processing (PIL)
|
172 |
-
Base64 encoding/decoding
|
173 |
-
WebVTT handling
|
174 |
-
|
175 |
-
## System Design:
|
176 |
-
|
177 |
-
Client-server architecture
|
178 |
-
API endpoint design
|
179 |
-
Data pipeline construction
|
180 |
-
Vector store implementation
|
181 |
-
Multimodal system integration
|
182 |
-
|
|
|
1 |
+
---
|
2 |
+
title: multimodel-rag-chat-with-videos
|
3 |
+
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.17.1
|
6 |
+
---
|
7 |
+
|
8 |
+
# Demo
|
9 |
+
## Sample Video
|
10 |
+
- https://www.youtube.com/watch?v=kOEDG3j1bjs
|
11 |
+
- https://www.youtube.com/watch?v=7Hcg-rLYwdM
|
12 |
+
## Questions
|
13 |
+
- Event Horizon
|
14 |
+
- show me a group of astronauts, AStronaut name
|
15 |
+
# ReArchitecture Multimodal RAG System Pipeline Journey
|
16 |
+
I ported it locally and isolated each concept into a step as Python runnable
|
17 |
+
It is simplified, refactored and bug-fixed now.
|
18 |
+
I migrated from Prediction Guard to HuggingFace.
|
19 |
+
|
20 |
+
[**Interactive Video Chat Demo and Multimodal RAG System Architecture**](https://learn.deeplearning.ai/courses/multimodal-rag-chat-with-videos/lesson/2/interactive-demo-and-multimodal-rag-system-architecture)
|
21 |
+
|
22 |
+
### A multimodal AI system should be able to understand both text and video content.
|
23 |
+
|
24 |
+
## Setup
|
25 |
+
```bash
|
26 |
+
python -m venv venv
|
27 |
+
source venv/bin/activate
|
28 |
+
```
|
29 |
+
For Fish
|
30 |
+
```bash
|
31 |
+
source venv/bin/activate.fish
|
32 |
+
```
|
33 |
+
|
34 |
+
## Step 1 - Learn Gradio (UI) (30 mins)
|
35 |
+
|
36 |
+
Gradio is a powerful Python library for quickly building browser-based UIs. It supports hot reloading for fast development.
|
37 |
+
|
38 |
+
### Key Concepts:
|
39 |
+
- **fn**: The function wrapped by the UI.
|
40 |
+
- **inputs**: The Gradio components used for input (should match function arguments).
|
41 |
+
- **outputs**: The Gradio components used for output (should match return values).
|
42 |
+
|
43 |
+
📖 [**Gradio Documentation**](https://www.gradio.app/docs/gradio/introduction)
|
44 |
+
|
45 |
+
Gradio includes **30+ built-in components**.
|
46 |
+
|
47 |
+
💡 **Tip**: For `inputs` and `outputs`, you can pass either:
|
48 |
+
- The **component name** as a string (e.g., `"textbox"`)
|
49 |
+
- An **instance of the component class** (e.g., `gr.Textbox()`)
|
50 |
+
|
51 |
+
### Sharing Your Demo
|
52 |
+
```python
|
53 |
+
demo.launch(share=True) # Share your demo with just one extra parameter.
|
54 |
+
```
|
55 |
+
|
56 |
+
## Gradio Advanced Features
|
57 |
+
|
58 |
+
### **Gradio.Blocks**
|
59 |
+
Gradio provides `gr.Blocks`, a flexible way to design web apps with **custom layouts and complex interactions**:
|
60 |
+
- Arrange components freely on the page.
|
61 |
+
- Handle multiple data flows.
|
62 |
+
- Use outputs as inputs for other components.
|
63 |
+
- Dynamically update components based on user interaction.
|
64 |
+
|
65 |
+
### **Gradio.ChatInterface**
|
66 |
+
- Always set `type="messages"` in `gr.ChatInterface`.
|
67 |
+
- The default (`type="tuples"`) is **deprecated** and will be removed in future versions.
|
68 |
+
- For more UI flexibility, use `gr.ChatBot`.
|
69 |
+
- `gr.ChatInterface` supports **Markdown** (not tested yet).
|
70 |
+
|
71 |
+
---
|
72 |
+
|
73 |
+
## Step 2 - Learn Bridge Tower Embedding Model (Multimodal Learning) (15 mins)
|
74 |
+
|
75 |
+
Developed in collaboration with Intel, this model maps image-caption pairs into **512-dimensional vectors**.
|
76 |
+
|
77 |
+
### Measuring Similarity
|
78 |
+
- **Cosine Similarity** → Measures how close images are in vector space (**efficient & commonly used**).
|
79 |
+
- **Euclidean Distance** → Uses `cv2.NORM_L2` to compute similarity between two images.
|
80 |
+
|
81 |
+
### Converting to 2D for Visualization
|
82 |
+
- **UMAP** reduces 512D embeddings to **2D for display purposes**.
|
83 |
+
|
84 |
+
## Preprocessing Videos for Multimodal RAG
|
85 |
+
|
86 |
+
### **Case 1: WEBVTT → Extracting Text Segments from Video**
|
87 |
+
- Converts video + text into structured metadata.
|
88 |
+
- Splits content inhttps://www.youtube.com/watch?v=kOEDG3j1bjsto multiple segments.
|
89 |
+
|
90 |
+
### **Case 2: Whisper (Small) → Video Only**
|
91 |
+
- Extracts **audio** → `model.transcribe()`.
|
92 |
+
- Applies `getSubs()` helper function to retrieve **WEBVTT** subtitles.
|
93 |
+
- Uses **Case 1** processing.
|
94 |
+
|
95 |
+
### **Case 3: LvLM → Video + Silent/Music Extraction**
|
96 |
+
- Uses **Llava (LvLM model)** for **frame-based captioning**.
|
97 |
+
- Encodes each frame as a **Base64 image**.
|
98 |
+
- Extracts context and captions from video frames.
|
99 |
+
- Uses **Case 1** processing.
|
100 |
+
|
101 |
+
# Step 4 - What is LLaVA?
|
102 |
+
LLaVA (Large Language-and-Vision Assistant), a large multimodal model that connects a vision encoder that doesn't just see images but understands them, reads the text embedded in them, and reasons about their context—all.
|
103 |
+
|
104 |
+
# Step 5 - what is a vector Store?
|
105 |
+
A vector store is a specialized database designed to:
|
106 |
+
|
107 |
+
- Store and manage high-dimensional vector data efficiently
|
108 |
+
- Perform similarity-based searches where K=1 returns the most similar result
|
109 |
+
|
110 |
+
- In LanceDB specifically, store multiple data types:
|
111 |
+
. Text content (captions)
|
112 |
+
. Image file paths
|
113 |
+
. Metadata
|
114 |
+
. Vector embeddings
|
115 |
+
|
116 |
+
```python
|
117 |
+
_ = MultimodalLanceDB.from_text_image_pairs(
|
118 |
+
texts=updated_vid1_trans+vid2_trans,
|
119 |
+
image_paths=vid1_img_path+vid2_img_path,
|
120 |
+
embedding=BridgeTowerEmbeddings(),
|
121 |
+
metadatas=vid1_metadata+vid2_metadata,
|
122 |
+
connection=db,
|
123 |
+
table_name=TBL_NAME,
|
124 |
+
mode="overwrite",
|
125 |
+
)
|
126 |
+
```
|
127 |
+
# Gotchas and Solutions
|
128 |
+
Image Processing: When working with base64 encoded images, convert them to PIL.Image format before processing with BridgeTower
|
129 |
+
Model Selection: Using BridgeTowerForContrastiveLearning instead of PredictionGuard due to API access limitations
|
130 |
+
Model Size: BridgeTower model requires ~3.5GB download
|
131 |
+
Image Downloads: Some Flickr images may be unavailable; implement robust error handling
|
132 |
+
Token Decoding: BridgeTower contrastive learning model works with embeddings, not token predictions
|
133 |
+
Install from git+https://github.com/openai/whisper.git
|
134 |
+
|
135 |
+
# Install ffmepg using brew
|
136 |
+
```bash
|
137 |
+
brew install ffmpeg
|
138 |
+
brew link ffmpeg
|
139 |
+
```
|
140 |
+
|
141 |
+
|
142 |
+
# Learning and Skills
|
143 |
+
|
144 |
+
## Technical Skills:
|
145 |
+
|
146 |
+
Basic Machine learning and deep learning
|
147 |
+
Vector embeddings and similarity search
|
148 |
+
Multimodal data processing
|
149 |
+
|
150 |
+
## Framework & Library Expertise:
|
151 |
+
|
152 |
+
Hugging Face Transformers
|
153 |
+
Gradio UI development
|
154 |
+
LangChain integration (Basic)
|
155 |
+
PyTorch basics
|
156 |
+
LanceDB vector storage
|
157 |
+
|
158 |
+
## AI/ML Concepts:
|
159 |
+
|
160 |
+
Multimodal RAG system architecture
|
161 |
+
Vector embeddings and similarity search
|
162 |
+
Large Language Models (LLaVA)
|
163 |
+
Image-text pair processing
|
164 |
+
Dimensionality reduction techniques
|
165 |
+
|
166 |
+
|
167 |
+
## Multimedia Processing:
|
168 |
+
|
169 |
+
Video frame extraction
|
170 |
+
Audio transcription (Whisper)
|
171 |
+
Image processing (PIL)
|
172 |
+
Base64 encoding/decoding
|
173 |
+
WebVTT handling
|
174 |
+
|
175 |
+
## System Design:
|
176 |
+
|
177 |
+
Client-server architecture
|
178 |
+
API endpoint design
|
179 |
+
Data pipeline construction
|
180 |
+
Vector store implementation
|
181 |
+
Multimodal system integration
|
182 |
+
|
app.py
CHANGED
@@ -1,376 +1,385 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
import gradio as gr
|
3 |
-
import os
|
4 |
-
from PIL import Image
|
5 |
-
import ollama
|
6 |
-
from utility import download_video, get_transcript_vtt, extract_meta_data, lvlm_inference_with_phi, lvlm_inference_with_tiny_model, lvlm_inference_with_tiny_model
|
7 |
-
from mm_rag.embeddings.bridgetower_embeddings import (
|
8 |
-
BridgeTowerEmbeddings
|
9 |
-
)
|
10 |
-
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
11 |
-
import lancedb
|
12 |
-
import json
|
13 |
-
import os
|
14 |
-
from PIL import Image
|
15 |
-
from utility import load_json_file, display_retrieved_results
|
16 |
-
import pyarrow as pa
|
17 |
-
|
18 |
-
# declare host file
|
19 |
-
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
20 |
-
# declare table name
|
21 |
-
# initialize vectorstore
|
22 |
-
db = lancedb.connect(LANCEDB_HOST_FILE)
|
23 |
-
# initialize an BridgeTower embedder
|
24 |
-
embedder = BridgeTowerEmbeddings()
|
25 |
-
|
26 |
-
base_dir = "./shared_data/videos/yt_video"
|
27 |
-
Path(base_dir).mkdir(parents=True, exist_ok=True)
|
28 |
-
|
29 |
-
|
30 |
-
def open_table(table_name):
|
31 |
-
# open a connection to table TBL_NAME
|
32 |
-
tbl = db.open_table(table_name)
|
33 |
-
|
34 |
-
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
35 |
-
# display the first 3 rows of the table
|
36 |
-
tbl.to_pandas()[['text', 'image_path']].head(3)
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
# for video1, we pick n = 7
|
53 |
-
n = 7
|
54 |
-
updated_vid_subs = [
|
55 |
-
|
56 |
-
|
57 |
-
]
|
58 |
-
|
59 |
-
# also need to update the updated transcripts in metadata
|
60 |
-
for i in range(len(updated_vid_subs)):
|
61 |
-
vid_metadata[i]['transcript'] = updated_vid_subs[i]
|
62 |
-
|
63 |
-
# you can pass in mode="append"
|
64 |
-
# to add more entries to the vector store
|
65 |
-
# in case you want to start with a fresh vector store,
|
66 |
-
# you can pass in mode="overwrite" instead
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
def get_metadata_of_yt_video_with_captions(vid_url, from_gen=False):
|
84 |
-
vid_filepath, vid_folder_path, is_downloaded = download_video(
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
print(
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
)
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
image2 = gr.update(visible=
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
video = gr.Video(vid_filepath,render=True)
|
193 |
-
return url_input, submit_btn, video, vid_table_name, chatbox,submit_btn2, frame1, frame2, chatbox_llm, submit_btn_chat
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
with gr.Row():
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
with gr.Row():
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
)
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
gr.
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
)
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
fn=
|
347 |
-
inputs=[url_input],
|
348 |
-
outputs=[url_input, submit_btn, video, vid_table_name,
|
349 |
-
|
350 |
-
|
351 |
-
)
|
352 |
-
|
353 |
-
|
354 |
-
fn=
|
355 |
-
inputs=[
|
356 |
-
outputs=[
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import ollama
|
6 |
+
from utility import download_video, get_transcript_vtt, extract_meta_data, lvlm_inference_with_phi, lvlm_inference_with_tiny_model, lvlm_inference_with_tiny_model
|
7 |
+
from mm_rag.embeddings.bridgetower_embeddings import (
|
8 |
+
BridgeTowerEmbeddings
|
9 |
+
)
|
10 |
+
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
11 |
+
import lancedb
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
from PIL import Image
|
15 |
+
from utility import load_json_file, display_retrieved_results
|
16 |
+
import pyarrow as pa
|
17 |
+
|
18 |
+
# declare host file
|
19 |
+
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
20 |
+
# declare table name
|
21 |
+
# initialize vectorstore
|
22 |
+
db = lancedb.connect(LANCEDB_HOST_FILE)
|
23 |
+
# initialize an BridgeTower embedder
|
24 |
+
embedder = BridgeTowerEmbeddings()
|
25 |
+
video_processed = False
|
26 |
+
base_dir = "./shared_data/videos/yt_video"
|
27 |
+
Path(base_dir).mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
|
30 |
+
def open_table(table_name):
|
31 |
+
# open a connection to table TBL_NAME
|
32 |
+
tbl = db.open_table(table_name)
|
33 |
+
|
34 |
+
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
35 |
+
# display the first 3 rows of the table
|
36 |
+
tbl.to_pandas()[['text', 'image_path']].head(3)
|
37 |
+
|
38 |
+
|
39 |
+
def check_if_table_exists(table_name):
|
40 |
+
return table_name in db.table_names()
|
41 |
+
|
42 |
+
|
43 |
+
def store_in_rag(vid_table_name, vid_metadata_path):
|
44 |
+
|
45 |
+
# load metadata files
|
46 |
+
|
47 |
+
vid_metadata = load_json_file(vid_metadata_path)
|
48 |
+
|
49 |
+
vid_subs = [vid['transcript'] for vid in vid_metadata]
|
50 |
+
vid_img_path = [vid['extracted_frame_path'] for vid in vid_metadata]
|
51 |
+
|
52 |
+
# for video1, we pick n = 7
|
53 |
+
n = 7
|
54 |
+
updated_vid_subs = [
|
55 |
+
' '.join(vid_subs[i-int(n/2): i+int(n/2)]) if i-int(n/2) >= 0 else
|
56 |
+
' '.join(vid_subs[0: i + int(n/2)]) for i in range(len(vid_subs))
|
57 |
+
]
|
58 |
+
|
59 |
+
# also need to update the updated transcripts in metadata
|
60 |
+
for i in range(len(updated_vid_subs)):
|
61 |
+
vid_metadata[i]['transcript'] = updated_vid_subs[i]
|
62 |
+
|
63 |
+
# you can pass in mode="append"
|
64 |
+
# to add more entries to the vector store
|
65 |
+
# in case you want to start with a fresh vector store,
|
66 |
+
# you can pass in mode="overwrite" instead
|
67 |
+
|
68 |
+
print("Creating vid_table_name ", vid_table_name)
|
69 |
+
_ = MultimodalLanceDB.from_text_image_pairs(
|
70 |
+
texts=updated_vid_subs,
|
71 |
+
image_paths=vid_img_path,
|
72 |
+
embedding=embedder,
|
73 |
+
metadatas=vid_metadata,
|
74 |
+
connection=db,
|
75 |
+
table_name=vid_table_name,
|
76 |
+
mode="overwrite",
|
77 |
+
)
|
78 |
+
open_table(vid_table_name)
|
79 |
+
|
80 |
+
return vid_table_name
|
81 |
+
|
82 |
+
|
83 |
+
def get_metadata_of_yt_video_with_captions(vid_url, from_gen=False):
|
84 |
+
vid_filepath, vid_folder_path, is_downloaded = download_video(
|
85 |
+
vid_url, base_dir)
|
86 |
+
if is_downloaded:
|
87 |
+
print("Video downloaded at ", vid_filepath)
|
88 |
+
if from_gen:
|
89 |
+
# Delete existing caption and metadata files if they exist
|
90 |
+
caption_file = f"{vid_folder_path}/captions.vtt"
|
91 |
+
metadata_file = f"{vid_folder_path}/metadatas.json"
|
92 |
+
if os.path.exists(caption_file):
|
93 |
+
os.remove(caption_file)
|
94 |
+
print(f"Deleted existing caption file: {caption_file}")
|
95 |
+
if os.path.exists(metadata_file):
|
96 |
+
os.remove(metadata_file)
|
97 |
+
print(f"Deleted existing metadata file: {metadata_file}")
|
98 |
+
|
99 |
+
print("checking transcript")
|
100 |
+
vid_transcript_filepath = get_transcript_vtt(
|
101 |
+
vid_folder_path, vid_url, vid_filepath, from_gen)
|
102 |
+
vid_metadata_path = f"{vid_folder_path}/metadatas.json"
|
103 |
+
print("checking metadatas at", vid_metadata_path)
|
104 |
+
if os.path.exists(vid_metadata_path):
|
105 |
+
print('Metadatas already exists')
|
106 |
+
else:
|
107 |
+
print("Downloading metadatas for the video ", vid_filepath)
|
108 |
+
# should return lowercase file name without spaces
|
109 |
+
extract_meta_data(vid_folder_path, vid_filepath,
|
110 |
+
vid_transcript_filepath)
|
111 |
+
|
112 |
+
parent_dir_name = os.path.basename(os.path.dirname(vid_metadata_path))
|
113 |
+
vid_table_name = f"{parent_dir_name}_table"
|
114 |
+
print("Checking db and Table name ", vid_table_name)
|
115 |
+
if not check_if_table_exists(vid_table_name):
|
116 |
+
print("Table does not exists Storing in RAG")
|
117 |
+
else:
|
118 |
+
print("Table exists")
|
119 |
+
|
120 |
+
def delete_table(table_name):
|
121 |
+
db.drop_table(table_name)
|
122 |
+
print(f"Deleted table {table_name}")
|
123 |
+
delete_table(vid_table_name)
|
124 |
+
|
125 |
+
store_in_rag(vid_table_name, vid_metadata_path)
|
126 |
+
return vid_filepath, vid_table_name
|
127 |
+
|
128 |
+
|
129 |
+
def return_top_k_most_similar_docs(vid_table_name, query, use_llm=False):
|
130 |
+
if not video_processed:
|
131 |
+
gr.Error("Please process the video first in Step 1")
|
132 |
+
# Initialize results variable outside the if condition
|
133 |
+
max_docs = 2
|
134 |
+
print("Querying ", vid_table_name)
|
135 |
+
vectorstore = MultimodalLanceDB(
|
136 |
+
uri=LANCEDB_HOST_FILE,
|
137 |
+
embedding=embedder,
|
138 |
+
table_name=vid_table_name
|
139 |
+
)
|
140 |
+
|
141 |
+
retriever = vectorstore.as_retriever(
|
142 |
+
search_type='similarity',
|
143 |
+
search_kwargs={"k": max_docs}
|
144 |
+
)
|
145 |
+
|
146 |
+
# Get results first
|
147 |
+
results = retriever.invoke(query)
|
148 |
+
|
149 |
+
if use_llm:
|
150 |
+
# Read captions.vtt file
|
151 |
+
def read_vtt_file(file_path):
|
152 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
153 |
+
return f.read()
|
154 |
+
|
155 |
+
vid_table_name = vid_table_name.split('_table')[0]
|
156 |
+
caption_file = 'shared_data/videos/yt_video/' + vid_table_name + '/captions.vtt'
|
157 |
+
print("Caption file path ", caption_file)
|
158 |
+
captions = read_vtt_file(caption_file)
|
159 |
+
prompt = "Answer this query : " + query + " from the content " + captions
|
160 |
+
print("Prompt ", prompt)
|
161 |
+
all_page_content = lvlm_inference_with_phi(prompt)
|
162 |
+
else:
|
163 |
+
all_page_content = "\n\n".join(
|
164 |
+
[result.page_content for result in results])
|
165 |
+
|
166 |
+
page_content = gr.Textbox(all_page_content, label="Response",
|
167 |
+
elem_id='chat-response', visible=True, interactive=False)
|
168 |
+
image1 = Image.open(results[0].metadata['extracted_frame_path'])
|
169 |
+
image2_path = results[1].metadata['extracted_frame_path']
|
170 |
+
|
171 |
+
if results[0].metadata['extracted_frame_path'] == image2_path:
|
172 |
+
image2 = gr.update(visible=False)
|
173 |
+
else:
|
174 |
+
image2 = Image.open(image2_path)
|
175 |
+
image2 = gr.update(value=image2, visible=True)
|
176 |
+
|
177 |
+
return page_content, image1, image2
|
178 |
+
|
179 |
+
|
180 |
+
def process_url_and_init(youtube_url, from_gen=False):
|
181 |
+
video_processed = True
|
182 |
+
url_input = gr.update(visible=False)
|
183 |
+
submit_btn = gr.update(visible=True)
|
184 |
+
chatbox = gr.update(visible=True)
|
185 |
+
submit_btn2 = gr.update(visible=True)
|
186 |
+
frame1 = gr.update(visible=True)
|
187 |
+
frame2 = gr.update(visible=False)
|
188 |
+
chatbox_llm, submit_btn_chat = gr.update(
|
189 |
+
visible=True), gr.update(visible=True)
|
190 |
+
vid_filepath, vid_table_name = get_metadata_of_yt_video_with_captions(
|
191 |
+
youtube_url, from_gen)
|
192 |
+
video = gr.Video(vid_filepath, render=True)
|
193 |
+
return url_input, submit_btn, video, vid_table_name, chatbox, submit_btn2, frame1, frame2, chatbox_llm, submit_btn_chat
|
194 |
+
|
195 |
+
|
196 |
+
def test_btn():
|
197 |
+
text = "hi"
|
198 |
+
res = lvlm_inference_with_phi(text)
|
199 |
+
response = gr.Textbox(res, visible=True, interactive=False)
|
200 |
+
return response
|
201 |
+
|
202 |
+
|
203 |
+
def init_ui():
|
204 |
+
with gr.Blocks() as demo:
|
205 |
+
|
206 |
+
gr.Markdown("Welcome to video chat demo - Initial processing can take up to 2 minutes, and responses may be slow. Please be patient and avoid clicking repeatedly.")
|
207 |
+
url_input = gr.Textbox(label="Enter YouTube URL", visible=False, elem_id='url-inp',
|
208 |
+
value="https://www.youtube.com/watch?v=kOEDG3j1bjs", interactive=True)
|
209 |
+
vid_table_name = gr.Textbox(
|
210 |
+
label="Enter Table Name", visible=False, interactive=False)
|
211 |
+
video = gr.Video()
|
212 |
+
with gr.Row():
|
213 |
+
submit_btn = gr.Button("Process Video By Download Subtitles")
|
214 |
+
submit_btn_gen = gr.Button("Process Video By Generating Subtitles")
|
215 |
+
|
216 |
+
with gr.Row():
|
217 |
+
chatbox = gr.Textbox(label="Enter the keyword/s and AI will get related captions and images",
|
218 |
+
visible=False, value="event horizan", scale=4)
|
219 |
+
submit_btn_whisper = gr.Button(
|
220 |
+
"Submit", elem_id='chat-submit', visible=False, scale=1)
|
221 |
+
with gr.Row():
|
222 |
+
chatbox_llm = gr.Textbox(
|
223 |
+
label="Ask a Question", visible=False, value="what this video is about?", scale=4)
|
224 |
+
submit_btn_chat = gr.Button("Ask", visible=False, scale=1)
|
225 |
+
|
226 |
+
response = gr.Textbox(
|
227 |
+
label="Response", elem_id='chat-response', visible=False, interactive=False)
|
228 |
+
|
229 |
+
with gr.Row():
|
230 |
+
frame1 = gr.Image(visible=False, interactive=False, scale=2)
|
231 |
+
frame2 = gr.Image(visible=False, interactive=False, scale=2)
|
232 |
+
submit_btn.click(fn=process_url_and_init, inputs=[url_input], outputs=[
|
233 |
+
url_input, submit_btn, video, vid_table_name, chatbox, submit_btn_whisper, frame1, frame2, chatbox_llm, submit_btn_chat])
|
234 |
+
submit_btn_gen.click(fn=lambda x: process_url_and_init(x, from_gen=True), inputs=[url_input], outputs=[
|
235 |
+
url_input, submit_btn, video, vid_table_name, chatbox, submit_btn_whisper, frame1, frame2, chatbox_llm, submit_btn_chat])
|
236 |
+
submit_btn_whisper.click(fn=return_top_k_most_similar_docs, inputs=[
|
237 |
+
vid_table_name, chatbox], outputs=[response, frame1, frame2])
|
238 |
+
|
239 |
+
submit_btn_chat.click(
|
240 |
+
fn=lambda table_name, query: return_top_k_most_similar_docs(
|
241 |
+
vid_table_name=table_name,
|
242 |
+
query=query,
|
243 |
+
use_llm=True
|
244 |
+
),
|
245 |
+
inputs=[vid_table_name, chatbox_llm],
|
246 |
+
outputs=[response, frame1, frame2]
|
247 |
+
)
|
248 |
+
reset_btn = gr.Button("Reload Page")
|
249 |
+
reset_btn.click(None, js="() => { location.reload(); }")
|
250 |
+
|
251 |
+
test_llama = gr.Button("Test Llama")
|
252 |
+
test_llama.click(test_btn, None, outputs=[response])
|
253 |
+
return demo
|
254 |
+
|
255 |
+
|
256 |
+
def init_improved_ui():
|
257 |
+
|
258 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
259 |
+
# Header Section with Introduction
|
260 |
+
with gr.Accordion(label=" # 🎬 Video Analysis Assistant", open=True):
|
261 |
+
gr.Markdown("""
|
262 |
+
## How it Works:
|
263 |
+
1. 📥 Provide a YouTube URL.
|
264 |
+
2. 🔄 Choose a processing method:
|
265 |
+
- Download the video and its captions/subtitles from YouTube.
|
266 |
+
- Download the video and generate captions using Whisper AI.
|
267 |
+
The system will load the video in video player for preview and process the video and extract frames from it.
|
268 |
+
It will then pass the captions and images to the RAG model to store them in the database.
|
269 |
+
The RAG (Lance DB) uses a pre-trained BridgeTower model to generate embeddings that provide pairs of captions and related images.
|
270 |
+
3. 🤖 Analyze video content through:
|
271 |
+
- Keyword Search - Use this functionality to search for keywords in the video. Our RAG model will return the most relevant captions and images.
|
272 |
+
- AI-powered Q&A - Use this functionality to ask questions about the video content. Our system will use the Meta/LLaMA model to analyze the captions and images and provide detailed answers.
|
273 |
+
4. 📊 Results will be displayed in the response section with related images.
|
274 |
+
|
275 |
+
> **Note**: Initial processing takes several minutes. Please be patient and monitor the logs for progress updates.
|
276 |
+
""")
|
277 |
+
|
278 |
+
# Video Input Section
|
279 |
+
with gr.Group():
|
280 |
+
url_input = gr.Textbox(
|
281 |
+
label="YouTube URL",
|
282 |
+
value="https://www.youtube.com/watch?v=kOEDG3j1bjs",
|
283 |
+
visible=True,
|
284 |
+
interactive=False
|
285 |
+
)
|
286 |
+
vid_table_name = gr.Textbox(label="Table Name", visible=False)
|
287 |
+
video = gr.Video(label="Video Preview")
|
288 |
+
|
289 |
+
with gr.Row():
|
290 |
+
submit_btn = gr.Button(
|
291 |
+
"📥 Step 1: Process with Existing Subtitles", variant="primary", size='md')
|
292 |
+
submit_btn_gen = gr.Button(
|
293 |
+
"🎯 Generate New Subtitles", variant="secondary", visible=False)
|
294 |
+
|
295 |
+
# Analysis Tools Section
|
296 |
+
with gr.Group():
|
297 |
+
gr.Markdown("### 🔍 Step 2: Chat AI about the video")
|
298 |
+
|
299 |
+
with gr.Row():
|
300 |
+
chatbox = gr.Textbox(
|
301 |
+
label="Step 2: Search Keywords",
|
302 |
+
value="event horizon, black holes, space",
|
303 |
+
visible=False
|
304 |
+
)
|
305 |
+
submit_btn_whisper = gr.Button(
|
306 |
+
"🔎 Search",
|
307 |
+
visible=False,
|
308 |
+
variant="primary"
|
309 |
+
)
|
310 |
+
|
311 |
+
with gr.Row():
|
312 |
+
chatbox_llm = gr.Textbox(
|
313 |
+
label="",
|
314 |
+
value="What is this video about?",
|
315 |
+
visible=True
|
316 |
+
)
|
317 |
+
submit_btn_chat = gr.Button(
|
318 |
+
"🤖 Ask",
|
319 |
+
visible=True,
|
320 |
+
scale=1
|
321 |
+
)
|
322 |
+
|
323 |
+
# Results Display Section
|
324 |
+
with gr.Group():
|
325 |
+
gr.Markdown("### 📊 AI Response")
|
326 |
+
response = gr.Textbox(
|
327 |
+
label="AI Response",
|
328 |
+
visible=True,
|
329 |
+
interactive=False
|
330 |
+
)
|
331 |
+
|
332 |
+
with gr.Row():
|
333 |
+
frame1 = gr.Image(
|
334 |
+
visible=False, label="Related Frame 1", scale=1)
|
335 |
+
frame2 = gr.Image(
|
336 |
+
visible=False, label="Related Frame 2", scale=2)
|
337 |
+
|
338 |
+
# Control Buttons
|
339 |
+
with gr.Row():
|
340 |
+
reset_btn = gr.Button("🔄 Start Over", variant="secondary")
|
341 |
+
test_llama = gr.Button("🧪 Say Hi to Llama",
|
342 |
+
visible=False, variant="secondary")
|
343 |
+
|
344 |
+
# Event Handlers
|
345 |
+
submit_btn.click(
|
346 |
+
fn=process_url_and_init,
|
347 |
+
inputs=[url_input],
|
348 |
+
outputs=[url_input, submit_btn, video, vid_table_name,
|
349 |
+
chatbox, submit_btn_whisper, frame1, frame2,
|
350 |
+
chatbox_llm, submit_btn_chat]
|
351 |
+
)
|
352 |
+
|
353 |
+
submit_btn_gen.click(
|
354 |
+
fn=lambda x: process_url_and_init(x, from_gen=True),
|
355 |
+
inputs=[url_input],
|
356 |
+
outputs=[url_input, submit_btn, video, vid_table_name,
|
357 |
+
chatbox, submit_btn_whisper, frame1, frame2,
|
358 |
+
chatbox_llm, submit_btn_chat]
|
359 |
+
)
|
360 |
+
|
361 |
+
submit_btn_whisper.click(
|
362 |
+
fn=return_top_k_most_similar_docs,
|
363 |
+
inputs=[vid_table_name, chatbox],
|
364 |
+
outputs=[response, frame1, frame2]
|
365 |
+
)
|
366 |
+
|
367 |
+
submit_btn_chat.click(
|
368 |
+
fn=lambda table_name, query: return_top_k_most_similar_docs(
|
369 |
+
vid_table_name=table_name,
|
370 |
+
query=query,
|
371 |
+
use_llm=True
|
372 |
+
),
|
373 |
+
inputs=[vid_table_name, chatbox_llm],
|
374 |
+
outputs=[response, frame1, frame2]
|
375 |
+
)
|
376 |
+
|
377 |
+
reset_btn.click(None, js="() => { location.reload(); }")
|
378 |
+
test_llama.click(test_btn, None, outputs=[response])
|
379 |
+
|
380 |
+
return demo
|
381 |
+
|
382 |
+
|
383 |
+
if __name__ == '__main__':
|
384 |
+
demo = init_improved_ui() # Updated function name here
|
385 |
+
demo.launch(share=True, debug=True)
|
gradio_utils.py
CHANGED
@@ -1,483 +1,483 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import io
|
3 |
-
import sys
|
4 |
-
import time
|
5 |
-
import dataclasses
|
6 |
-
from pathlib import Path
|
7 |
-
import os
|
8 |
-
from enum import auto, Enum
|
9 |
-
from typing import List, Tuple, Any
|
10 |
-
from utility import prediction_guard_llava_conv
|
11 |
-
import lancedb
|
12 |
-
from utility import load_json_file
|
13 |
-
from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings
|
14 |
-
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
15 |
-
from mm_rag.MLM.client import PredictionGuardClient
|
16 |
-
from mm_rag.MLM.lvlm import LVLM
|
17 |
-
from PIL import Image
|
18 |
-
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
|
19 |
-
from moviepy.video.io.VideoFileClip import VideoFileClip
|
20 |
-
from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation
|
21 |
-
|
22 |
-
server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
23 |
-
|
24 |
-
# function to split video at a timestamp
|
25 |
-
def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3):
|
26 |
-
timestamp_in_sec = int(timestamp_in_ms / 1000)
|
27 |
-
# create output_video_name folder if not exist:
|
28 |
-
Path(output_video_path).mkdir(parents=True, exist_ok=True)
|
29 |
-
output_video = os.path.join(output_video_path, output_video_name)
|
30 |
-
with VideoFileClip(video_path) as video:
|
31 |
-
duration = video.duration
|
32 |
-
start_time = max(timestamp_in_sec - play_before_sec, 0)
|
33 |
-
end_time = min(timestamp_in_sec + play_after_sec, duration)
|
34 |
-
new = video.subclip(start_time, end_time)
|
35 |
-
new.write_videofile(output_video, audio_codec='aac')
|
36 |
-
return output_video
|
37 |
-
|
38 |
-
|
39 |
-
prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
|
40 |
-
|
41 |
-
# define default rag_chain
|
42 |
-
def get_default_rag_chain():
|
43 |
-
# declare host file
|
44 |
-
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
45 |
-
# declare table name
|
46 |
-
TBL_NAME = "demo_tbl"
|
47 |
-
|
48 |
-
# initialize vectorstore
|
49 |
-
db = lancedb.connect(LANCEDB_HOST_FILE)
|
50 |
-
|
51 |
-
# initialize an BridgeTower embedder
|
52 |
-
embedder = BridgeTowerEmbeddings()
|
53 |
-
|
54 |
-
## Creating a LanceDB vector store
|
55 |
-
vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME)
|
56 |
-
### creating a retriever for the vector store
|
57 |
-
retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1})
|
58 |
-
|
59 |
-
# initialize a client as PredictionGuardClien
|
60 |
-
client = PredictionGuardClient()
|
61 |
-
# initialize LVLM with the given client
|
62 |
-
lvlm_inference_module = LVLM(client=client)
|
63 |
-
|
64 |
-
def prompt_processing(input):
|
65 |
-
# get the retrieved results and user's query
|
66 |
-
retrieved_results, user_query = input['retrieved_results'], input['user_query']
|
67 |
-
# get the first retrieved result by default
|
68 |
-
retrieved_result = retrieved_results[0]
|
69 |
-
# prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
|
70 |
-
|
71 |
-
# get all metadata of the retrieved video segment
|
72 |
-
metadata_retrieved_video_segment = retrieved_result.metadata['metadata']
|
73 |
-
|
74 |
-
# get the frame and the corresponding transcript, path to extracted frame, path to whole video, and time stamp of the retrieved video segment.
|
75 |
-
transcript = metadata_retrieved_video_segment['transcript']
|
76 |
-
frame_path = metadata_retrieved_video_segment['extracted_frame_path']
|
77 |
-
return {
|
78 |
-
'prompt': prompt_template.format(transcript=transcript, user_query=user_query),
|
79 |
-
'image' : frame_path,
|
80 |
-
'metadata' : metadata_retrieved_video_segment,
|
81 |
-
}
|
82 |
-
# initialize prompt processing module as a Langchain RunnableLambda of function prompt_processing
|
83 |
-
prompt_processing_module = RunnableLambda(prompt_processing)
|
84 |
-
|
85 |
-
# the output of this new chain will be a dictionary
|
86 |
-
mm_rag_chain_with_retrieved_image = (
|
87 |
-
RunnableParallel({"retrieved_results": retriever_module ,
|
88 |
-
"user_query": RunnablePassthrough()})
|
89 |
-
| prompt_processing_module
|
90 |
-
| RunnableParallel({'final_text_output': lvlm_inference_module,
|
91 |
-
'input_to_lvlm' : RunnablePassthrough()})
|
92 |
-
)
|
93 |
-
return mm_rag_chain_with_retrieved_image
|
94 |
-
|
95 |
-
class SeparatorStyle(Enum):
|
96 |
-
"""Different separator style."""
|
97 |
-
SINGLE = auto()
|
98 |
-
|
99 |
-
@dataclasses.dataclass
|
100 |
-
class GradioInstance:
|
101 |
-
"""A class that keeps all conversation history."""
|
102 |
-
system: str
|
103 |
-
roles: List[str]
|
104 |
-
messages: List[List[str]]
|
105 |
-
offset: int
|
106 |
-
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
107 |
-
sep: str = "\n"
|
108 |
-
sep2: str = None
|
109 |
-
version: str = "Unknown"
|
110 |
-
path_to_img: str = None
|
111 |
-
video_title: str = None
|
112 |
-
path_to_video: str = None
|
113 |
-
caption: str = None
|
114 |
-
mm_rag_chain: Any = None
|
115 |
-
|
116 |
-
skip_next: bool = False
|
117 |
-
|
118 |
-
def _template_caption(self):
|
119 |
-
out = ""
|
120 |
-
if self.caption is not None:
|
121 |
-
out = f"The caption associated with the image is '{self.caption}'. "
|
122 |
-
return out
|
123 |
-
|
124 |
-
def get_prompt_for_rag(self):
|
125 |
-
messages = self.messages
|
126 |
-
assert len(messages) == 2, "length of current conversation should be 2"
|
127 |
-
assert messages[1][1] is None, "the first response message of current conversation should be None"
|
128 |
-
ret = messages[0][1]
|
129 |
-
return ret
|
130 |
-
|
131 |
-
def get_conversation_for_lvlm(self):
|
132 |
-
pg_conv = prediction_guard_llava_conv.copy()
|
133 |
-
image_path = self.path_to_img
|
134 |
-
b64_img = encode_image(image_path)
|
135 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
136 |
-
if msg is None:
|
137 |
-
break
|
138 |
-
if i == 0:
|
139 |
-
pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img])
|
140 |
-
elif i == len(self.messages[self.offset:]) - 2:
|
141 |
-
pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)])
|
142 |
-
else:
|
143 |
-
pg_conv.append_message(role, [msg])
|
144 |
-
return pg_conv
|
145 |
-
|
146 |
-
def append_message(self, role, message):
|
147 |
-
self.messages.append([role, message])
|
148 |
-
|
149 |
-
def get_images(self, return_pil=False):
|
150 |
-
images = []
|
151 |
-
if self.path_to_img is not None:
|
152 |
-
path_to_image = self.path_to_img
|
153 |
-
images.append(path_to_image)
|
154 |
-
return images
|
155 |
-
|
156 |
-
def to_gradio_chatbot(self):
|
157 |
-
ret = []
|
158 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
159 |
-
if i % 2 == 0:
|
160 |
-
if type(msg) is tuple:
|
161 |
-
import base64
|
162 |
-
from io import BytesIO
|
163 |
-
msg, image, image_process_mode = msg
|
164 |
-
max_hw, min_hw = max(image.size), min(image.size)
|
165 |
-
aspect_ratio = max_hw / min_hw
|
166 |
-
max_len, min_len = 800, 400
|
167 |
-
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
168 |
-
longest_edge = int(shortest_edge * aspect_ratio)
|
169 |
-
W, H = image.size
|
170 |
-
if H > W:
|
171 |
-
H, W = longest_edge, shortest_edge
|
172 |
-
else:
|
173 |
-
H, W = shortest_edge, longest_edge
|
174 |
-
image = image.resize((W, H))
|
175 |
-
buffered = BytesIO()
|
176 |
-
image.save(buffered, format="JPEG")
|
177 |
-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
178 |
-
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
179 |
-
msg = img_str + msg.replace('<image>', '').strip()
|
180 |
-
ret.append([msg, None])
|
181 |
-
else:
|
182 |
-
ret.append([msg, None])
|
183 |
-
else:
|
184 |
-
ret[-1][-1] = msg
|
185 |
-
return ret
|
186 |
-
|
187 |
-
def copy(self):
|
188 |
-
return GradioInstance(
|
189 |
-
system=self.system,
|
190 |
-
roles=self.roles,
|
191 |
-
messages=[[x, y] for x, y in self.messages],
|
192 |
-
offset=self.offset,
|
193 |
-
sep_style=self.sep_style,
|
194 |
-
sep=self.sep,
|
195 |
-
sep2=self.sep2,
|
196 |
-
version=self.version,
|
197 |
-
mm_rag_chain=self.mm_rag_chain,
|
198 |
-
)
|
199 |
-
|
200 |
-
def dict(self):
|
201 |
-
return {
|
202 |
-
"system": self.system,
|
203 |
-
"roles": self.roles,
|
204 |
-
"messages": self.messages,
|
205 |
-
"offset": self.offset,
|
206 |
-
"sep": self.sep,
|
207 |
-
"sep2": self.sep2,
|
208 |
-
"path_to_img": self.path_to_img,
|
209 |
-
"video_title" : self.video_title,
|
210 |
-
"path_to_video": self.path_to_video,
|
211 |
-
"caption" : self.caption,
|
212 |
-
}
|
213 |
-
def get_path_to_subvideos(self):
|
214 |
-
if self.video_title is not None and self.path_to_img is not None:
|
215 |
-
info = video_helper_map[self.video_title]
|
216 |
-
path = info['path']
|
217 |
-
prefix = info['prefix']
|
218 |
-
vid_index = self.path_to_img.split('/')[-1]
|
219 |
-
vid_index = vid_index.split('_')[-1]
|
220 |
-
vid_index = vid_index.replace('.jpg', '')
|
221 |
-
ret = f"{prefix}{vid_index}.mp4"
|
222 |
-
ret = os.path.join(path, ret)
|
223 |
-
return ret
|
224 |
-
elif self.path_to_video is not None:
|
225 |
-
return self.path_to_video
|
226 |
-
return None
|
227 |
-
|
228 |
-
def get_gradio_instance(mm_rag_chain=None):
|
229 |
-
if mm_rag_chain is None:
|
230 |
-
mm_rag_chain = get_default_rag_chain()
|
231 |
-
|
232 |
-
instance = GradioInstance(
|
233 |
-
system="",
|
234 |
-
roles=prediction_guard_llava_conv.roles,
|
235 |
-
messages=[],
|
236 |
-
offset=0,
|
237 |
-
sep_style=SeparatorStyle.SINGLE,
|
238 |
-
sep="\n",
|
239 |
-
path_to_img=None,
|
240 |
-
video_title=None,
|
241 |
-
caption=None,
|
242 |
-
mm_rag_chain=mm_rag_chain,
|
243 |
-
)
|
244 |
-
return instance
|
245 |
-
|
246 |
-
gr.set_static_paths(paths=["./assets/"])
|
247 |
-
theme = gr.themes.Base(
|
248 |
-
primary_hue=gr.themes.Color(
|
249 |
-
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
|
250 |
-
secondary_hue=gr.themes.Color(
|
251 |
-
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
|
252 |
-
).set(
|
253 |
-
body_background_fill_dark='*primary_950',
|
254 |
-
body_text_color_dark='*neutral_300',
|
255 |
-
border_color_accent='*primary_700',
|
256 |
-
border_color_accent_dark='*neutral_800',
|
257 |
-
block_background_fill_dark='*primary_950',
|
258 |
-
block_border_width='2px',
|
259 |
-
block_border_width_dark='2px',
|
260 |
-
button_primary_background_fill_dark='*primary_500',
|
261 |
-
button_primary_border_color_dark='*primary_500'
|
262 |
-
)
|
263 |
-
|
264 |
-
css='''
|
265 |
-
@font-face {
|
266 |
-
font-family: IntelOne;
|
267 |
-
src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf");
|
268 |
-
}
|
269 |
-
.gradio-container {background-color: #0a0c2b}
|
270 |
-
table {
|
271 |
-
border-collapse: collapse;
|
272 |
-
border: none;
|
273 |
-
}
|
274 |
-
'''
|
275 |
-
|
276 |
-
## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
|
277 |
-
|
278 |
-
# html_title = '''
|
279 |
-
# <table style="bordercolor=#0a0c2b; border=0">
|
280 |
-
# <tr style="height:150px; border:0">
|
281 |
-
# <td style="border:0"><img src="/file=../assets/intel-labs.png" height="100" width="100"></td>
|
282 |
-
# <td style="vertical-align:bottom; border:0">
|
283 |
-
# <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
|
284 |
-
# Multimodal RAG:
|
285 |
-
# <br>
|
286 |
-
# Chat with Videos
|
287 |
-
# </p>
|
288 |
-
# </td>
|
289 |
-
# <td style="border:0"><img src="/file=../assets/gaudi.png" width="100" height="100"></td>
|
290 |
-
|
291 |
-
# <td style="border:0"><img src="/file=../assets/IDC7.png" width="300" height="350"></td>
|
292 |
-
# <td style="border:0"><img src="/file=../assets/prediction_guard3.png" width="120" height="120"></td>
|
293 |
-
# </tr>
|
294 |
-
# </table>
|
295 |
-
|
296 |
-
# '''
|
297 |
-
|
298 |
-
html_title = '''
|
299 |
-
<table style="bordercolor=#0a0c2b; border=0">
|
300 |
-
<tr style="height:150px; border:0">
|
301 |
-
<td style="border:0"><img src="/file=./assets/header.png"></td>
|
302 |
-
</tr>
|
303 |
-
</table>
|
304 |
-
|
305 |
-
'''
|
306 |
-
|
307 |
-
#<td style="border:0"><img src="/file=../assets/xeon.png" width="100" height="100"></td>
|
308 |
-
dropdown_list = [
|
309 |
-
"What is the name of one of the astronauts?",
|
310 |
-
"An astronaut's spacewalk",
|
311 |
-
"What does the astronaut say?",
|
312 |
-
|
313 |
-
]
|
314 |
-
|
315 |
-
no_change_btn = gr.Button()
|
316 |
-
enable_btn = gr.Button(interactive=True)
|
317 |
-
disable_btn = gr.Button(interactive=False)
|
318 |
-
|
319 |
-
def clear_history(state, request: gr.Request):
|
320 |
-
state = get_gradio_instance(state.mm_rag_chain)
|
321 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
|
322 |
-
|
323 |
-
def add_text(state, text, request: gr.Request):
|
324 |
-
if len(text) <= 0 :
|
325 |
-
state.skip_next = True
|
326 |
-
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
|
327 |
-
|
328 |
-
text = text[:1536] # Hard cut-off
|
329 |
-
|
330 |
-
state.append_message(state.roles[0], text)
|
331 |
-
state.append_message(state.roles[1], None)
|
332 |
-
state.skip_next = False
|
333 |
-
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
|
334 |
-
|
335 |
-
def http_bot(
|
336 |
-
state, request: gr.Request
|
337 |
-
):
|
338 |
-
start_tstamp = time.time()
|
339 |
-
|
340 |
-
if state.skip_next:
|
341 |
-
# This generate call is skipped due to invalid inputs
|
342 |
-
path_to_sub_videos = state.get_path_to_subvideos()
|
343 |
-
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
|
344 |
-
return
|
345 |
-
|
346 |
-
if len(state.messages) == state.offset + 2:
|
347 |
-
# First round of conversation
|
348 |
-
new_state = get_gradio_instance(state.mm_rag_chain)
|
349 |
-
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
350 |
-
new_state.append_message(new_state.roles[1], None)
|
351 |
-
state = new_state
|
352 |
-
|
353 |
-
all_images = state.get_images(return_pil=False)
|
354 |
-
|
355 |
-
# Make requests
|
356 |
-
is_very_first_query = True
|
357 |
-
if len(all_images) == 0:
|
358 |
-
# first query need to do RAG
|
359 |
-
# Construct prompt
|
360 |
-
prompt_or_conversation = state.get_prompt_for_rag()
|
361 |
-
else:
|
362 |
-
# subsequence queries, no need to do Retrieval
|
363 |
-
is_very_first_query = False
|
364 |
-
prompt_or_conversation = state.get_conversation_for_lvlm()
|
365 |
-
|
366 |
-
if is_very_first_query:
|
367 |
-
executor = state.mm_rag_chain
|
368 |
-
else:
|
369 |
-
executor = lvlm_inference_with_conversation
|
370 |
-
|
371 |
-
state.messages[-1][-1] = "▌"
|
372 |
-
path_to_sub_videos = state.get_path_to_subvideos()
|
373 |
-
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
|
374 |
-
|
375 |
-
try:
|
376 |
-
if is_very_first_query:
|
377 |
-
# get response by invoke executor chain
|
378 |
-
response = executor.invoke(prompt_or_conversation)
|
379 |
-
message = response['final_text_output']
|
380 |
-
if 'metadata' in response['input_to_lvlm']:
|
381 |
-
metadata = response['input_to_lvlm']['metadata']
|
382 |
-
if (state.path_to_img is None
|
383 |
-
and 'input_to_lvlm' in response
|
384 |
-
and 'image' in response['input_to_lvlm']
|
385 |
-
):
|
386 |
-
state.path_to_img = response['input_to_lvlm']['image']
|
387 |
-
|
388 |
-
if state.path_to_video is None and 'video_path' in metadata:
|
389 |
-
video_path = metadata['video_path']
|
390 |
-
mid_time_ms = metadata['mid_time_ms']
|
391 |
-
splited_video_path = split_video(video_path, mid_time_ms)
|
392 |
-
state.path_to_video = splited_video_path
|
393 |
-
|
394 |
-
if state.caption is None and 'transcript' in metadata:
|
395 |
-
state.caption = metadata['transcript']
|
396 |
-
else:
|
397 |
-
raise ValueError("Response's format is changed")
|
398 |
-
else:
|
399 |
-
# get the response message by directly call PredictionGuardAPI
|
400 |
-
message = executor(prompt_or_conversation)
|
401 |
-
|
402 |
-
except Exception as e:
|
403 |
-
print(e)
|
404 |
-
state.messages[-1][-1] = server_error_msg
|
405 |
-
yield (state, state.to_gradio_chatbot(), None) + (
|
406 |
-
enable_btn,
|
407 |
-
)
|
408 |
-
return
|
409 |
-
|
410 |
-
state.messages[-1][-1] = message
|
411 |
-
path_to_sub_videos = state.get_path_to_subvideos()
|
412 |
-
# path_to_image = state.path_to_img
|
413 |
-
# caption = state.caption
|
414 |
-
# # print(path_to_sub_videos)
|
415 |
-
# # print(path_to_image)
|
416 |
-
# # print('caption: ', caption)
|
417 |
-
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
|
418 |
-
|
419 |
-
finish_tstamp = time.time()
|
420 |
-
return
|
421 |
-
|
422 |
-
def get_demo(rag_chain=None):
|
423 |
-
if rag_chain is None:
|
424 |
-
rag_chain = get_default_rag_chain()
|
425 |
-
|
426 |
-
with gr.Blocks(theme=theme, css=css) as demo:
|
427 |
-
# gr.Markdown(description)
|
428 |
-
instance = get_gradio_instance(rag_chain)
|
429 |
-
state = gr.State(instance)
|
430 |
-
demo.load(
|
431 |
-
None,
|
432 |
-
None,
|
433 |
-
js="""
|
434 |
-
() => {
|
435 |
-
const params = new URLSearchParams(window.location.search);
|
436 |
-
if (!params.has('__theme')) {
|
437 |
-
params.set('__theme', 'dark');
|
438 |
-
window.location.search = params.toString();
|
439 |
-
}
|
440 |
-
}""",
|
441 |
-
)
|
442 |
-
gr.HTML(value=html_title)
|
443 |
-
with gr.Row():
|
444 |
-
with gr.Column(scale=4):
|
445 |
-
video = gr.Video(height=512, width=512, elem_id="video", interactive=False )
|
446 |
-
with gr.Column(scale=7):
|
447 |
-
chatbot = gr.Chatbot(
|
448 |
-
elem_id="chatbot", label="Multimodal RAG Chatbot", height=512,
|
449 |
-
)
|
450 |
-
with gr.Row():
|
451 |
-
with gr.Column(scale=8):
|
452 |
-
# textbox.render()
|
453 |
-
textbox = gr.Dropdown(
|
454 |
-
dropdown_list,
|
455 |
-
allow_custom_value=True,
|
456 |
-
# show_label=False,
|
457 |
-
# container=False,
|
458 |
-
label="Query",
|
459 |
-
info="Enter your query here or choose a sample from the dropdown list!"
|
460 |
-
)
|
461 |
-
with gr.Column(scale=1, min_width=50):
|
462 |
-
submit_btn = gr.Button(
|
463 |
-
value="Send", variant="primary", interactive=True
|
464 |
-
)
|
465 |
-
with gr.Row(elem_id="buttons") as button_row:
|
466 |
-
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
467 |
-
|
468 |
-
btn_list = [clear_btn]
|
469 |
-
|
470 |
-
clear_btn.click(
|
471 |
-
clear_history, [state], [state, chatbot, textbox, video] + btn_list
|
472 |
-
)
|
473 |
-
submit_btn.click(
|
474 |
-
add_text,
|
475 |
-
[state, textbox],
|
476 |
-
[state, chatbot, textbox,] + btn_list,
|
477 |
-
).then(
|
478 |
-
http_bot,
|
479 |
-
[state],
|
480 |
-
[state, chatbot, video] + btn_list,
|
481 |
-
)
|
482 |
-
return demo
|
483 |
-
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import io
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import dataclasses
|
6 |
+
from pathlib import Path
|
7 |
+
import os
|
8 |
+
from enum import auto, Enum
|
9 |
+
from typing import List, Tuple, Any
|
10 |
+
from utility import prediction_guard_llava_conv
|
11 |
+
import lancedb
|
12 |
+
from utility import load_json_file
|
13 |
+
from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings
|
14 |
+
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
15 |
+
from mm_rag.MLM.client import PredictionGuardClient
|
16 |
+
from mm_rag.MLM.lvlm import LVLM
|
17 |
+
from PIL import Image
|
18 |
+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
|
19 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
20 |
+
from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation
|
21 |
+
|
22 |
+
server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
23 |
+
|
24 |
+
# function to split video at a timestamp
|
25 |
+
def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3):
|
26 |
+
timestamp_in_sec = int(timestamp_in_ms / 1000)
|
27 |
+
# create output_video_name folder if not exist:
|
28 |
+
Path(output_video_path).mkdir(parents=True, exist_ok=True)
|
29 |
+
output_video = os.path.join(output_video_path, output_video_name)
|
30 |
+
with VideoFileClip(video_path) as video:
|
31 |
+
duration = video.duration
|
32 |
+
start_time = max(timestamp_in_sec - play_before_sec, 0)
|
33 |
+
end_time = min(timestamp_in_sec + play_after_sec, duration)
|
34 |
+
new = video.subclip(start_time, end_time)
|
35 |
+
new.write_videofile(output_video, audio_codec='aac')
|
36 |
+
return output_video
|
37 |
+
|
38 |
+
|
39 |
+
prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
|
40 |
+
|
41 |
+
# define default rag_chain
|
42 |
+
def get_default_rag_chain():
|
43 |
+
# declare host file
|
44 |
+
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
45 |
+
# declare table name
|
46 |
+
TBL_NAME = "demo_tbl"
|
47 |
+
|
48 |
+
# initialize vectorstore
|
49 |
+
db = lancedb.connect(LANCEDB_HOST_FILE)
|
50 |
+
|
51 |
+
# initialize an BridgeTower embedder
|
52 |
+
embedder = BridgeTowerEmbeddings()
|
53 |
+
|
54 |
+
## Creating a LanceDB vector store
|
55 |
+
vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME)
|
56 |
+
### creating a retriever for the vector store
|
57 |
+
retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1})
|
58 |
+
|
59 |
+
# initialize a client as PredictionGuardClien
|
60 |
+
client = PredictionGuardClient()
|
61 |
+
# initialize LVLM with the given client
|
62 |
+
lvlm_inference_module = LVLM(client=client)
|
63 |
+
|
64 |
+
def prompt_processing(input):
|
65 |
+
# get the retrieved results and user's query
|
66 |
+
retrieved_results, user_query = input['retrieved_results'], input['user_query']
|
67 |
+
# get the first retrieved result by default
|
68 |
+
retrieved_result = retrieved_results[0]
|
69 |
+
# prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
|
70 |
+
|
71 |
+
# get all metadata of the retrieved video segment
|
72 |
+
metadata_retrieved_video_segment = retrieved_result.metadata['metadata']
|
73 |
+
|
74 |
+
# get the frame and the corresponding transcript, path to extracted frame, path to whole video, and time stamp of the retrieved video segment.
|
75 |
+
transcript = metadata_retrieved_video_segment['transcript']
|
76 |
+
frame_path = metadata_retrieved_video_segment['extracted_frame_path']
|
77 |
+
return {
|
78 |
+
'prompt': prompt_template.format(transcript=transcript, user_query=user_query),
|
79 |
+
'image' : frame_path,
|
80 |
+
'metadata' : metadata_retrieved_video_segment,
|
81 |
+
}
|
82 |
+
# initialize prompt processing module as a Langchain RunnableLambda of function prompt_processing
|
83 |
+
prompt_processing_module = RunnableLambda(prompt_processing)
|
84 |
+
|
85 |
+
# the output of this new chain will be a dictionary
|
86 |
+
mm_rag_chain_with_retrieved_image = (
|
87 |
+
RunnableParallel({"retrieved_results": retriever_module ,
|
88 |
+
"user_query": RunnablePassthrough()})
|
89 |
+
| prompt_processing_module
|
90 |
+
| RunnableParallel({'final_text_output': lvlm_inference_module,
|
91 |
+
'input_to_lvlm' : RunnablePassthrough()})
|
92 |
+
)
|
93 |
+
return mm_rag_chain_with_retrieved_image
|
94 |
+
|
95 |
+
class SeparatorStyle(Enum):
|
96 |
+
"""Different separator style."""
|
97 |
+
SINGLE = auto()
|
98 |
+
|
99 |
+
@dataclasses.dataclass
|
100 |
+
class GradioInstance:
|
101 |
+
"""A class that keeps all conversation history."""
|
102 |
+
system: str
|
103 |
+
roles: List[str]
|
104 |
+
messages: List[List[str]]
|
105 |
+
offset: int
|
106 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
107 |
+
sep: str = "\n"
|
108 |
+
sep2: str = None
|
109 |
+
version: str = "Unknown"
|
110 |
+
path_to_img: str = None
|
111 |
+
video_title: str = None
|
112 |
+
path_to_video: str = None
|
113 |
+
caption: str = None
|
114 |
+
mm_rag_chain: Any = None
|
115 |
+
|
116 |
+
skip_next: bool = False
|
117 |
+
|
118 |
+
def _template_caption(self):
|
119 |
+
out = ""
|
120 |
+
if self.caption is not None:
|
121 |
+
out = f"The caption associated with the image is '{self.caption}'. "
|
122 |
+
return out
|
123 |
+
|
124 |
+
def get_prompt_for_rag(self):
|
125 |
+
messages = self.messages
|
126 |
+
assert len(messages) == 2, "length of current conversation should be 2"
|
127 |
+
assert messages[1][1] is None, "the first response message of current conversation should be None"
|
128 |
+
ret = messages[0][1]
|
129 |
+
return ret
|
130 |
+
|
131 |
+
def get_conversation_for_lvlm(self):
|
132 |
+
pg_conv = prediction_guard_llava_conv.copy()
|
133 |
+
image_path = self.path_to_img
|
134 |
+
b64_img = encode_image(image_path)
|
135 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
136 |
+
if msg is None:
|
137 |
+
break
|
138 |
+
if i == 0:
|
139 |
+
pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img])
|
140 |
+
elif i == len(self.messages[self.offset:]) - 2:
|
141 |
+
pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)])
|
142 |
+
else:
|
143 |
+
pg_conv.append_message(role, [msg])
|
144 |
+
return pg_conv
|
145 |
+
|
146 |
+
def append_message(self, role, message):
|
147 |
+
self.messages.append([role, message])
|
148 |
+
|
149 |
+
def get_images(self, return_pil=False):
|
150 |
+
images = []
|
151 |
+
if self.path_to_img is not None:
|
152 |
+
path_to_image = self.path_to_img
|
153 |
+
images.append(path_to_image)
|
154 |
+
return images
|
155 |
+
|
156 |
+
def to_gradio_chatbot(self):
|
157 |
+
ret = []
|
158 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
159 |
+
if i % 2 == 0:
|
160 |
+
if type(msg) is tuple:
|
161 |
+
import base64
|
162 |
+
from io import BytesIO
|
163 |
+
msg, image, image_process_mode = msg
|
164 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
165 |
+
aspect_ratio = max_hw / min_hw
|
166 |
+
max_len, min_len = 800, 400
|
167 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
168 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
169 |
+
W, H = image.size
|
170 |
+
if H > W:
|
171 |
+
H, W = longest_edge, shortest_edge
|
172 |
+
else:
|
173 |
+
H, W = shortest_edge, longest_edge
|
174 |
+
image = image.resize((W, H))
|
175 |
+
buffered = BytesIO()
|
176 |
+
image.save(buffered, format="JPEG")
|
177 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
178 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
179 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
180 |
+
ret.append([msg, None])
|
181 |
+
else:
|
182 |
+
ret.append([msg, None])
|
183 |
+
else:
|
184 |
+
ret[-1][-1] = msg
|
185 |
+
return ret
|
186 |
+
|
187 |
+
def copy(self):
|
188 |
+
return GradioInstance(
|
189 |
+
system=self.system,
|
190 |
+
roles=self.roles,
|
191 |
+
messages=[[x, y] for x, y in self.messages],
|
192 |
+
offset=self.offset,
|
193 |
+
sep_style=self.sep_style,
|
194 |
+
sep=self.sep,
|
195 |
+
sep2=self.sep2,
|
196 |
+
version=self.version,
|
197 |
+
mm_rag_chain=self.mm_rag_chain,
|
198 |
+
)
|
199 |
+
|
200 |
+
def dict(self):
|
201 |
+
return {
|
202 |
+
"system": self.system,
|
203 |
+
"roles": self.roles,
|
204 |
+
"messages": self.messages,
|
205 |
+
"offset": self.offset,
|
206 |
+
"sep": self.sep,
|
207 |
+
"sep2": self.sep2,
|
208 |
+
"path_to_img": self.path_to_img,
|
209 |
+
"video_title" : self.video_title,
|
210 |
+
"path_to_video": self.path_to_video,
|
211 |
+
"caption" : self.caption,
|
212 |
+
}
|
213 |
+
def get_path_to_subvideos(self):
|
214 |
+
if self.video_title is not None and self.path_to_img is not None:
|
215 |
+
info = video_helper_map[self.video_title]
|
216 |
+
path = info['path']
|
217 |
+
prefix = info['prefix']
|
218 |
+
vid_index = self.path_to_img.split('/')[-1]
|
219 |
+
vid_index = vid_index.split('_')[-1]
|
220 |
+
vid_index = vid_index.replace('.jpg', '')
|
221 |
+
ret = f"{prefix}{vid_index}.mp4"
|
222 |
+
ret = os.path.join(path, ret)
|
223 |
+
return ret
|
224 |
+
elif self.path_to_video is not None:
|
225 |
+
return self.path_to_video
|
226 |
+
return None
|
227 |
+
|
228 |
+
def get_gradio_instance(mm_rag_chain=None):
|
229 |
+
if mm_rag_chain is None:
|
230 |
+
mm_rag_chain = get_default_rag_chain()
|
231 |
+
|
232 |
+
instance = GradioInstance(
|
233 |
+
system="",
|
234 |
+
roles=prediction_guard_llava_conv.roles,
|
235 |
+
messages=[],
|
236 |
+
offset=0,
|
237 |
+
sep_style=SeparatorStyle.SINGLE,
|
238 |
+
sep="\n",
|
239 |
+
path_to_img=None,
|
240 |
+
video_title=None,
|
241 |
+
caption=None,
|
242 |
+
mm_rag_chain=mm_rag_chain,
|
243 |
+
)
|
244 |
+
return instance
|
245 |
+
|
246 |
+
gr.set_static_paths(paths=["./assets/"])
|
247 |
+
theme = gr.themes.Base(
|
248 |
+
primary_hue=gr.themes.Color(
|
249 |
+
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
|
250 |
+
secondary_hue=gr.themes.Color(
|
251 |
+
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
|
252 |
+
).set(
|
253 |
+
body_background_fill_dark='*primary_950',
|
254 |
+
body_text_color_dark='*neutral_300',
|
255 |
+
border_color_accent='*primary_700',
|
256 |
+
border_color_accent_dark='*neutral_800',
|
257 |
+
block_background_fill_dark='*primary_950',
|
258 |
+
block_border_width='2px',
|
259 |
+
block_border_width_dark='2px',
|
260 |
+
button_primary_background_fill_dark='*primary_500',
|
261 |
+
button_primary_border_color_dark='*primary_500'
|
262 |
+
)
|
263 |
+
|
264 |
+
css='''
|
265 |
+
@font-face {
|
266 |
+
font-family: IntelOne;
|
267 |
+
src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf");
|
268 |
+
}
|
269 |
+
.gradio-container {background-color: #0a0c2b}
|
270 |
+
table {
|
271 |
+
border-collapse: collapse;
|
272 |
+
border: none;
|
273 |
+
}
|
274 |
+
'''
|
275 |
+
|
276 |
+
## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
|
277 |
+
|
278 |
+
# html_title = '''
|
279 |
+
# <table style="bordercolor=#0a0c2b; border=0">
|
280 |
+
# <tr style="height:150px; border:0">
|
281 |
+
# <td style="border:0"><img src="/file=../assets/intel-labs.png" height="100" width="100"></td>
|
282 |
+
# <td style="vertical-align:bottom; border:0">
|
283 |
+
# <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
|
284 |
+
# Multimodal RAG:
|
285 |
+
# <br>
|
286 |
+
# Chat with Videos
|
287 |
+
# </p>
|
288 |
+
# </td>
|
289 |
+
# <td style="border:0"><img src="/file=../assets/gaudi.png" width="100" height="100"></td>
|
290 |
+
|
291 |
+
# <td style="border:0"><img src="/file=../assets/IDC7.png" width="300" height="350"></td>
|
292 |
+
# <td style="border:0"><img src="/file=../assets/prediction_guard3.png" width="120" height="120"></td>
|
293 |
+
# </tr>
|
294 |
+
# </table>
|
295 |
+
|
296 |
+
# '''
|
297 |
+
|
298 |
+
html_title = '''
|
299 |
+
<table style="bordercolor=#0a0c2b; border=0">
|
300 |
+
<tr style="height:150px; border:0">
|
301 |
+
<td style="border:0"><img src="/file=./assets/header.png"></td>
|
302 |
+
</tr>
|
303 |
+
</table>
|
304 |
+
|
305 |
+
'''
|
306 |
+
|
307 |
+
#<td style="border:0"><img src="/file=../assets/xeon.png" width="100" height="100"></td>
|
308 |
+
dropdown_list = [
|
309 |
+
"What is the name of one of the astronauts?",
|
310 |
+
"An astronaut's spacewalk",
|
311 |
+
"What does the astronaut say?",
|
312 |
+
|
313 |
+
]
|
314 |
+
|
315 |
+
no_change_btn = gr.Button()
|
316 |
+
enable_btn = gr.Button(interactive=True)
|
317 |
+
disable_btn = gr.Button(interactive=False)
|
318 |
+
|
319 |
+
def clear_history(state, request: gr.Request):
|
320 |
+
state = get_gradio_instance(state.mm_rag_chain)
|
321 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
|
322 |
+
|
323 |
+
def add_text(state, text, request: gr.Request):
|
324 |
+
if len(text) <= 0 :
|
325 |
+
state.skip_next = True
|
326 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
|
327 |
+
|
328 |
+
text = text[:1536] # Hard cut-off
|
329 |
+
|
330 |
+
state.append_message(state.roles[0], text)
|
331 |
+
state.append_message(state.roles[1], None)
|
332 |
+
state.skip_next = False
|
333 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
|
334 |
+
|
335 |
+
def http_bot(
|
336 |
+
state, request: gr.Request
|
337 |
+
):
|
338 |
+
start_tstamp = time.time()
|
339 |
+
|
340 |
+
if state.skip_next:
|
341 |
+
# This generate call is skipped due to invalid inputs
|
342 |
+
path_to_sub_videos = state.get_path_to_subvideos()
|
343 |
+
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
|
344 |
+
return
|
345 |
+
|
346 |
+
if len(state.messages) == state.offset + 2:
|
347 |
+
# First round of conversation
|
348 |
+
new_state = get_gradio_instance(state.mm_rag_chain)
|
349 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
350 |
+
new_state.append_message(new_state.roles[1], None)
|
351 |
+
state = new_state
|
352 |
+
|
353 |
+
all_images = state.get_images(return_pil=False)
|
354 |
+
|
355 |
+
# Make requests
|
356 |
+
is_very_first_query = True
|
357 |
+
if len(all_images) == 0:
|
358 |
+
# first query need to do RAG
|
359 |
+
# Construct prompt
|
360 |
+
prompt_or_conversation = state.get_prompt_for_rag()
|
361 |
+
else:
|
362 |
+
# subsequence queries, no need to do Retrieval
|
363 |
+
is_very_first_query = False
|
364 |
+
prompt_or_conversation = state.get_conversation_for_lvlm()
|
365 |
+
|
366 |
+
if is_very_first_query:
|
367 |
+
executor = state.mm_rag_chain
|
368 |
+
else:
|
369 |
+
executor = lvlm_inference_with_conversation
|
370 |
+
|
371 |
+
state.messages[-1][-1] = "▌"
|
372 |
+
path_to_sub_videos = state.get_path_to_subvideos()
|
373 |
+
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
|
374 |
+
|
375 |
+
try:
|
376 |
+
if is_very_first_query:
|
377 |
+
# get response by invoke executor chain
|
378 |
+
response = executor.invoke(prompt_or_conversation)
|
379 |
+
message = response['final_text_output']
|
380 |
+
if 'metadata' in response['input_to_lvlm']:
|
381 |
+
metadata = response['input_to_lvlm']['metadata']
|
382 |
+
if (state.path_to_img is None
|
383 |
+
and 'input_to_lvlm' in response
|
384 |
+
and 'image' in response['input_to_lvlm']
|
385 |
+
):
|
386 |
+
state.path_to_img = response['input_to_lvlm']['image']
|
387 |
+
|
388 |
+
if state.path_to_video is None and 'video_path' in metadata:
|
389 |
+
video_path = metadata['video_path']
|
390 |
+
mid_time_ms = metadata['mid_time_ms']
|
391 |
+
splited_video_path = split_video(video_path, mid_time_ms)
|
392 |
+
state.path_to_video = splited_video_path
|
393 |
+
|
394 |
+
if state.caption is None and 'transcript' in metadata:
|
395 |
+
state.caption = metadata['transcript']
|
396 |
+
else:
|
397 |
+
raise ValueError("Response's format is changed")
|
398 |
+
else:
|
399 |
+
# get the response message by directly call PredictionGuardAPI
|
400 |
+
message = executor(prompt_or_conversation)
|
401 |
+
|
402 |
+
except Exception as e:
|
403 |
+
print(e)
|
404 |
+
state.messages[-1][-1] = server_error_msg
|
405 |
+
yield (state, state.to_gradio_chatbot(), None) + (
|
406 |
+
enable_btn,
|
407 |
+
)
|
408 |
+
return
|
409 |
+
|
410 |
+
state.messages[-1][-1] = message
|
411 |
+
path_to_sub_videos = state.get_path_to_subvideos()
|
412 |
+
# path_to_image = state.path_to_img
|
413 |
+
# caption = state.caption
|
414 |
+
# # print(path_to_sub_videos)
|
415 |
+
# # print(path_to_image)
|
416 |
+
# # print('caption: ', caption)
|
417 |
+
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
|
418 |
+
|
419 |
+
finish_tstamp = time.time()
|
420 |
+
return
|
421 |
+
|
422 |
+
def get_demo(rag_chain=None):
|
423 |
+
if rag_chain is None:
|
424 |
+
rag_chain = get_default_rag_chain()
|
425 |
+
|
426 |
+
with gr.Blocks(theme=theme, css=css) as demo:
|
427 |
+
# gr.Markdown(description)
|
428 |
+
instance = get_gradio_instance(rag_chain)
|
429 |
+
state = gr.State(instance)
|
430 |
+
demo.load(
|
431 |
+
None,
|
432 |
+
None,
|
433 |
+
js="""
|
434 |
+
() => {
|
435 |
+
const params = new URLSearchParams(window.location.search);
|
436 |
+
if (!params.has('__theme')) {
|
437 |
+
params.set('__theme', 'dark');
|
438 |
+
window.location.search = params.toString();
|
439 |
+
}
|
440 |
+
}""",
|
441 |
+
)
|
442 |
+
gr.HTML(value=html_title)
|
443 |
+
with gr.Row():
|
444 |
+
with gr.Column(scale=4):
|
445 |
+
video = gr.Video(height=512, width=512, elem_id="video", interactive=False )
|
446 |
+
with gr.Column(scale=7):
|
447 |
+
chatbot = gr.Chatbot(
|
448 |
+
elem_id="chatbot", label="Multimodal RAG Chatbot", height=512,
|
449 |
+
)
|
450 |
+
with gr.Row():
|
451 |
+
with gr.Column(scale=8):
|
452 |
+
# textbox.render()
|
453 |
+
textbox = gr.Dropdown(
|
454 |
+
dropdown_list,
|
455 |
+
allow_custom_value=True,
|
456 |
+
# show_label=False,
|
457 |
+
# container=False,
|
458 |
+
label="Query",
|
459 |
+
info="Enter your query here or choose a sample from the dropdown list!"
|
460 |
+
)
|
461 |
+
with gr.Column(scale=1, min_width=50):
|
462 |
+
submit_btn = gr.Button(
|
463 |
+
value="Send", variant="primary", interactive=True
|
464 |
+
)
|
465 |
+
with gr.Row(elem_id="buttons") as button_row:
|
466 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
467 |
+
|
468 |
+
btn_list = [clear_btn]
|
469 |
+
|
470 |
+
clear_btn.click(
|
471 |
+
clear_history, [state], [state, chatbot, textbox, video] + btn_list
|
472 |
+
)
|
473 |
+
submit_btn.click(
|
474 |
+
add_text,
|
475 |
+
[state, textbox],
|
476 |
+
[state, chatbot, textbox,] + btn_list,
|
477 |
+
).then(
|
478 |
+
http_bot,
|
479 |
+
[state],
|
480 |
+
[state, chatbot, video] + btn_list,
|
481 |
+
)
|
482 |
+
return demo
|
483 |
+
|
mm_rag/MLM/client.py
CHANGED
@@ -1,135 +1,135 @@
|
|
1 |
-
"""Base interface for client making requests/call to visual language model provider API"""
|
2 |
-
|
3 |
-
from abc import ABC, abstractmethod
|
4 |
-
from typing import List, Optional, Dict, Union, Iterator
|
5 |
-
import requests
|
6 |
-
import json
|
7 |
-
from utility import isBase64, encode_image, encode_image_from_path_or_url, lvlm_inference
|
8 |
-
|
9 |
-
class BaseClient(ABC):
|
10 |
-
def __init__(self,
|
11 |
-
hostname: str = "127.0.0.1",
|
12 |
-
port: int = 8090,
|
13 |
-
timeout: int = 60,
|
14 |
-
url: Optional[str] = None):
|
15 |
-
self.connection_url = f"http://{hostname}:{port}" if url is None else url
|
16 |
-
self.timeout = timeout
|
17 |
-
# self.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
18 |
-
self.headers = {'Content-Type': 'application/json'}
|
19 |
-
|
20 |
-
def root(self):
|
21 |
-
"""Request for showing welcome message"""
|
22 |
-
connection_route = f"{self.connection_url}/"
|
23 |
-
return requests.get(connection_route)
|
24 |
-
|
25 |
-
@abstractmethod
|
26 |
-
def generate(self,
|
27 |
-
prompt: str,
|
28 |
-
image: str,
|
29 |
-
**kwargs
|
30 |
-
) -> str:
|
31 |
-
"""Send request to visual language model API
|
32 |
-
and return generated text that was returned by the visual language model API
|
33 |
-
|
34 |
-
Use this method when you want to call visual language model API to generate text without streaming
|
35 |
-
|
36 |
-
Args:
|
37 |
-
prompt: A prompt.
|
38 |
-
image: A string that can be either path to image or base64 of an image.
|
39 |
-
**kwargs: Arbitrary additional keyword arguments.
|
40 |
-
These are usually passed to the model provider API call as hyperparameter for generation.
|
41 |
-
|
42 |
-
Returns:
|
43 |
-
Text returned from visual language model provider API call
|
44 |
-
"""
|
45 |
-
|
46 |
-
|
47 |
-
def generate_stream(
|
48 |
-
self,
|
49 |
-
prompt: str,
|
50 |
-
image: str,
|
51 |
-
**kwargs
|
52 |
-
) -> Iterator[str]:
|
53 |
-
"""Send request to visual language model API
|
54 |
-
and return an iterator of streaming text that were returned from the visual language model API call
|
55 |
-
|
56 |
-
Use this method when you want to call visual language model API to stream generated text.
|
57 |
-
|
58 |
-
Args:
|
59 |
-
prompt: A prompt.
|
60 |
-
image: A string that can be either path to image or base64 of an image.
|
61 |
-
**kwargs: Arbitrary additional keyword arguments.
|
62 |
-
These are usually passed to the model provider API call as hyperparameter for generation.
|
63 |
-
|
64 |
-
Returns:
|
65 |
-
Iterator of text streamed from visual language model provider API call
|
66 |
-
"""
|
67 |
-
raise NotImplementedError()
|
68 |
-
|
69 |
-
def generate_batch(
|
70 |
-
self,
|
71 |
-
prompt: List[str],
|
72 |
-
image: List[str],
|
73 |
-
**kwargs
|
74 |
-
) -> List[str]:
|
75 |
-
"""Send a request to visual language model API for multi-batch generation
|
76 |
-
and return a list of generated text that was returned by the visual language model API
|
77 |
-
|
78 |
-
Use this method when you want to call visual language model API to multi-batch generate text.
|
79 |
-
Multi-batch generation does not support streaming.
|
80 |
-
|
81 |
-
Args:
|
82 |
-
prompt: List of prompts.
|
83 |
-
image: List of strings; each of which can be either path to image or base64 of an image.
|
84 |
-
**kwargs: Arbitrary additional keyword arguments.
|
85 |
-
These are usually passed to the model provider API call as hyperparameter for generation.
|
86 |
-
|
87 |
-
Returns:
|
88 |
-
List of texts returned from visual language model provider API call
|
89 |
-
"""
|
90 |
-
raise NotImplementedError()
|
91 |
-
|
92 |
-
class PredictionGuardClient(BaseClient):
|
93 |
-
|
94 |
-
generate_kwargs = ['max_tokens',
|
95 |
-
'temperature',
|
96 |
-
'top_p',
|
97 |
-
'top_k']
|
98 |
-
|
99 |
-
def filter_accepted_genkwargs(self, kwargs):
|
100 |
-
gen_args = {}
|
101 |
-
if "generate_kwargs" in kwargs and isinstance(kwargs["generate_kwargs"], dict):
|
102 |
-
gen_args = {k:kwargs["generate_kwargs"][k]
|
103 |
-
for k in self.generate_kwargs
|
104 |
-
if k in kwargs["generate_kwargs"]}
|
105 |
-
return gen_args
|
106 |
-
|
107 |
-
def generate(self,
|
108 |
-
prompt: str,
|
109 |
-
image: str,
|
110 |
-
**kwargs
|
111 |
-
) -> str:
|
112 |
-
"""Send request to PredictionGuard's API
|
113 |
-
and return generated text that was returned by LLAVA model
|
114 |
-
|
115 |
-
Use this method when you want to call LLAVA model API to generate text without streaming
|
116 |
-
|
117 |
-
Args:
|
118 |
-
prompt: A prompt.
|
119 |
-
image: A string that can be either path/URL to image or base64 of an image.
|
120 |
-
**kwargs: Arbitrary additional keyword arguments.
|
121 |
-
These are usually passed to the model provider API call as hyperparameter for generation.
|
122 |
-
|
123 |
-
Returns:
|
124 |
-
Text returned from visual language model provider API call
|
125 |
-
"""
|
126 |
-
|
127 |
-
assert image is not None and len(image) != "", "the input image cannot be None, it must be either base64-encoded image or path/URL to image"
|
128 |
-
if isBase64(image):
|
129 |
-
base64_image = image
|
130 |
-
else: # this is path to image or URL to image
|
131 |
-
base64_image = encode_image_from_path_or_url(image)
|
132 |
-
|
133 |
-
args = self.filter_accepted_genkwargs(kwargs)
|
134 |
-
return lvlm_inference(prompt=prompt, image=base64_image, **args)
|
135 |
|
|
|
1 |
+
"""Base interface for client making requests/call to visual language model provider API"""
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import List, Optional, Dict, Union, Iterator
|
5 |
+
import requests
|
6 |
+
import json
|
7 |
+
from utility import isBase64, encode_image, encode_image_from_path_or_url, lvlm_inference
|
8 |
+
|
9 |
+
class BaseClient(ABC):
|
10 |
+
def __init__(self,
|
11 |
+
hostname: str = "127.0.0.1",
|
12 |
+
port: int = 8090,
|
13 |
+
timeout: int = 60,
|
14 |
+
url: Optional[str] = None):
|
15 |
+
self.connection_url = f"http://{hostname}:{port}" if url is None else url
|
16 |
+
self.timeout = timeout
|
17 |
+
# self.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
18 |
+
self.headers = {'Content-Type': 'application/json'}
|
19 |
+
|
20 |
+
def root(self):
|
21 |
+
"""Request for showing welcome message"""
|
22 |
+
connection_route = f"{self.connection_url}/"
|
23 |
+
return requests.get(connection_route)
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def generate(self,
|
27 |
+
prompt: str,
|
28 |
+
image: str,
|
29 |
+
**kwargs
|
30 |
+
) -> str:
|
31 |
+
"""Send request to visual language model API
|
32 |
+
and return generated text that was returned by the visual language model API
|
33 |
+
|
34 |
+
Use this method when you want to call visual language model API to generate text without streaming
|
35 |
+
|
36 |
+
Args:
|
37 |
+
prompt: A prompt.
|
38 |
+
image: A string that can be either path to image or base64 of an image.
|
39 |
+
**kwargs: Arbitrary additional keyword arguments.
|
40 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Text returned from visual language model provider API call
|
44 |
+
"""
|
45 |
+
|
46 |
+
|
47 |
+
def generate_stream(
|
48 |
+
self,
|
49 |
+
prompt: str,
|
50 |
+
image: str,
|
51 |
+
**kwargs
|
52 |
+
) -> Iterator[str]:
|
53 |
+
"""Send request to visual language model API
|
54 |
+
and return an iterator of streaming text that were returned from the visual language model API call
|
55 |
+
|
56 |
+
Use this method when you want to call visual language model API to stream generated text.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
prompt: A prompt.
|
60 |
+
image: A string that can be either path to image or base64 of an image.
|
61 |
+
**kwargs: Arbitrary additional keyword arguments.
|
62 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
Iterator of text streamed from visual language model provider API call
|
66 |
+
"""
|
67 |
+
raise NotImplementedError()
|
68 |
+
|
69 |
+
def generate_batch(
|
70 |
+
self,
|
71 |
+
prompt: List[str],
|
72 |
+
image: List[str],
|
73 |
+
**kwargs
|
74 |
+
) -> List[str]:
|
75 |
+
"""Send a request to visual language model API for multi-batch generation
|
76 |
+
and return a list of generated text that was returned by the visual language model API
|
77 |
+
|
78 |
+
Use this method when you want to call visual language model API to multi-batch generate text.
|
79 |
+
Multi-batch generation does not support streaming.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
prompt: List of prompts.
|
83 |
+
image: List of strings; each of which can be either path to image or base64 of an image.
|
84 |
+
**kwargs: Arbitrary additional keyword arguments.
|
85 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
List of texts returned from visual language model provider API call
|
89 |
+
"""
|
90 |
+
raise NotImplementedError()
|
91 |
+
|
92 |
+
class PredictionGuardClient(BaseClient):
|
93 |
+
|
94 |
+
generate_kwargs = ['max_tokens',
|
95 |
+
'temperature',
|
96 |
+
'top_p',
|
97 |
+
'top_k']
|
98 |
+
|
99 |
+
def filter_accepted_genkwargs(self, kwargs):
|
100 |
+
gen_args = {}
|
101 |
+
if "generate_kwargs" in kwargs and isinstance(kwargs["generate_kwargs"], dict):
|
102 |
+
gen_args = {k:kwargs["generate_kwargs"][k]
|
103 |
+
for k in self.generate_kwargs
|
104 |
+
if k in kwargs["generate_kwargs"]}
|
105 |
+
return gen_args
|
106 |
+
|
107 |
+
def generate(self,
|
108 |
+
prompt: str,
|
109 |
+
image: str,
|
110 |
+
**kwargs
|
111 |
+
) -> str:
|
112 |
+
"""Send request to PredictionGuard's API
|
113 |
+
and return generated text that was returned by LLAVA model
|
114 |
+
|
115 |
+
Use this method when you want to call LLAVA model API to generate text without streaming
|
116 |
+
|
117 |
+
Args:
|
118 |
+
prompt: A prompt.
|
119 |
+
image: A string that can be either path/URL to image or base64 of an image.
|
120 |
+
**kwargs: Arbitrary additional keyword arguments.
|
121 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Text returned from visual language model provider API call
|
125 |
+
"""
|
126 |
+
|
127 |
+
assert image is not None and len(image) != "", "the input image cannot be None, it must be either base64-encoded image or path/URL to image"
|
128 |
+
if isBase64(image):
|
129 |
+
base64_image = image
|
130 |
+
else: # this is path to image or URL to image
|
131 |
+
base64_image = encode_image_from_path_or_url(image)
|
132 |
+
|
133 |
+
args = self.filter_accepted_genkwargs(kwargs)
|
134 |
+
return lvlm_inference(prompt=prompt, image=base64_image, **args)
|
135 |
|
mm_rag/MLM/lvlm.py
CHANGED
@@ -1,301 +1,301 @@
|
|
1 |
-
from .client import PredictionGuardClient
|
2 |
-
from langchain_core.language_models.llms import LLM
|
3 |
-
from langchain_core.pydantic_v1 import Extra, root_validator
|
4 |
-
from typing import Any, Optional, List, Dict, Iterator, AsyncIterator
|
5 |
-
from langchain_core.callbacks import CallbackManagerForLLMRun
|
6 |
-
from utility import get_from_dict_or_env, MultimodalModelInput
|
7 |
-
|
8 |
-
from langchain_core.runnables import RunnableConfig, ensure_config
|
9 |
-
from langchain_core.language_models.base import LanguageModelInput
|
10 |
-
from langchain_core.prompt_values import StringPromptValue
|
11 |
-
# from langchain_core.outputs import GenerationChunk, LLMResult
|
12 |
-
from langchain_core.language_models.llms import BaseLLM
|
13 |
-
from langchain_core.callbacks import (
|
14 |
-
# CallbackManager,
|
15 |
-
CallbackManagerForLLMRun,
|
16 |
-
)
|
17 |
-
# from langchain_core.load import dumpd
|
18 |
-
from langchain_core.runnables.config import run_in_executor
|
19 |
-
|
20 |
-
class LVLM(LLM):
|
21 |
-
"""This class extends LLM class for implementing a custom request to LVLM provider API"""
|
22 |
-
|
23 |
-
|
24 |
-
client: Any = None #: :meta private:
|
25 |
-
hostname: Optional[str] = None
|
26 |
-
port: Optional[int] = None
|
27 |
-
url: Optional[str] = None
|
28 |
-
max_new_tokens: Optional[int] = 200
|
29 |
-
temperature: Optional[float] = 0.6
|
30 |
-
top_k: Optional[float] = 0
|
31 |
-
stop: Optional[List[str]] = None
|
32 |
-
ignore_eos: Optional[bool] = False
|
33 |
-
do_sample: Optional[bool] = True
|
34 |
-
lazy_mode: Optional[bool] = True
|
35 |
-
hpu_graphs: Optional[bool] = True
|
36 |
-
|
37 |
-
@root_validator()
|
38 |
-
def validate_environment(cls, values: Dict) -> Dict:
|
39 |
-
"""Validate that the access token and python package exists in environment if needed"""
|
40 |
-
if values['client'] is None:
|
41 |
-
# check if url of API is provided
|
42 |
-
url = get_from_dict_or_env(values, 'url', "VLM_URL", None)
|
43 |
-
if url is None:
|
44 |
-
hostname = get_from_dict_or_env(values, 'hostname', 'VLM_HOSTNAME', None)
|
45 |
-
port = get_from_dict_or_env(values, 'port', 'VLM_PORT', None)
|
46 |
-
if hostname is not None and port is not None:
|
47 |
-
values['client'] = PredictionGuardClient(hostname=hostname, port=port)
|
48 |
-
else:
|
49 |
-
# using default hostname and port to create Client
|
50 |
-
values['client'] = PredictionGuardClient()
|
51 |
-
else:
|
52 |
-
values['client'] = PredictionGuardClient(url=url)
|
53 |
-
return values
|
54 |
-
|
55 |
-
@property
|
56 |
-
def _llm_type(self) -> str:
|
57 |
-
"""Return type of llm"""
|
58 |
-
return "Large Vision Language Model"
|
59 |
-
|
60 |
-
@property
|
61 |
-
def _default_params(self) -> Dict[str, Any]:
|
62 |
-
"""Get the default parameters for calling the Prediction Guard API."""
|
63 |
-
return {
|
64 |
-
"max_tokens": self.max_new_tokens,
|
65 |
-
"temperature": self.temperature,
|
66 |
-
"top_k": self.top_k,
|
67 |
-
"ignore_eos": self.ignore_eos,
|
68 |
-
"do_sample": self.do_sample,
|
69 |
-
"stop" : self.stop,
|
70 |
-
}
|
71 |
-
|
72 |
-
def get_params(self, **kwargs):
|
73 |
-
params = self._default_params
|
74 |
-
params.update(kwargs)
|
75 |
-
return params
|
76 |
-
|
77 |
-
|
78 |
-
def _call(
|
79 |
-
self,
|
80 |
-
prompt: str,
|
81 |
-
image: str,
|
82 |
-
stop: Optional[List[str]] = None,
|
83 |
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
84 |
-
**kwargs: Any,
|
85 |
-
) -> str:
|
86 |
-
"""Run the VLM on the given input.
|
87 |
-
|
88 |
-
Args:
|
89 |
-
prompt: The prompt to generate from.
|
90 |
-
image: This can be either path to image or base64 encode of the image.
|
91 |
-
stop: Stop words to use when generating. Model output is cut off at the
|
92 |
-
first occurrence of any of the stop substrings.
|
93 |
-
If stop tokens are not supported consider raising NotImplementedError.
|
94 |
-
Returns:
|
95 |
-
The model output as a string. Actual completions DOES NOT include the prompt
|
96 |
-
Example: TBD
|
97 |
-
"""
|
98 |
-
params = {}
|
99 |
-
if stop is not None:
|
100 |
-
raise ValueError("stop kwargs are not permitted.")
|
101 |
-
params['generate_kwargs'] = self.get_params(**kwargs)
|
102 |
-
response = self.client.generate(prompt=prompt, image=image, **params)
|
103 |
-
return response
|
104 |
-
|
105 |
-
def _stream(
|
106 |
-
self,
|
107 |
-
prompt: str,
|
108 |
-
image: str,
|
109 |
-
stop: Optional[List[str]] = None,
|
110 |
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
111 |
-
**kwargs: Any,
|
112 |
-
) -> Iterator[str]:
|
113 |
-
"""Stream the VLM on the given prompt and image.
|
114 |
-
|
115 |
-
Args:
|
116 |
-
prompt: The prompt to generate from.
|
117 |
-
image: This can be either path to image or base64 encode of the image.
|
118 |
-
stop: Stop words to use when generating. Model output is cut off at the
|
119 |
-
first occurrence of any of the stop substrings.
|
120 |
-
If stop tokens are not supported consider raising NotImplementedError.
|
121 |
-
Returns:
|
122 |
-
The model outputs an iterator of string. Actual completions DOES NOT include the prompt
|
123 |
-
Example: TBD
|
124 |
-
"""
|
125 |
-
params = {}
|
126 |
-
params['generate_kwargs'] = self.get_params(**kwargs)
|
127 |
-
for chunk in self.client.generate_stream(prompt=prompt, image=image, **params):
|
128 |
-
yield chunk
|
129 |
-
|
130 |
-
async def _astream(
|
131 |
-
self,
|
132 |
-
prompt: str,
|
133 |
-
image: str,
|
134 |
-
stop: Optional[List[str]] = None,
|
135 |
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
136 |
-
**kwargs: Any,
|
137 |
-
) -> AsyncIterator[str]:
|
138 |
-
"""An async version of _stream method that stream the VLM on the given prompt and image.
|
139 |
-
|
140 |
-
Args:
|
141 |
-
prompt: The prompt to generate from.
|
142 |
-
image: This can be either path to image or base64 encode of the image.
|
143 |
-
stop: Stop words to use when generating. Model output is cut off at the
|
144 |
-
first occurrence of any of the stop substrings.
|
145 |
-
If stop tokens are not supported consider raising NotImplementedError.
|
146 |
-
Returns:
|
147 |
-
The model outputs an async iterator of string. Actual completions DOES NOT include the prompt
|
148 |
-
Example: TBD
|
149 |
-
"""
|
150 |
-
iterator = await run_in_executor(
|
151 |
-
None,
|
152 |
-
self._stream,
|
153 |
-
prompt,
|
154 |
-
image,
|
155 |
-
stop,
|
156 |
-
run_manager.get_sync() if run_manager else None,
|
157 |
-
**kwargs,
|
158 |
-
)
|
159 |
-
done = object()
|
160 |
-
while True:
|
161 |
-
item = await run_in_executor(
|
162 |
-
None,
|
163 |
-
next,
|
164 |
-
iterator,
|
165 |
-
done, # type: ignore[call-arg, arg-type]
|
166 |
-
)
|
167 |
-
if item is done:
|
168 |
-
break
|
169 |
-
yield item # type: ignore[misc]
|
170 |
-
|
171 |
-
def invoke(
|
172 |
-
self,
|
173 |
-
input: MultimodalModelInput,
|
174 |
-
config: Optional[RunnableConfig] = None,
|
175 |
-
*,
|
176 |
-
stop: Optional[List[str]] = None,
|
177 |
-
**kwargs: Any,
|
178 |
-
) -> str:
|
179 |
-
config = ensure_config(config)
|
180 |
-
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
181 |
-
return (
|
182 |
-
self.generate_prompt(
|
183 |
-
[self._convert_input(StringPromptValue(text=input['prompt']))],
|
184 |
-
stop=stop,
|
185 |
-
callbacks=config.get("callbacks"),
|
186 |
-
tags=config.get("tags"),
|
187 |
-
metadata=config.get("metadata"),
|
188 |
-
run_name=config.get("run_name"),
|
189 |
-
run_id=config.pop("run_id", None),
|
190 |
-
image= input['image'],
|
191 |
-
**kwargs,
|
192 |
-
)
|
193 |
-
.generations[0][0]
|
194 |
-
.text
|
195 |
-
)
|
196 |
-
return (
|
197 |
-
self.generate_prompt(
|
198 |
-
[self._convert_input(input)],
|
199 |
-
stop=stop,
|
200 |
-
callbacks=config.get("callbacks"),
|
201 |
-
tags=config.get("tags"),
|
202 |
-
metadata=config.get("metadata"),
|
203 |
-
run_name=config.get("run_name"),
|
204 |
-
run_id=config.pop("run_id", None),
|
205 |
-
**kwargs,
|
206 |
-
)
|
207 |
-
.generations[0][0]
|
208 |
-
.text
|
209 |
-
)
|
210 |
-
|
211 |
-
async def ainvoke(
|
212 |
-
self,
|
213 |
-
input: MultimodalModelInput,
|
214 |
-
config: Optional[RunnableConfig] = None,
|
215 |
-
*,
|
216 |
-
stop: Optional[List[str]] = None,
|
217 |
-
**kwargs: Any,
|
218 |
-
) -> str:
|
219 |
-
config = ensure_config(config)
|
220 |
-
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
221 |
-
llm_result = await self.agenerate_prompt(
|
222 |
-
[self._convert_input(StringPromptValue(text=input['prompt']))],
|
223 |
-
stop=stop,
|
224 |
-
callbacks=config.get("callbacks"),
|
225 |
-
tags=config.get("tags"),
|
226 |
-
metadata=config.get("metadata"),
|
227 |
-
run_name=config.get("run_name"),
|
228 |
-
run_id=config.pop("run_id", None),
|
229 |
-
image=input['image'],
|
230 |
-
**kwargs,
|
231 |
-
)
|
232 |
-
else:
|
233 |
-
llm_result = await self.agenerate_prompt(
|
234 |
-
[self._convert_input(input)],
|
235 |
-
stop=stop,
|
236 |
-
callbacks=config.get("callbacks"),
|
237 |
-
tags=config.get("tags"),
|
238 |
-
metadata=config.get("metadata"),
|
239 |
-
run_name=config.get("run_name"),
|
240 |
-
run_id=config.pop("run_id", None),
|
241 |
-
**kwargs,
|
242 |
-
)
|
243 |
-
return llm_result.generations[0][0].text
|
244 |
-
|
245 |
-
def stream(
|
246 |
-
self,
|
247 |
-
input: MultimodalModelInput,
|
248 |
-
config: Optional[RunnableConfig] = None,
|
249 |
-
*,
|
250 |
-
stop: Optional[List[str]] = None,
|
251 |
-
**kwargs: Any,
|
252 |
-
) -> Iterator[str]:
|
253 |
-
if type(self)._stream == BaseLLM._stream:
|
254 |
-
# model doesn't implement streaming, so use default implementation
|
255 |
-
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
256 |
-
else:
|
257 |
-
if stop is not None:
|
258 |
-
raise ValueError("stop kwargs are not permitted.")
|
259 |
-
image = None
|
260 |
-
prompt = None
|
261 |
-
if isinstance(input, dict) and 'prompt' in input.keys():
|
262 |
-
prompt = self._convert_input(input['prompt']).to_string()
|
263 |
-
else:
|
264 |
-
raise ValueError("prompt must be provided")
|
265 |
-
if isinstance(input, dict) and 'image' in input.keys():
|
266 |
-
image = input['image']
|
267 |
-
|
268 |
-
for chunk in self._stream(
|
269 |
-
prompt=prompt, image=image, **kwargs
|
270 |
-
):
|
271 |
-
yield chunk
|
272 |
-
|
273 |
-
async def astream(
|
274 |
-
self,
|
275 |
-
input: LanguageModelInput,
|
276 |
-
config: Optional[RunnableConfig] = None,
|
277 |
-
*,
|
278 |
-
stop: Optional[List[str]] = None,
|
279 |
-
**kwargs: Any,
|
280 |
-
) -> AsyncIterator[str]:
|
281 |
-
if (
|
282 |
-
type(self)._astream is BaseLLM._astream
|
283 |
-
and type(self)._stream is BaseLLM._stream
|
284 |
-
):
|
285 |
-
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
286 |
-
return
|
287 |
-
else:
|
288 |
-
if stop is not None:
|
289 |
-
raise ValueError("stop kwargs are not permitted.")
|
290 |
-
image = None
|
291 |
-
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
292 |
-
prompt = self._convert_input(input['prompt']).to_string()
|
293 |
-
image = input['image']
|
294 |
-
else:
|
295 |
-
raise ValueError("missing image is not permitted")
|
296 |
-
prompt = self._convert_input(input).to_string()
|
297 |
-
|
298 |
-
async for chunk in self._astream(
|
299 |
-
prompt=prompt, image=image, **kwargs
|
300 |
-
):
|
301 |
yield chunk
|
|
|
1 |
+
from .client import PredictionGuardClient
|
2 |
+
from langchain_core.language_models.llms import LLM
|
3 |
+
from langchain_core.pydantic_v1 import Extra, root_validator
|
4 |
+
from typing import Any, Optional, List, Dict, Iterator, AsyncIterator
|
5 |
+
from langchain_core.callbacks import CallbackManagerForLLMRun
|
6 |
+
from utility import get_from_dict_or_env, MultimodalModelInput
|
7 |
+
|
8 |
+
from langchain_core.runnables import RunnableConfig, ensure_config
|
9 |
+
from langchain_core.language_models.base import LanguageModelInput
|
10 |
+
from langchain_core.prompt_values import StringPromptValue
|
11 |
+
# from langchain_core.outputs import GenerationChunk, LLMResult
|
12 |
+
from langchain_core.language_models.llms import BaseLLM
|
13 |
+
from langchain_core.callbacks import (
|
14 |
+
# CallbackManager,
|
15 |
+
CallbackManagerForLLMRun,
|
16 |
+
)
|
17 |
+
# from langchain_core.load import dumpd
|
18 |
+
from langchain_core.runnables.config import run_in_executor
|
19 |
+
|
20 |
+
class LVLM(LLM):
|
21 |
+
"""This class extends LLM class for implementing a custom request to LVLM provider API"""
|
22 |
+
|
23 |
+
|
24 |
+
client: Any = None #: :meta private:
|
25 |
+
hostname: Optional[str] = None
|
26 |
+
port: Optional[int] = None
|
27 |
+
url: Optional[str] = None
|
28 |
+
max_new_tokens: Optional[int] = 200
|
29 |
+
temperature: Optional[float] = 0.6
|
30 |
+
top_k: Optional[float] = 0
|
31 |
+
stop: Optional[List[str]] = None
|
32 |
+
ignore_eos: Optional[bool] = False
|
33 |
+
do_sample: Optional[bool] = True
|
34 |
+
lazy_mode: Optional[bool] = True
|
35 |
+
hpu_graphs: Optional[bool] = True
|
36 |
+
|
37 |
+
@root_validator()
|
38 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
39 |
+
"""Validate that the access token and python package exists in environment if needed"""
|
40 |
+
if values['client'] is None:
|
41 |
+
# check if url of API is provided
|
42 |
+
url = get_from_dict_or_env(values, 'url', "VLM_URL", None)
|
43 |
+
if url is None:
|
44 |
+
hostname = get_from_dict_or_env(values, 'hostname', 'VLM_HOSTNAME', None)
|
45 |
+
port = get_from_dict_or_env(values, 'port', 'VLM_PORT', None)
|
46 |
+
if hostname is not None and port is not None:
|
47 |
+
values['client'] = PredictionGuardClient(hostname=hostname, port=port)
|
48 |
+
else:
|
49 |
+
# using default hostname and port to create Client
|
50 |
+
values['client'] = PredictionGuardClient()
|
51 |
+
else:
|
52 |
+
values['client'] = PredictionGuardClient(url=url)
|
53 |
+
return values
|
54 |
+
|
55 |
+
@property
|
56 |
+
def _llm_type(self) -> str:
|
57 |
+
"""Return type of llm"""
|
58 |
+
return "Large Vision Language Model"
|
59 |
+
|
60 |
+
@property
|
61 |
+
def _default_params(self) -> Dict[str, Any]:
|
62 |
+
"""Get the default parameters for calling the Prediction Guard API."""
|
63 |
+
return {
|
64 |
+
"max_tokens": self.max_new_tokens,
|
65 |
+
"temperature": self.temperature,
|
66 |
+
"top_k": self.top_k,
|
67 |
+
"ignore_eos": self.ignore_eos,
|
68 |
+
"do_sample": self.do_sample,
|
69 |
+
"stop" : self.stop,
|
70 |
+
}
|
71 |
+
|
72 |
+
def get_params(self, **kwargs):
|
73 |
+
params = self._default_params
|
74 |
+
params.update(kwargs)
|
75 |
+
return params
|
76 |
+
|
77 |
+
|
78 |
+
def _call(
|
79 |
+
self,
|
80 |
+
prompt: str,
|
81 |
+
image: str,
|
82 |
+
stop: Optional[List[str]] = None,
|
83 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
84 |
+
**kwargs: Any,
|
85 |
+
) -> str:
|
86 |
+
"""Run the VLM on the given input.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
prompt: The prompt to generate from.
|
90 |
+
image: This can be either path to image or base64 encode of the image.
|
91 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
92 |
+
first occurrence of any of the stop substrings.
|
93 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
94 |
+
Returns:
|
95 |
+
The model output as a string. Actual completions DOES NOT include the prompt
|
96 |
+
Example: TBD
|
97 |
+
"""
|
98 |
+
params = {}
|
99 |
+
if stop is not None:
|
100 |
+
raise ValueError("stop kwargs are not permitted.")
|
101 |
+
params['generate_kwargs'] = self.get_params(**kwargs)
|
102 |
+
response = self.client.generate(prompt=prompt, image=image, **params)
|
103 |
+
return response
|
104 |
+
|
105 |
+
def _stream(
|
106 |
+
self,
|
107 |
+
prompt: str,
|
108 |
+
image: str,
|
109 |
+
stop: Optional[List[str]] = None,
|
110 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
111 |
+
**kwargs: Any,
|
112 |
+
) -> Iterator[str]:
|
113 |
+
"""Stream the VLM on the given prompt and image.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
prompt: The prompt to generate from.
|
117 |
+
image: This can be either path to image or base64 encode of the image.
|
118 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
119 |
+
first occurrence of any of the stop substrings.
|
120 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
121 |
+
Returns:
|
122 |
+
The model outputs an iterator of string. Actual completions DOES NOT include the prompt
|
123 |
+
Example: TBD
|
124 |
+
"""
|
125 |
+
params = {}
|
126 |
+
params['generate_kwargs'] = self.get_params(**kwargs)
|
127 |
+
for chunk in self.client.generate_stream(prompt=prompt, image=image, **params):
|
128 |
+
yield chunk
|
129 |
+
|
130 |
+
async def _astream(
|
131 |
+
self,
|
132 |
+
prompt: str,
|
133 |
+
image: str,
|
134 |
+
stop: Optional[List[str]] = None,
|
135 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
136 |
+
**kwargs: Any,
|
137 |
+
) -> AsyncIterator[str]:
|
138 |
+
"""An async version of _stream method that stream the VLM on the given prompt and image.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
prompt: The prompt to generate from.
|
142 |
+
image: This can be either path to image or base64 encode of the image.
|
143 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
144 |
+
first occurrence of any of the stop substrings.
|
145 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
146 |
+
Returns:
|
147 |
+
The model outputs an async iterator of string. Actual completions DOES NOT include the prompt
|
148 |
+
Example: TBD
|
149 |
+
"""
|
150 |
+
iterator = await run_in_executor(
|
151 |
+
None,
|
152 |
+
self._stream,
|
153 |
+
prompt,
|
154 |
+
image,
|
155 |
+
stop,
|
156 |
+
run_manager.get_sync() if run_manager else None,
|
157 |
+
**kwargs,
|
158 |
+
)
|
159 |
+
done = object()
|
160 |
+
while True:
|
161 |
+
item = await run_in_executor(
|
162 |
+
None,
|
163 |
+
next,
|
164 |
+
iterator,
|
165 |
+
done, # type: ignore[call-arg, arg-type]
|
166 |
+
)
|
167 |
+
if item is done:
|
168 |
+
break
|
169 |
+
yield item # type: ignore[misc]
|
170 |
+
|
171 |
+
def invoke(
|
172 |
+
self,
|
173 |
+
input: MultimodalModelInput,
|
174 |
+
config: Optional[RunnableConfig] = None,
|
175 |
+
*,
|
176 |
+
stop: Optional[List[str]] = None,
|
177 |
+
**kwargs: Any,
|
178 |
+
) -> str:
|
179 |
+
config = ensure_config(config)
|
180 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
181 |
+
return (
|
182 |
+
self.generate_prompt(
|
183 |
+
[self._convert_input(StringPromptValue(text=input['prompt']))],
|
184 |
+
stop=stop,
|
185 |
+
callbacks=config.get("callbacks"),
|
186 |
+
tags=config.get("tags"),
|
187 |
+
metadata=config.get("metadata"),
|
188 |
+
run_name=config.get("run_name"),
|
189 |
+
run_id=config.pop("run_id", None),
|
190 |
+
image= input['image'],
|
191 |
+
**kwargs,
|
192 |
+
)
|
193 |
+
.generations[0][0]
|
194 |
+
.text
|
195 |
+
)
|
196 |
+
return (
|
197 |
+
self.generate_prompt(
|
198 |
+
[self._convert_input(input)],
|
199 |
+
stop=stop,
|
200 |
+
callbacks=config.get("callbacks"),
|
201 |
+
tags=config.get("tags"),
|
202 |
+
metadata=config.get("metadata"),
|
203 |
+
run_name=config.get("run_name"),
|
204 |
+
run_id=config.pop("run_id", None),
|
205 |
+
**kwargs,
|
206 |
+
)
|
207 |
+
.generations[0][0]
|
208 |
+
.text
|
209 |
+
)
|
210 |
+
|
211 |
+
async def ainvoke(
|
212 |
+
self,
|
213 |
+
input: MultimodalModelInput,
|
214 |
+
config: Optional[RunnableConfig] = None,
|
215 |
+
*,
|
216 |
+
stop: Optional[List[str]] = None,
|
217 |
+
**kwargs: Any,
|
218 |
+
) -> str:
|
219 |
+
config = ensure_config(config)
|
220 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
221 |
+
llm_result = await self.agenerate_prompt(
|
222 |
+
[self._convert_input(StringPromptValue(text=input['prompt']))],
|
223 |
+
stop=stop,
|
224 |
+
callbacks=config.get("callbacks"),
|
225 |
+
tags=config.get("tags"),
|
226 |
+
metadata=config.get("metadata"),
|
227 |
+
run_name=config.get("run_name"),
|
228 |
+
run_id=config.pop("run_id", None),
|
229 |
+
image=input['image'],
|
230 |
+
**kwargs,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
llm_result = await self.agenerate_prompt(
|
234 |
+
[self._convert_input(input)],
|
235 |
+
stop=stop,
|
236 |
+
callbacks=config.get("callbacks"),
|
237 |
+
tags=config.get("tags"),
|
238 |
+
metadata=config.get("metadata"),
|
239 |
+
run_name=config.get("run_name"),
|
240 |
+
run_id=config.pop("run_id", None),
|
241 |
+
**kwargs,
|
242 |
+
)
|
243 |
+
return llm_result.generations[0][0].text
|
244 |
+
|
245 |
+
def stream(
|
246 |
+
self,
|
247 |
+
input: MultimodalModelInput,
|
248 |
+
config: Optional[RunnableConfig] = None,
|
249 |
+
*,
|
250 |
+
stop: Optional[List[str]] = None,
|
251 |
+
**kwargs: Any,
|
252 |
+
) -> Iterator[str]:
|
253 |
+
if type(self)._stream == BaseLLM._stream:
|
254 |
+
# model doesn't implement streaming, so use default implementation
|
255 |
+
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
256 |
+
else:
|
257 |
+
if stop is not None:
|
258 |
+
raise ValueError("stop kwargs are not permitted.")
|
259 |
+
image = None
|
260 |
+
prompt = None
|
261 |
+
if isinstance(input, dict) and 'prompt' in input.keys():
|
262 |
+
prompt = self._convert_input(input['prompt']).to_string()
|
263 |
+
else:
|
264 |
+
raise ValueError("prompt must be provided")
|
265 |
+
if isinstance(input, dict) and 'image' in input.keys():
|
266 |
+
image = input['image']
|
267 |
+
|
268 |
+
for chunk in self._stream(
|
269 |
+
prompt=prompt, image=image, **kwargs
|
270 |
+
):
|
271 |
+
yield chunk
|
272 |
+
|
273 |
+
async def astream(
|
274 |
+
self,
|
275 |
+
input: LanguageModelInput,
|
276 |
+
config: Optional[RunnableConfig] = None,
|
277 |
+
*,
|
278 |
+
stop: Optional[List[str]] = None,
|
279 |
+
**kwargs: Any,
|
280 |
+
) -> AsyncIterator[str]:
|
281 |
+
if (
|
282 |
+
type(self)._astream is BaseLLM._astream
|
283 |
+
and type(self)._stream is BaseLLM._stream
|
284 |
+
):
|
285 |
+
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
286 |
+
return
|
287 |
+
else:
|
288 |
+
if stop is not None:
|
289 |
+
raise ValueError("stop kwargs are not permitted.")
|
290 |
+
image = None
|
291 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
292 |
+
prompt = self._convert_input(input['prompt']).to_string()
|
293 |
+
image = input['image']
|
294 |
+
else:
|
295 |
+
raise ValueError("missing image is not permitted")
|
296 |
+
prompt = self._convert_input(input).to_string()
|
297 |
+
|
298 |
+
async for chunk in self._astream(
|
299 |
+
prompt=prompt, image=image, **kwargs
|
300 |
+
):
|
301 |
yield chunk
|
mm_rag/embeddings/bridgetower_embeddings.py
CHANGED
@@ -1,89 +1,89 @@
|
|
1 |
-
from typing import List
|
2 |
-
from langchain_core.embeddings import Embeddings
|
3 |
-
import torch
|
4 |
-
from transformers import (
|
5 |
-
BridgeTowerProcessor,
|
6 |
-
BridgeTowerForContrastiveLearning
|
7 |
-
)
|
8 |
-
from langchain_core.pydantic_v1 import (
|
9 |
-
BaseModel,
|
10 |
-
)
|
11 |
-
from lrn_vector_embeddings import bt_embeddings_from_local
|
12 |
-
from utility import encode_image, bt_embedding_from_prediction_guard
|
13 |
-
from tqdm import tqdm
|
14 |
-
from PIL import Image
|
15 |
-
class BridgeTowerEmbeddings(BaseModel, Embeddings):
|
16 |
-
""" BridgeTower embedding model """
|
17 |
-
|
18 |
-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
19 |
-
"""Embed a list of documents using BridgeTower.
|
20 |
-
|
21 |
-
Args:
|
22 |
-
texts: The list of texts to embed.
|
23 |
-
|
24 |
-
Returns:
|
25 |
-
List of embeddings, one for each text.
|
26 |
-
"""
|
27 |
-
|
28 |
-
embeddings = []
|
29 |
-
img = Image.new('RGB', (100, 100))
|
30 |
-
for text in texts:
|
31 |
-
embedding = bt_embeddings_from_local(text, img)
|
32 |
-
embeddings.append(embedding)
|
33 |
-
return embeddings
|
34 |
-
|
35 |
-
def embed_query(self, text: str) -> List[float]:
|
36 |
-
"""Embed a query using BridgeTower.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
text: The text to embed.
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
Embeddings for the text as a flat list of floats.
|
43 |
-
"""
|
44 |
-
# Get embeddings
|
45 |
-
embeddings = self.embed_documents([text])[0]
|
46 |
-
|
47 |
-
# If embeddings is a dict, extract the text embeddings
|
48 |
-
if isinstance(embeddings, dict):
|
49 |
-
embeddings = embeddings["text_embeddings"]
|
50 |
-
|
51 |
-
# If embeddings is a nested list or tensor, flatten it
|
52 |
-
if isinstance(embeddings, (list, torch.Tensor)) and len(embeddings) == 1:
|
53 |
-
embeddings = embeddings[0]
|
54 |
-
|
55 |
-
# Convert tensor to list if needed
|
56 |
-
if torch.is_tensor(embeddings):
|
57 |
-
embeddings = embeddings.detach().tolist()
|
58 |
-
|
59 |
-
return embeddings
|
60 |
-
|
61 |
-
|
62 |
-
def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
|
63 |
-
"""Embed a list of image-text pairs using BridgeTower.
|
64 |
-
|
65 |
-
Args:
|
66 |
-
texts: The list of texts to embed.
|
67 |
-
images: The list of path-to-images to embed
|
68 |
-
batch_size: the batch size to process, default to 2
|
69 |
-
Returns:
|
70 |
-
List of embeddings, one for each image-text pairs.
|
71 |
-
"""
|
72 |
-
|
73 |
-
# the length of texts must be equal to the length of images
|
74 |
-
assert len(texts)==len(images), "the len of captions should be equal to the len of images"
|
75 |
-
|
76 |
-
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
77 |
-
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
embeddings = []
|
82 |
-
for path_to_img, text in tqdm(zip(images, texts), total=len(texts)):
|
83 |
-
inputs = processor(text=[text], images=[Image.open(path_to_img)], return_tensors="pt")
|
84 |
-
outputs = model(**inputs)
|
85 |
-
# Get embeddings and convert to list
|
86 |
-
embedding = outputs.text_embeds.detach().numpy().tolist()[0]
|
87 |
-
embeddings.append(embedding)
|
88 |
-
|
89 |
return embeddings
|
|
|
1 |
+
from typing import List
|
2 |
+
from langchain_core.embeddings import Embeddings
|
3 |
+
import torch
|
4 |
+
from transformers import (
|
5 |
+
BridgeTowerProcessor,
|
6 |
+
BridgeTowerForContrastiveLearning
|
7 |
+
)
|
8 |
+
from langchain_core.pydantic_v1 import (
|
9 |
+
BaseModel,
|
10 |
+
)
|
11 |
+
from lrn_vector_embeddings import bt_embeddings_from_local
|
12 |
+
from utility import encode_image, bt_embedding_from_prediction_guard
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
class BridgeTowerEmbeddings(BaseModel, Embeddings):
|
16 |
+
""" BridgeTower embedding model """
|
17 |
+
|
18 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
19 |
+
"""Embed a list of documents using BridgeTower.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
texts: The list of texts to embed.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
List of embeddings, one for each text.
|
26 |
+
"""
|
27 |
+
|
28 |
+
embeddings = []
|
29 |
+
img = Image.new('RGB', (100, 100))
|
30 |
+
for text in texts:
|
31 |
+
embedding = bt_embeddings_from_local(text, img)
|
32 |
+
embeddings.append(embedding)
|
33 |
+
return embeddings
|
34 |
+
|
35 |
+
def embed_query(self, text: str) -> List[float]:
|
36 |
+
"""Embed a query using BridgeTower.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
text: The text to embed.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Embeddings for the text as a flat list of floats.
|
43 |
+
"""
|
44 |
+
# Get embeddings
|
45 |
+
embeddings = self.embed_documents([text])[0]
|
46 |
+
|
47 |
+
# If embeddings is a dict, extract the text embeddings
|
48 |
+
if isinstance(embeddings, dict):
|
49 |
+
embeddings = embeddings["text_embeddings"]
|
50 |
+
|
51 |
+
# If embeddings is a nested list or tensor, flatten it
|
52 |
+
if isinstance(embeddings, (list, torch.Tensor)) and len(embeddings) == 1:
|
53 |
+
embeddings = embeddings[0]
|
54 |
+
|
55 |
+
# Convert tensor to list if needed
|
56 |
+
if torch.is_tensor(embeddings):
|
57 |
+
embeddings = embeddings.detach().tolist()
|
58 |
+
|
59 |
+
return embeddings
|
60 |
+
|
61 |
+
|
62 |
+
def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
|
63 |
+
"""Embed a list of image-text pairs using BridgeTower.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
texts: The list of texts to embed.
|
67 |
+
images: The list of path-to-images to embed
|
68 |
+
batch_size: the batch size to process, default to 2
|
69 |
+
Returns:
|
70 |
+
List of embeddings, one for each image-text pairs.
|
71 |
+
"""
|
72 |
+
|
73 |
+
# the length of texts must be equal to the length of images
|
74 |
+
assert len(texts)==len(images), "the len of captions should be equal to the len of images"
|
75 |
+
|
76 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
77 |
+
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
embeddings = []
|
82 |
+
for path_to_img, text in tqdm(zip(images, texts), total=len(texts)):
|
83 |
+
inputs = processor(text=[text], images=[Image.open(path_to_img)], return_tensors="pt")
|
84 |
+
outputs = model(**inputs)
|
85 |
+
# Get embeddings and convert to list
|
86 |
+
embedding = outputs.text_embeds.detach().numpy().tolist()[0]
|
87 |
+
embeddings.append(embedding)
|
88 |
+
|
89 |
return embeddings
|
mm_rag/vectorstores/multimodal_lancedb.py
CHANGED
@@ -1,131 +1,131 @@
|
|
1 |
-
from typing import Any, Iterable, List, Optional
|
2 |
-
from langchain_core.embeddings import Embeddings
|
3 |
-
import uuid
|
4 |
-
from langchain_community.vectorstores.lancedb import LanceDB
|
5 |
-
|
6 |
-
class MultimodalLanceDB(LanceDB):
|
7 |
-
"""`LanceDB` vector store to process multimodal data
|
8 |
-
|
9 |
-
To use, you should have ``lancedb`` python package installed.
|
10 |
-
You can install it with ``pip install lancedb``.
|
11 |
-
|
12 |
-
Args:
|
13 |
-
connection: LanceDB connection to use. If not provided, a new connection
|
14 |
-
will be created.
|
15 |
-
embedding: Embedding to use for the vectorstore.
|
16 |
-
vector_key: Key to use for the vector in the database. Defaults to ``vector``.
|
17 |
-
id_key: Key to use for the id in the database. Defaults to ``id``.
|
18 |
-
text_key: Key to use for the text in the database. Defaults to ``text``.
|
19 |
-
image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.
|
20 |
-
table_name: Name of the table to use. Defaults to ``vectorstore``.
|
21 |
-
api_key: API key to use for LanceDB cloud database.
|
22 |
-
region: Region to use for LanceDB cloud database.
|
23 |
-
mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
Example:
|
28 |
-
.. code-block:: python
|
29 |
-
vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)
|
30 |
-
vectorstore.add_texts(['text1', 'text2'])
|
31 |
-
result = vectorstore.similarity_search('text1')
|
32 |
-
"""
|
33 |
-
|
34 |
-
def __init__(
|
35 |
-
self,
|
36 |
-
connection: Optional[Any] = None,
|
37 |
-
embedding: Optional[Embeddings] = None,
|
38 |
-
uri: Optional[str] = "/tmp/lancedb",
|
39 |
-
vector_key: Optional[str] = "vector",
|
40 |
-
id_key: Optional[str] = "id",
|
41 |
-
text_key: Optional[str] = "text",
|
42 |
-
image_path_key: Optional[str] = "image_path",
|
43 |
-
table_name: Optional[str] = "vectorstore",
|
44 |
-
api_key: Optional[str] = None,
|
45 |
-
region: Optional[str] = None,
|
46 |
-
mode: Optional[str] = "append",
|
47 |
-
):
|
48 |
-
super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
|
49 |
-
self._image_path_key = image_path_key
|
50 |
-
|
51 |
-
def add_text_image_pairs(
|
52 |
-
self,
|
53 |
-
texts: Iterable[str],
|
54 |
-
image_paths: Iterable[str],
|
55 |
-
metadatas: Optional[List[dict]] = None,
|
56 |
-
ids: Optional[List[str]] = None,
|
57 |
-
**kwargs: Any,
|
58 |
-
) -> List[str]:
|
59 |
-
"""Turn text-image pairs into embedding and add it to the database
|
60 |
-
|
61 |
-
Args:
|
62 |
-
texts: Iterable of strings to combine with corresponding images to add to the vectorstore.
|
63 |
-
images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.
|
64 |
-
metadatas: Optional list of metadatas associated with the texts.
|
65 |
-
ids: Optional list of ids to associate w ith the texts.
|
66 |
-
|
67 |
-
Returns:
|
68 |
-
List of ids of the added text-image pairs.
|
69 |
-
"""
|
70 |
-
# the length of texts must be equal to the length of images
|
71 |
-
assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"
|
72 |
-
|
73 |
-
# Embed texts and create documents
|
74 |
-
docs = []
|
75 |
-
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
76 |
-
embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) # type: ignore
|
77 |
-
for idx, text in enumerate(texts):
|
78 |
-
embedding = embeddings[idx]
|
79 |
-
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
|
80 |
-
docs.append(
|
81 |
-
{
|
82 |
-
self._vector_key: embedding,
|
83 |
-
self._id_key: ids[idx],
|
84 |
-
self._text_key: text,
|
85 |
-
self._image_path_key : image_paths[idx],
|
86 |
-
"metadata": metadata,
|
87 |
-
}
|
88 |
-
)
|
89 |
-
|
90 |
-
if 'mode' in kwargs:
|
91 |
-
mode = kwargs['mode']
|
92 |
-
else:
|
93 |
-
mode = self.mode
|
94 |
-
if self._table_name in self._connection.table_names():
|
95 |
-
tbl = self._connection.open_table(self._table_name)
|
96 |
-
if self.api_key is None:
|
97 |
-
tbl.add(docs, mode=mode)
|
98 |
-
else:
|
99 |
-
tbl.add(docs)
|
100 |
-
else:
|
101 |
-
self._connection.create_table(self._table_name, data=docs)
|
102 |
-
return ids
|
103 |
-
|
104 |
-
@classmethod
|
105 |
-
def from_text_image_pairs(
|
106 |
-
cls,
|
107 |
-
texts: List[str],
|
108 |
-
image_paths: List[str],
|
109 |
-
embedding: Embeddings,
|
110 |
-
metadatas: Optional[List[dict]] = None,
|
111 |
-
connection: Any = None,
|
112 |
-
vector_key: Optional[str] = "vector",
|
113 |
-
id_key: Optional[str] = "id",
|
114 |
-
text_key: Optional[str] = "text",
|
115 |
-
image_path_key: Optional[str] = "image_path",
|
116 |
-
table_name: Optional[str] = "vectorstore",
|
117 |
-
**kwargs: Any,
|
118 |
-
):
|
119 |
-
|
120 |
-
instance = MultimodalLanceDB(
|
121 |
-
connection=connection,
|
122 |
-
embedding=embedding,
|
123 |
-
vector_key=vector_key,
|
124 |
-
id_key=id_key,
|
125 |
-
text_key=text_key,
|
126 |
-
image_path_key=image_path_key,
|
127 |
-
table_name=table_name,
|
128 |
-
)
|
129 |
-
instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)
|
130 |
-
|
131 |
return instance
|
|
|
1 |
+
from typing import Any, Iterable, List, Optional
|
2 |
+
from langchain_core.embeddings import Embeddings
|
3 |
+
import uuid
|
4 |
+
from langchain_community.vectorstores.lancedb import LanceDB
|
5 |
+
|
6 |
+
class MultimodalLanceDB(LanceDB):
|
7 |
+
"""`LanceDB` vector store to process multimodal data
|
8 |
+
|
9 |
+
To use, you should have ``lancedb`` python package installed.
|
10 |
+
You can install it with ``pip install lancedb``.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
connection: LanceDB connection to use. If not provided, a new connection
|
14 |
+
will be created.
|
15 |
+
embedding: Embedding to use for the vectorstore.
|
16 |
+
vector_key: Key to use for the vector in the database. Defaults to ``vector``.
|
17 |
+
id_key: Key to use for the id in the database. Defaults to ``id``.
|
18 |
+
text_key: Key to use for the text in the database. Defaults to ``text``.
|
19 |
+
image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.
|
20 |
+
table_name: Name of the table to use. Defaults to ``vectorstore``.
|
21 |
+
api_key: API key to use for LanceDB cloud database.
|
22 |
+
region: Region to use for LanceDB cloud database.
|
23 |
+
mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
Example:
|
28 |
+
.. code-block:: python
|
29 |
+
vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)
|
30 |
+
vectorstore.add_texts(['text1', 'text2'])
|
31 |
+
result = vectorstore.similarity_search('text1')
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
connection: Optional[Any] = None,
|
37 |
+
embedding: Optional[Embeddings] = None,
|
38 |
+
uri: Optional[str] = "/tmp/lancedb",
|
39 |
+
vector_key: Optional[str] = "vector",
|
40 |
+
id_key: Optional[str] = "id",
|
41 |
+
text_key: Optional[str] = "text",
|
42 |
+
image_path_key: Optional[str] = "image_path",
|
43 |
+
table_name: Optional[str] = "vectorstore",
|
44 |
+
api_key: Optional[str] = None,
|
45 |
+
region: Optional[str] = None,
|
46 |
+
mode: Optional[str] = "append",
|
47 |
+
):
|
48 |
+
super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
|
49 |
+
self._image_path_key = image_path_key
|
50 |
+
|
51 |
+
def add_text_image_pairs(
|
52 |
+
self,
|
53 |
+
texts: Iterable[str],
|
54 |
+
image_paths: Iterable[str],
|
55 |
+
metadatas: Optional[List[dict]] = None,
|
56 |
+
ids: Optional[List[str]] = None,
|
57 |
+
**kwargs: Any,
|
58 |
+
) -> List[str]:
|
59 |
+
"""Turn text-image pairs into embedding and add it to the database
|
60 |
+
|
61 |
+
Args:
|
62 |
+
texts: Iterable of strings to combine with corresponding images to add to the vectorstore.
|
63 |
+
images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.
|
64 |
+
metadatas: Optional list of metadatas associated with the texts.
|
65 |
+
ids: Optional list of ids to associate w ith the texts.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
List of ids of the added text-image pairs.
|
69 |
+
"""
|
70 |
+
# the length of texts must be equal to the length of images
|
71 |
+
assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"
|
72 |
+
|
73 |
+
# Embed texts and create documents
|
74 |
+
docs = []
|
75 |
+
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
76 |
+
embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) # type: ignore
|
77 |
+
for idx, text in enumerate(texts):
|
78 |
+
embedding = embeddings[idx]
|
79 |
+
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
|
80 |
+
docs.append(
|
81 |
+
{
|
82 |
+
self._vector_key: embedding,
|
83 |
+
self._id_key: ids[idx],
|
84 |
+
self._text_key: text,
|
85 |
+
self._image_path_key : image_paths[idx],
|
86 |
+
"metadata": metadata,
|
87 |
+
}
|
88 |
+
)
|
89 |
+
|
90 |
+
if 'mode' in kwargs:
|
91 |
+
mode = kwargs['mode']
|
92 |
+
else:
|
93 |
+
mode = self.mode
|
94 |
+
if self._table_name in self._connection.table_names():
|
95 |
+
tbl = self._connection.open_table(self._table_name)
|
96 |
+
if self.api_key is None:
|
97 |
+
tbl.add(docs, mode=mode)
|
98 |
+
else:
|
99 |
+
tbl.add(docs)
|
100 |
+
else:
|
101 |
+
self._connection.create_table(self._table_name, data=docs)
|
102 |
+
return ids
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def from_text_image_pairs(
|
106 |
+
cls,
|
107 |
+
texts: List[str],
|
108 |
+
image_paths: List[str],
|
109 |
+
embedding: Embeddings,
|
110 |
+
metadatas: Optional[List[dict]] = None,
|
111 |
+
connection: Any = None,
|
112 |
+
vector_key: Optional[str] = "vector",
|
113 |
+
id_key: Optional[str] = "id",
|
114 |
+
text_key: Optional[str] = "text",
|
115 |
+
image_path_key: Optional[str] = "image_path",
|
116 |
+
table_name: Optional[str] = "vectorstore",
|
117 |
+
**kwargs: Any,
|
118 |
+
):
|
119 |
+
|
120 |
+
instance = MultimodalLanceDB(
|
121 |
+
connection=connection,
|
122 |
+
embedding=embedding,
|
123 |
+
vector_key=vector_key,
|
124 |
+
id_key=id_key,
|
125 |
+
text_key=text_key,
|
126 |
+
image_path_key=image_path_key,
|
127 |
+
table_name=table_name,
|
128 |
+
)
|
129 |
+
instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)
|
130 |
+
|
131 |
return instance
|
requirements.txt
CHANGED
@@ -1,25 +1,25 @@
|
|
1 |
-
gradio
|
2 |
-
langchain-predictionguard
|
3 |
-
IPython
|
4 |
-
umap-learn
|
5 |
-
pytubefix
|
6 |
-
youtube_transcript_api
|
7 |
-
torch
|
8 |
-
transformers
|
9 |
-
matplotlib
|
10 |
-
seaborn
|
11 |
-
datasets
|
12 |
-
moviepy
|
13 |
-
whisper
|
14 |
-
webvtt-py
|
15 |
-
tqdm
|
16 |
-
lancedb
|
17 |
-
langchain-core
|
18 |
-
langchain-community
|
19 |
-
ollama
|
20 |
-
opencv-python
|
21 |
-
openai-whisper
|
22 |
-
huggingface_hub[cli]
|
23 |
-
huggingface_hub
|
24 |
-
pillow
|
25 |
-
accelerate>=0.26.0
|
|
|
1 |
+
gradio
|
2 |
+
langchain-predictionguard
|
3 |
+
IPython
|
4 |
+
umap-learn
|
5 |
+
pytubefix
|
6 |
+
youtube_transcript_api
|
7 |
+
torch
|
8 |
+
transformers
|
9 |
+
matplotlib
|
10 |
+
seaborn
|
11 |
+
datasets
|
12 |
+
moviepy
|
13 |
+
whisper
|
14 |
+
webvtt-py
|
15 |
+
tqdm
|
16 |
+
lancedb
|
17 |
+
langchain-core
|
18 |
+
langchain-community
|
19 |
+
ollama
|
20 |
+
opencv-python
|
21 |
+
openai-whisper
|
22 |
+
huggingface_hub[cli]
|
23 |
+
huggingface_hub
|
24 |
+
pillow
|
25 |
+
accelerate>=0.26.0
|
s6_prepare_video_input.py
CHANGED
@@ -1,90 +1,90 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
import os
|
3 |
-
from os import path as osp
|
4 |
-
import whisper
|
5 |
-
from moviepy import VideoFileClip
|
6 |
-
from PIL import Image
|
7 |
-
from utility import download_video, extract_meta_data, get_transcript_vtt, getSubs
|
8 |
-
from urllib.request import urlretrieve
|
9 |
-
from IPython.display import display
|
10 |
-
import ollama
|
11 |
-
|
12 |
-
def demp_video_input_that_has_transcript():
|
13 |
-
# first video's url
|
14 |
-
vid_url = "https://www.youtube.com/watch?v=7Hcg-rLYwdM"
|
15 |
-
|
16 |
-
# download Youtube video to ./shared_data/videos/video1
|
17 |
-
vid_dir = "./shared_data/videos/video1"
|
18 |
-
vid_filepath = download_video(vid_url, vid_dir)
|
19 |
-
|
20 |
-
# download Youtube video's subtitle to ./shared_data/videos/video1
|
21 |
-
vid_transcript_filepath = get_transcript_vtt(vid_url, vid_dir)
|
22 |
-
|
23 |
-
return extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath)
|
24 |
-
|
25 |
-
def demp_video_input_that_has_no_transcript():
|
26 |
-
# second video's url
|
27 |
-
vid_url=(
|
28 |
-
"https://multimedia-commons.s3-us-west-2.amazonaws.com/"
|
29 |
-
"data/videos/mp4/010/a07/010a074acb1975c4d6d6e43c1faeb8.mp4"
|
30 |
-
)
|
31 |
-
vid_dir = "./shared_data/videos/video2"
|
32 |
-
vid_name = "toddler_in_playground.mp4"
|
33 |
-
|
34 |
-
# create folder to which video2 will be downloaded
|
35 |
-
Path(vid_dir).mkdir(parents=True, exist_ok=True)
|
36 |
-
vid_filepath = urlretrieve(
|
37 |
-
vid_url,
|
38 |
-
osp.join(vid_dir, vid_name)
|
39 |
-
)[0]
|
40 |
-
|
41 |
-
path_to_video_no_transcript = vid_filepath
|
42 |
-
|
43 |
-
# declare where to save .mp3 audio
|
44 |
-
path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
|
45 |
-
|
46 |
-
# extract mp3 audio file from mp4 video video file
|
47 |
-
clip = VideoFileClip(path_to_video_no_transcript)
|
48 |
-
clip.audio.write_audiofile(path_to_extracted_audio_file)
|
49 |
-
|
50 |
-
model = whisper.load_model("small")
|
51 |
-
options = dict(task="translate", best_of=1, language='en')
|
52 |
-
results = model.transcribe(path_to_extracted_audio_file, **options)
|
53 |
-
|
54 |
-
vtt = getSubs(results["segments"], "vtt")
|
55 |
-
|
56 |
-
# path to save generated transcript of video1
|
57 |
-
path_to_generated_trans = osp.join(vid_dir, 'generated_video1.vtt')
|
58 |
-
# write transcription to file
|
59 |
-
with open(path_to_generated_trans, 'w') as f:
|
60 |
-
f.write(vtt)
|
61 |
-
|
62 |
-
return extract_meta_data(vid_dir, vid_filepath, path_to_generated_trans)
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
def ask_llvm(instruction, file_path):
|
67 |
-
result = ollama.generate(
|
68 |
-
model='llava',
|
69 |
-
prompt=instruction,
|
70 |
-
images=[file_path],
|
71 |
-
stream=False
|
72 |
-
)['response']
|
73 |
-
img=Image.open(file_path, mode='r')
|
74 |
-
img = img.resize([int(i/1.2) for i in img.size])
|
75 |
-
display(img)
|
76 |
-
for i in result.split('.'):
|
77 |
-
print(i, end='', flush=True)
|
78 |
-
if __name__ == "__main__":
|
79 |
-
meta_data = demp_video_input_that_has_transcript()
|
80 |
-
|
81 |
-
meta_data1 = demp_video_input_that_has_no_transcript()
|
82 |
-
data = meta_data1[1]
|
83 |
-
caption = data['transcript']
|
84 |
-
print(f'Generated caption is: "{caption}"')
|
85 |
-
frame = Image.open(data['extracted_frame_path'])
|
86 |
-
display(frame)
|
87 |
-
instruction = "Can you describe the image?"
|
88 |
-
ask_llvm(instruction, data['extracted_frame_path'])
|
89 |
-
#print(meta_data)
|
90 |
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
from os import path as osp
|
4 |
+
import whisper
|
5 |
+
from moviepy import VideoFileClip
|
6 |
+
from PIL import Image
|
7 |
+
from utility import download_video, extract_meta_data, get_transcript_vtt, getSubs
|
8 |
+
from urllib.request import urlretrieve
|
9 |
+
from IPython.display import display
|
10 |
+
import ollama
|
11 |
+
|
12 |
+
def demp_video_input_that_has_transcript():
|
13 |
+
# first video's url
|
14 |
+
vid_url = "https://www.youtube.com/watch?v=7Hcg-rLYwdM"
|
15 |
+
|
16 |
+
# download Youtube video to ./shared_data/videos/video1
|
17 |
+
vid_dir = "./shared_data/videos/video1"
|
18 |
+
vid_filepath = download_video(vid_url, vid_dir)
|
19 |
+
|
20 |
+
# download Youtube video's subtitle to ./shared_data/videos/video1
|
21 |
+
vid_transcript_filepath = get_transcript_vtt(vid_url, vid_dir)
|
22 |
+
|
23 |
+
return extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath)
|
24 |
+
|
25 |
+
def demp_video_input_that_has_no_transcript():
|
26 |
+
# second video's url
|
27 |
+
vid_url=(
|
28 |
+
"https://multimedia-commons.s3-us-west-2.amazonaws.com/"
|
29 |
+
"data/videos/mp4/010/a07/010a074acb1975c4d6d6e43c1faeb8.mp4"
|
30 |
+
)
|
31 |
+
vid_dir = "./shared_data/videos/video2"
|
32 |
+
vid_name = "toddler_in_playground.mp4"
|
33 |
+
|
34 |
+
# create folder to which video2 will be downloaded
|
35 |
+
Path(vid_dir).mkdir(parents=True, exist_ok=True)
|
36 |
+
vid_filepath = urlretrieve(
|
37 |
+
vid_url,
|
38 |
+
osp.join(vid_dir, vid_name)
|
39 |
+
)[0]
|
40 |
+
|
41 |
+
path_to_video_no_transcript = vid_filepath
|
42 |
+
|
43 |
+
# declare where to save .mp3 audio
|
44 |
+
path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
|
45 |
+
|
46 |
+
# extract mp3 audio file from mp4 video video file
|
47 |
+
clip = VideoFileClip(path_to_video_no_transcript)
|
48 |
+
clip.audio.write_audiofile(path_to_extracted_audio_file)
|
49 |
+
|
50 |
+
model = whisper.load_model("small")
|
51 |
+
options = dict(task="translate", best_of=1, language='en')
|
52 |
+
results = model.transcribe(path_to_extracted_audio_file, **options)
|
53 |
+
|
54 |
+
vtt = getSubs(results["segments"], "vtt")
|
55 |
+
|
56 |
+
# path to save generated transcript of video1
|
57 |
+
path_to_generated_trans = osp.join(vid_dir, 'generated_video1.vtt')
|
58 |
+
# write transcription to file
|
59 |
+
with open(path_to_generated_trans, 'w') as f:
|
60 |
+
f.write(vtt)
|
61 |
+
|
62 |
+
return extract_meta_data(vid_dir, vid_filepath, path_to_generated_trans)
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def ask_llvm(instruction, file_path):
|
67 |
+
result = ollama.generate(
|
68 |
+
model='llava',
|
69 |
+
prompt=instruction,
|
70 |
+
images=[file_path],
|
71 |
+
stream=False
|
72 |
+
)['response']
|
73 |
+
img=Image.open(file_path, mode='r')
|
74 |
+
img = img.resize([int(i/1.2) for i in img.size])
|
75 |
+
display(img)
|
76 |
+
for i in result.split('.'):
|
77 |
+
print(i, end='', flush=True)
|
78 |
+
if __name__ == "__main__":
|
79 |
+
meta_data = demp_video_input_that_has_transcript()
|
80 |
+
|
81 |
+
meta_data1 = demp_video_input_that_has_no_transcript()
|
82 |
+
data = meta_data1[1]
|
83 |
+
caption = data['transcript']
|
84 |
+
print(f'Generated caption is: "{caption}"')
|
85 |
+
frame = Image.open(data['extracted_frame_path'])
|
86 |
+
display(frame)
|
87 |
+
instruction = "Can you describe the image?"
|
88 |
+
ask_llvm(instruction, data['extracted_frame_path'])
|
89 |
+
#print(meta_data)
|
90 |
|
s7_store_in_rag.py
CHANGED
@@ -1,105 +1,105 @@
|
|
1 |
-
from mm_rag.embeddings.bridgetower_embeddings import (
|
2 |
-
BridgeTowerEmbeddings
|
3 |
-
)
|
4 |
-
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
5 |
-
import lancedb
|
6 |
-
import json
|
7 |
-
import os
|
8 |
-
from PIL import Image
|
9 |
-
from utility import load_json_file, display_retrieved_results
|
10 |
-
import pyarrow as pa
|
11 |
-
|
12 |
-
# declare host file
|
13 |
-
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
14 |
-
# declare table name
|
15 |
-
TBL_NAME = "test_tbl"
|
16 |
-
# initialize vectorstore
|
17 |
-
db = lancedb.connect(LANCEDB_HOST_FILE)
|
18 |
-
# initialize an BridgeTower embedder
|
19 |
-
embedder = BridgeTowerEmbeddings()
|
20 |
-
|
21 |
-
|
22 |
-
def return_top_k_most_similar_docs(max_docs=3):
|
23 |
-
# ask to return top 3 most similar documents
|
24 |
-
# Creating a LanceDB vector store
|
25 |
-
vectorstore = MultimodalLanceDB(
|
26 |
-
uri=LANCEDB_HOST_FILE,
|
27 |
-
embedding=embedder,
|
28 |
-
table_name=TBL_NAME)
|
29 |
-
|
30 |
-
# creating a retriever for the vector store
|
31 |
-
# search_type="similarity"
|
32 |
-
# declares that the type of search that the Retriever should perform
|
33 |
-
# is similarity search
|
34 |
-
# search_kwargs={"k": 1} means returning top-1 most similar document
|
35 |
-
|
36 |
-
|
37 |
-
retriever = vectorstore.as_retriever(
|
38 |
-
search_type='similarity',
|
39 |
-
search_kwargs={"k": max_docs})
|
40 |
-
query2 = (
|
41 |
-
"an astronaut's spacewalk "
|
42 |
-
"with an amazing view of the earth from space behind"
|
43 |
-
)
|
44 |
-
results2 = retriever.invoke(query2)
|
45 |
-
display_retrieved_results(results2)
|
46 |
-
query3 = "a group of astronauts"
|
47 |
-
results3 = retriever.invoke(query3)
|
48 |
-
display_retrieved_results(results3)
|
49 |
-
|
50 |
-
|
51 |
-
def open_table(TBL_NAME):
|
52 |
-
# open a connection to table TBL_NAME
|
53 |
-
tbl = db.open_table()
|
54 |
-
|
55 |
-
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
56 |
-
# display the first 3 rows of the table
|
57 |
-
tbl.to_pandas()[['text', 'image_path']].head(3)
|
58 |
-
|
59 |
-
def store_in_rag():
|
60 |
-
|
61 |
-
# load metadata files
|
62 |
-
vid1_metadata_path = './shared_data/videos/video1/metadatas.json'
|
63 |
-
vid2_metadata_path = './shared_data/videos/video2/metadatas.json'
|
64 |
-
vid1_metadata = load_json_file(vid1_metadata_path)
|
65 |
-
vid2_metadata = load_json_file(vid2_metadata_path)
|
66 |
-
|
67 |
-
# collect transcripts and image paths
|
68 |
-
vid1_trans = [vid['transcript'] for vid in vid1_metadata]
|
69 |
-
vid1_img_path = [vid['extracted_frame_path'] for vid in vid1_metadata]
|
70 |
-
|
71 |
-
vid2_trans = [vid['transcript'] for vid in vid2_metadata]
|
72 |
-
vid2_img_path = [vid['extracted_frame_path'] for vid in vid2_metadata]
|
73 |
-
|
74 |
-
|
75 |
-
# for video1, we pick n = 7
|
76 |
-
n = 7
|
77 |
-
updated_vid1_trans = [
|
78 |
-
' '.join(vid1_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
|
79 |
-
' '.join(vid1_trans[0 : i + int(n/2)]) for i in range(len(vid1_trans))
|
80 |
-
]
|
81 |
-
|
82 |
-
# also need to update the updated transcripts in metadata
|
83 |
-
for i in range(len(updated_vid1_trans)):
|
84 |
-
vid1_metadata[i]['transcript'] = updated_vid1_trans[i]
|
85 |
-
|
86 |
-
# you can pass in mode="append"
|
87 |
-
# to add more entries to the vector store
|
88 |
-
# in case you want to start with a fresh vector store,
|
89 |
-
# you can pass in mode="overwrite" instead
|
90 |
-
|
91 |
-
_ = MultimodalLanceDB.from_text_image_pairs(
|
92 |
-
texts=updated_vid1_trans+vid2_trans,
|
93 |
-
image_paths=vid1_img_path+vid2_img_path,
|
94 |
-
embedding=embedder,
|
95 |
-
metadatas=vid1_metadata+vid2_metadata,
|
96 |
-
connection=db,
|
97 |
-
table_name=TBL_NAME,
|
98 |
-
mode="overwrite",
|
99 |
-
)
|
100 |
-
|
101 |
-
if __name__ == "__main__":
|
102 |
-
tbl = db.open_table(TBL_NAME)
|
103 |
-
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
104 |
-
#display the first 3 rows of the table
|
105 |
return_top_k_most_similar_docs()
|
|
|
1 |
+
from mm_rag.embeddings.bridgetower_embeddings import (
|
2 |
+
BridgeTowerEmbeddings
|
3 |
+
)
|
4 |
+
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
5 |
+
import lancedb
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
from utility import load_json_file, display_retrieved_results
|
10 |
+
import pyarrow as pa
|
11 |
+
|
12 |
+
# declare host file
|
13 |
+
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
14 |
+
# declare table name
|
15 |
+
TBL_NAME = "test_tbl"
|
16 |
+
# initialize vectorstore
|
17 |
+
db = lancedb.connect(LANCEDB_HOST_FILE)
|
18 |
+
# initialize an BridgeTower embedder
|
19 |
+
embedder = BridgeTowerEmbeddings()
|
20 |
+
|
21 |
+
|
22 |
+
def return_top_k_most_similar_docs(max_docs=3):
|
23 |
+
# ask to return top 3 most similar documents
|
24 |
+
# Creating a LanceDB vector store
|
25 |
+
vectorstore = MultimodalLanceDB(
|
26 |
+
uri=LANCEDB_HOST_FILE,
|
27 |
+
embedding=embedder,
|
28 |
+
table_name=TBL_NAME)
|
29 |
+
|
30 |
+
# creating a retriever for the vector store
|
31 |
+
# search_type="similarity"
|
32 |
+
# declares that the type of search that the Retriever should perform
|
33 |
+
# is similarity search
|
34 |
+
# search_kwargs={"k": 1} means returning top-1 most similar document
|
35 |
+
|
36 |
+
|
37 |
+
retriever = vectorstore.as_retriever(
|
38 |
+
search_type='similarity',
|
39 |
+
search_kwargs={"k": max_docs})
|
40 |
+
query2 = (
|
41 |
+
"an astronaut's spacewalk "
|
42 |
+
"with an amazing view of the earth from space behind"
|
43 |
+
)
|
44 |
+
results2 = retriever.invoke(query2)
|
45 |
+
display_retrieved_results(results2)
|
46 |
+
query3 = "a group of astronauts"
|
47 |
+
results3 = retriever.invoke(query3)
|
48 |
+
display_retrieved_results(results3)
|
49 |
+
|
50 |
+
|
51 |
+
def open_table(TBL_NAME):
|
52 |
+
# open a connection to table TBL_NAME
|
53 |
+
tbl = db.open_table()
|
54 |
+
|
55 |
+
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
56 |
+
# display the first 3 rows of the table
|
57 |
+
tbl.to_pandas()[['text', 'image_path']].head(3)
|
58 |
+
|
59 |
+
def store_in_rag():
|
60 |
+
|
61 |
+
# load metadata files
|
62 |
+
vid1_metadata_path = './shared_data/videos/video1/metadatas.json'
|
63 |
+
vid2_metadata_path = './shared_data/videos/video2/metadatas.json'
|
64 |
+
vid1_metadata = load_json_file(vid1_metadata_path)
|
65 |
+
vid2_metadata = load_json_file(vid2_metadata_path)
|
66 |
+
|
67 |
+
# collect transcripts and image paths
|
68 |
+
vid1_trans = [vid['transcript'] for vid in vid1_metadata]
|
69 |
+
vid1_img_path = [vid['extracted_frame_path'] for vid in vid1_metadata]
|
70 |
+
|
71 |
+
vid2_trans = [vid['transcript'] for vid in vid2_metadata]
|
72 |
+
vid2_img_path = [vid['extracted_frame_path'] for vid in vid2_metadata]
|
73 |
+
|
74 |
+
|
75 |
+
# for video1, we pick n = 7
|
76 |
+
n = 7
|
77 |
+
updated_vid1_trans = [
|
78 |
+
' '.join(vid1_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
|
79 |
+
' '.join(vid1_trans[0 : i + int(n/2)]) for i in range(len(vid1_trans))
|
80 |
+
]
|
81 |
+
|
82 |
+
# also need to update the updated transcripts in metadata
|
83 |
+
for i in range(len(updated_vid1_trans)):
|
84 |
+
vid1_metadata[i]['transcript'] = updated_vid1_trans[i]
|
85 |
+
|
86 |
+
# you can pass in mode="append"
|
87 |
+
# to add more entries to the vector store
|
88 |
+
# in case you want to start with a fresh vector store,
|
89 |
+
# you can pass in mode="overwrite" instead
|
90 |
+
|
91 |
+
_ = MultimodalLanceDB.from_text_image_pairs(
|
92 |
+
texts=updated_vid1_trans+vid2_trans,
|
93 |
+
image_paths=vid1_img_path+vid2_img_path,
|
94 |
+
embedding=embedder,
|
95 |
+
metadatas=vid1_metadata+vid2_metadata,
|
96 |
+
connection=db,
|
97 |
+
table_name=TBL_NAME,
|
98 |
+
mode="overwrite",
|
99 |
+
)
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
tbl = db.open_table(TBL_NAME)
|
103 |
+
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
104 |
+
#display the first 3 rows of the table
|
105 |
return_top_k_most_similar_docs()
|
shared_data/videos/yt_video/blackholes101nationalgeographic/blackholes101nationalgeographic.mp4
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e61dc16ee65a8a0ef08316bfb4d8f7110c7f5186d298a7e28e1cafb3bb25c338
|
3 |
+
size 132
|
shared_data/videos/yt_video/blackholes101nationalgeographic/captions.vtt
CHANGED
@@ -1,104 +1,104 @@
|
|
1 |
-
WEBVTT
|
2 |
-
|
3 |
-
00:00.000 --> 00:08.760
|
4 |
-
Black holes are among the most fascinating objects in our universe, and also the most
|
5 |
-
|
6 |
-
00:08.760 --> 00:13.520
|
7 |
-
mysterious.
|
8 |
-
|
9 |
-
00:13.520 --> 00:19.040
|
10 |
-
A black hole is a region in space where the force of gravity is so strong, not even light,
|
11 |
-
|
12 |
-
00:19.040 --> 00:23.200
|
13 |
-
the fastest known entity in our universe can escape.
|
14 |
-
|
15 |
-
00:23.200 --> 00:28.680
|
16 |
-
The boundary of a black hole is called the event horizon, a point of no return beyond
|
17 |
-
|
18 |
-
00:28.680 --> 00:31.840
|
19 |
-
which we truly cannot see.
|
20 |
-
|
21 |
-
00:31.840 --> 00:37.040
|
22 |
-
When something crosses the event horizon, it collapses into the black hole's singularity,
|
23 |
-
|
24 |
-
00:37.040 --> 00:42.400
|
25 |
-
an infinitely small, infinitely dense point where space, time, and the laws of physics
|
26 |
-
|
27 |
-
00:42.400 --> 00:46.200
|
28 |
-
no longer apply.
|
29 |
-
|
30 |
-
00:46.200 --> 00:51.400
|
31 |
-
Scientists have theorized several different types of black holes, with stellar and supermassive
|
32 |
-
|
33 |
-
00:51.400 --> 00:54.280
|
34 |
-
black holes being the most common.
|
35 |
-
|
36 |
-
00:54.280 --> 00:58.640
|
37 |
-
Stolar black holes form when massive stars die and collapse.
|
38 |
-
|
39 |
-
00:58.640 --> 01:05.080
|
40 |
-
They're roughly 10 to 20 times the mass of our sun, and scattered throughout the universe.
|
41 |
-
|
42 |
-
01:05.080 --> 01:11.040
|
43 |
-
There could be millions of these stellar black holes in the Milky Way alone.
|
44 |
-
|
45 |
-
01:11.040 --> 01:16.440
|
46 |
-
Supermassive black holes are giants by comparison, measuring millions, even billions of times
|
47 |
-
|
48 |
-
01:16.440 --> 01:19.440
|
49 |
-
more massive than our sun.
|
50 |
-
|
51 |
-
01:19.440 --> 01:23.800
|
52 |
-
Scientists can only guess how they form, but we do know they exist at the center of just
|
53 |
-
|
54 |
-
01:23.800 --> 01:28.920
|
55 |
-
about every large galaxy, including our own.
|
56 |
-
|
57 |
-
01:28.920 --> 01:33.760
|
58 |
-
Sagittarius A, the supermassive black hole at the center of the Milky Way, has a mass
|
59 |
-
|
60 |
-
01:33.760 --> 01:39.360
|
61 |
-
of roughly four million suns, and has a diameter about the distance between the Earth and our
|
62 |
-
|
63 |
-
01:39.360 --> 01:41.960
|
64 |
-
sun.
|
65 |
-
|
66 |
-
01:41.960 --> 01:46.680
|
67 |
-
Because black holes are invisible, the only way for scientists to detect and study them
|
68 |
-
|
69 |
-
01:46.680 --> 01:50.040
|
70 |
-
is to observe their effect on nearby matter.
|
71 |
-
|
72 |
-
01:50.040 --> 01:55.360
|
73 |
-
This includes accretion disks, a disk of particles that form when gases and dust fall toward a
|
74 |
-
|
75 |
-
01:55.360 --> 02:03.920
|
76 |
-
black hole, and quasars, jets of particles that blast out of supermassive black holes.
|
77 |
-
|
78 |
-
02:03.920 --> 02:08.720
|
79 |
-
Black holes remained largely unknown until the 20th century.
|
80 |
-
|
81 |
-
02:08.720 --> 02:14.840
|
82 |
-
In 1916, using Einstein's General Theory of Relativity, a German physicist named Karl
|
83 |
-
|
84 |
-
02:14.840 --> 02:20.280
|
85 |
-
Schwartzschild calculated that any mass could become a black hole if it were compressed tightly
|
86 |
-
|
87 |
-
02:20.280 --> 02:22.640
|
88 |
-
enough.
|
89 |
-
|
90 |
-
02:22.640 --> 02:27.480
|
91 |
-
But it wasn't until 1971 when theory became reality.
|
92 |
-
|
93 |
-
02:27.480 --> 02:34.000
|
94 |
-
Astronomers, studying the constellation Cygnus, discovered the first black hole.
|
95 |
-
|
96 |
-
02:34.000 --> 02:39.440
|
97 |
-
An untold number of black holes are scattered throughout the universe, constantly warping
|
98 |
-
|
99 |
-
02:39.440 --> 02:45.600
|
100 |
-
space and time, altering entire galaxies, and endlessly inspiring both scientists and
|
101 |
-
|
102 |
-
02:45.600 --> 02:47.120
|
103 |
-
our collective imagination.
|
104 |
-
|
|
|
1 |
+
WEBVTT
|
2 |
+
|
3 |
+
00:00.000 --> 00:08.760
|
4 |
+
Black holes are among the most fascinating objects in our universe, and also the most
|
5 |
+
|
6 |
+
00:08.760 --> 00:13.520
|
7 |
+
mysterious.
|
8 |
+
|
9 |
+
00:13.520 --> 00:19.040
|
10 |
+
A black hole is a region in space where the force of gravity is so strong, not even light,
|
11 |
+
|
12 |
+
00:19.040 --> 00:23.200
|
13 |
+
the fastest known entity in our universe can escape.
|
14 |
+
|
15 |
+
00:23.200 --> 00:28.680
|
16 |
+
The boundary of a black hole is called the event horizon, a point of no return beyond
|
17 |
+
|
18 |
+
00:28.680 --> 00:31.840
|
19 |
+
which we truly cannot see.
|
20 |
+
|
21 |
+
00:31.840 --> 00:37.040
|
22 |
+
When something crosses the event horizon, it collapses into the black hole's singularity,
|
23 |
+
|
24 |
+
00:37.040 --> 00:42.400
|
25 |
+
an infinitely small, infinitely dense point where space, time, and the laws of physics
|
26 |
+
|
27 |
+
00:42.400 --> 00:46.200
|
28 |
+
no longer apply.
|
29 |
+
|
30 |
+
00:46.200 --> 00:51.400
|
31 |
+
Scientists have theorized several different types of black holes, with stellar and supermassive
|
32 |
+
|
33 |
+
00:51.400 --> 00:54.280
|
34 |
+
black holes being the most common.
|
35 |
+
|
36 |
+
00:54.280 --> 00:58.640
|
37 |
+
Stolar black holes form when massive stars die and collapse.
|
38 |
+
|
39 |
+
00:58.640 --> 01:05.080
|
40 |
+
They're roughly 10 to 20 times the mass of our sun, and scattered throughout the universe.
|
41 |
+
|
42 |
+
01:05.080 --> 01:11.040
|
43 |
+
There could be millions of these stellar black holes in the Milky Way alone.
|
44 |
+
|
45 |
+
01:11.040 --> 01:16.440
|
46 |
+
Supermassive black holes are giants by comparison, measuring millions, even billions of times
|
47 |
+
|
48 |
+
01:16.440 --> 01:19.440
|
49 |
+
more massive than our sun.
|
50 |
+
|
51 |
+
01:19.440 --> 01:23.800
|
52 |
+
Scientists can only guess how they form, but we do know they exist at the center of just
|
53 |
+
|
54 |
+
01:23.800 --> 01:28.920
|
55 |
+
about every large galaxy, including our own.
|
56 |
+
|
57 |
+
01:28.920 --> 01:33.760
|
58 |
+
Sagittarius A, the supermassive black hole at the center of the Milky Way, has a mass
|
59 |
+
|
60 |
+
01:33.760 --> 01:39.360
|
61 |
+
of roughly four million suns, and has a diameter about the distance between the Earth and our
|
62 |
+
|
63 |
+
01:39.360 --> 01:41.960
|
64 |
+
sun.
|
65 |
+
|
66 |
+
01:41.960 --> 01:46.680
|
67 |
+
Because black holes are invisible, the only way for scientists to detect and study them
|
68 |
+
|
69 |
+
01:46.680 --> 01:50.040
|
70 |
+
is to observe their effect on nearby matter.
|
71 |
+
|
72 |
+
01:50.040 --> 01:55.360
|
73 |
+
This includes accretion disks, a disk of particles that form when gases and dust fall toward a
|
74 |
+
|
75 |
+
01:55.360 --> 02:03.920
|
76 |
+
black hole, and quasars, jets of particles that blast out of supermassive black holes.
|
77 |
+
|
78 |
+
02:03.920 --> 02:08.720
|
79 |
+
Black holes remained largely unknown until the 20th century.
|
80 |
+
|
81 |
+
02:08.720 --> 02:14.840
|
82 |
+
In 1916, using Einstein's General Theory of Relativity, a German physicist named Karl
|
83 |
+
|
84 |
+
02:14.840 --> 02:20.280
|
85 |
+
Schwartzschild calculated that any mass could become a black hole if it were compressed tightly
|
86 |
+
|
87 |
+
02:20.280 --> 02:22.640
|
88 |
+
enough.
|
89 |
+
|
90 |
+
02:22.640 --> 02:27.480
|
91 |
+
But it wasn't until 1971 when theory became reality.
|
92 |
+
|
93 |
+
02:27.480 --> 02:34.000
|
94 |
+
Astronomers, studying the constellation Cygnus, discovered the first black hole.
|
95 |
+
|
96 |
+
02:34.000 --> 02:39.440
|
97 |
+
An untold number of black holes are scattered throughout the universe, constantly warping
|
98 |
+
|
99 |
+
02:39.440 --> 02:45.600
|
100 |
+
space and time, altering entire galaxies, and endlessly inspiring both scientists and
|
101 |
+
|
102 |
+
02:45.600 --> 02:47.120
|
103 |
+
our collective imagination.
|
104 |
+
|
utility.py
CHANGED
@@ -1,764 +1,764 @@
|
|
1 |
-
# Add your utilities or helper functions to this file.
|
2 |
-
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
from dotenv import load_dotenv, find_dotenv
|
6 |
-
from io import StringIO, BytesIO
|
7 |
-
import textwrap
|
8 |
-
from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
|
9 |
-
from enum import auto, Enum
|
10 |
-
import base64
|
11 |
-
import glob
|
12 |
-
from moviepy import VideoFileClip
|
13 |
-
import requests
|
14 |
-
from tqdm import tqdm
|
15 |
-
from pytubefix import YouTube, Stream
|
16 |
-
import webvtt
|
17 |
-
import whisper
|
18 |
-
from youtube_transcript_api import YouTubeTranscriptApi
|
19 |
-
from youtube_transcript_api.formatters import WebVTTFormatter
|
20 |
-
from predictionguard import PredictionGuard
|
21 |
-
import cv2
|
22 |
-
import re
|
23 |
-
import json
|
24 |
-
import PIL
|
25 |
-
from ollama import chat
|
26 |
-
from ollama import ChatResponse
|
27 |
-
from PIL import Image
|
28 |
-
import dataclasses
|
29 |
-
import random
|
30 |
-
from datasets import load_dataset
|
31 |
-
from os import path as osp
|
32 |
-
from IPython.display import display
|
33 |
-
from langchain_core.prompt_values import PromptValue
|
34 |
-
from langchain_core.messages import (
|
35 |
-
MessageLikeRepresentation,
|
36 |
-
)
|
37 |
-
from transformers import pipeline
|
38 |
-
from huggingface_hub import InferenceClient
|
39 |
-
|
40 |
-
MultimodalModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation], Dict[str, Any]]
|
41 |
-
|
42 |
-
def get_from_dict_or_env(
|
43 |
-
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
44 |
-
) -> str:
|
45 |
-
"""Get a value from a dictionary or an environment variable."""
|
46 |
-
if key in data and data[key]:
|
47 |
-
return data[key]
|
48 |
-
else:
|
49 |
-
return get_from_env(key, env_key, default=default)
|
50 |
-
|
51 |
-
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
52 |
-
"""Get a value from a dictionary or an environment variable."""
|
53 |
-
if env_key in os.environ and os.environ[env_key]:
|
54 |
-
return os.environ[env_key]
|
55 |
-
else:
|
56 |
-
return default
|
57 |
-
|
58 |
-
def load_env():
|
59 |
-
_ = load_dotenv(find_dotenv())
|
60 |
-
|
61 |
-
def get_openai_api_key():
|
62 |
-
load_env()
|
63 |
-
openai_api_key = os.getenv("OPENAI_API_KEY")
|
64 |
-
return openai_api_key
|
65 |
-
|
66 |
-
def get_prediction_guard_api_key():
|
67 |
-
load_env()
|
68 |
-
PREDICTION_GUARD_API_KEY = os.getenv("PREDICTION_GUARD_API_KEY", None)
|
69 |
-
if PREDICTION_GUARD_API_KEY is None:
|
70 |
-
PREDICTION_GUARD_API_KEY = input("Please enter your Prediction Guard API Key: ")
|
71 |
-
return PREDICTION_GUARD_API_KEY
|
72 |
-
|
73 |
-
PREDICTION_GUARD_URL_ENDPOINT = os.getenv("DLAI_PREDICTION_GUARD_URL_ENDPOINT", "https://dl-itdc.predictionguard.com") ###"https://proxy-dl-itdc.predictionguard.com"
|
74 |
-
|
75 |
-
# prompt templates
|
76 |
-
templates = [
|
77 |
-
'a picture of {}',
|
78 |
-
'an image of {}',
|
79 |
-
'a nice {}',
|
80 |
-
'a beautiful {}',
|
81 |
-
]
|
82 |
-
|
83 |
-
# function helps to prepare list image-text pairs from the first [test_size] data of a Huggingface dataset
|
84 |
-
def prepare_dataset_for_umap_visualization(hf_dataset, class_name, templates=templates, test_size=1000):
|
85 |
-
# load Huggingface dataset (download if needed)
|
86 |
-
dataset = load_dataset(hf_dataset, trust_remote_code=True)
|
87 |
-
# split dataset with specific test_size
|
88 |
-
train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
|
89 |
-
# get the test dataset
|
90 |
-
test_dataset = train_test_dataset['test']
|
91 |
-
img_txt_pairs = []
|
92 |
-
for i in range(len(test_dataset)):
|
93 |
-
img_txt_pairs.append({
|
94 |
-
'caption' : templates[random.randint(0, len(templates)-1)].format(class_name),
|
95 |
-
'pil_img' : test_dataset[i]['image']
|
96 |
-
})
|
97 |
-
return img_txt_pairs
|
98 |
-
|
99 |
-
|
100 |
-
def download_video(video_url, path):
|
101 |
-
print(f'Getting video information for {video_url}')
|
102 |
-
|
103 |
-
def progress_callback(stream: Stream, data_chunk: bytes, bytes_remaining: int) -> None:
|
104 |
-
pbar.update(len(data_chunk))
|
105 |
-
stream = None
|
106 |
-
try:
|
107 |
-
yt = YouTube(video_url, on_progress_callback=progress_callback)
|
108 |
-
stream = yt.streams.filter(progressive=True, file_extension='mp4', res='480p').desc().first()
|
109 |
-
if stream is None:
|
110 |
-
stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
111 |
-
except Exception as e:
|
112 |
-
print(f"Youtube Exception Occured.Loading from local resource: {e}")
|
113 |
-
|
114 |
-
uncleaned_filename = stream.default_filename.replace(' ', '').lower() if stream else "blackholes101nationalgeographic.mp4"
|
115 |
-
print(f'Uncleaned filename: {uncleaned_filename}')
|
116 |
-
filename= re.sub(r'[^a-zA-Z0-9]', '', uncleaned_filename).replace('mp4', '')
|
117 |
-
filename_without_extension = os.path.splitext(filename)[0]
|
118 |
-
filename_with_extension = filename+'.mp4'
|
119 |
-
folder_path = os.path.join(path, filename_without_extension)
|
120 |
-
print(f'Checking the folder path {folder_path}')
|
121 |
-
full_file_path = os.path.join(folder_path, filename_with_extension)
|
122 |
-
|
123 |
-
if not os.path.exists(folder_path):
|
124 |
-
os.makedirs(folder_path, exist_ok=True)
|
125 |
-
|
126 |
-
if os.path.exists(full_file_path):
|
127 |
-
print('Video already downloaded at the folder path', full_file_path)
|
128 |
-
is_downloaded = False
|
129 |
-
return full_file_path, folder_path, is_downloaded
|
130 |
-
|
131 |
-
|
132 |
-
is_downloaded = True
|
133 |
-
|
134 |
-
print('Downloading video from YouTube...')
|
135 |
-
pbar = tqdm(desc='Downloading video from YouTube', total=stream.filesize, unit="bytes")
|
136 |
-
stream.download(folder_path, filename=filename_with_extension)
|
137 |
-
pbar.close()
|
138 |
-
return full_file_path, folder_path, is_downloaded
|
139 |
-
|
140 |
-
def get_video_id_from_url(video_url):
|
141 |
-
"""
|
142 |
-
Examples:
|
143 |
-
- http://youtu.be/SA2iWivDJiE
|
144 |
-
- http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
|
145 |
-
- http://www.youtube.com/embed/SA2iWivDJiE
|
146 |
-
- http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US
|
147 |
-
"""
|
148 |
-
import urllib.parse
|
149 |
-
url = urllib.parse.urlparse(video_url)
|
150 |
-
if url.hostname == 'youtu.be':
|
151 |
-
return url.path[1:]
|
152 |
-
if url.hostname in ('www.youtube.com', 'youtube.com'):
|
153 |
-
if url.path == '/watch':
|
154 |
-
p = urllib.parse.parse_qs(url.query)
|
155 |
-
return p['v'][0]
|
156 |
-
if url.path[:7] == '/embed/':
|
157 |
-
return url.path.split('/')[2]
|
158 |
-
if url.path[:3] == '/v/':
|
159 |
-
return url.path.split('/')[2]
|
160 |
-
|
161 |
-
return video_url
|
162 |
-
|
163 |
-
def generate_transcript_vtt(vid_dir, vid_filepath):
|
164 |
-
print("Generating transcript for video ", vid_filepath)
|
165 |
-
# declare where to save .mp3 audio
|
166 |
-
path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
|
167 |
-
|
168 |
-
# extract mp3 audio file from mp4 video video file
|
169 |
-
path_to_video_no_transcript = vid_filepath
|
170 |
-
clip = VideoFileClip(path_to_video_no_transcript)
|
171 |
-
clip.audio.write_audiofile(path_to_extracted_audio_file)
|
172 |
-
|
173 |
-
model = whisper.load_model("small")
|
174 |
-
options = dict(task="translate", best_of=1, language='en')
|
175 |
-
results = model.transcribe(path_to_extracted_audio_file, **options)
|
176 |
-
|
177 |
-
vtt = getSubs(results["segments"], "vtt")
|
178 |
-
|
179 |
-
# path to save generated transcript of video1
|
180 |
-
path_to_generated_trans = osp.join(vid_dir, 'captions.vtt')
|
181 |
-
# write transcription to file
|
182 |
-
with open(path_to_generated_trans, 'w') as f:
|
183 |
-
f.write(vtt)
|
184 |
-
return path_to_generated_trans
|
185 |
-
|
186 |
-
|
187 |
-
# if this has transcript then download
|
188 |
-
def get_transcript_vtt(path, video_url, vid_file_path, from_gen=False):
|
189 |
-
if from_gen:
|
190 |
-
return generate_transcript_vtt(path,vid_file_path)
|
191 |
-
video_id = get_video_id_from_url(video_url)
|
192 |
-
filepath = os.path.join(path,'captions.vtt')
|
193 |
-
if os.path.exists(filepath):
|
194 |
-
print('Transcript already exists')
|
195 |
-
return filepath
|
196 |
-
|
197 |
-
print('Downloading Transcript...')
|
198 |
-
|
199 |
-
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en-GB', 'en'])
|
200 |
-
formatter = WebVTTFormatter()
|
201 |
-
webvtt_formatted = formatter.format_transcript(transcript)
|
202 |
-
|
203 |
-
with open(filepath, 'w', encoding='utf-8') as webvtt_file:
|
204 |
-
webvtt_file.write(webvtt_formatted)
|
205 |
-
webvtt_file.close()
|
206 |
-
|
207 |
-
return filepath
|
208 |
-
|
209 |
-
|
210 |
-
# helper function for convert time in second to time format for .vtt or .srt file
|
211 |
-
def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
|
212 |
-
assert seconds >= 0, "non-negative timestamp expected"
|
213 |
-
milliseconds = round(seconds * 1000.0)
|
214 |
-
|
215 |
-
hours = milliseconds // 3_600_000
|
216 |
-
milliseconds -= hours * 3_600_000
|
217 |
-
|
218 |
-
minutes = milliseconds // 60_000
|
219 |
-
milliseconds -= minutes * 60_000
|
220 |
-
|
221 |
-
seconds = milliseconds // 1_000
|
222 |
-
milliseconds -= seconds * 1_000
|
223 |
-
|
224 |
-
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
225 |
-
return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
|
226 |
-
|
227 |
-
# a help function that helps to convert a specific time written as a string in format `webvtt` into a time in miliseconds
|
228 |
-
def str2time(strtime):
|
229 |
-
# strip character " if exists
|
230 |
-
strtime = strtime.strip('"')
|
231 |
-
# get hour, minute, second from time string
|
232 |
-
hrs, mins, seconds = [float(c) for c in strtime.split(':')]
|
233 |
-
# get the corresponding time as total seconds
|
234 |
-
total_seconds = hrs * 60**2 + mins * 60 + seconds
|
235 |
-
total_miliseconds = total_seconds * 1000
|
236 |
-
return total_miliseconds
|
237 |
-
|
238 |
-
def _processText(text: str, maxLineWidth=None):
|
239 |
-
if (maxLineWidth is None or maxLineWidth < 0):
|
240 |
-
return text
|
241 |
-
|
242 |
-
lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
|
243 |
-
return '\n'.join(lines)
|
244 |
-
|
245 |
-
# Resizes a image and maintains aspect ratio
|
246 |
-
def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
|
247 |
-
# Grab the image size and initialize dimensions
|
248 |
-
dim = None
|
249 |
-
(h, w) = image.shape[:2]
|
250 |
-
|
251 |
-
# Return original image if no need to resize
|
252 |
-
if width is None and height is None:
|
253 |
-
return image
|
254 |
-
|
255 |
-
# We are resizing height if width is none
|
256 |
-
if width is None:
|
257 |
-
# Calculate the ratio of the height and construct the dimensions
|
258 |
-
r = height / float(h)
|
259 |
-
dim = (int(w * r), height)
|
260 |
-
# We are resizing width if height is none
|
261 |
-
else:
|
262 |
-
# Calculate the ratio of the width and construct the dimensions
|
263 |
-
r = width / float(w)
|
264 |
-
dim = (width, int(h * r))
|
265 |
-
|
266 |
-
# Return the resized image
|
267 |
-
return cv2.resize(image, dim, interpolation=inter)
|
268 |
-
|
269 |
-
# helper function to convert transcripts generated by whisper to .vtt file
|
270 |
-
def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
271 |
-
print("WEBVTT\n", file=file)
|
272 |
-
for segment in transcript:
|
273 |
-
text = _processText(segment['text'], maxLineWidth).replace('-->', '->')
|
274 |
-
|
275 |
-
print(
|
276 |
-
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
277 |
-
f"{text}\n",
|
278 |
-
file=file,
|
279 |
-
flush=True,
|
280 |
-
)
|
281 |
-
|
282 |
-
# helper function to convert transcripts generated by whisper to .srt file
|
283 |
-
def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
284 |
-
"""
|
285 |
-
Write a transcript to a file in SRT format.
|
286 |
-
Example usage:
|
287 |
-
from pathlib import Path
|
288 |
-
from whisper.utils import write_srt
|
289 |
-
import requests
|
290 |
-
result = transcribe(model, audio_path, temperature=temperature, **args)
|
291 |
-
# save SRT
|
292 |
-
audio_basename = Path(audio_path).stem
|
293 |
-
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
294 |
-
write_srt(result["segments"], file=srt)
|
295 |
-
"""
|
296 |
-
for i, segment in enumerate(transcript, start=1):
|
297 |
-
text = _processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
|
298 |
-
|
299 |
-
# write srt lines
|
300 |
-
print(
|
301 |
-
f"{i}\n"
|
302 |
-
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
303 |
-
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
|
304 |
-
f"{text}\n",
|
305 |
-
file=file,
|
306 |
-
flush=True,
|
307 |
-
)
|
308 |
-
|
309 |
-
def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int=-1) -> str:
|
310 |
-
segmentStream = StringIO()
|
311 |
-
|
312 |
-
if format == 'vtt':
|
313 |
-
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
314 |
-
elif format == 'srt':
|
315 |
-
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
316 |
-
else:
|
317 |
-
raise Exception("Unknown format " + format)
|
318 |
-
|
319 |
-
segmentStream.seek(0)
|
320 |
-
return segmentStream.read()
|
321 |
-
|
322 |
-
# encoding image at given path or PIL Image using base64
|
323 |
-
def encode_image(image_path_or_PIL_img):
|
324 |
-
if isinstance(image_path_or_PIL_img, PIL.Image.Image):
|
325 |
-
# this is a PIL image
|
326 |
-
buffered = BytesIO()
|
327 |
-
image_path_or_PIL_img.save(buffered, format="JPEG")
|
328 |
-
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
329 |
-
else:
|
330 |
-
# this is a image_path
|
331 |
-
with open(image_path_or_PIL_img, "rb") as image_file:
|
332 |
-
return base64.b64encode(image_file.read()).decode('utf-8')
|
333 |
-
|
334 |
-
# checking whether the given string is base64 or not
|
335 |
-
def isBase64(sb):
|
336 |
-
try:
|
337 |
-
if isinstance(sb, str):
|
338 |
-
# If there's any unicode here, an exception will be thrown and the function will return false
|
339 |
-
sb_bytes = bytes(sb, 'ascii')
|
340 |
-
elif isinstance(sb, bytes):
|
341 |
-
sb_bytes = sb
|
342 |
-
else:
|
343 |
-
raise ValueError("Argument must be string or bytes")
|
344 |
-
return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes
|
345 |
-
except Exception:
|
346 |
-
return False
|
347 |
-
|
348 |
-
def encode_image_from_path_or_url(image_path_or_url):
|
349 |
-
try:
|
350 |
-
# try to open the url to check valid url
|
351 |
-
f = urlopen(image_path_or_url)
|
352 |
-
# if this is an url
|
353 |
-
return base64.b64encode(requests.get(image_path_or_url).content).decode('utf-8')
|
354 |
-
except:
|
355 |
-
# this is a path to image
|
356 |
-
with open(image_path_or_url, "rb") as image_file:
|
357 |
-
return base64.b64encode(image_file.read()).decode('utf-8')
|
358 |
-
|
359 |
-
# helper function to compute the joint embedding of a prompt and a base64-encoded image through PredictionGuard
|
360 |
-
def bt_embedding_from_prediction_guard(prompt, base64_image):
|
361 |
-
# get PredictionGuard client
|
362 |
-
client = _getPredictionGuardClient()
|
363 |
-
message = {"text": prompt,}
|
364 |
-
if base64_image is not None and base64_image != "":
|
365 |
-
if not isBase64(base64_image):
|
366 |
-
raise TypeError("image input must be in base64 encoding!")
|
367 |
-
message['image'] = base64_image
|
368 |
-
response = client.embeddings.create(
|
369 |
-
model="bridgetower-large-itm-mlm-itc",
|
370 |
-
input=[message]
|
371 |
-
)
|
372 |
-
return response['data'][0]['embedding']
|
373 |
-
|
374 |
-
|
375 |
-
def load_json_file(file_path):
|
376 |
-
# Open the JSON file in read mode
|
377 |
-
with open(file_path, 'r') as file:
|
378 |
-
data = json.load(file)
|
379 |
-
return data
|
380 |
-
|
381 |
-
def display_retrieved_results(results):
|
382 |
-
print(f'There is/are {len(results)} retrieved result(s)')
|
383 |
-
print()
|
384 |
-
for i, res in enumerate(results):
|
385 |
-
print(f'The caption of the {str(i+1)}-th retrieved result is:\n"{results[i].page_content}"')
|
386 |
-
print()
|
387 |
-
print(results[i])
|
388 |
-
#display(Image.open(results[i].metadata['metadata']['extracted_frame_path']))
|
389 |
-
print("------------------------------------------------------------")
|
390 |
-
|
391 |
-
class SeparatorStyle(Enum):
|
392 |
-
"""Different separator style."""
|
393 |
-
SINGLE = auto()
|
394 |
-
|
395 |
-
@dataclasses.dataclass
|
396 |
-
class Conversation:
|
397 |
-
"""A class that keeps all conversation history"""
|
398 |
-
system: str
|
399 |
-
roles: List[str]
|
400 |
-
messages: List[List[str]]
|
401 |
-
map_roles: Dict[str, str]
|
402 |
-
version: str = "Unknown"
|
403 |
-
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
404 |
-
sep: str = "\n"
|
405 |
-
|
406 |
-
def _get_prompt_role(self, role):
|
407 |
-
if self.map_roles is not None and role in self.map_roles.keys():
|
408 |
-
return self.map_roles[role]
|
409 |
-
else:
|
410 |
-
return role
|
411 |
-
|
412 |
-
def _build_content_for_first_message_in_conversation(self, first_message: List[str]):
|
413 |
-
content = []
|
414 |
-
if len(first_message) != 2:
|
415 |
-
raise TypeError("First message in Conversation needs to include a prompt and a base64-enconded image!")
|
416 |
-
|
417 |
-
prompt, b64_image = first_message[0], first_message[1]
|
418 |
-
|
419 |
-
# handling prompt
|
420 |
-
if prompt is None:
|
421 |
-
raise TypeError("API does not support None prompt yet")
|
422 |
-
content.append({
|
423 |
-
"type": "text",
|
424 |
-
"text": prompt
|
425 |
-
})
|
426 |
-
if b64_image is None:
|
427 |
-
raise TypeError("API does not support text only conversation yet")
|
428 |
-
|
429 |
-
# handling image
|
430 |
-
if not isBase64(b64_image):
|
431 |
-
raise TypeError("Image in Conversation's first message must be stored under base64 encoding!")
|
432 |
-
|
433 |
-
content.append({
|
434 |
-
"type": "image_url",
|
435 |
-
"image_url": {
|
436 |
-
"url": b64_image,
|
437 |
-
}
|
438 |
-
})
|
439 |
-
return content
|
440 |
-
|
441 |
-
def _build_content_for_follow_up_messages_in_conversation(self, follow_up_message: List[str]):
|
442 |
-
|
443 |
-
if follow_up_message is not None and len(follow_up_message) > 1:
|
444 |
-
raise TypeError("Follow-up message in Conversation must not include an image!")
|
445 |
-
|
446 |
-
# handling text prompt
|
447 |
-
if follow_up_message is None or follow_up_message[0] is None:
|
448 |
-
raise TypeError("Follow-up message in Conversation must include exactly one text message")
|
449 |
-
|
450 |
-
text = follow_up_message[0]
|
451 |
-
return text
|
452 |
-
|
453 |
-
def get_message(self):
|
454 |
-
messages = self.messages
|
455 |
-
api_messages = []
|
456 |
-
for i, msg in enumerate(messages):
|
457 |
-
role, message_content = msg
|
458 |
-
if i == 0:
|
459 |
-
# get content for very first message in conversation
|
460 |
-
content = self._build_content_for_first_message_in_conversation(message_content)
|
461 |
-
else:
|
462 |
-
# get content for follow-up message in conversation
|
463 |
-
content = self._build_content_for_follow_up_messages_in_conversation(message_content)
|
464 |
-
|
465 |
-
api_messages.append({
|
466 |
-
"role": role,
|
467 |
-
"content": content,
|
468 |
-
})
|
469 |
-
return api_messages
|
470 |
-
|
471 |
-
# this method helps represent a multi-turn chat into as a single turn chat format
|
472 |
-
def serialize_messages(self):
|
473 |
-
messages = self.messages
|
474 |
-
ret = ""
|
475 |
-
if self.sep_style == SeparatorStyle.SINGLE:
|
476 |
-
if self.system is not None and self.system != "":
|
477 |
-
ret = self.system + self.sep
|
478 |
-
for i, (role, message) in enumerate(messages):
|
479 |
-
role = self._get_prompt_role(role)
|
480 |
-
if message:
|
481 |
-
if isinstance(message, List):
|
482 |
-
# get prompt only
|
483 |
-
message = message[0]
|
484 |
-
if i == 0:
|
485 |
-
# do not include role at the beginning
|
486 |
-
ret += message
|
487 |
-
else:
|
488 |
-
ret += role + ": " + message
|
489 |
-
if i < len(messages) - 1:
|
490 |
-
# avoid including sep at the end of serialized message
|
491 |
-
ret += self.sep
|
492 |
-
else:
|
493 |
-
ret += role + ":"
|
494 |
-
else:
|
495 |
-
raise ValueError(f"Invalid style: {self.sep_style}")
|
496 |
-
|
497 |
-
return ret
|
498 |
-
|
499 |
-
def append_message(self, role, message):
|
500 |
-
if len(self.messages) == 0:
|
501 |
-
# data verification for the very first message
|
502 |
-
assert role == self.roles[0], f"the very first message in conversation must be from role {self.roles[0]}"
|
503 |
-
assert len(message) == 2, f"the very first message in conversation must include both prompt and an image"
|
504 |
-
prompt, image = message[0], message[1]
|
505 |
-
assert prompt is not None, f"prompt must be not None"
|
506 |
-
assert isBase64(image), f"image must be under base64 encoding"
|
507 |
-
else:
|
508 |
-
# data verification for follow-up message
|
509 |
-
assert role in self.roles, f"the follow-up message must be from one of the roles {self.roles}"
|
510 |
-
assert len(message) == 1, f"the follow-up message must consist of one text message only, no image"
|
511 |
-
|
512 |
-
self.messages.append([role, message])
|
513 |
-
|
514 |
-
def copy(self):
|
515 |
-
return Conversation(
|
516 |
-
system=self.system,
|
517 |
-
roles=self.roles,
|
518 |
-
messages=[[x,y] for x, y in self.messages],
|
519 |
-
version=self.version,
|
520 |
-
map_roles=self.map_roles,
|
521 |
-
)
|
522 |
-
|
523 |
-
def dict(self):
|
524 |
-
return {
|
525 |
-
"system": self.system,
|
526 |
-
"roles": self.roles,
|
527 |
-
"messages": [[x, y[0] if len(y) == 1 else y] for x, y in self.messages],
|
528 |
-
"version": self.version,
|
529 |
-
}
|
530 |
-
|
531 |
-
prediction_guard_llava_conv = Conversation(
|
532 |
-
system="",
|
533 |
-
roles=("user", "assistant"),
|
534 |
-
messages=[],
|
535 |
-
version="Prediction Guard LLaVA enpoint Conversation v0",
|
536 |
-
sep_style=SeparatorStyle.SINGLE,
|
537 |
-
map_roles={
|
538 |
-
"user": "USER",
|
539 |
-
"assistant": "ASSISTANT"
|
540 |
-
}
|
541 |
-
)
|
542 |
-
|
543 |
-
# get PredictionGuard Client
|
544 |
-
def _getPredictionGuardClient():
|
545 |
-
PREDICTION_GUARD_API_KEY = get_prediction_guard_api_key()
|
546 |
-
client = PredictionGuard(
|
547 |
-
api_key=PREDICTION_GUARD_API_KEY,
|
548 |
-
url=PREDICTION_GUARD_URL_ENDPOINT,
|
549 |
-
)
|
550 |
-
return client
|
551 |
-
|
552 |
-
# helper function to call chat completion endpoint of PredictionGuard given a prompt and an image
|
553 |
-
def lvlm_inference(prompt, image, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
|
554 |
-
# prepare conversation
|
555 |
-
conversation = prediction_guard_llava_conv.copy()
|
556 |
-
conversation.append_message(conversation.roles[0], [prompt, image])
|
557 |
-
return lvlm_inference_with_conversation(conversation, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
def lvlm_inference_with_conversation(conversation, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
|
562 |
-
# get PredictionGuard client
|
563 |
-
client = _getPredictionGuardClient()
|
564 |
-
# get message from conversation
|
565 |
-
messages = conversation.get_message()
|
566 |
-
# call chat completion endpoint at Grediction Guard
|
567 |
-
response = client.chat.completions.create(
|
568 |
-
model="llava-1.5-7b-hf",
|
569 |
-
messages=messages,
|
570 |
-
max_tokens=max_tokens,
|
571 |
-
temperature=temperature,
|
572 |
-
top_p=top_p,
|
573 |
-
top_k=top_k,
|
574 |
-
)
|
575 |
-
return response['choices'][-1]['message']['content']
|
576 |
-
|
577 |
-
def get_token():
|
578 |
-
load_env()
|
579 |
-
token = os.getenv("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
580 |
-
if token is None:
|
581 |
-
raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
|
582 |
-
return token
|
583 |
-
|
584 |
-
|
585 |
-
def lvlm_inference_with_phi(prompt):
|
586 |
-
|
587 |
-
|
588 |
-
messages = [{"role": "user", "content": prompt}]
|
589 |
-
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=get_token())
|
590 |
-
response = ''
|
591 |
-
token = client.chat_completion(messages, max_tokens=256)
|
592 |
-
response = token['choices'][0]['message']['content']
|
593 |
-
return response
|
594 |
-
|
595 |
-
def lvlm_inference_with_tiny_model(prompt):
|
596 |
-
classifier = pipeline(
|
597 |
-
"text-generation",
|
598 |
-
model="microsoft/phi-2", # Only ~2.7GB
|
599 |
-
device_map="auto",
|
600 |
-
torch_dtype="auto",
|
601 |
-
)
|
602 |
-
|
603 |
-
response = classifier(
|
604 |
-
prompt,
|
605 |
-
max_new_tokens=512, # Remove max_length and use only max_new_tokens
|
606 |
-
temperature=0.7,
|
607 |
-
do_sample=True,
|
608 |
-
num_return_sequences=1,
|
609 |
-
truncation=True, # Add explicit truncation
|
610 |
-
pad_token_id=classifier.tokenizer.eos_token_id,
|
611 |
-
eos_token_id=classifier.tokenizer.eos_token_id,
|
612 |
-
)[0]['generated_text']
|
613 |
-
|
614 |
-
# Remove the input prompt from the response and clean up
|
615 |
-
return response.replace(prompt, "").strip()
|
616 |
-
|
617 |
-
# function `extract_and_save_frames_and_metadata``:
|
618 |
-
# receives as input a video and its transcript
|
619 |
-
# does extracting and saving frames and their metadatas
|
620 |
-
# returns the extracted metadatas
|
621 |
-
def extract_and_save_frames_and_metadata(
|
622 |
-
path_to_video,
|
623 |
-
path_to_transcript,
|
624 |
-
path_to_save_extracted_frames,
|
625 |
-
path_to_save_metadatas):
|
626 |
-
|
627 |
-
# metadatas will store the metadata of all extracted frames
|
628 |
-
metadatas = []
|
629 |
-
|
630 |
-
# load video using cv2
|
631 |
-
print(f"Loading video from {path_to_video}")
|
632 |
-
video = cv2.VideoCapture(path_to_video)
|
633 |
-
# load transcript using webvtt
|
634 |
-
print(f"Loading transcript from {path_to_transcript}")
|
635 |
-
trans = webvtt.read(path_to_transcript)
|
636 |
-
|
637 |
-
# iterate transcript file
|
638 |
-
# for each video segment specified in the transcript file
|
639 |
-
for idx, transcript in enumerate(trans):
|
640 |
-
# get the start time and end time in seconds
|
641 |
-
start_time_ms = str2time(transcript.start)
|
642 |
-
end_time_ms = str2time(transcript.end)
|
643 |
-
# get the time in ms exactly
|
644 |
-
# in the middle of start time and end time
|
645 |
-
mid_time_ms = (end_time_ms + start_time_ms) / 2
|
646 |
-
# get the transcript, remove the next-line symbol
|
647 |
-
text = transcript.text.replace("\n", ' ')
|
648 |
-
# get frame at the middle time
|
649 |
-
video.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
|
650 |
-
print(f"Extracting frame at {mid_time_ms} ms")
|
651 |
-
success, frame = video.read()
|
652 |
-
if success:
|
653 |
-
# if the frame is extracted successfully, resize it
|
654 |
-
image = maintain_aspect_ratio_resize(frame, height=350)
|
655 |
-
# save frame as JPEG file
|
656 |
-
img_fname = f'frame_{idx}.jpg'
|
657 |
-
img_fpath = osp.join(
|
658 |
-
path_to_save_extracted_frames, img_fname
|
659 |
-
)
|
660 |
-
cv2.imwrite(img_fpath, image)
|
661 |
-
|
662 |
-
# prepare the metadata
|
663 |
-
metadata = {
|
664 |
-
'extracted_frame_path': img_fpath,
|
665 |
-
'transcript': text,
|
666 |
-
'video_segment_id': idx,
|
667 |
-
'video_path': path_to_video,
|
668 |
-
'mid_time_ms': mid_time_ms,
|
669 |
-
}
|
670 |
-
metadatas.append(metadata)
|
671 |
-
|
672 |
-
else:
|
673 |
-
print(f"ERROR! Cannot extract frame: idx = {idx}")
|
674 |
-
|
675 |
-
# save metadata of all extracted frames
|
676 |
-
fn = osp.join(path_to_save_metadatas, 'metadatas.json')
|
677 |
-
with open(fn, 'w') as outfile:
|
678 |
-
json.dump(metadatas, outfile)
|
679 |
-
return metadatas
|
680 |
-
|
681 |
-
def extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath):
|
682 |
-
# output paths to save extracted frames and their metadata
|
683 |
-
extracted_frames_path = osp.join(vid_dir, 'extracted_frame')
|
684 |
-
metadatas_path = vid_dir
|
685 |
-
|
686 |
-
# create these output folders if not existing
|
687 |
-
print(f"Creating folders {extracted_frames_path} and {metadatas_path}")
|
688 |
-
Path(extracted_frames_path).mkdir(parents=True, exist_ok=True)
|
689 |
-
Path(metadatas_path).mkdir(parents=True, exist_ok=True)
|
690 |
-
print("Extracting frames the video path ", vid_filepath)
|
691 |
-
|
692 |
-
# call the function to extract frames and metadatas
|
693 |
-
metadatas = extract_and_save_frames_and_metadata(
|
694 |
-
vid_filepath,
|
695 |
-
vid_transcript_filepath,
|
696 |
-
extracted_frames_path,
|
697 |
-
metadatas_path,
|
698 |
-
)
|
699 |
-
return metadatas
|
700 |
-
|
701 |
-
# function extract_and_save_frames_and_metadata_with_fps
|
702 |
-
# receives as input a video
|
703 |
-
# does extracting and saving frames and their metadatas
|
704 |
-
# returns the extracted metadatas
|
705 |
-
def extract_and_save_frames_and_metadata_with_fps(
|
706 |
-
lvlm_prompt,
|
707 |
-
path_to_video,
|
708 |
-
path_to_save_extracted_frames,
|
709 |
-
path_to_save_metadatas,
|
710 |
-
num_of_extracted_frames_per_second=1):
|
711 |
-
|
712 |
-
# metadatas will store the metadata of all extracted frames
|
713 |
-
metadatas = []
|
714 |
-
|
715 |
-
# load video using cv2
|
716 |
-
video = cv2.VideoCapture(path_to_video)
|
717 |
-
|
718 |
-
# Get the frames per second
|
719 |
-
fps = video.get(cv2.CAP_PROP_FPS)
|
720 |
-
# Get hop = the number of frames pass before a frame is extracted
|
721 |
-
hop = round(fps / num_of_extracted_frames_per_second)
|
722 |
-
curr_frame = 0
|
723 |
-
idx = -1
|
724 |
-
while(True):
|
725 |
-
# iterate all frames
|
726 |
-
ret, frame = video.read()
|
727 |
-
if not ret:
|
728 |
-
break
|
729 |
-
if curr_frame % hop == 0:
|
730 |
-
idx = idx + 1
|
731 |
-
|
732 |
-
# if the frame is extracted successfully, resize it
|
733 |
-
image = maintain_aspect_ratio_resize(frame, height=350)
|
734 |
-
# save frame as JPEG file
|
735 |
-
img_fname = f'frame_{idx}.jpg'
|
736 |
-
img_fpath = osp.join(
|
737 |
-
path_to_save_extracted_frames,
|
738 |
-
img_fname
|
739 |
-
)
|
740 |
-
cv2.imwrite(img_fpath, image)
|
741 |
-
|
742 |
-
# generate caption using lvlm_inference
|
743 |
-
b64_image = encode_image(img_fpath)
|
744 |
-
caption = lvlm_inference(lvlm_prompt, b64_image)
|
745 |
-
|
746 |
-
# prepare the metadata
|
747 |
-
metadata = {
|
748 |
-
'extracted_frame_path': img_fpath,
|
749 |
-
'transcript': caption,
|
750 |
-
'video_segment_id': idx,
|
751 |
-
'video_path': path_to_video,
|
752 |
-
}
|
753 |
-
metadatas.append(metadata)
|
754 |
-
curr_frame += 1
|
755 |
-
|
756 |
-
# save metadata of all extracted frames
|
757 |
-
metadatas_path = osp.join(path_to_save_metadatas,'metadatas.json')
|
758 |
-
with open(metadatas_path, 'w') as outfile:
|
759 |
-
json.dump(metadatas, outfile)
|
760 |
-
return metadatas
|
761 |
-
|
762 |
-
if __name__ == "__main__":
|
763 |
-
res = lvlm_inference_with_phi("Tell me a story")
|
764 |
print(res)
|
|
|
1 |
+
# Add your utilities or helper functions to this file.
|
2 |
+
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from dotenv import load_dotenv, find_dotenv
|
6 |
+
from io import StringIO, BytesIO
|
7 |
+
import textwrap
|
8 |
+
from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
|
9 |
+
from enum import auto, Enum
|
10 |
+
import base64
|
11 |
+
import glob
|
12 |
+
from moviepy import VideoFileClip
|
13 |
+
import requests
|
14 |
+
from tqdm import tqdm
|
15 |
+
from pytubefix import YouTube, Stream
|
16 |
+
import webvtt
|
17 |
+
import whisper
|
18 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
19 |
+
from youtube_transcript_api.formatters import WebVTTFormatter
|
20 |
+
from predictionguard import PredictionGuard
|
21 |
+
import cv2
|
22 |
+
import re
|
23 |
+
import json
|
24 |
+
import PIL
|
25 |
+
from ollama import chat
|
26 |
+
from ollama import ChatResponse
|
27 |
+
from PIL import Image
|
28 |
+
import dataclasses
|
29 |
+
import random
|
30 |
+
from datasets import load_dataset
|
31 |
+
from os import path as osp
|
32 |
+
from IPython.display import display
|
33 |
+
from langchain_core.prompt_values import PromptValue
|
34 |
+
from langchain_core.messages import (
|
35 |
+
MessageLikeRepresentation,
|
36 |
+
)
|
37 |
+
from transformers import pipeline
|
38 |
+
from huggingface_hub import InferenceClient
|
39 |
+
|
40 |
+
MultimodalModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation], Dict[str, Any]]
|
41 |
+
|
42 |
+
def get_from_dict_or_env(
|
43 |
+
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
44 |
+
) -> str:
|
45 |
+
"""Get a value from a dictionary or an environment variable."""
|
46 |
+
if key in data and data[key]:
|
47 |
+
return data[key]
|
48 |
+
else:
|
49 |
+
return get_from_env(key, env_key, default=default)
|
50 |
+
|
51 |
+
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
52 |
+
"""Get a value from a dictionary or an environment variable."""
|
53 |
+
if env_key in os.environ and os.environ[env_key]:
|
54 |
+
return os.environ[env_key]
|
55 |
+
else:
|
56 |
+
return default
|
57 |
+
|
58 |
+
def load_env():
|
59 |
+
_ = load_dotenv(find_dotenv())
|
60 |
+
|
61 |
+
def get_openai_api_key():
|
62 |
+
load_env()
|
63 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
64 |
+
return openai_api_key
|
65 |
+
|
66 |
+
def get_prediction_guard_api_key():
|
67 |
+
load_env()
|
68 |
+
PREDICTION_GUARD_API_KEY = os.getenv("PREDICTION_GUARD_API_KEY", None)
|
69 |
+
if PREDICTION_GUARD_API_KEY is None:
|
70 |
+
PREDICTION_GUARD_API_KEY = input("Please enter your Prediction Guard API Key: ")
|
71 |
+
return PREDICTION_GUARD_API_KEY
|
72 |
+
|
73 |
+
PREDICTION_GUARD_URL_ENDPOINT = os.getenv("DLAI_PREDICTION_GUARD_URL_ENDPOINT", "https://dl-itdc.predictionguard.com") ###"https://proxy-dl-itdc.predictionguard.com"
|
74 |
+
|
75 |
+
# prompt templates
|
76 |
+
templates = [
|
77 |
+
'a picture of {}',
|
78 |
+
'an image of {}',
|
79 |
+
'a nice {}',
|
80 |
+
'a beautiful {}',
|
81 |
+
]
|
82 |
+
|
83 |
+
# function helps to prepare list image-text pairs from the first [test_size] data of a Huggingface dataset
|
84 |
+
def prepare_dataset_for_umap_visualization(hf_dataset, class_name, templates=templates, test_size=1000):
|
85 |
+
# load Huggingface dataset (download if needed)
|
86 |
+
dataset = load_dataset(hf_dataset, trust_remote_code=True)
|
87 |
+
# split dataset with specific test_size
|
88 |
+
train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
|
89 |
+
# get the test dataset
|
90 |
+
test_dataset = train_test_dataset['test']
|
91 |
+
img_txt_pairs = []
|
92 |
+
for i in range(len(test_dataset)):
|
93 |
+
img_txt_pairs.append({
|
94 |
+
'caption' : templates[random.randint(0, len(templates)-1)].format(class_name),
|
95 |
+
'pil_img' : test_dataset[i]['image']
|
96 |
+
})
|
97 |
+
return img_txt_pairs
|
98 |
+
|
99 |
+
|
100 |
+
def download_video(video_url, path):
|
101 |
+
print(f'Getting video information for {video_url}')
|
102 |
+
|
103 |
+
def progress_callback(stream: Stream, data_chunk: bytes, bytes_remaining: int) -> None:
|
104 |
+
pbar.update(len(data_chunk))
|
105 |
+
stream = None
|
106 |
+
try:
|
107 |
+
yt = YouTube(video_url, on_progress_callback=progress_callback)
|
108 |
+
stream = yt.streams.filter(progressive=True, file_extension='mp4', res='480p').desc().first()
|
109 |
+
if stream is None:
|
110 |
+
stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Youtube Exception Occured.Loading from local resource: {e}")
|
113 |
+
|
114 |
+
uncleaned_filename = stream.default_filename.replace(' ', '').lower() if stream else "blackholes101nationalgeographic.mp4"
|
115 |
+
print(f'Uncleaned filename: {uncleaned_filename}')
|
116 |
+
filename= re.sub(r'[^a-zA-Z0-9]', '', uncleaned_filename).replace('mp4', '')
|
117 |
+
filename_without_extension = os.path.splitext(filename)[0]
|
118 |
+
filename_with_extension = filename+'.mp4'
|
119 |
+
folder_path = os.path.join(path, filename_without_extension)
|
120 |
+
print(f'Checking the folder path {folder_path}')
|
121 |
+
full_file_path = os.path.join(folder_path, filename_with_extension)
|
122 |
+
|
123 |
+
if not os.path.exists(folder_path):
|
124 |
+
os.makedirs(folder_path, exist_ok=True)
|
125 |
+
|
126 |
+
if os.path.exists(full_file_path):
|
127 |
+
print('Video already downloaded at the folder path', full_file_path)
|
128 |
+
is_downloaded = False
|
129 |
+
return full_file_path, folder_path, is_downloaded
|
130 |
+
|
131 |
+
|
132 |
+
is_downloaded = True
|
133 |
+
|
134 |
+
print('Downloading video from YouTube...')
|
135 |
+
pbar = tqdm(desc='Downloading video from YouTube', total=stream.filesize, unit="bytes")
|
136 |
+
stream.download(folder_path, filename=filename_with_extension)
|
137 |
+
pbar.close()
|
138 |
+
return full_file_path, folder_path, is_downloaded
|
139 |
+
|
140 |
+
def get_video_id_from_url(video_url):
|
141 |
+
"""
|
142 |
+
Examples:
|
143 |
+
- http://youtu.be/SA2iWivDJiE
|
144 |
+
- http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
|
145 |
+
- http://www.youtube.com/embed/SA2iWivDJiE
|
146 |
+
- http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US
|
147 |
+
"""
|
148 |
+
import urllib.parse
|
149 |
+
url = urllib.parse.urlparse(video_url)
|
150 |
+
if url.hostname == 'youtu.be':
|
151 |
+
return url.path[1:]
|
152 |
+
if url.hostname in ('www.youtube.com', 'youtube.com'):
|
153 |
+
if url.path == '/watch':
|
154 |
+
p = urllib.parse.parse_qs(url.query)
|
155 |
+
return p['v'][0]
|
156 |
+
if url.path[:7] == '/embed/':
|
157 |
+
return url.path.split('/')[2]
|
158 |
+
if url.path[:3] == '/v/':
|
159 |
+
return url.path.split('/')[2]
|
160 |
+
|
161 |
+
return video_url
|
162 |
+
|
163 |
+
def generate_transcript_vtt(vid_dir, vid_filepath):
|
164 |
+
print("Generating transcript for video ", vid_filepath)
|
165 |
+
# declare where to save .mp3 audio
|
166 |
+
path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
|
167 |
+
|
168 |
+
# extract mp3 audio file from mp4 video video file
|
169 |
+
path_to_video_no_transcript = vid_filepath
|
170 |
+
clip = VideoFileClip(path_to_video_no_transcript)
|
171 |
+
clip.audio.write_audiofile(path_to_extracted_audio_file)
|
172 |
+
|
173 |
+
model = whisper.load_model("small")
|
174 |
+
options = dict(task="translate", best_of=1, language='en')
|
175 |
+
results = model.transcribe(path_to_extracted_audio_file, **options)
|
176 |
+
|
177 |
+
vtt = getSubs(results["segments"], "vtt")
|
178 |
+
|
179 |
+
# path to save generated transcript of video1
|
180 |
+
path_to_generated_trans = osp.join(vid_dir, 'captions.vtt')
|
181 |
+
# write transcription to file
|
182 |
+
with open(path_to_generated_trans, 'w') as f:
|
183 |
+
f.write(vtt)
|
184 |
+
return path_to_generated_trans
|
185 |
+
|
186 |
+
|
187 |
+
# if this has transcript then download
|
188 |
+
def get_transcript_vtt(path, video_url, vid_file_path, from_gen=False):
|
189 |
+
if from_gen:
|
190 |
+
return generate_transcript_vtt(path,vid_file_path)
|
191 |
+
video_id = get_video_id_from_url(video_url)
|
192 |
+
filepath = os.path.join(path,'captions.vtt')
|
193 |
+
if os.path.exists(filepath):
|
194 |
+
print('Transcript already exists')
|
195 |
+
return filepath
|
196 |
+
|
197 |
+
print('Downloading Transcript...')
|
198 |
+
|
199 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en-GB', 'en'])
|
200 |
+
formatter = WebVTTFormatter()
|
201 |
+
webvtt_formatted = formatter.format_transcript(transcript)
|
202 |
+
|
203 |
+
with open(filepath, 'w', encoding='utf-8') as webvtt_file:
|
204 |
+
webvtt_file.write(webvtt_formatted)
|
205 |
+
webvtt_file.close()
|
206 |
+
|
207 |
+
return filepath
|
208 |
+
|
209 |
+
|
210 |
+
# helper function for convert time in second to time format for .vtt or .srt file
|
211 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
|
212 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
213 |
+
milliseconds = round(seconds * 1000.0)
|
214 |
+
|
215 |
+
hours = milliseconds // 3_600_000
|
216 |
+
milliseconds -= hours * 3_600_000
|
217 |
+
|
218 |
+
minutes = milliseconds // 60_000
|
219 |
+
milliseconds -= minutes * 60_000
|
220 |
+
|
221 |
+
seconds = milliseconds // 1_000
|
222 |
+
milliseconds -= seconds * 1_000
|
223 |
+
|
224 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
225 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
|
226 |
+
|
227 |
+
# a help function that helps to convert a specific time written as a string in format `webvtt` into a time in miliseconds
|
228 |
+
def str2time(strtime):
|
229 |
+
# strip character " if exists
|
230 |
+
strtime = strtime.strip('"')
|
231 |
+
# get hour, minute, second from time string
|
232 |
+
hrs, mins, seconds = [float(c) for c in strtime.split(':')]
|
233 |
+
# get the corresponding time as total seconds
|
234 |
+
total_seconds = hrs * 60**2 + mins * 60 + seconds
|
235 |
+
total_miliseconds = total_seconds * 1000
|
236 |
+
return total_miliseconds
|
237 |
+
|
238 |
+
def _processText(text: str, maxLineWidth=None):
|
239 |
+
if (maxLineWidth is None or maxLineWidth < 0):
|
240 |
+
return text
|
241 |
+
|
242 |
+
lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
|
243 |
+
return '\n'.join(lines)
|
244 |
+
|
245 |
+
# Resizes a image and maintains aspect ratio
|
246 |
+
def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
|
247 |
+
# Grab the image size and initialize dimensions
|
248 |
+
dim = None
|
249 |
+
(h, w) = image.shape[:2]
|
250 |
+
|
251 |
+
# Return original image if no need to resize
|
252 |
+
if width is None and height is None:
|
253 |
+
return image
|
254 |
+
|
255 |
+
# We are resizing height if width is none
|
256 |
+
if width is None:
|
257 |
+
# Calculate the ratio of the height and construct the dimensions
|
258 |
+
r = height / float(h)
|
259 |
+
dim = (int(w * r), height)
|
260 |
+
# We are resizing width if height is none
|
261 |
+
else:
|
262 |
+
# Calculate the ratio of the width and construct the dimensions
|
263 |
+
r = width / float(w)
|
264 |
+
dim = (width, int(h * r))
|
265 |
+
|
266 |
+
# Return the resized image
|
267 |
+
return cv2.resize(image, dim, interpolation=inter)
|
268 |
+
|
269 |
+
# helper function to convert transcripts generated by whisper to .vtt file
|
270 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
271 |
+
print("WEBVTT\n", file=file)
|
272 |
+
for segment in transcript:
|
273 |
+
text = _processText(segment['text'], maxLineWidth).replace('-->', '->')
|
274 |
+
|
275 |
+
print(
|
276 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
277 |
+
f"{text}\n",
|
278 |
+
file=file,
|
279 |
+
flush=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
# helper function to convert transcripts generated by whisper to .srt file
|
283 |
+
def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
284 |
+
"""
|
285 |
+
Write a transcript to a file in SRT format.
|
286 |
+
Example usage:
|
287 |
+
from pathlib import Path
|
288 |
+
from whisper.utils import write_srt
|
289 |
+
import requests
|
290 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
291 |
+
# save SRT
|
292 |
+
audio_basename = Path(audio_path).stem
|
293 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
294 |
+
write_srt(result["segments"], file=srt)
|
295 |
+
"""
|
296 |
+
for i, segment in enumerate(transcript, start=1):
|
297 |
+
text = _processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
|
298 |
+
|
299 |
+
# write srt lines
|
300 |
+
print(
|
301 |
+
f"{i}\n"
|
302 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
303 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
|
304 |
+
f"{text}\n",
|
305 |
+
file=file,
|
306 |
+
flush=True,
|
307 |
+
)
|
308 |
+
|
309 |
+
def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int=-1) -> str:
|
310 |
+
segmentStream = StringIO()
|
311 |
+
|
312 |
+
if format == 'vtt':
|
313 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
314 |
+
elif format == 'srt':
|
315 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
316 |
+
else:
|
317 |
+
raise Exception("Unknown format " + format)
|
318 |
+
|
319 |
+
segmentStream.seek(0)
|
320 |
+
return segmentStream.read()
|
321 |
+
|
322 |
+
# encoding image at given path or PIL Image using base64
|
323 |
+
def encode_image(image_path_or_PIL_img):
|
324 |
+
if isinstance(image_path_or_PIL_img, PIL.Image.Image):
|
325 |
+
# this is a PIL image
|
326 |
+
buffered = BytesIO()
|
327 |
+
image_path_or_PIL_img.save(buffered, format="JPEG")
|
328 |
+
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
329 |
+
else:
|
330 |
+
# this is a image_path
|
331 |
+
with open(image_path_or_PIL_img, "rb") as image_file:
|
332 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
333 |
+
|
334 |
+
# checking whether the given string is base64 or not
|
335 |
+
def isBase64(sb):
|
336 |
+
try:
|
337 |
+
if isinstance(sb, str):
|
338 |
+
# If there's any unicode here, an exception will be thrown and the function will return false
|
339 |
+
sb_bytes = bytes(sb, 'ascii')
|
340 |
+
elif isinstance(sb, bytes):
|
341 |
+
sb_bytes = sb
|
342 |
+
else:
|
343 |
+
raise ValueError("Argument must be string or bytes")
|
344 |
+
return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes
|
345 |
+
except Exception:
|
346 |
+
return False
|
347 |
+
|
348 |
+
def encode_image_from_path_or_url(image_path_or_url):
|
349 |
+
try:
|
350 |
+
# try to open the url to check valid url
|
351 |
+
f = urlopen(image_path_or_url)
|
352 |
+
# if this is an url
|
353 |
+
return base64.b64encode(requests.get(image_path_or_url).content).decode('utf-8')
|
354 |
+
except:
|
355 |
+
# this is a path to image
|
356 |
+
with open(image_path_or_url, "rb") as image_file:
|
357 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
358 |
+
|
359 |
+
# helper function to compute the joint embedding of a prompt and a base64-encoded image through PredictionGuard
|
360 |
+
def bt_embedding_from_prediction_guard(prompt, base64_image):
|
361 |
+
# get PredictionGuard client
|
362 |
+
client = _getPredictionGuardClient()
|
363 |
+
message = {"text": prompt,}
|
364 |
+
if base64_image is not None and base64_image != "":
|
365 |
+
if not isBase64(base64_image):
|
366 |
+
raise TypeError("image input must be in base64 encoding!")
|
367 |
+
message['image'] = base64_image
|
368 |
+
response = client.embeddings.create(
|
369 |
+
model="bridgetower-large-itm-mlm-itc",
|
370 |
+
input=[message]
|
371 |
+
)
|
372 |
+
return response['data'][0]['embedding']
|
373 |
+
|
374 |
+
|
375 |
+
def load_json_file(file_path):
|
376 |
+
# Open the JSON file in read mode
|
377 |
+
with open(file_path, 'r') as file:
|
378 |
+
data = json.load(file)
|
379 |
+
return data
|
380 |
+
|
381 |
+
def display_retrieved_results(results):
|
382 |
+
print(f'There is/are {len(results)} retrieved result(s)')
|
383 |
+
print()
|
384 |
+
for i, res in enumerate(results):
|
385 |
+
print(f'The caption of the {str(i+1)}-th retrieved result is:\n"{results[i].page_content}"')
|
386 |
+
print()
|
387 |
+
print(results[i])
|
388 |
+
#display(Image.open(results[i].metadata['metadata']['extracted_frame_path']))
|
389 |
+
print("------------------------------------------------------------")
|
390 |
+
|
391 |
+
class SeparatorStyle(Enum):
|
392 |
+
"""Different separator style."""
|
393 |
+
SINGLE = auto()
|
394 |
+
|
395 |
+
@dataclasses.dataclass
|
396 |
+
class Conversation:
|
397 |
+
"""A class that keeps all conversation history"""
|
398 |
+
system: str
|
399 |
+
roles: List[str]
|
400 |
+
messages: List[List[str]]
|
401 |
+
map_roles: Dict[str, str]
|
402 |
+
version: str = "Unknown"
|
403 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
404 |
+
sep: str = "\n"
|
405 |
+
|
406 |
+
def _get_prompt_role(self, role):
|
407 |
+
if self.map_roles is not None and role in self.map_roles.keys():
|
408 |
+
return self.map_roles[role]
|
409 |
+
else:
|
410 |
+
return role
|
411 |
+
|
412 |
+
def _build_content_for_first_message_in_conversation(self, first_message: List[str]):
|
413 |
+
content = []
|
414 |
+
if len(first_message) != 2:
|
415 |
+
raise TypeError("First message in Conversation needs to include a prompt and a base64-enconded image!")
|
416 |
+
|
417 |
+
prompt, b64_image = first_message[0], first_message[1]
|
418 |
+
|
419 |
+
# handling prompt
|
420 |
+
if prompt is None:
|
421 |
+
raise TypeError("API does not support None prompt yet")
|
422 |
+
content.append({
|
423 |
+
"type": "text",
|
424 |
+
"text": prompt
|
425 |
+
})
|
426 |
+
if b64_image is None:
|
427 |
+
raise TypeError("API does not support text only conversation yet")
|
428 |
+
|
429 |
+
# handling image
|
430 |
+
if not isBase64(b64_image):
|
431 |
+
raise TypeError("Image in Conversation's first message must be stored under base64 encoding!")
|
432 |
+
|
433 |
+
content.append({
|
434 |
+
"type": "image_url",
|
435 |
+
"image_url": {
|
436 |
+
"url": b64_image,
|
437 |
+
}
|
438 |
+
})
|
439 |
+
return content
|
440 |
+
|
441 |
+
def _build_content_for_follow_up_messages_in_conversation(self, follow_up_message: List[str]):
|
442 |
+
|
443 |
+
if follow_up_message is not None and len(follow_up_message) > 1:
|
444 |
+
raise TypeError("Follow-up message in Conversation must not include an image!")
|
445 |
+
|
446 |
+
# handling text prompt
|
447 |
+
if follow_up_message is None or follow_up_message[0] is None:
|
448 |
+
raise TypeError("Follow-up message in Conversation must include exactly one text message")
|
449 |
+
|
450 |
+
text = follow_up_message[0]
|
451 |
+
return text
|
452 |
+
|
453 |
+
def get_message(self):
|
454 |
+
messages = self.messages
|
455 |
+
api_messages = []
|
456 |
+
for i, msg in enumerate(messages):
|
457 |
+
role, message_content = msg
|
458 |
+
if i == 0:
|
459 |
+
# get content for very first message in conversation
|
460 |
+
content = self._build_content_for_first_message_in_conversation(message_content)
|
461 |
+
else:
|
462 |
+
# get content for follow-up message in conversation
|
463 |
+
content = self._build_content_for_follow_up_messages_in_conversation(message_content)
|
464 |
+
|
465 |
+
api_messages.append({
|
466 |
+
"role": role,
|
467 |
+
"content": content,
|
468 |
+
})
|
469 |
+
return api_messages
|
470 |
+
|
471 |
+
# this method helps represent a multi-turn chat into as a single turn chat format
|
472 |
+
def serialize_messages(self):
|
473 |
+
messages = self.messages
|
474 |
+
ret = ""
|
475 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
476 |
+
if self.system is not None and self.system != "":
|
477 |
+
ret = self.system + self.sep
|
478 |
+
for i, (role, message) in enumerate(messages):
|
479 |
+
role = self._get_prompt_role(role)
|
480 |
+
if message:
|
481 |
+
if isinstance(message, List):
|
482 |
+
# get prompt only
|
483 |
+
message = message[0]
|
484 |
+
if i == 0:
|
485 |
+
# do not include role at the beginning
|
486 |
+
ret += message
|
487 |
+
else:
|
488 |
+
ret += role + ": " + message
|
489 |
+
if i < len(messages) - 1:
|
490 |
+
# avoid including sep at the end of serialized message
|
491 |
+
ret += self.sep
|
492 |
+
else:
|
493 |
+
ret += role + ":"
|
494 |
+
else:
|
495 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
496 |
+
|
497 |
+
return ret
|
498 |
+
|
499 |
+
def append_message(self, role, message):
|
500 |
+
if len(self.messages) == 0:
|
501 |
+
# data verification for the very first message
|
502 |
+
assert role == self.roles[0], f"the very first message in conversation must be from role {self.roles[0]}"
|
503 |
+
assert len(message) == 2, f"the very first message in conversation must include both prompt and an image"
|
504 |
+
prompt, image = message[0], message[1]
|
505 |
+
assert prompt is not None, f"prompt must be not None"
|
506 |
+
assert isBase64(image), f"image must be under base64 encoding"
|
507 |
+
else:
|
508 |
+
# data verification for follow-up message
|
509 |
+
assert role in self.roles, f"the follow-up message must be from one of the roles {self.roles}"
|
510 |
+
assert len(message) == 1, f"the follow-up message must consist of one text message only, no image"
|
511 |
+
|
512 |
+
self.messages.append([role, message])
|
513 |
+
|
514 |
+
def copy(self):
|
515 |
+
return Conversation(
|
516 |
+
system=self.system,
|
517 |
+
roles=self.roles,
|
518 |
+
messages=[[x,y] for x, y in self.messages],
|
519 |
+
version=self.version,
|
520 |
+
map_roles=self.map_roles,
|
521 |
+
)
|
522 |
+
|
523 |
+
def dict(self):
|
524 |
+
return {
|
525 |
+
"system": self.system,
|
526 |
+
"roles": self.roles,
|
527 |
+
"messages": [[x, y[0] if len(y) == 1 else y] for x, y in self.messages],
|
528 |
+
"version": self.version,
|
529 |
+
}
|
530 |
+
|
531 |
+
prediction_guard_llava_conv = Conversation(
|
532 |
+
system="",
|
533 |
+
roles=("user", "assistant"),
|
534 |
+
messages=[],
|
535 |
+
version="Prediction Guard LLaVA enpoint Conversation v0",
|
536 |
+
sep_style=SeparatorStyle.SINGLE,
|
537 |
+
map_roles={
|
538 |
+
"user": "USER",
|
539 |
+
"assistant": "ASSISTANT"
|
540 |
+
}
|
541 |
+
)
|
542 |
+
|
543 |
+
# get PredictionGuard Client
|
544 |
+
def _getPredictionGuardClient():
|
545 |
+
PREDICTION_GUARD_API_KEY = get_prediction_guard_api_key()
|
546 |
+
client = PredictionGuard(
|
547 |
+
api_key=PREDICTION_GUARD_API_KEY,
|
548 |
+
url=PREDICTION_GUARD_URL_ENDPOINT,
|
549 |
+
)
|
550 |
+
return client
|
551 |
+
|
552 |
+
# helper function to call chat completion endpoint of PredictionGuard given a prompt and an image
|
553 |
+
def lvlm_inference(prompt, image, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
|
554 |
+
# prepare conversation
|
555 |
+
conversation = prediction_guard_llava_conv.copy()
|
556 |
+
conversation.append_message(conversation.roles[0], [prompt, image])
|
557 |
+
return lvlm_inference_with_conversation(conversation, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
|
558 |
+
|
559 |
+
|
560 |
+
|
561 |
+
def lvlm_inference_with_conversation(conversation, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
|
562 |
+
# get PredictionGuard client
|
563 |
+
client = _getPredictionGuardClient()
|
564 |
+
# get message from conversation
|
565 |
+
messages = conversation.get_message()
|
566 |
+
# call chat completion endpoint at Grediction Guard
|
567 |
+
response = client.chat.completions.create(
|
568 |
+
model="llava-1.5-7b-hf",
|
569 |
+
messages=messages,
|
570 |
+
max_tokens=max_tokens,
|
571 |
+
temperature=temperature,
|
572 |
+
top_p=top_p,
|
573 |
+
top_k=top_k,
|
574 |
+
)
|
575 |
+
return response['choices'][-1]['message']['content']
|
576 |
+
|
577 |
+
def get_token():
|
578 |
+
load_env()
|
579 |
+
token = os.getenv("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
580 |
+
if token is None:
|
581 |
+
raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
|
582 |
+
return token
|
583 |
+
|
584 |
+
|
585 |
+
def lvlm_inference_with_phi(prompt):
|
586 |
+
|
587 |
+
|
588 |
+
messages = [{"role": "user", "content": prompt}]
|
589 |
+
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=get_token())
|
590 |
+
response = ''
|
591 |
+
token = client.chat_completion(messages, max_tokens=256)
|
592 |
+
response = token['choices'][0]['message']['content']
|
593 |
+
return response
|
594 |
+
|
595 |
+
def lvlm_inference_with_tiny_model(prompt):
|
596 |
+
classifier = pipeline(
|
597 |
+
"text-generation",
|
598 |
+
model="microsoft/phi-2", # Only ~2.7GB
|
599 |
+
device_map="auto",
|
600 |
+
torch_dtype="auto",
|
601 |
+
)
|
602 |
+
|
603 |
+
response = classifier(
|
604 |
+
prompt,
|
605 |
+
max_new_tokens=512, # Remove max_length and use only max_new_tokens
|
606 |
+
temperature=0.7,
|
607 |
+
do_sample=True,
|
608 |
+
num_return_sequences=1,
|
609 |
+
truncation=True, # Add explicit truncation
|
610 |
+
pad_token_id=classifier.tokenizer.eos_token_id,
|
611 |
+
eos_token_id=classifier.tokenizer.eos_token_id,
|
612 |
+
)[0]['generated_text']
|
613 |
+
|
614 |
+
# Remove the input prompt from the response and clean up
|
615 |
+
return response.replace(prompt, "").strip()
|
616 |
+
|
617 |
+
# function `extract_and_save_frames_and_metadata``:
|
618 |
+
# receives as input a video and its transcript
|
619 |
+
# does extracting and saving frames and their metadatas
|
620 |
+
# returns the extracted metadatas
|
621 |
+
def extract_and_save_frames_and_metadata(
|
622 |
+
path_to_video,
|
623 |
+
path_to_transcript,
|
624 |
+
path_to_save_extracted_frames,
|
625 |
+
path_to_save_metadatas):
|
626 |
+
|
627 |
+
# metadatas will store the metadata of all extracted frames
|
628 |
+
metadatas = []
|
629 |
+
|
630 |
+
# load video using cv2
|
631 |
+
print(f"Loading video from {path_to_video}")
|
632 |
+
video = cv2.VideoCapture(path_to_video)
|
633 |
+
# load transcript using webvtt
|
634 |
+
print(f"Loading transcript from {path_to_transcript}")
|
635 |
+
trans = webvtt.read(path_to_transcript)
|
636 |
+
|
637 |
+
# iterate transcript file
|
638 |
+
# for each video segment specified in the transcript file
|
639 |
+
for idx, transcript in enumerate(trans):
|
640 |
+
# get the start time and end time in seconds
|
641 |
+
start_time_ms = str2time(transcript.start)
|
642 |
+
end_time_ms = str2time(transcript.end)
|
643 |
+
# get the time in ms exactly
|
644 |
+
# in the middle of start time and end time
|
645 |
+
mid_time_ms = (end_time_ms + start_time_ms) / 2
|
646 |
+
# get the transcript, remove the next-line symbol
|
647 |
+
text = transcript.text.replace("\n", ' ')
|
648 |
+
# get frame at the middle time
|
649 |
+
video.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
|
650 |
+
print(f"Extracting frame at {mid_time_ms} ms")
|
651 |
+
success, frame = video.read()
|
652 |
+
if success:
|
653 |
+
# if the frame is extracted successfully, resize it
|
654 |
+
image = maintain_aspect_ratio_resize(frame, height=350)
|
655 |
+
# save frame as JPEG file
|
656 |
+
img_fname = f'frame_{idx}.jpg'
|
657 |
+
img_fpath = osp.join(
|
658 |
+
path_to_save_extracted_frames, img_fname
|
659 |
+
)
|
660 |
+
cv2.imwrite(img_fpath, image)
|
661 |
+
|
662 |
+
# prepare the metadata
|
663 |
+
metadata = {
|
664 |
+
'extracted_frame_path': img_fpath,
|
665 |
+
'transcript': text,
|
666 |
+
'video_segment_id': idx,
|
667 |
+
'video_path': path_to_video,
|
668 |
+
'mid_time_ms': mid_time_ms,
|
669 |
+
}
|
670 |
+
metadatas.append(metadata)
|
671 |
+
|
672 |
+
else:
|
673 |
+
print(f"ERROR! Cannot extract frame: idx = {idx}")
|
674 |
+
|
675 |
+
# save metadata of all extracted frames
|
676 |
+
fn = osp.join(path_to_save_metadatas, 'metadatas.json')
|
677 |
+
with open(fn, 'w') as outfile:
|
678 |
+
json.dump(metadatas, outfile)
|
679 |
+
return metadatas
|
680 |
+
|
681 |
+
def extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath):
|
682 |
+
# output paths to save extracted frames and their metadata
|
683 |
+
extracted_frames_path = osp.join(vid_dir, 'extracted_frame')
|
684 |
+
metadatas_path = vid_dir
|
685 |
+
|
686 |
+
# create these output folders if not existing
|
687 |
+
print(f"Creating folders {extracted_frames_path} and {metadatas_path}")
|
688 |
+
Path(extracted_frames_path).mkdir(parents=True, exist_ok=True)
|
689 |
+
Path(metadatas_path).mkdir(parents=True, exist_ok=True)
|
690 |
+
print("Extracting frames the video path ", vid_filepath)
|
691 |
+
|
692 |
+
# call the function to extract frames and metadatas
|
693 |
+
metadatas = extract_and_save_frames_and_metadata(
|
694 |
+
vid_filepath,
|
695 |
+
vid_transcript_filepath,
|
696 |
+
extracted_frames_path,
|
697 |
+
metadatas_path,
|
698 |
+
)
|
699 |
+
return metadatas
|
700 |
+
|
701 |
+
# function extract_and_save_frames_and_metadata_with_fps
|
702 |
+
# receives as input a video
|
703 |
+
# does extracting and saving frames and their metadatas
|
704 |
+
# returns the extracted metadatas
|
705 |
+
def extract_and_save_frames_and_metadata_with_fps(
|
706 |
+
lvlm_prompt,
|
707 |
+
path_to_video,
|
708 |
+
path_to_save_extracted_frames,
|
709 |
+
path_to_save_metadatas,
|
710 |
+
num_of_extracted_frames_per_second=1):
|
711 |
+
|
712 |
+
# metadatas will store the metadata of all extracted frames
|
713 |
+
metadatas = []
|
714 |
+
|
715 |
+
# load video using cv2
|
716 |
+
video = cv2.VideoCapture(path_to_video)
|
717 |
+
|
718 |
+
# Get the frames per second
|
719 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
720 |
+
# Get hop = the number of frames pass before a frame is extracted
|
721 |
+
hop = round(fps / num_of_extracted_frames_per_second)
|
722 |
+
curr_frame = 0
|
723 |
+
idx = -1
|
724 |
+
while(True):
|
725 |
+
# iterate all frames
|
726 |
+
ret, frame = video.read()
|
727 |
+
if not ret:
|
728 |
+
break
|
729 |
+
if curr_frame % hop == 0:
|
730 |
+
idx = idx + 1
|
731 |
+
|
732 |
+
# if the frame is extracted successfully, resize it
|
733 |
+
image = maintain_aspect_ratio_resize(frame, height=350)
|
734 |
+
# save frame as JPEG file
|
735 |
+
img_fname = f'frame_{idx}.jpg'
|
736 |
+
img_fpath = osp.join(
|
737 |
+
path_to_save_extracted_frames,
|
738 |
+
img_fname
|
739 |
+
)
|
740 |
+
cv2.imwrite(img_fpath, image)
|
741 |
+
|
742 |
+
# generate caption using lvlm_inference
|
743 |
+
b64_image = encode_image(img_fpath)
|
744 |
+
caption = lvlm_inference(lvlm_prompt, b64_image)
|
745 |
+
|
746 |
+
# prepare the metadata
|
747 |
+
metadata = {
|
748 |
+
'extracted_frame_path': img_fpath,
|
749 |
+
'transcript': caption,
|
750 |
+
'video_segment_id': idx,
|
751 |
+
'video_path': path_to_video,
|
752 |
+
}
|
753 |
+
metadatas.append(metadata)
|
754 |
+
curr_frame += 1
|
755 |
+
|
756 |
+
# save metadata of all extracted frames
|
757 |
+
metadatas_path = osp.join(path_to_save_metadatas,'metadatas.json')
|
758 |
+
with open(metadatas_path, 'w') as outfile:
|
759 |
+
json.dump(metadatas, outfile)
|
760 |
+
return metadatas
|
761 |
+
|
762 |
+
if __name__ == "__main__":
|
763 |
+
res = lvlm_inference_with_phi("Tell me a story")
|
764 |
print(res)
|