Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| os.environ['MKL_THREADING_LAYER'] = 'GNU' | |
| import spaces | |
| import torch | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList | |
| from .prompts import format_rag_prompt | |
| from .shared import generation_interrupt | |
| models = { | |
| "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct", | |
| "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", | |
| "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct", | |
| "Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct", | |
| #"Gemma-3-1b-it": "google/gemma-3-1b-it", | |
| #"Gemma-3-4b-it": "google/gemma-3-4b-it", | |
| "Gemma-2-2b-it": "google/gemma-2-2b-it", | |
| "Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct", | |
| "Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b", | |
| "IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct", | |
| # #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T", | |
| # #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA", | |
| "Qwen3-0.6b": "qwen/qwen3-0.6b", | |
| "Qwen3-1.7b": "qwen/qwen3-1.7b", | |
| "Qwen3-4b": "qwen/qwen3-4b", | |
| "SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
| "EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct", | |
| "OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct", | |
| } | |
| tokenizer_cache = {} | |
| # List of model names for easy access | |
| model_names = list(models.keys()) | |
| # Custom stopping criteria that checks the interrupt flag | |
| class InterruptCriteria(StoppingCriteria): | |
| def __init__(self, interrupt_event): | |
| self.interrupt_event = interrupt_event | |
| def __call__(self, input_ids, scores, **kwargs): | |
| return self.interrupt_event.is_set() | |
| def generate_summaries(example, model_a_name, model_b_name): | |
| """ | |
| Generates summaries for the given example using the assigned models sequentially. | |
| """ | |
| if generation_interrupt.is_set(): | |
| return "", "" | |
| context_text = "" | |
| context_parts = [] | |
| if "full_contexts" in example and example["full_contexts"]: | |
| for i, ctx in enumerate(example["full_contexts"]): | |
| content = "" | |
| # Extract content from either dict or string | |
| if isinstance(ctx, dict) and "content" in ctx: | |
| content = ctx["content"] | |
| elif isinstance(ctx, str): | |
| content = ctx | |
| # Add document number if not already present | |
| if not content.strip().startswith("Document"): | |
| content = f"Document {i+1}:\n{content}" | |
| context_parts.append(content) | |
| context_text = "\n\n".join(context_parts) | |
| else: | |
| # Provide a graceful fallback instead of raising an error | |
| print("Warning: No full context found in the example, using empty context") | |
| context_text = "" | |
| question = example.get("question", "") | |
| if generation_interrupt.is_set(): | |
| return "", "" | |
| # Run model A | |
| summary_a = run_inference(models[model_a_name], context_text, question) | |
| if generation_interrupt.is_set(): | |
| return summary_a, "" | |
| # Run model B | |
| summary_b = run_inference(models[model_b_name], context_text, question) | |
| return summary_a, summary_b | |
| def run_inference(model_name, context, question): | |
| """ | |
| Run inference using the specified model. | |
| Returns the generated text or empty string if interrupted. | |
| """ | |
| # Check interrupt at the beginning | |
| if generation_interrupt.is_set(): | |
| return "" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| result = "" | |
| tokenizer_kwargs = { | |
| "add_generation_prompt": True, | |
| } # make sure qwen3 doesn't use thinking | |
| generation_kwargs = { | |
| "max_new_tokens": 512, | |
| } | |
| if "qwen3" in model_name.lower(): | |
| print(f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False.") | |
| tokenizer_kwargs["enable_thinking"] = False | |
| try: | |
| if model_name in tokenizer_cache: | |
| tokenizer = tokenizer_cache[model_name] | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| padding_side="left", | |
| token=True, | |
| kwargs=tokenizer_kwargs | |
| ) | |
| tokenizer_cache[model_name] = tokenizer | |
| accepts_sys = ( | |
| "System role not supported" not in tokenizer.chat_template | |
| if tokenizer.chat_template else False # Handle missing chat_template | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Check interrupt before loading the model | |
| if generation_interrupt.is_set(): | |
| return "" | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| device_map='cuda', | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| model_kwargs={ | |
| "attn_implementation": "eager", | |
| } | |
| ) | |
| text_input = format_rag_prompt(question, context, accepts_sys) | |
| if "Gemma-3".lower() not in model_name.lower(): | |
| formatted = pipe.tokenizer.apply_chat_template( | |
| text_input, | |
| tokenize=False, | |
| **tokenizer_kwargs, | |
| ) | |
| input_length = len(formatted) | |
| # Check interrupt before generation | |
| outputs = pipe(formatted, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True}) | |
| #print(outputs[0]['generated_text']) | |
| result = outputs[0]['generated_text'][input_length:] | |
| else: # don't use apply chat template? I don't know why gemma keeps breaking | |
| result = pipe(text_input, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})[0]['generated_text'] | |
| result = result[0]['generated_text'][-1]['content'] | |
| except Exception as e: | |
| print(f"Error in inference for {model_name}: {e}") | |
| result = f"Error generating response: {str(e)[:200]}..." | |
| finally: | |
| # Clean up resources | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return result |