Upload OLMoForCausalLM
Browse files- README.md +199 -0
- config.json +56 -0
- configuration_olmo.py +43 -0
- generation_config.json +6 -0
- model-00001-of-00006.safetensors +3 -0
- model-00002-of-00006.safetensors +3 -0
- model-00003-of-00006.safetensors +3 -0
- model-00004-of-00006.safetensors +3 -0
- model-00005-of-00006.safetensors +3 -0
- model-00006-of-00006.safetensors +3 -0
- model.safetensors.index.json +137 -0
- modeling_olmo.py +228 -0
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/home/itay.itzhak/projects/proj2/finetuning/open-instruct/output/allenai/tulu-v2-sft-mixture_allenai/OLMo-7B_lora_r128_alpha256_LR2e-5_seed_1/merged",
|
| 3 |
+
"activation_type": "swiglu",
|
| 4 |
+
"alibi": false,
|
| 5 |
+
"alibi_bias_max": 8.0,
|
| 6 |
+
"architectures": [
|
| 7 |
+
"OLMoForCausalLM"
|
| 8 |
+
],
|
| 9 |
+
"attention_dropout": 0.0,
|
| 10 |
+
"attention_layer_norm": false,
|
| 11 |
+
"attention_layer_norm_with_affine": false,
|
| 12 |
+
"auto_map": {
|
| 13 |
+
"AutoConfig": "configuration_olmo.OLMoConfig",
|
| 14 |
+
"AutoModelForCausalLM": "modeling_olmo.OLMoForCausalLM",
|
| 15 |
+
"AutoTokenizer": [
|
| 16 |
+
"allenai/OLMo-7B--tokenization_olmo_fast.OLMoTokenizerFast",
|
| 17 |
+
"allenai/OLMo-7B--tokenization_olmo_fast.OLMoTokenizerFast"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
"bias_for_layer_norm": false,
|
| 21 |
+
"block_group_size": 1,
|
| 22 |
+
"block_type": "sequential",
|
| 23 |
+
"clip_qkv": null,
|
| 24 |
+
"d_model": 4096,
|
| 25 |
+
"embedding_dropout": 0.0,
|
| 26 |
+
"embedding_size": 50304,
|
| 27 |
+
"eos_token_id": 50279,
|
| 28 |
+
"flash_attention": true,
|
| 29 |
+
"include_bias": false,
|
| 30 |
+
"init_cutoff_factor": null,
|
| 31 |
+
"init_device": "meta",
|
| 32 |
+
"init_fn": "mitchell",
|
| 33 |
+
"init_std": 0.02,
|
| 34 |
+
"layer_norm_eps": 1e-05,
|
| 35 |
+
"layer_norm_type": "default",
|
| 36 |
+
"layer_norm_with_affine": false,
|
| 37 |
+
"max_sequence_length": 2048,
|
| 38 |
+
"mlp_hidden_size": 22016,
|
| 39 |
+
"mlp_ratio": 4,
|
| 40 |
+
"model_type": "hf_olmo",
|
| 41 |
+
"multi_query_attention": false,
|
| 42 |
+
"n_heads": 32,
|
| 43 |
+
"n_kv_heads": null,
|
| 44 |
+
"n_layers": 32,
|
| 45 |
+
"pad_token_id": 1,
|
| 46 |
+
"precision": "amp_bf16",
|
| 47 |
+
"residual_dropout": 0.0,
|
| 48 |
+
"rope": true,
|
| 49 |
+
"rope_full_precision": true,
|
| 50 |
+
"scale_logits": false,
|
| 51 |
+
"torch_dtype": "float32",
|
| 52 |
+
"transformers_version": "4.42.4",
|
| 53 |
+
"use_cache": true,
|
| 54 |
+
"vocab_size": 50280,
|
| 55 |
+
"weight_tying": false
|
| 56 |
+
}
|
configuration_olmo.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OLMo configuration
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 6 |
+
from transformers.utils import logging
|
| 7 |
+
|
| 8 |
+
from olmo.config import ModelConfig
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OLMoConfig(PretrainedConfig):
|
| 14 |
+
model_type = "hf_olmo"
|
| 15 |
+
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
| 16 |
+
|
| 17 |
+
def __init__(self, use_cache: bool = False, **kwargs):
|
| 18 |
+
model_config = ModelConfig()
|
| 19 |
+
all_kwargs = model_config.asdict()
|
| 20 |
+
all_kwargs.update(kwargs)
|
| 21 |
+
all_kwargs.update({"use_cache": use_cache})
|
| 22 |
+
all_kwargs.update(
|
| 23 |
+
{"architectures": all_kwargs.get("architectures", ["OLMoForCausalLM"]) or ["OLMoForCausalLM"]}
|
| 24 |
+
)
|
| 25 |
+
super().__init__(**all_kwargs)
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def num_attention_heads(self):
|
| 29 |
+
return self.n_heads
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def num_hidden_layers(self):
|
| 33 |
+
return self.n_layers
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def hidden_size(self):
|
| 37 |
+
return self.d_model
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Register the config class so that it is available for transformer pipelines, auto-loading etc.
|
| 41 |
+
# OLMo is integrated directly in transformers from v4.40.0 onwards, but the version in transformers
|
| 42 |
+
# may not support the newest architectures we create.
|
| 43 |
+
AutoConfig.register("hf_olmo", OLMoConfig)
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"eos_token_id": 50279,
|
| 4 |
+
"pad_token_id": 1,
|
| 5 |
+
"transformers_version": "4.42.4"
|
| 6 |
+
}
|
model-00001-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f293930eb7caea3c358d578dabfc062ef209405d0c9e127867bc13c27a6aff4
|
| 3 |
+
size 4938795616
|
model-00002-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a166764d29ff96f88677ee8aff712c3cae35f4a93682d7b2fb8689f762f2c089
|
| 3 |
+
size 4857006944
|
model-00003-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be1ab9adc2940a3e3f96ad4b75b651f25cd14d7e3a289a30933b3d5e6967376e
|
| 3 |
+
size 4857006960
|
model-00004-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47be792f0e8d526ae36084958aa89f470dd054d4db3342e0159ae6b68405a201
|
| 3 |
+
size 4857006960
|
model-00005-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1bd25c8b6a2e982e440b30e2eb0c4aee84a31c76337c8ac22291039e92e89b14
|
| 3 |
+
size 4857006960
|
model-00006-of-00006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d1c060b0d24a50d9edceb38cc51f982ae8aa68356afa249c001df9fb38ab819
|
| 3 |
+
size 3185575352
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 27552382976
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"model.transformer.blocks.0.att_proj.weight": "model-00001-of-00006.safetensors",
|
| 7 |
+
"model.transformer.blocks.0.attn_out.weight": "model-00001-of-00006.safetensors",
|
| 8 |
+
"model.transformer.blocks.0.ff_out.weight": "model-00001-of-00006.safetensors",
|
| 9 |
+
"model.transformer.blocks.0.ff_proj.weight": "model-00001-of-00006.safetensors",
|
| 10 |
+
"model.transformer.blocks.1.att_proj.weight": "model-00001-of-00006.safetensors",
|
| 11 |
+
"model.transformer.blocks.1.attn_out.weight": "model-00001-of-00006.safetensors",
|
| 12 |
+
"model.transformer.blocks.1.ff_out.weight": "model-00001-of-00006.safetensors",
|
| 13 |
+
"model.transformer.blocks.1.ff_proj.weight": "model-00001-of-00006.safetensors",
|
| 14 |
+
"model.transformer.blocks.10.att_proj.weight": "model-00002-of-00006.safetensors",
|
| 15 |
+
"model.transformer.blocks.10.attn_out.weight": "model-00002-of-00006.safetensors",
|
| 16 |
+
"model.transformer.blocks.10.ff_out.weight": "model-00002-of-00006.safetensors",
|
| 17 |
+
"model.transformer.blocks.10.ff_proj.weight": "model-00002-of-00006.safetensors",
|
| 18 |
+
"model.transformer.blocks.11.att_proj.weight": "model-00003-of-00006.safetensors",
|
| 19 |
+
"model.transformer.blocks.11.attn_out.weight": "model-00002-of-00006.safetensors",
|
| 20 |
+
"model.transformer.blocks.11.ff_out.weight": "model-00003-of-00006.safetensors",
|
| 21 |
+
"model.transformer.blocks.11.ff_proj.weight": "model-00003-of-00006.safetensors",
|
| 22 |
+
"model.transformer.blocks.12.att_proj.weight": "model-00003-of-00006.safetensors",
|
| 23 |
+
"model.transformer.blocks.12.attn_out.weight": "model-00003-of-00006.safetensors",
|
| 24 |
+
"model.transformer.blocks.12.ff_out.weight": "model-00003-of-00006.safetensors",
|
| 25 |
+
"model.transformer.blocks.12.ff_proj.weight": "model-00003-of-00006.safetensors",
|
| 26 |
+
"model.transformer.blocks.13.att_proj.weight": "model-00003-of-00006.safetensors",
|
| 27 |
+
"model.transformer.blocks.13.attn_out.weight": "model-00003-of-00006.safetensors",
|
| 28 |
+
"model.transformer.blocks.13.ff_out.weight": "model-00003-of-00006.safetensors",
|
| 29 |
+
"model.transformer.blocks.13.ff_proj.weight": "model-00003-of-00006.safetensors",
|
| 30 |
+
"model.transformer.blocks.14.att_proj.weight": "model-00003-of-00006.safetensors",
|
| 31 |
+
"model.transformer.blocks.14.attn_out.weight": "model-00003-of-00006.safetensors",
|
| 32 |
+
"model.transformer.blocks.14.ff_out.weight": "model-00003-of-00006.safetensors",
|
| 33 |
+
"model.transformer.blocks.14.ff_proj.weight": "model-00003-of-00006.safetensors",
|
| 34 |
+
"model.transformer.blocks.15.att_proj.weight": "model-00003-of-00006.safetensors",
|
| 35 |
+
"model.transformer.blocks.15.attn_out.weight": "model-00003-of-00006.safetensors",
|
| 36 |
+
"model.transformer.blocks.15.ff_out.weight": "model-00003-of-00006.safetensors",
|
| 37 |
+
"model.transformer.blocks.15.ff_proj.weight": "model-00003-of-00006.safetensors",
|
| 38 |
+
"model.transformer.blocks.16.att_proj.weight": "model-00003-of-00006.safetensors",
|
| 39 |
+
"model.transformer.blocks.16.attn_out.weight": "model-00003-of-00006.safetensors",
|
| 40 |
+
"model.transformer.blocks.16.ff_out.weight": "model-00003-of-00006.safetensors",
|
| 41 |
+
"model.transformer.blocks.16.ff_proj.weight": "model-00003-of-00006.safetensors",
|
| 42 |
+
"model.transformer.blocks.17.att_proj.weight": "model-00004-of-00006.safetensors",
|
| 43 |
+
"model.transformer.blocks.17.attn_out.weight": "model-00003-of-00006.safetensors",
|
| 44 |
+
"model.transformer.blocks.17.ff_out.weight": "model-00004-of-00006.safetensors",
|
| 45 |
+
"model.transformer.blocks.17.ff_proj.weight": "model-00004-of-00006.safetensors",
|
| 46 |
+
"model.transformer.blocks.18.att_proj.weight": "model-00004-of-00006.safetensors",
|
| 47 |
+
"model.transformer.blocks.18.attn_out.weight": "model-00004-of-00006.safetensors",
|
| 48 |
+
"model.transformer.blocks.18.ff_out.weight": "model-00004-of-00006.safetensors",
|
| 49 |
+
"model.transformer.blocks.18.ff_proj.weight": "model-00004-of-00006.safetensors",
|
| 50 |
+
"model.transformer.blocks.19.att_proj.weight": "model-00004-of-00006.safetensors",
|
| 51 |
+
"model.transformer.blocks.19.attn_out.weight": "model-00004-of-00006.safetensors",
|
| 52 |
+
"model.transformer.blocks.19.ff_out.weight": "model-00004-of-00006.safetensors",
|
| 53 |
+
"model.transformer.blocks.19.ff_proj.weight": "model-00004-of-00006.safetensors",
|
| 54 |
+
"model.transformer.blocks.2.att_proj.weight": "model-00001-of-00006.safetensors",
|
| 55 |
+
"model.transformer.blocks.2.attn_out.weight": "model-00001-of-00006.safetensors",
|
| 56 |
+
"model.transformer.blocks.2.ff_out.weight": "model-00001-of-00006.safetensors",
|
| 57 |
+
"model.transformer.blocks.2.ff_proj.weight": "model-00001-of-00006.safetensors",
|
| 58 |
+
"model.transformer.blocks.20.att_proj.weight": "model-00004-of-00006.safetensors",
|
| 59 |
+
"model.transformer.blocks.20.attn_out.weight": "model-00004-of-00006.safetensors",
|
| 60 |
+
"model.transformer.blocks.20.ff_out.weight": "model-00004-of-00006.safetensors",
|
| 61 |
+
"model.transformer.blocks.20.ff_proj.weight": "model-00004-of-00006.safetensors",
|
| 62 |
+
"model.transformer.blocks.21.att_proj.weight": "model-00004-of-00006.safetensors",
|
| 63 |
+
"model.transformer.blocks.21.attn_out.weight": "model-00004-of-00006.safetensors",
|
| 64 |
+
"model.transformer.blocks.21.ff_out.weight": "model-00004-of-00006.safetensors",
|
| 65 |
+
"model.transformer.blocks.21.ff_proj.weight": "model-00004-of-00006.safetensors",
|
| 66 |
+
"model.transformer.blocks.22.att_proj.weight": "model-00004-of-00006.safetensors",
|
| 67 |
+
"model.transformer.blocks.22.attn_out.weight": "model-00004-of-00006.safetensors",
|
| 68 |
+
"model.transformer.blocks.22.ff_out.weight": "model-00004-of-00006.safetensors",
|
| 69 |
+
"model.transformer.blocks.22.ff_proj.weight": "model-00004-of-00006.safetensors",
|
| 70 |
+
"model.transformer.blocks.23.att_proj.weight": "model-00005-of-00006.safetensors",
|
| 71 |
+
"model.transformer.blocks.23.attn_out.weight": "model-00004-of-00006.safetensors",
|
| 72 |
+
"model.transformer.blocks.23.ff_out.weight": "model-00005-of-00006.safetensors",
|
| 73 |
+
"model.transformer.blocks.23.ff_proj.weight": "model-00005-of-00006.safetensors",
|
| 74 |
+
"model.transformer.blocks.24.att_proj.weight": "model-00005-of-00006.safetensors",
|
| 75 |
+
"model.transformer.blocks.24.attn_out.weight": "model-00005-of-00006.safetensors",
|
| 76 |
+
"model.transformer.blocks.24.ff_out.weight": "model-00005-of-00006.safetensors",
|
| 77 |
+
"model.transformer.blocks.24.ff_proj.weight": "model-00005-of-00006.safetensors",
|
| 78 |
+
"model.transformer.blocks.25.att_proj.weight": "model-00005-of-00006.safetensors",
|
| 79 |
+
"model.transformer.blocks.25.attn_out.weight": "model-00005-of-00006.safetensors",
|
| 80 |
+
"model.transformer.blocks.25.ff_out.weight": "model-00005-of-00006.safetensors",
|
| 81 |
+
"model.transformer.blocks.25.ff_proj.weight": "model-00005-of-00006.safetensors",
|
| 82 |
+
"model.transformer.blocks.26.att_proj.weight": "model-00005-of-00006.safetensors",
|
| 83 |
+
"model.transformer.blocks.26.attn_out.weight": "model-00005-of-00006.safetensors",
|
| 84 |
+
"model.transformer.blocks.26.ff_out.weight": "model-00005-of-00006.safetensors",
|
| 85 |
+
"model.transformer.blocks.26.ff_proj.weight": "model-00005-of-00006.safetensors",
|
| 86 |
+
"model.transformer.blocks.27.att_proj.weight": "model-00005-of-00006.safetensors",
|
| 87 |
+
"model.transformer.blocks.27.attn_out.weight": "model-00005-of-00006.safetensors",
|
| 88 |
+
"model.transformer.blocks.27.ff_out.weight": "model-00005-of-00006.safetensors",
|
| 89 |
+
"model.transformer.blocks.27.ff_proj.weight": "model-00005-of-00006.safetensors",
|
| 90 |
+
"model.transformer.blocks.28.att_proj.weight": "model-00005-of-00006.safetensors",
|
| 91 |
+
"model.transformer.blocks.28.attn_out.weight": "model-00005-of-00006.safetensors",
|
| 92 |
+
"model.transformer.blocks.28.ff_out.weight": "model-00005-of-00006.safetensors",
|
| 93 |
+
"model.transformer.blocks.28.ff_proj.weight": "model-00005-of-00006.safetensors",
|
| 94 |
+
"model.transformer.blocks.29.att_proj.weight": "model-00006-of-00006.safetensors",
|
| 95 |
+
"model.transformer.blocks.29.attn_out.weight": "model-00005-of-00006.safetensors",
|
| 96 |
+
"model.transformer.blocks.29.ff_out.weight": "model-00006-of-00006.safetensors",
|
| 97 |
+
"model.transformer.blocks.29.ff_proj.weight": "model-00006-of-00006.safetensors",
|
| 98 |
+
"model.transformer.blocks.3.att_proj.weight": "model-00001-of-00006.safetensors",
|
| 99 |
+
"model.transformer.blocks.3.attn_out.weight": "model-00001-of-00006.safetensors",
|
| 100 |
+
"model.transformer.blocks.3.ff_out.weight": "model-00001-of-00006.safetensors",
|
| 101 |
+
"model.transformer.blocks.3.ff_proj.weight": "model-00001-of-00006.safetensors",
|
| 102 |
+
"model.transformer.blocks.30.att_proj.weight": "model-00006-of-00006.safetensors",
|
| 103 |
+
"model.transformer.blocks.30.attn_out.weight": "model-00006-of-00006.safetensors",
|
| 104 |
+
"model.transformer.blocks.30.ff_out.weight": "model-00006-of-00006.safetensors",
|
| 105 |
+
"model.transformer.blocks.30.ff_proj.weight": "model-00006-of-00006.safetensors",
|
| 106 |
+
"model.transformer.blocks.31.att_proj.weight": "model-00006-of-00006.safetensors",
|
| 107 |
+
"model.transformer.blocks.31.attn_out.weight": "model-00006-of-00006.safetensors",
|
| 108 |
+
"model.transformer.blocks.31.ff_out.weight": "model-00006-of-00006.safetensors",
|
| 109 |
+
"model.transformer.blocks.31.ff_proj.weight": "model-00006-of-00006.safetensors",
|
| 110 |
+
"model.transformer.blocks.4.att_proj.weight": "model-00001-of-00006.safetensors",
|
| 111 |
+
"model.transformer.blocks.4.attn_out.weight": "model-00001-of-00006.safetensors",
|
| 112 |
+
"model.transformer.blocks.4.ff_out.weight": "model-00001-of-00006.safetensors",
|
| 113 |
+
"model.transformer.blocks.4.ff_proj.weight": "model-00001-of-00006.safetensors",
|
| 114 |
+
"model.transformer.blocks.5.att_proj.weight": "model-00002-of-00006.safetensors",
|
| 115 |
+
"model.transformer.blocks.5.attn_out.weight": "model-00001-of-00006.safetensors",
|
| 116 |
+
"model.transformer.blocks.5.ff_out.weight": "model-00002-of-00006.safetensors",
|
| 117 |
+
"model.transformer.blocks.5.ff_proj.weight": "model-00002-of-00006.safetensors",
|
| 118 |
+
"model.transformer.blocks.6.att_proj.weight": "model-00002-of-00006.safetensors",
|
| 119 |
+
"model.transformer.blocks.6.attn_out.weight": "model-00002-of-00006.safetensors",
|
| 120 |
+
"model.transformer.blocks.6.ff_out.weight": "model-00002-of-00006.safetensors",
|
| 121 |
+
"model.transformer.blocks.6.ff_proj.weight": "model-00002-of-00006.safetensors",
|
| 122 |
+
"model.transformer.blocks.7.att_proj.weight": "model-00002-of-00006.safetensors",
|
| 123 |
+
"model.transformer.blocks.7.attn_out.weight": "model-00002-of-00006.safetensors",
|
| 124 |
+
"model.transformer.blocks.7.ff_out.weight": "model-00002-of-00006.safetensors",
|
| 125 |
+
"model.transformer.blocks.7.ff_proj.weight": "model-00002-of-00006.safetensors",
|
| 126 |
+
"model.transformer.blocks.8.att_proj.weight": "model-00002-of-00006.safetensors",
|
| 127 |
+
"model.transformer.blocks.8.attn_out.weight": "model-00002-of-00006.safetensors",
|
| 128 |
+
"model.transformer.blocks.8.ff_out.weight": "model-00002-of-00006.safetensors",
|
| 129 |
+
"model.transformer.blocks.8.ff_proj.weight": "model-00002-of-00006.safetensors",
|
| 130 |
+
"model.transformer.blocks.9.att_proj.weight": "model-00002-of-00006.safetensors",
|
| 131 |
+
"model.transformer.blocks.9.attn_out.weight": "model-00002-of-00006.safetensors",
|
| 132 |
+
"model.transformer.blocks.9.ff_out.weight": "model-00002-of-00006.safetensors",
|
| 133 |
+
"model.transformer.blocks.9.ff_proj.weight": "model-00002-of-00006.safetensors",
|
| 134 |
+
"model.transformer.ff_out.weight": "model-00006-of-00006.safetensors",
|
| 135 |
+
"model.transformer.wte.weight": "model-00001-of-00006.safetensors"
|
| 136 |
+
}
|
| 137 |
+
}
|
modeling_olmo.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import fields
|
| 3 |
+
from typing import List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import PreTrainedModel
|
| 7 |
+
from transformers.cache_utils import Cache
|
| 8 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 9 |
+
from transformers.models.auto import AutoModelForCausalLM
|
| 10 |
+
|
| 11 |
+
from olmo.config import ModelConfig
|
| 12 |
+
from olmo.model import OLMo
|
| 13 |
+
|
| 14 |
+
from .configuration_olmo import OLMoConfig
|
| 15 |
+
|
| 16 |
+
log = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def create_model_config_from_pretrained_config(config: OLMoConfig):
|
| 20 |
+
"""
|
| 21 |
+
Utility function
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
kwargs = {}
|
| 25 |
+
for field in fields(ModelConfig):
|
| 26 |
+
kwargs[field.name] = getattr(config, field.name)
|
| 27 |
+
|
| 28 |
+
model_config = ModelConfig(**kwargs)
|
| 29 |
+
return model_config
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class OLMoForCausalLM(PreTrainedModel):
|
| 33 |
+
"""
|
| 34 |
+
Extremely barebones HF model wrapper.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
config_class = OLMoConfig
|
| 38 |
+
base_model_prefix = "model"
|
| 39 |
+
_no_split_modules = ["OLMoBlock"]
|
| 40 |
+
|
| 41 |
+
def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
|
| 42 |
+
super().__init__(config)
|
| 43 |
+
|
| 44 |
+
if not model:
|
| 45 |
+
model_config = create_model_config_from_pretrained_config(config)
|
| 46 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
| 47 |
+
model_config.init_device = "cpu"
|
| 48 |
+
self.model = OLMo(model_config, init_params=init_params)
|
| 49 |
+
else:
|
| 50 |
+
self.model = model
|
| 51 |
+
|
| 52 |
+
def forward(
|
| 53 |
+
self,
|
| 54 |
+
input_ids: torch.LongTensor = None,
|
| 55 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 56 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 57 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 58 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 59 |
+
labels: Optional[torch.LongTensor] = None,
|
| 60 |
+
use_cache: Optional[bool] = None,
|
| 61 |
+
output_attentions: Optional[bool] = None,
|
| 62 |
+
output_hidden_states: Optional[bool] = None,
|
| 63 |
+
return_dict: Optional[bool] = None,
|
| 64 |
+
cache_position: Optional[
|
| 65 |
+
Cache
|
| 66 |
+
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
|
| 67 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 68 |
+
if use_cache is None:
|
| 69 |
+
use_cache = self.config.use_cache
|
| 70 |
+
|
| 71 |
+
if output_attentions:
|
| 72 |
+
raise ValueError("output_attentions is not yet supported in OLMo")
|
| 73 |
+
|
| 74 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 75 |
+
|
| 76 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 77 |
+
outputs = self.model.forward(
|
| 78 |
+
input_ids=input_ids,
|
| 79 |
+
input_embeddings=inputs_embeds,
|
| 80 |
+
attention_mask=attention_mask,
|
| 81 |
+
attention_bias=attention_bias,
|
| 82 |
+
past_key_values=past_key_values,
|
| 83 |
+
use_cache=use_cache,
|
| 84 |
+
output_hidden_states=output_hidden_states,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
logits = outputs.logits
|
| 88 |
+
hidden_states = outputs.hidden_states
|
| 89 |
+
|
| 90 |
+
loss = None
|
| 91 |
+
if labels is not None:
|
| 92 |
+
# Shift so that tokens < n predict n
|
| 93 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 94 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 95 |
+
# Flatten the tokens
|
| 96 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
| 97 |
+
shift_logits = shift_logits.view(-1, self.config.embedding_size)
|
| 98 |
+
shift_labels = shift_labels.view(-1)
|
| 99 |
+
# Enable model parallelism
|
| 100 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 101 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 102 |
+
|
| 103 |
+
if not return_dict:
|
| 104 |
+
output = (logits,) + outputs[1:]
|
| 105 |
+
return (loss,) + output if loss is not None else output
|
| 106 |
+
|
| 107 |
+
return CausalLMOutputWithPast(
|
| 108 |
+
loss=loss,
|
| 109 |
+
logits=logits,
|
| 110 |
+
past_key_values=outputs.attn_key_values,
|
| 111 |
+
hidden_states=hidden_states,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def can_generate(self) -> bool:
|
| 115 |
+
return True
|
| 116 |
+
|
| 117 |
+
def prepare_inputs_for_generation(
|
| 118 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
| 119 |
+
):
|
| 120 |
+
if past_key_values:
|
| 121 |
+
# This is because we want the model to only process the last generated token.
|
| 122 |
+
input_ids = input_ids[:, -1:]
|
| 123 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 124 |
+
|
| 125 |
+
model_inputs.update(kwargs)
|
| 126 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
| 127 |
+
return model_inputs
|
| 128 |
+
|
| 129 |
+
# TODO: these are required to make the implementation complete.
|
| 130 |
+
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
| 131 |
+
# pass
|
| 132 |
+
#
|
| 133 |
+
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
| 134 |
+
# pass
|
| 135 |
+
#
|
| 136 |
+
# def _reorder_cache(self, past_key_values, beam_idx):
|
| 137 |
+
# pass
|
| 138 |
+
|
| 139 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
| 140 |
+
return self.model.transformer.wte
|
| 141 |
+
|
| 142 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
| 143 |
+
self.model.transformer.wte = value
|
| 144 |
+
|
| 145 |
+
def get_output_embeddings(self):
|
| 146 |
+
if self.config.weight_tying:
|
| 147 |
+
return self.model.transformer.wte
|
| 148 |
+
else:
|
| 149 |
+
return self.model.transformer.ff_out
|
| 150 |
+
|
| 151 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
| 152 |
+
if self.config.weight_tying:
|
| 153 |
+
self.model.transformer.wte = value
|
| 154 |
+
else:
|
| 155 |
+
self.model.transformer.ff_out = value
|
| 156 |
+
|
| 157 |
+
def tie_weights(self):
|
| 158 |
+
"""
|
| 159 |
+
This function is intentionally left as a no-op.
|
| 160 |
+
|
| 161 |
+
Weight tying is handled as follows:
|
| 162 |
+
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
|
| 163 |
+
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
|
| 164 |
+
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
|
| 165 |
+
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
|
| 166 |
+
|
| 167 |
+
Therefore, there is no need to explicitly tie the weights in this function.
|
| 168 |
+
"""
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
def resize_token_embeddings(
|
| 172 |
+
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
| 173 |
+
) -> torch.nn.Embedding:
|
| 174 |
+
"""
|
| 175 |
+
Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.
|
| 176 |
+
|
| 177 |
+
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
| 178 |
+
|
| 179 |
+
Arguments:
|
| 180 |
+
new_num_tokens (`int`, *optional*):
|
| 181 |
+
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
|
| 182 |
+
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
| 183 |
+
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
|
| 184 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 185 |
+
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
|
| 186 |
+
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
|
| 187 |
+
|
| 188 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 189 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
| 190 |
+
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
| 191 |
+
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
| 192 |
+
|
| 193 |
+
Return:
|
| 194 |
+
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
| 195 |
+
|
| 196 |
+
Note:
|
| 197 |
+
This method differs from the base class implementation by resizing the `embedding_size` attribute of the
|
| 198 |
+
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
|
| 199 |
+
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
|
| 200 |
+
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
|
| 201 |
+
"""
|
| 202 |
+
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| 203 |
+
if new_num_tokens is None and pad_to_multiple_of is None:
|
| 204 |
+
return model_embeds
|
| 205 |
+
|
| 206 |
+
# Update base model and current model config
|
| 207 |
+
self.config.embedding_size = model_embeds.weight.shape[0]
|
| 208 |
+
self.model.config.embedding_size = model_embeds.weight.shape[0]
|
| 209 |
+
|
| 210 |
+
# Check if the embedding size is less than the vocab size
|
| 211 |
+
if self.config.embedding_size < self.config.vocab_size:
|
| 212 |
+
warning_message = (
|
| 213 |
+
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
|
| 214 |
+
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
|
| 215 |
+
"size is less than or equal to the new token embedding size."
|
| 216 |
+
)
|
| 217 |
+
log.warning(warning_message)
|
| 218 |
+
|
| 219 |
+
# Tie weights again if needed
|
| 220 |
+
self.tie_weights()
|
| 221 |
+
|
| 222 |
+
return model_embeds
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
| 226 |
+
# OLMo is integrated directly in transformers from v4.40.0 onwards, but the version in transformers
|
| 227 |
+
# may not support the newest architectures we create.
|
| 228 |
+
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
|