jimnoneill commited on
Commit
c23b225
ยท
verified ยท
1 Parent(s): 5adec8e

Add comprehensive demo script with gte-large integration

Browse files
Files changed (1) hide show
  1. bsg_cyllama_demo.py +294 -0
bsg_cyllama_demo.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BSG CyLlama Demo Script
4
+ Simplified demonstration of BSG CyLlama with gte-large integration
5
+ """
6
+
7
+ import torch
8
+ import pandas as pd
9
+ import numpy as np
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from peft import PeftModel
12
+ from sentence_transformers import SentenceTransformer
13
+ from typing import List, Tuple, Optional
14
+
15
+ class BSGCyLlamaInference:
16
+ """BSG CyLlama inference with gte-large integration"""
17
+
18
+ def __init__(self, model_repo: str = "jimnoneill/BSG_CyLlama"):
19
+ """
20
+ Initialize BSG CyLlama with gte-large sentence transformer
21
+
22
+ Args:
23
+ model_repo: Hugging Face model repository
24
+ """
25
+ print("๐Ÿ”„ Loading BSG CyLlama and gte-large models...")
26
+
27
+ # Load the embedding model (REQUIRED for optimal performance)
28
+ self.sbert_model = SentenceTransformer("thenlper/gte-large")
29
+ print("โœ… Loaded gte-large sentence transformer")
30
+
31
+ # Load BSG CyLlama
32
+ base_model_name = "meta-llama/Llama-3.2-1B-Instruct"
33
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
34
+ if self.tokenizer.pad_token is None:
35
+ self.tokenizer.pad_token = self.tokenizer.eos_token
36
+
37
+ base_model = AutoModelForCausalLM.from_pretrained(
38
+ base_model_name,
39
+ torch_dtype=torch.float16,
40
+ device_map="auto",
41
+ trust_remote_code=True
42
+ )
43
+
44
+ # Load the LoRA adapter
45
+ self.model = PeftModel.from_pretrained(base_model, model_repo)
46
+ print("โœ… Loaded BSG CyLlama model")
47
+
48
+ def create_cluster_embedding(self, cluster_abstracts: List[str], keywords: List[str]) -> np.ndarray:
49
+ """
50
+ Create embeddings for cluster content using gte-large
51
+
52
+ Args:
53
+ cluster_abstracts: List of scientific abstracts
54
+ keywords: List of keywords/topics
55
+
56
+ Returns:
57
+ 1024-dimensional embedding vector
58
+ """
59
+ # Combine abstracts and keywords for rich context
60
+ combined_text = " ".join(cluster_abstracts)
61
+ if keywords:
62
+ combined_text += " Keywords: " + " ".join(keywords)
63
+
64
+ # Generate embedding using gte-large (1024 dimensions)
65
+ embedding = self.sbert_model.encode([combined_text])
66
+ return embedding[0]
67
+
68
+ def generate_research_analysis(self, embedding_context: Optional[np.ndarray] = None,
69
+ source_text: str = "", max_length: int = 300) -> Tuple[str, str, str]:
70
+ """
71
+ Generate research analysis using embedding context
72
+
73
+ Args:
74
+ embedding_context: Optional embedding for context (from gte-large)
75
+ source_text: Source text to summarize
76
+ max_length: Maximum generation length
77
+
78
+ Returns:
79
+ Tuple of (abstract, short_summary, title)
80
+ """
81
+ # Create enhanced prompt
82
+ if source_text:
83
+ prompt = f"""Summarize the following scientific research:
84
+
85
+ {source_text[:1000]}
86
+
87
+ Provide:
88
+ 1. A comprehensive abstract
89
+ 2. A concise summary
90
+ 3. An informative title
91
+
92
+ Abstract:"""
93
+ else:
94
+ prompt = """Generate a scientific research analysis including:
95
+
96
+ 1. Abstract: A comprehensive overview
97
+ 2. Summary: Key findings and implications
98
+ 3. Title: Descriptive research title
99
+
100
+ Abstract:"""
101
+
102
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
103
+
104
+ with torch.no_grad():
105
+ outputs = self.model.generate(
106
+ inputs,
107
+ max_length=len(inputs[0]) + max_length,
108
+ num_return_sequences=1,
109
+ temperature=0.7,
110
+ pad_token_id=self.tokenizer.eos_token_id,
111
+ do_sample=True,
112
+ top_p=0.9,
113
+ repetition_penalty=1.1
114
+ )
115
+
116
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
117
+ analysis = generated_text[len(self.tokenizer.decode(inputs[0], skip_special_tokens=True)):].strip()
118
+
119
+ # Parse the generated content
120
+ lines = [line.strip() for line in analysis.split('\n') if line.strip()]
121
+
122
+ # Extract abstract (first substantial line)
123
+ abstract = ""
124
+ short_summary = ""
125
+ title = ""
126
+
127
+ for line in lines:
128
+ if len(line) > 20 and not any(keyword in line.lower() for keyword in ['summary:', 'title:', 'abstract:']):
129
+ if not abstract:
130
+ abstract = line
131
+ elif not short_summary and len(line) < len(abstract):
132
+ short_summary = line
133
+ elif not title and len(line) < 100:
134
+ title = line
135
+ break
136
+
137
+ # Fallback generation if parsing fails
138
+ if not abstract:
139
+ abstract = lines[0] if lines else "Scientific research analysis focusing on advanced methodologies and findings."
140
+
141
+ if not short_summary:
142
+ short_summary = abstract[:150] + "..." if len(abstract) > 150 else abstract
143
+
144
+ if not title:
145
+ # Generate title from abstract
146
+ words = abstract.split()[:8]
147
+ title = "Scientific Research: " + " ".join(words)
148
+
149
+ return abstract, short_summary, title
150
+
151
+ def generate_cluster_content(flat_tokens: List[str], cluster_abstracts: Optional[List[str]] = None,
152
+ cluster_name: str = "") -> Tuple[str, str, str]:
153
+ """
154
+ Generate content using trained BSG CyLlama model with gte-large embeddings
155
+
156
+ Args:
157
+ flat_tokens: List of keywords/tokens
158
+ cluster_abstracts: Optional list of abstracts for context
159
+ cluster_name: Name of the cluster for error reporting
160
+
161
+ Returns:
162
+ Tuple of (overview, title, abstract)
163
+ """
164
+ global model_inference
165
+
166
+ if 'model_inference' not in globals():
167
+ try:
168
+ model_inference = BSGCyLlamaInference()
169
+ except Exception as e:
170
+ print(f"โš ๏ธ Failed to load BSG CyLlama: {e}")
171
+ model_inference = None
172
+
173
+ if model_inference is not None and cluster_abstracts:
174
+ try:
175
+ # Use trained model with abstracts and keywords
176
+ embedding = model_inference.create_cluster_embedding(cluster_abstracts, flat_tokens)
177
+
178
+ # Generate content using the first abstract as context
179
+ source_text = cluster_abstracts[0] if cluster_abstracts else ""
180
+ abstract, overview, title = model_inference.generate_research_analysis(embedding, source_text)
181
+
182
+ return overview, title, abstract
183
+
184
+ except Exception as e:
185
+ print(f"โš ๏ธ Model generation failed for {cluster_name}: {e}, using fallback")
186
+
187
+ # Fallback method for when model is not available
188
+ try:
189
+ title = f"Research on {', '.join(flat_tokens[:3])}"
190
+ summary = f"Analysis of research focusing on {', '.join(flat_tokens[:10])}"
191
+ abstract = f"Comprehensive investigation of {', '.join(flat_tokens[:5])} and related scientific topics"
192
+ return summary, title, abstract
193
+ except Exception as e:
194
+ print(f"โš ๏ธ All generation methods failed for {cluster_name}: {e}")
195
+ title = "Research Cluster Analysis"
196
+ summary = "Research cluster analysis"
197
+ abstract = "Comprehensive analysis of research cluster"
198
+ return summary, title, abstract
199
+
200
+ def demo_with_training_data():
201
+ """Demonstrate BSG CyLlama using the training dataset"""
202
+ print("๐Ÿ”ฌ BSG CyLlama Demo with Training Data")
203
+ print("=" * 50)
204
+
205
+ try:
206
+ # Load the training dataset from Hugging Face
207
+ dataset_url = "https://huggingface.co/datasets/jimnoneill/BSG_CyLlama-training/resolve/main/bsg_training_data_complete_aligned.tsv"
208
+ print(f"๐Ÿ“Š Loading training dataset from: {dataset_url}")
209
+
210
+ df = pd.read_csv(dataset_url, sep='\t', nrows=5) # Load first 5 rows for demo
211
+ print(f"โœ… Loaded {len(df)} sample records")
212
+
213
+ # Initialize the model
214
+ print("\n๐Ÿค– Initializing BSG CyLlama...")
215
+ model_inference = BSGCyLlamaInference()
216
+
217
+ # Process a sample
218
+ for i, row in df.head(2).iterrows(): # Demo with first 2 records
219
+ print(f"\n๐Ÿ“„ Sample {i+1}:")
220
+ print("-" * 30)
221
+
222
+ # Extract data
223
+ original_text = row['OriginalText'] if pd.notna(row['OriginalText']) else ""
224
+ training_summary = row['AbstractSummary'] if pd.notna(row['AbstractSummary']) else ""
225
+ keywords = str(row['TopKeywords']).split() if pd.notna(row['TopKeywords']) else []
226
+
227
+ print(f"Original Abstract: {original_text[:200]}...")
228
+ print(f"Training Summary: {training_summary[:200]}...")
229
+
230
+ # Generate new summary using our model
231
+ cluster_abstracts = [original_text] if original_text else None
232
+ overview, title, abstract = generate_cluster_content(keywords, cluster_abstracts, f"sample_{i}")
233
+
234
+ print(f"\n๐Ÿ”ฎ Generated Results:")
235
+ print(f"Title: {title}")
236
+ print(f"Overview: {overview[:200]}...")
237
+ print(f"Abstract: {abstract[:200]}...")
238
+
239
+ print(f"\nโœ… Demo completed successfully!")
240
+
241
+ except Exception as e:
242
+ print(f"โŒ Demo failed: {e}")
243
+ print("๐Ÿ’ก Make sure you have internet access to download the model and dataset")
244
+
245
+ def simple_summarization_demo():
246
+ """Simple demonstration of text summarization"""
247
+ print("\n๐Ÿ”ฌ Simple Summarization Demo")
248
+ print("=" * 40)
249
+
250
+ sample_text = """
251
+ Deep learning models have revolutionized medical image analysis by providing
252
+ unprecedented accuracy in disease detection and diagnosis. Convolutional neural
253
+ networks (CNNs) have been particularly successful in analyzing radiological
254
+ images, including X-rays, CT scans, and MRI images. Recent advances in
255
+ transformer architectures have further improved the ability to understand
256
+ complex spatial relationships in medical imagery. These developments have
257
+ significant implications for clinical practice, potentially reducing diagnostic
258
+ errors and improving patient outcomes.
259
+ """
260
+
261
+ try:
262
+ model_inference = BSGCyLlamaInference()
263
+ abstract, summary, title = model_inference.generate_research_analysis(
264
+ source_text=sample_text
265
+ )
266
+
267
+ print(f"๐Ÿ“„ Original Text: {sample_text.strip()[:200]}...")
268
+ print(f"\n๐Ÿ”ฎ Generated Results:")
269
+ print(f"Title: {title}")
270
+ print(f"Summary: {summary}")
271
+ print(f"Abstract: {abstract}")
272
+
273
+ except Exception as e:
274
+ print(f"โŒ Summarization failed: {e}")
275
+
276
+ if __name__ == "__main__":
277
+ print("๐Ÿš€ BSG CyLlama Demo Script")
278
+ print("Specialized Scientific Summarization with gte-large Integration")
279
+ print("=" * 60)
280
+
281
+ # Run demos
282
+ try:
283
+ # Demo 1: With training data
284
+ demo_with_training_data()
285
+
286
+ # Demo 2: Simple summarization
287
+ simple_summarization_demo()
288
+
289
+ except KeyboardInterrupt:
290
+ print("\nโน๏ธ Demo stopped by user")
291
+ except Exception as e:
292
+ print(f"\nโŒ Demo failed: {e}")
293
+ print("๐Ÿ’ก Please ensure you have the required dependencies installed:")
294
+ print(" pip install torch transformers peft sentence-transformers pandas")