| """ |
| Test script to verify 250K context length support |
| Tests RoPE scaling and long context handling |
| """ |
|
|
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| import logging |
| from typing import Optional |
| import time |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LongContextTester: |
| """Test long context capabilities of Helion-OSC""" |
| |
| def __init__(self, model_path: str = "./inference"): |
| """ |
| Initialize tester |
| |
| Args: |
| model_path: Path to model inference directory |
| """ |
| self.model_path = model_path |
| logger.info("Loading model configuration...") |
| |
| |
| self.config = AutoConfig.from_pretrained(model_path) |
| |
| |
| max_pos = self.config.max_position_embeddings |
| logger.info(f"Model max position embeddings: {max_pos:,}") |
| |
| if max_pos < 250000: |
| logger.warning(f"Context length ({max_pos:,}) is less than 250K!") |
| else: |
| logger.info(f"✓ Context length supports 250K+ tokens ({max_pos:,})") |
| |
| |
| rope_scaling = getattr(self.config, 'rope_scaling', None) |
| rope_theta = getattr(self.config, 'rope_theta', None) |
| |
| if rope_scaling: |
| logger.info(f"RoPE Scaling: {rope_scaling}") |
| if rope_theta: |
| logger.info(f"RoPE Theta: {rope_theta:,}") |
| |
| def test_tokenization_capacity(self, tokenizer_path: str = "DeepXR/Helion-OSC"): |
| """Test that tokenizer supports long sequences""" |
| logger.info("\n" + "="*80) |
| logger.info("TEST 1: Tokenizer Capacity") |
| logger.info("="*80) |
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| |
| max_length = tokenizer.model_max_length |
| logger.info(f"Tokenizer max length: {max_length:,}") |
| |
| if max_length >= 250000: |
| logger.info("✓ Tokenizer supports 250K+ tokens") |
| else: |
| logger.warning(f"✗ Tokenizer max length only {max_length:,}") |
| |
| |
| test_tokens = 10000 |
| test_text = "Hello world! " * (test_tokens // 2) |
| |
| logger.info(f"Testing tokenization of ~{test_tokens:,} tokens...") |
| encoded = tokenizer(test_text, return_tensors="pt", truncation=False) |
| actual_tokens = encoded['input_ids'].shape[1] |
| |
| logger.info(f"Successfully tokenized {actual_tokens:,} tokens") |
| logger.info("✓ Tokenization test passed") |
| |
| return True |
| |
| except Exception as e: |
| logger.error(f"✗ Tokenization test failed: {e}") |
| return False |
| |
| def test_position_embeddings(self): |
| """Test position embedding capacity""" |
| logger.info("\n" + "="*80) |
| logger.info("TEST 2: Position Embeddings") |
| logger.info("="*80) |
| |
| max_pos = self.config.max_position_embeddings |
| hidden_size = self.config.hidden_size |
| |
| logger.info(f"Max positions: {max_pos:,}") |
| logger.info(f"Hidden size: {hidden_size:,}") |
| |
| |
| if hasattr(self.config, 'rope_theta'): |
| logger.info("Using RoPE (Rotary Position Embeddings)") |
| logger.info("✓ RoPE scales efficiently to long contexts") |
| |
| |
| logger.info(f"RoPE Theta: {self.config.rope_theta:,}") |
| |
| if hasattr(self.config, 'rope_scaling'): |
| scaling = self.config.rope_scaling |
| logger.info(f"RoPE Scaling Configuration:") |
| logger.info(f" Type: {scaling.get('type', 'N/A')}") |
| logger.info(f" Factor: {scaling.get('factor', 'N/A')}") |
| |
| if scaling.get('factor', 0) >= 32: |
| logger.info("✓ RoPE scaling factor supports 250K+ context (32x from 8K base)") |
| else: |
| logger.warning("✗ RoPE scaling factor may be insufficient") |
| |
| return True |
| else: |
| |
| pos_emb_size = max_pos * hidden_size * 2 |
| pos_emb_gb = pos_emb_size / (1024**3) |
| logger.info(f"Position embedding size: {pos_emb_gb:.2f} GB") |
| |
| if max_pos >= 250000: |
| logger.info("✓ Sufficient position embeddings for 250K context") |
| return True |
| else: |
| logger.warning("✗ Insufficient position embeddings") |
| return False |
| |
| def test_attention_computation(self, sequence_lengths: list = [1024, 8192, 32768, 131072]): |
| """Test attention computation at various lengths""" |
| logger.info("\n" + "="*80) |
| logger.info("TEST 3: Attention Computation Scaling") |
| logger.info("="*80) |
| |
| hidden_size = self.config.hidden_size |
| num_heads = self.config.num_attention_heads |
| head_dim = hidden_size // num_heads |
| |
| logger.info(f"Attention heads: {num_heads}") |
| logger.info(f"Head dimension: {head_dim}") |
| |
| for seq_len in sequence_lengths: |
| |
| |
| attn_size = 1 * num_heads * seq_len * seq_len * 2 |
| attn_gb = attn_size / (1024**3) |
| |
| logger.info(f"\nSequence length: {seq_len:,} tokens") |
| logger.info(f" Attention matrix: {attn_gb:.2f} GB") |
| |
| if seq_len <= 32768: |
| logger.info(f" ✓ Manageable size") |
| elif seq_len <= 131072: |
| logger.info(f" ⚠ Large - may need Flash Attention") |
| else: |
| logger.info(f" ⚠ Very large - requires optimizations") |
| |
| |
| use_flash = getattr(self.config, 'use_flash_attention_2', False) |
| if use_flash: |
| logger.info("\n✓ Flash Attention 2 enabled - efficient for long contexts") |
| else: |
| logger.warning("\n⚠ Flash Attention not configured - may be slow for long contexts") |
| |
| return True |
| |
| def test_memory_requirements(self): |
| """Calculate memory requirements for 250K context""" |
| logger.info("\n" + "="*80) |
| logger.info("TEST 4: Memory Requirements") |
| logger.info("="*80) |
| |
| context_length = 250000 |
| batch_size = 1 |
| hidden_size = self.config.hidden_size |
| num_layers = self.config.num_hidden_layers |
| |
| logger.info(f"Configuration:") |
| logger.info(f" Context: {context_length:,} tokens") |
| logger.info(f" Batch size: {batch_size}") |
| logger.info(f" Hidden size: {hidden_size:,}") |
| logger.info(f" Layers: {num_layers}") |
| |
| |
| |
| hidden_states_size = batch_size * context_length * hidden_size * 2 |
| hidden_states_gb = hidden_states_size / (1024**3) |
| |
| |
| layer_memory_gb = hidden_states_gb * 2 |
| total_activation_gb = layer_memory_gb * num_layers |
| |
| logger.info(f"\nMemory estimates:") |
| logger.info(f" Hidden states per layer: {hidden_states_gb:.2f} GB") |
| logger.info(f" Total activation memory: {total_activation_gb:.2f} GB") |
| logger.info(f" Model weights: ~349 GB") |
| logger.info(f" Total (weights + activations): ~{349 + total_activation_gb:.2f} GB") |
| |
| logger.info(f"\nRecommendations:") |
| if total_activation_gb < 50: |
| logger.info(" ✓ Should fit on 8x A100 (80GB) GPUs") |
| elif total_activation_gb < 100: |
| logger.info(" ⚠ May need gradient checkpointing") |
| else: |
| logger.info(" ⚠ Will need aggressive optimizations (checkpointing, CPU offload)") |
| |
| return True |
| |
| def test_rope_frequencies(self): |
| """Test RoPE frequency calculations for long context""" |
| logger.info("\n" + "="*80) |
| logger.info("TEST 5: RoPE Frequency Analysis") |
| logger.info("="*80) |
| |
| rope_theta = getattr(self.config, 'rope_theta', 10000) |
| hidden_size = self.config.hidden_size |
| num_heads = self.config.num_attention_heads |
| head_dim = hidden_size // num_heads |
| |
| logger.info(f"RoPE theta: {rope_theta:,}") |
| logger.info(f"Head dimension: {head_dim}") |
| |
| |
| |
| min_freq = rope_theta ** (-2 * (head_dim-1) / head_dim) |
| max_freq = rope_theta ** 0 |
| |
| logger.info(f"Frequency range: [{min_freq:.6f}, {max_freq:.6f}]") |
| |
| |
| wavelengths = [2 * 3.14159 / (rope_theta ** (-2 * i / head_dim)) |
| for i in range(0, head_dim // 2, head_dim // 8)] |
| |
| logger.info(f"\nWavelengths (in tokens):") |
| for i, wl in enumerate(wavelengths): |
| logger.info(f" Frequency {i}: {wl:,.0f} tokens") |
| |
| max_wavelength = max(wavelengths) |
| if max_wavelength >= 250000: |
| logger.info(f"\n✓ Maximum wavelength ({max_wavelength:,.0f}) supports 250K context") |
| else: |
| logger.warning(f"\n⚠ Maximum wavelength ({max_wavelength:,.0f}) may be insufficient") |
| |
| return True |
| |
| def run_all_tests(self): |
| """Run all context length tests""" |
| logger.info("\n" + "="*80) |
| logger.info("HELION-OSC 250K CONTEXT LENGTH TEST SUITE") |
| logger.info("="*80) |
| |
| results = { |
| "tokenization": self.test_tokenization_capacity(), |
| "position_embeddings": self.test_position_embeddings(), |
| "attention_scaling": self.test_attention_computation(), |
| "memory_requirements": self.test_memory_requirements(), |
| "rope_frequencies": self.test_rope_frequencies() |
| } |
| |
| |
| logger.info("\n" + "="*80) |
| logger.info("TEST SUMMARY") |
| logger.info("="*80) |
| |
| for test_name, passed in results.items(): |
| status = "✓ PASS" if passed else "✗ FAIL" |
| logger.info(f"{test_name}: {status}") |
| |
| all_passed = all(results.values()) |
| |
| if all_passed: |
| logger.info("\n✓ All tests passed - Model supports 250K context length") |
| else: |
| logger.warning("\n⚠ Some tests failed - Check configuration") |
| |
| return all_passed |
|
|
|
|
| def main(): |
| """Main test script""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Test Helion-OSC 250K context support") |
| parser.add_argument( |
| "--model-path", |
| type=str, |
| default="./inference", |
| help="Path to model inference directory" |
| ) |
| parser.add_argument( |
| "--test", |
| choices=["all", "tokenization", "position", "attention", "memory", "rope"], |
| default="all", |
| help="Which test to run" |
| ) |
| |
| args = parser.parse_args() |
| |
| tester = LongContextTester(args.model_path) |
| |
| if args.test == "all": |
| tester.run_all_tests() |
| elif args.test == "tokenization": |
| tester.test_tokenization_capacity() |
| elif args.test == "position": |
| tester.test_position_embeddings() |
| elif args.test == "attention": |
| tester.test_attention_computation() |
| elif args.test == "memory": |
| tester.test_memory_requirements() |
| elif args.test == "rope": |
| tester.test_rope_frequencies() |
|
|
|
|
| if __name__ == "__main__": |
| main() |