Spaces:
Running
Running
| import spacy | |
| import gradio as gr | |
| import json | |
| from typing import Dict, List, Tuple, Any | |
| from zshot import PipelineConfig | |
| from zshot.linker import LinkerSMXM | |
| from zshot.utils.data_models import Entity | |
| from spacy.cli import download | |
| download("en_core_web_sm") | |
| # Function to load the NER model | |
| def load_model(entity_data): | |
| entities = [ | |
| Entity( | |
| name=entity["name"], | |
| description=entity["description"], | |
| vocabulary=entity.get("vocabulary") | |
| ) for entity in entity_data | |
| ] | |
| nlp = spacy.blank("en") | |
| nlp_config = PipelineConfig( | |
| linker=LinkerSMXM(model_name="disi-unibo-nlp/openbioner-base"), | |
| entities=entities, | |
| device='cpu' # Change to 'cpu' if GPU not available | |
| ) | |
| nlp.add_pipe("zshot", config=nlp_config, last=True) | |
| return nlp | |
| # Default entities - focusing on BACTERIUM example | |
| default_entities = [ | |
| { | |
| "name": "BACTERIUM", | |
| "description": "A bacterium refers to a type of microorganism that can exist as a single cell and may cause infections or play a role in various biological processes. Examples include species like Streptococcus pneumoniae and Streptomyces ahygroscopicus.", | |
| } | |
| ] | |
| # Initialize model with default entities | |
| nlp = load_model(default_entities) | |
| # Function to create HTML visualization of entities | |
| def get_entity_html(doc) -> str: | |
| colors = { | |
| "BACTERIUM": "#8dd3c7", | |
| "CHEMICAL": "#fb8072", | |
| "DISEASE": "#80b1d3", | |
| "GENE": "#fdb462", | |
| "SPECIES": "#b3de69" | |
| } | |
| html_parts = [] | |
| last_idx = 0 | |
| # Display text with highlighted entities | |
| for ent in doc.ents: | |
| # Add text before the entity | |
| html_parts.append(doc.text[last_idx:ent.start_char]) | |
| # Add the highlighted entity | |
| color = colors.get(ent.label_, "#ddd") | |
| html_parts.append( | |
| f'<span style="background-color: {color}; padding: 0.2em 0.3em; ' | |
| f'border-radius: 0.35em; margin: 0 0.1em; font-weight: bold; color: #000;">' | |
| f'{doc.text[ent.start_char:ent.end_char]}' | |
| f'<span style="font-size: 0.8em; font-weight: bold; margin-left: 0.5em">{ent.label_}</span>' | |
| f'</span>' | |
| ) | |
| # Update the last index | |
| last_idx = ent.end_char | |
| # Add any remaining text | |
| html_parts.append(doc.text[last_idx:]) | |
| # Wrap the result in a div with dark theme styling | |
| return f'<div style="line-height: 1.5; padding: 10px; background: #222; color: #fff; border-radius: 5px;">{"".join(html_parts)}</div>' | |
| # Function to get entity details including spans | |
| def get_entity_details(doc) -> List[Dict[str, Any]]: | |
| entity_details = [] | |
| for ent in doc.ents: | |
| entity_details.append({ | |
| "text": ent.text, | |
| "type": ent.label_, | |
| "start": ent.start_char, | |
| "end": ent.end_char | |
| }) | |
| return entity_details | |
| # Main processing function | |
| def process_text(text: str, entities_json: str) -> Tuple[str, List[Dict[str, Any]]]: | |
| global nlp | |
| # Update model if entities have changed | |
| try: | |
| entities = json.loads(entities_json) | |
| nlp = load_model(entities) | |
| except json.JSONDecodeError: | |
| return "Error: Invalid JSON in entity configuration", [] | |
| # Process the text with the NER model | |
| doc = nlp(text) | |
| # Generate visualization HTML | |
| html_output = get_entity_html(doc) | |
| # Get detailed entity information including spans | |
| entity_details = get_entity_details(doc) | |
| return html_output, entity_details | |
| # Set theme to dark | |
| theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="slate", | |
| neutral_hue="slate", | |
| text_size=gr.themes.sizes.text_md, | |
| ).set( | |
| body_background_fill="#1a1a1a", | |
| background_fill_primary="#222", | |
| background_fill_secondary="#333", | |
| border_color_primary="#444", | |
| block_background_fill="#222", | |
| block_label_background_fill="#333", | |
| block_label_text_color="#fff", | |
| block_title_text_color="#fff", | |
| body_text_color="#fff", | |
| button_primary_background_fill="#2563eb", | |
| button_primary_text_color="#fff", | |
| input_background_fill="#333", | |
| input_border_color="#555", | |
| input_placeholder_color="#888", | |
| panel_background_fill="#222", | |
| slider_color="#2563eb", | |
| ) | |
| # Create Gradio interface with dark theme | |
| with gr.Blocks(title="Named Entity Recognition", theme=theme) as demo: | |
| gr.Markdown("# OpenBioNER - Demo") | |
| # First row: Entity Definitions | |
| with gr.Row(): | |
| entities_input = gr.Code( | |
| label="Entity Definitions (JSON)", | |
| language="json", | |
| value=json.dumps(default_entities, indent=2), | |
| lines=6 | |
| ) | |
| # Second row: Input text and examples side by side | |
| with gr.Row(): | |
| # Left side - Input text | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Text to analyze", | |
| placeholder="Enter text to analyze...", | |
| value="Impact of cofactor - binding loop mutations on thermotolerance and activity of E. coli transketolase", | |
| lines=3 | |
| ) | |
| analyze_btn = gr.Button("Analyze Text", variant="primary") | |
| # Right side - Example texts | |
| with gr.Column(): | |
| gr.Markdown("### Quick Examples") | |
| example1_btn = gr.Button("E. coli research") | |
| example2_btn = gr.Button("Bacterial infection case") | |
| example3_btn = gr.Button("Multiple bacterial species") | |
| # Third row: Output visualization and spans side by side | |
| with gr.Row(): | |
| # Left side - Highlighted text output | |
| with gr.Column(): | |
| gr.Markdown("### Recognized Entities") | |
| result_html = gr.HTML() | |
| # Right side - Entity spans details | |
| with gr.Column(): | |
| gr.Markdown("### Entity Details with Spans") | |
| entity_details = gr.JSON() | |
| # Set up event handlers for the analyze button | |
| analyze_btn.click( | |
| fn=process_text, | |
| inputs=[text_input, entities_input], | |
| outputs=[result_html, entity_details] | |
| ) | |
| # Set up event handlers for example buttons | |
| example1_btn.click( | |
| fn=lambda: "Impact of cofactor - binding loop mutations on thermotolerance and activity of E. coli transketolase", | |
| inputs=None, | |
| outputs=text_input | |
| ) | |
| example2_btn.click( | |
| fn=lambda: "The patient was diagnosed with pneumonia caused by Streptococcus pneumoniae and treated with antibiotics for 7 days.", | |
| inputs=None, | |
| outputs=text_input | |
| ) | |
| example3_btn.click( | |
| fn=lambda: "We compared growth rates of E. coli, B. subtilis and S. aureus in various media containing different carbon sources.", | |
| inputs=None, | |
| outputs=text_input | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |