""" Helion-V2 Inference Script Provides optimized inference with various sampling strategies. """ import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import argparse from typing import Optional, List, Dict import time class HelionInference: """Inference wrapper for Helion-V2 model.""" def __init__( self, model_name: str = "DeepXR/Helion-V2", device: str = "auto", load_in_4bit: bool = False, load_in_8bit: bool = False, use_flash_attention: bool = True, ): """ Initialize the Helion-V2 model for inference. Args: model_name: HuggingFace model identifier device: Device placement ('auto', 'cuda', 'cpu') load_in_4bit: Use 4-bit quantization load_in_8bit: Use 8-bit quantization use_flash_attention: Enable Flash Attention 2 """ self.model_name = model_name self.device = device print(f"Loading tokenizer from {model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Configure quantization quantization_config = None if load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) elif load_in_8bit: quantization_config = BitsAndBytesConfig(load_in_8bit=True) print(f"Loading model from {model_name}...") model_kwargs = { "device_map": device, "torch_dtype": torch.float16, "quantization_config": quantization_config, } if use_flash_attention and not (load_in_4bit or load_in_8bit): model_kwargs["attn_implementation"] = "flash_attention_2" self.model = AutoModelForCausalLM.from_pretrained( model_name, **model_kwargs ) self.model.eval() print("Model loaded successfully!") def generate( self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.1, do_sample: bool = True, num_return_sequences: int = 1, ) -> List[str]: """ Generate text from a prompt. Args: prompt: Input text prompt max_new_tokens: Maximum tokens to generate temperature: Sampling temperature (higher = more random) top_p: Nucleus sampling threshold top_k: Top-k sampling parameter repetition_penalty: Penalty for repeating tokens do_sample: Use sampling vs greedy decoding num_return_sequences: Number of sequences to generate Returns: List of generated text strings """ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) start_time = time.time() with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=do_sample, num_return_sequences=num_return_sequences, pad_token_id=self.tokenizer.eos_token_id, ) generation_time = time.time() - start_time tokens_generated = outputs.shape[1] - inputs["input_ids"].shape[1] tokens_per_second = tokens_generated / generation_time results = [] for output in outputs: text = self.tokenizer.decode(output, skip_special_tokens=True) results.append(text) print(f"\nGeneration stats:") print(f" Tokens generated: {tokens_generated}") print(f" Time: {generation_time:.2f}s") print(f" Speed: {tokens_per_second:.2f} tokens/s") return results def chat( self, messages: List[Dict[str, str]], max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, **kwargs ) -> str: """ Generate response in chat format. Args: messages: List of message dicts with 'role' and 'content' max_new_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Nucleus sampling threshold **kwargs: Additional generation parameters Returns: Generated response text """ input_text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) results = self.generate( input_text, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, **kwargs ) # Extract only the assistant's response full_text = results[0] if "<|assistant|>" in full_text: response = full_text.split("<|assistant|>")[-1].split("<|end|>")[0].strip() else: response = full_text[len(input_text):].strip() return response def main(): parser = argparse.ArgumentParser(description="Helion-V2 Inference") parser.add_argument( "--model", type=str, default="DeepXR/Helion-V2", help="Model name or path" ) parser.add_argument( "--prompt", type=str, required=True, help="Input prompt" ) parser.add_argument( "--max-tokens", type=int, default=512, help="Maximum tokens to generate" ) parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature" ) parser.add_argument( "--top-p", type=float, default=0.9, help="Nucleus sampling threshold" ) parser.add_argument( "--top-k", type=int, default=50, help="Top-k sampling" ) parser.add_argument( "--repetition-penalty", type=float, default=1.1, help="Repetition penalty" ) parser.add_argument( "--load-in-4bit", action="store_true", help="Load model in 4-bit precision" ) parser.add_argument( "--load-in-8bit", action="store_true", help="Load model in 8-bit precision" ) parser.add_argument( "--device", type=str, default="auto", help="Device placement" ) parser.add_argument( "--chat-mode", action="store_true", help="Use chat format" ) args = parser.parse_args() # Initialize model inference = HelionInference( model_name=args.model, device=args.device, load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, ) # Generate response if args.chat_mode: messages = [ {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": args.prompt} ] response = inference.chat( messages, max_new_tokens=args.max_tokens, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, ) print(f"\nAssistant: {response}") else: results = inference.generate( args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, ) print(f"\nGenerated text:\n{results[0]}") if __name__ == "__main__": main()