| | """ |
| | Simple example usage of CosmicFish model (local model) |
| | """ |
| | import torch |
| | from transformers import GPT2Tokenizer |
| | from modeling_cosmicfish import CosmicFish, CosmicConfig |
| | from safetensors.torch import load_file |
| | import json |
| |
|
| | def load_cosmicfish(model_dir): |
| | """Load CosmicFish model and tokenizer""" |
| | |
| | with open(f"{model_dir}/config.json", "r") as f: |
| | config_dict = json.load(f) |
| |
|
| | |
| | config = CosmicConfig( |
| | vocab_size=config_dict["vocab_size"], |
| | block_size=config_dict["block_size"], |
| | n_layer=config_dict["n_layer"], |
| | n_head=config_dict["n_head"], |
| | n_embd=config_dict["n_embd"], |
| | bias=config_dict["bias"], |
| | dropout=0.0, |
| | use_rotary=config_dict["use_rotary"], |
| | use_swiglu=config_dict["use_swiglu"], |
| | use_gqa=config_dict["use_gqa"], |
| | n_query_groups=config_dict["n_query_groups"], |
| | use_qk_norm=config_dict.get("use_qk_norm", False) |
| | ) |
| |
|
| | |
| | model = CosmicFish(config) |
| | state_dict = load_file(f"{model_dir}/model.safetensors") |
| |
|
| | |
| | if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: |
| | state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] |
| |
|
| | model.load_state_dict(state_dict) |
| | model.eval() |
| |
|
| | |
| | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| | return model, tokenizer |
| |
|
| | def simple_generate(model, tokenizer, prompt, max_tokens=50, temperature=0.7): |
| | """Generate text from a prompt""" |
| | inputs = tokenizer.encode(prompt, return_tensors="pt") |
| |
|
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | inputs, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature, |
| | top_k=40 |
| | ) |
| |
|
| | return tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | if __name__ == "__main__": |
| | |
| | print("Loading CosmicFish...") |
| | model, tokenizer = load_cosmicfish("./") |
| | print(f"Model loaded! ({model.get_num_params()/1e6:.1f}M parameters)") |
| |
|
| | |
| | prompts = [ |
| | "What is climate change?", |
| | "Write a poem", |
| | "Define ML" |
| | ] |
| |
|
| | |
| | for prompt in prompts: |
| | print(f"\nPrompt: {prompt}") |
| | response = simple_generate(model, tokenizer, prompt, max_tokens=30) |
| | print(f"Response: {response}") |
| |
|