|
|
|
""" |
|
BSG CyLlama Demo Script: Biomedical Summary Generation through Cyclical Llama |
|
Demonstrates the revolutionary cyclical embedding averaging methodology with named entity integration |
|
""" |
|
|
|
import torch |
|
import pandas as pd |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import PeftModel |
|
from sentence_transformers import SentenceTransformer |
|
from typing import List, Tuple, Optional |
|
|
|
class BSGCyLlamaInference: |
|
""" |
|
BSG CyLlama: Biomedical Summary Generation through Cyclical Llama |
|
|
|
Revolutionary corpus-level summarization using: |
|
1. Cyclical embedding averaging across document corpus |
|
2. Named entity concatenation with averaged embeddings |
|
3. Approximation embedding document generation |
|
4. Corpus-level summary synthesis |
|
""" |
|
|
|
def __init__(self, model_repo: str = "jimnoneill/BSG_CyLlama"): |
|
""" |
|
Initialize BSG CyLlama with gte-large sentence transformer |
|
|
|
Args: |
|
model_repo: Hugging Face model repository |
|
""" |
|
print("๐ Loading BSG CyLlama and gte-large models...") |
|
|
|
|
|
self.sbert_model = SentenceTransformer("thenlper/gte-large") |
|
print("โ
Loaded gte-large sentence transformer") |
|
|
|
|
|
base_model_name = "meta-llama/Llama-3.2-1B-Instruct" |
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.model = PeftModel.from_pretrained(base_model, model_repo) |
|
print("โ
Loaded BSG CyLlama model") |
|
|
|
def create_cluster_embedding(self, cluster_abstracts: List[str], keywords: List[str]) -> np.ndarray: |
|
""" |
|
BSG CyLlama Core Innovation: Cyclical Embedding Averaging |
|
|
|
Creates approximation embedding documents through cyclical averaging of corpus embeddings |
|
with named entity concatenation - the key methodology behind BSG CyLlama. |
|
|
|
Args: |
|
cluster_abstracts: List of scientific abstracts (corpus) |
|
keywords: List of named entities for concatenation |
|
|
|
Returns: |
|
1024-dimensional cyclically-averaged embedding with entity integration |
|
""" |
|
if not cluster_abstracts: |
|
|
|
combined_text = " ".join(keywords) if keywords else "scientific research analysis" |
|
return self.sbert_model.encode([combined_text])[0] |
|
|
|
|
|
document_embeddings = [] |
|
for abstract in cluster_abstracts: |
|
embedding = self.sbert_model.encode([abstract]) |
|
document_embeddings.append(embedding[0]) |
|
|
|
|
|
n_docs = len(document_embeddings) |
|
cyclically_averaged = np.zeros_like(document_embeddings[0]) |
|
|
|
for i, embedding in enumerate(document_embeddings): |
|
|
|
phase = 2 * np.pi * i / n_docs |
|
cycle_weight = (np.cos(phase) + 1) / 2 |
|
cyclically_averaged += embedding * cycle_weight |
|
|
|
cyclically_averaged = cyclically_averaged / n_docs |
|
|
|
|
|
if keywords: |
|
entity_text = " ".join(keywords) |
|
entity_embedding = self.sbert_model.encode([entity_text])[0] |
|
|
|
|
|
|
|
concatenated_embedding = np.concatenate([cyclically_averaged, entity_embedding]) |
|
|
|
|
|
if len(concatenated_embedding) > 1024: |
|
concatenated_embedding = concatenated_embedding[:1024] |
|
elif len(concatenated_embedding) < 1024: |
|
padding = np.zeros(1024 - len(concatenated_embedding)) |
|
concatenated_embedding = np.concatenate([concatenated_embedding, padding]) |
|
|
|
return concatenated_embedding |
|
|
|
return cyclically_averaged |
|
|
|
def generate_research_analysis(self, embedding_context: Optional[np.ndarray] = None, |
|
source_text: str = "", max_length: int = 300) -> Tuple[str, str, str]: |
|
""" |
|
Generate research analysis using embedding context |
|
|
|
Args: |
|
embedding_context: Optional embedding for context (from gte-large) |
|
source_text: Source text to summarize |
|
max_length: Maximum generation length |
|
|
|
Returns: |
|
Tuple of (abstract, short_summary, title) |
|
""" |
|
|
|
if source_text: |
|
prompt = f"""Summarize the following scientific research: |
|
|
|
{source_text[:1000]} |
|
|
|
Provide: |
|
1. A comprehensive abstract |
|
2. A concise summary |
|
3. An informative title |
|
|
|
Abstract:""" |
|
else: |
|
prompt = """Generate a scientific research analysis including: |
|
|
|
1. Abstract: A comprehensive overview |
|
2. Summary: Key findings and implications |
|
3. Title: Descriptive research title |
|
|
|
Abstract:""" |
|
|
|
inputs = self.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
inputs, |
|
max_length=len(inputs[0]) + max_length, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
do_sample=True, |
|
top_p=0.9, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
analysis = generated_text[len(self.tokenizer.decode(inputs[0], skip_special_tokens=True)):].strip() |
|
|
|
|
|
lines = [line.strip() for line in analysis.split('\n') if line.strip()] |
|
|
|
|
|
abstract = "" |
|
short_summary = "" |
|
title = "" |
|
|
|
for line in lines: |
|
if len(line) > 20 and not any(keyword in line.lower() for keyword in ['summary:', 'title:', 'abstract:']): |
|
if not abstract: |
|
abstract = line |
|
elif not short_summary and len(line) < len(abstract): |
|
short_summary = line |
|
elif not title and len(line) < 100: |
|
title = line |
|
break |
|
|
|
|
|
if not abstract: |
|
abstract = lines[0] if lines else "Scientific research analysis focusing on advanced methodologies and findings." |
|
|
|
if not short_summary: |
|
short_summary = abstract[:150] + "..." if len(abstract) > 150 else abstract |
|
|
|
if not title: |
|
|
|
words = abstract.split()[:8] |
|
title = "Scientific Research: " + " ".join(words) |
|
|
|
return abstract, short_summary, title |
|
|
|
def generate_cluster_content(flat_tokens: List[str], cluster_abstracts: Optional[List[str]] = None, |
|
cluster_name: str = "") -> Tuple[str, str, str]: |
|
""" |
|
BSG CyLlama Corpus-Level Content Generation |
|
|
|
Implements the complete BSG CyLlama methodology: |
|
1. Cyclical embedding averaging across corpus documents |
|
2. Named entity concatenation with averaged embeddings |
|
3. Approximation embedding document creation |
|
4. Corpus-level summary generation |
|
|
|
Args: |
|
flat_tokens: Named entities/keywords for concatenation |
|
cluster_abstracts: Corpus of related scientific documents |
|
cluster_name: Cluster identifier for error reporting |
|
|
|
Returns: |
|
Tuple of (corpus_overview, corpus_title, corpus_abstract) |
|
""" |
|
global model_inference |
|
|
|
if 'model_inference' not in globals(): |
|
try: |
|
model_inference = BSGCyLlamaInference() |
|
except Exception as e: |
|
print(f"โ ๏ธ Failed to load BSG CyLlama: {e}") |
|
model_inference = None |
|
|
|
if model_inference is not None and cluster_abstracts: |
|
try: |
|
|
|
print(f"๐ Processing corpus with {len(cluster_abstracts)} documents using cyclical averaging...") |
|
|
|
|
|
cyclical_embedding = model_inference.create_cluster_embedding(cluster_abstracts, flat_tokens) |
|
|
|
|
|
corpus_text = " | ".join(cluster_abstracts[:3]) if cluster_abstracts else "" |
|
abstract, overview, title = model_inference.generate_research_analysis(cyclical_embedding, corpus_text) |
|
|
|
print(f"โ
Generated corpus-level analysis for cluster {cluster_name}") |
|
return overview, title, abstract |
|
|
|
except Exception as e: |
|
print(f"โ ๏ธ BSG CyLlama cyclical generation failed for {cluster_name}: {e}, using fallback") |
|
|
|
|
|
try: |
|
title = f"Research on {', '.join(flat_tokens[:3])}" |
|
summary = f"Analysis of research focusing on {', '.join(flat_tokens[:10])}" |
|
abstract = f"Comprehensive investigation of {', '.join(flat_tokens[:5])} and related scientific topics" |
|
return summary, title, abstract |
|
except Exception as e: |
|
print(f"โ ๏ธ All generation methods failed for {cluster_name}: {e}") |
|
title = "Research Cluster Analysis" |
|
summary = "Research cluster analysis" |
|
abstract = "Comprehensive analysis of research cluster" |
|
return summary, title, abstract |
|
|
|
def demo_with_training_data(): |
|
"""Demonstrate BSG CyLlama using the training dataset""" |
|
print("๐ฌ BSG CyLlama Demo with Training Data") |
|
print("=" * 50) |
|
|
|
try: |
|
|
|
dataset_url = "https://huggingface.co/datasets/jimnoneill/BSG_CyLlama-training/resolve/main/bsg_training_data_complete_aligned.tsv" |
|
print(f"๐ Loading training dataset from: {dataset_url}") |
|
|
|
df = pd.read_csv(dataset_url, sep='\t', nrows=5) |
|
print(f"โ
Loaded {len(df)} sample records") |
|
|
|
|
|
print("\n๐ค Initializing BSG CyLlama...") |
|
model_inference = BSGCyLlamaInference() |
|
|
|
|
|
for i, row in df.head(2).iterrows(): |
|
print(f"\n๐ Sample {i+1}:") |
|
print("-" * 30) |
|
|
|
|
|
original_text = row['OriginalText'] if pd.notna(row['OriginalText']) else "" |
|
training_summary = row['AbstractSummary'] if pd.notna(row['AbstractSummary']) else "" |
|
keywords = str(row['TopKeywords']).split() if pd.notna(row['TopKeywords']) else [] |
|
|
|
print(f"Original Abstract: {original_text[:200]}...") |
|
print(f"Training Summary: {training_summary[:200]}...") |
|
|
|
|
|
cluster_abstracts = [original_text] if original_text else None |
|
overview, title, abstract = generate_cluster_content(keywords, cluster_abstracts, f"sample_{i}") |
|
|
|
print(f"\n๐ฎ Generated Results:") |
|
print(f"Title: {title}") |
|
print(f"Overview: {overview[:200]}...") |
|
print(f"Abstract: {abstract[:200]}...") |
|
|
|
print(f"\nโ
Demo completed successfully!") |
|
|
|
except Exception as e: |
|
print(f"โ Demo failed: {e}") |
|
print("๐ก Make sure you have internet access to download the model and dataset") |
|
|
|
def simple_summarization_demo(): |
|
"""Simple demonstration of text summarization""" |
|
print("\n๐ฌ Simple Summarization Demo") |
|
print("=" * 40) |
|
|
|
sample_text = """ |
|
Deep learning models have revolutionized medical image analysis by providing |
|
unprecedented accuracy in disease detection and diagnosis. Convolutional neural |
|
networks (CNNs) have been particularly successful in analyzing radiological |
|
images, including X-rays, CT scans, and MRI images. Recent advances in |
|
transformer architectures have further improved the ability to understand |
|
complex spatial relationships in medical imagery. These developments have |
|
significant implications for clinical practice, potentially reducing diagnostic |
|
errors and improving patient outcomes. |
|
""" |
|
|
|
try: |
|
model_inference = BSGCyLlamaInference() |
|
abstract, summary, title = model_inference.generate_research_analysis( |
|
source_text=sample_text |
|
) |
|
|
|
print(f"๐ Original Text: {sample_text.strip()[:200]}...") |
|
print(f"\n๐ฎ Generated Results:") |
|
print(f"Title: {title}") |
|
print(f"Summary: {summary}") |
|
print(f"Abstract: {abstract}") |
|
|
|
except Exception as e: |
|
print(f"โ Summarization failed: {e}") |
|
|
|
if __name__ == "__main__": |
|
print("๐ BSG CyLlama Demo Script") |
|
print("Specialized Scientific Summarization with gte-large Integration") |
|
print("=" * 60) |
|
|
|
|
|
try: |
|
|
|
demo_with_training_data() |
|
|
|
|
|
simple_summarization_demo() |
|
|
|
except KeyboardInterrupt: |
|
print("\nโน๏ธ Demo stopped by user") |
|
except Exception as e: |
|
print(f"\nโ Demo failed: {e}") |
|
print("๐ก Please ensure you have the required dependencies installed:") |
|
print(" pip install torch transformers peft sentence-transformers pandas") |
|
|
|
|
|
|
|
|