Add comprehensive demo script with gte-large integration
Browse files- bsg_cyllama_demo.py +294 -0
bsg_cyllama_demo.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
BSG CyLlama Demo Script
|
4 |
+
Simplified demonstration of BSG CyLlama with gte-large integration
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
+
from peft import PeftModel
|
12 |
+
from sentence_transformers import SentenceTransformer
|
13 |
+
from typing import List, Tuple, Optional
|
14 |
+
|
15 |
+
class BSGCyLlamaInference:
|
16 |
+
"""BSG CyLlama inference with gte-large integration"""
|
17 |
+
|
18 |
+
def __init__(self, model_repo: str = "jimnoneill/BSG_CyLlama"):
|
19 |
+
"""
|
20 |
+
Initialize BSG CyLlama with gte-large sentence transformer
|
21 |
+
|
22 |
+
Args:
|
23 |
+
model_repo: Hugging Face model repository
|
24 |
+
"""
|
25 |
+
print("๐ Loading BSG CyLlama and gte-large models...")
|
26 |
+
|
27 |
+
# Load the embedding model (REQUIRED for optimal performance)
|
28 |
+
self.sbert_model = SentenceTransformer("thenlper/gte-large")
|
29 |
+
print("โ
Loaded gte-large sentence transformer")
|
30 |
+
|
31 |
+
# Load BSG CyLlama
|
32 |
+
base_model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
33 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
34 |
+
if self.tokenizer.pad_token is None:
|
35 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
36 |
+
|
37 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
38 |
+
base_model_name,
|
39 |
+
torch_dtype=torch.float16,
|
40 |
+
device_map="auto",
|
41 |
+
trust_remote_code=True
|
42 |
+
)
|
43 |
+
|
44 |
+
# Load the LoRA adapter
|
45 |
+
self.model = PeftModel.from_pretrained(base_model, model_repo)
|
46 |
+
print("โ
Loaded BSG CyLlama model")
|
47 |
+
|
48 |
+
def create_cluster_embedding(self, cluster_abstracts: List[str], keywords: List[str]) -> np.ndarray:
|
49 |
+
"""
|
50 |
+
Create embeddings for cluster content using gte-large
|
51 |
+
|
52 |
+
Args:
|
53 |
+
cluster_abstracts: List of scientific abstracts
|
54 |
+
keywords: List of keywords/topics
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
1024-dimensional embedding vector
|
58 |
+
"""
|
59 |
+
# Combine abstracts and keywords for rich context
|
60 |
+
combined_text = " ".join(cluster_abstracts)
|
61 |
+
if keywords:
|
62 |
+
combined_text += " Keywords: " + " ".join(keywords)
|
63 |
+
|
64 |
+
# Generate embedding using gte-large (1024 dimensions)
|
65 |
+
embedding = self.sbert_model.encode([combined_text])
|
66 |
+
return embedding[0]
|
67 |
+
|
68 |
+
def generate_research_analysis(self, embedding_context: Optional[np.ndarray] = None,
|
69 |
+
source_text: str = "", max_length: int = 300) -> Tuple[str, str, str]:
|
70 |
+
"""
|
71 |
+
Generate research analysis using embedding context
|
72 |
+
|
73 |
+
Args:
|
74 |
+
embedding_context: Optional embedding for context (from gte-large)
|
75 |
+
source_text: Source text to summarize
|
76 |
+
max_length: Maximum generation length
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Tuple of (abstract, short_summary, title)
|
80 |
+
"""
|
81 |
+
# Create enhanced prompt
|
82 |
+
if source_text:
|
83 |
+
prompt = f"""Summarize the following scientific research:
|
84 |
+
|
85 |
+
{source_text[:1000]}
|
86 |
+
|
87 |
+
Provide:
|
88 |
+
1. A comprehensive abstract
|
89 |
+
2. A concise summary
|
90 |
+
3. An informative title
|
91 |
+
|
92 |
+
Abstract:"""
|
93 |
+
else:
|
94 |
+
prompt = """Generate a scientific research analysis including:
|
95 |
+
|
96 |
+
1. Abstract: A comprehensive overview
|
97 |
+
2. Summary: Key findings and implications
|
98 |
+
3. Title: Descriptive research title
|
99 |
+
|
100 |
+
Abstract:"""
|
101 |
+
|
102 |
+
inputs = self.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
|
103 |
+
|
104 |
+
with torch.no_grad():
|
105 |
+
outputs = self.model.generate(
|
106 |
+
inputs,
|
107 |
+
max_length=len(inputs[0]) + max_length,
|
108 |
+
num_return_sequences=1,
|
109 |
+
temperature=0.7,
|
110 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
111 |
+
do_sample=True,
|
112 |
+
top_p=0.9,
|
113 |
+
repetition_penalty=1.1
|
114 |
+
)
|
115 |
+
|
116 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
117 |
+
analysis = generated_text[len(self.tokenizer.decode(inputs[0], skip_special_tokens=True)):].strip()
|
118 |
+
|
119 |
+
# Parse the generated content
|
120 |
+
lines = [line.strip() for line in analysis.split('\n') if line.strip()]
|
121 |
+
|
122 |
+
# Extract abstract (first substantial line)
|
123 |
+
abstract = ""
|
124 |
+
short_summary = ""
|
125 |
+
title = ""
|
126 |
+
|
127 |
+
for line in lines:
|
128 |
+
if len(line) > 20 and not any(keyword in line.lower() for keyword in ['summary:', 'title:', 'abstract:']):
|
129 |
+
if not abstract:
|
130 |
+
abstract = line
|
131 |
+
elif not short_summary and len(line) < len(abstract):
|
132 |
+
short_summary = line
|
133 |
+
elif not title and len(line) < 100:
|
134 |
+
title = line
|
135 |
+
break
|
136 |
+
|
137 |
+
# Fallback generation if parsing fails
|
138 |
+
if not abstract:
|
139 |
+
abstract = lines[0] if lines else "Scientific research analysis focusing on advanced methodologies and findings."
|
140 |
+
|
141 |
+
if not short_summary:
|
142 |
+
short_summary = abstract[:150] + "..." if len(abstract) > 150 else abstract
|
143 |
+
|
144 |
+
if not title:
|
145 |
+
# Generate title from abstract
|
146 |
+
words = abstract.split()[:8]
|
147 |
+
title = "Scientific Research: " + " ".join(words)
|
148 |
+
|
149 |
+
return abstract, short_summary, title
|
150 |
+
|
151 |
+
def generate_cluster_content(flat_tokens: List[str], cluster_abstracts: Optional[List[str]] = None,
|
152 |
+
cluster_name: str = "") -> Tuple[str, str, str]:
|
153 |
+
"""
|
154 |
+
Generate content using trained BSG CyLlama model with gte-large embeddings
|
155 |
+
|
156 |
+
Args:
|
157 |
+
flat_tokens: List of keywords/tokens
|
158 |
+
cluster_abstracts: Optional list of abstracts for context
|
159 |
+
cluster_name: Name of the cluster for error reporting
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tuple of (overview, title, abstract)
|
163 |
+
"""
|
164 |
+
global model_inference
|
165 |
+
|
166 |
+
if 'model_inference' not in globals():
|
167 |
+
try:
|
168 |
+
model_inference = BSGCyLlamaInference()
|
169 |
+
except Exception as e:
|
170 |
+
print(f"โ ๏ธ Failed to load BSG CyLlama: {e}")
|
171 |
+
model_inference = None
|
172 |
+
|
173 |
+
if model_inference is not None and cluster_abstracts:
|
174 |
+
try:
|
175 |
+
# Use trained model with abstracts and keywords
|
176 |
+
embedding = model_inference.create_cluster_embedding(cluster_abstracts, flat_tokens)
|
177 |
+
|
178 |
+
# Generate content using the first abstract as context
|
179 |
+
source_text = cluster_abstracts[0] if cluster_abstracts else ""
|
180 |
+
abstract, overview, title = model_inference.generate_research_analysis(embedding, source_text)
|
181 |
+
|
182 |
+
return overview, title, abstract
|
183 |
+
|
184 |
+
except Exception as e:
|
185 |
+
print(f"โ ๏ธ Model generation failed for {cluster_name}: {e}, using fallback")
|
186 |
+
|
187 |
+
# Fallback method for when model is not available
|
188 |
+
try:
|
189 |
+
title = f"Research on {', '.join(flat_tokens[:3])}"
|
190 |
+
summary = f"Analysis of research focusing on {', '.join(flat_tokens[:10])}"
|
191 |
+
abstract = f"Comprehensive investigation of {', '.join(flat_tokens[:5])} and related scientific topics"
|
192 |
+
return summary, title, abstract
|
193 |
+
except Exception as e:
|
194 |
+
print(f"โ ๏ธ All generation methods failed for {cluster_name}: {e}")
|
195 |
+
title = "Research Cluster Analysis"
|
196 |
+
summary = "Research cluster analysis"
|
197 |
+
abstract = "Comprehensive analysis of research cluster"
|
198 |
+
return summary, title, abstract
|
199 |
+
|
200 |
+
def demo_with_training_data():
|
201 |
+
"""Demonstrate BSG CyLlama using the training dataset"""
|
202 |
+
print("๐ฌ BSG CyLlama Demo with Training Data")
|
203 |
+
print("=" * 50)
|
204 |
+
|
205 |
+
try:
|
206 |
+
# Load the training dataset from Hugging Face
|
207 |
+
dataset_url = "https://huggingface.co/datasets/jimnoneill/BSG_CyLlama-training/resolve/main/bsg_training_data_complete_aligned.tsv"
|
208 |
+
print(f"๐ Loading training dataset from: {dataset_url}")
|
209 |
+
|
210 |
+
df = pd.read_csv(dataset_url, sep='\t', nrows=5) # Load first 5 rows for demo
|
211 |
+
print(f"โ
Loaded {len(df)} sample records")
|
212 |
+
|
213 |
+
# Initialize the model
|
214 |
+
print("\n๐ค Initializing BSG CyLlama...")
|
215 |
+
model_inference = BSGCyLlamaInference()
|
216 |
+
|
217 |
+
# Process a sample
|
218 |
+
for i, row in df.head(2).iterrows(): # Demo with first 2 records
|
219 |
+
print(f"\n๐ Sample {i+1}:")
|
220 |
+
print("-" * 30)
|
221 |
+
|
222 |
+
# Extract data
|
223 |
+
original_text = row['OriginalText'] if pd.notna(row['OriginalText']) else ""
|
224 |
+
training_summary = row['AbstractSummary'] if pd.notna(row['AbstractSummary']) else ""
|
225 |
+
keywords = str(row['TopKeywords']).split() if pd.notna(row['TopKeywords']) else []
|
226 |
+
|
227 |
+
print(f"Original Abstract: {original_text[:200]}...")
|
228 |
+
print(f"Training Summary: {training_summary[:200]}...")
|
229 |
+
|
230 |
+
# Generate new summary using our model
|
231 |
+
cluster_abstracts = [original_text] if original_text else None
|
232 |
+
overview, title, abstract = generate_cluster_content(keywords, cluster_abstracts, f"sample_{i}")
|
233 |
+
|
234 |
+
print(f"\n๐ฎ Generated Results:")
|
235 |
+
print(f"Title: {title}")
|
236 |
+
print(f"Overview: {overview[:200]}...")
|
237 |
+
print(f"Abstract: {abstract[:200]}...")
|
238 |
+
|
239 |
+
print(f"\nโ
Demo completed successfully!")
|
240 |
+
|
241 |
+
except Exception as e:
|
242 |
+
print(f"โ Demo failed: {e}")
|
243 |
+
print("๐ก Make sure you have internet access to download the model and dataset")
|
244 |
+
|
245 |
+
def simple_summarization_demo():
|
246 |
+
"""Simple demonstration of text summarization"""
|
247 |
+
print("\n๐ฌ Simple Summarization Demo")
|
248 |
+
print("=" * 40)
|
249 |
+
|
250 |
+
sample_text = """
|
251 |
+
Deep learning models have revolutionized medical image analysis by providing
|
252 |
+
unprecedented accuracy in disease detection and diagnosis. Convolutional neural
|
253 |
+
networks (CNNs) have been particularly successful in analyzing radiological
|
254 |
+
images, including X-rays, CT scans, and MRI images. Recent advances in
|
255 |
+
transformer architectures have further improved the ability to understand
|
256 |
+
complex spatial relationships in medical imagery. These developments have
|
257 |
+
significant implications for clinical practice, potentially reducing diagnostic
|
258 |
+
errors and improving patient outcomes.
|
259 |
+
"""
|
260 |
+
|
261 |
+
try:
|
262 |
+
model_inference = BSGCyLlamaInference()
|
263 |
+
abstract, summary, title = model_inference.generate_research_analysis(
|
264 |
+
source_text=sample_text
|
265 |
+
)
|
266 |
+
|
267 |
+
print(f"๐ Original Text: {sample_text.strip()[:200]}...")
|
268 |
+
print(f"\n๐ฎ Generated Results:")
|
269 |
+
print(f"Title: {title}")
|
270 |
+
print(f"Summary: {summary}")
|
271 |
+
print(f"Abstract: {abstract}")
|
272 |
+
|
273 |
+
except Exception as e:
|
274 |
+
print(f"โ Summarization failed: {e}")
|
275 |
+
|
276 |
+
if __name__ == "__main__":
|
277 |
+
print("๐ BSG CyLlama Demo Script")
|
278 |
+
print("Specialized Scientific Summarization with gte-large Integration")
|
279 |
+
print("=" * 60)
|
280 |
+
|
281 |
+
# Run demos
|
282 |
+
try:
|
283 |
+
# Demo 1: With training data
|
284 |
+
demo_with_training_data()
|
285 |
+
|
286 |
+
# Demo 2: Simple summarization
|
287 |
+
simple_summarization_demo()
|
288 |
+
|
289 |
+
except KeyboardInterrupt:
|
290 |
+
print("\nโน๏ธ Demo stopped by user")
|
291 |
+
except Exception as e:
|
292 |
+
print(f"\nโ Demo failed: {e}")
|
293 |
+
print("๐ก Please ensure you have the required dependencies installed:")
|
294 |
+
print(" pip install torch transformers peft sentence-transformers pandas")
|