Fix pulled from Instruct model

#2
by Seas0 - opened
Files changed (2) hide show
  1. config.json +4 -5
  2. modeling_stable_diffcoder.py +23 -8
config.json CHANGED
@@ -1,8 +1,7 @@
1
  {
2
- "architectures": [
3
- "StableDiffcoderForCausalLM"
4
- ],
5
  "auto_map": {
 
6
  "AutoModelForCausalLM": "modeling_stable_diffcoder.StableDiffcoderForCausalLM"
7
  },
8
  "attention_bias": false,
@@ -21,11 +20,11 @@
21
  "num_hidden_layers": 32,
22
  "num_key_value_heads": 8,
23
  "resid_pdrop": 0.1,
24
- "rms_norm_eps": 1e-06,
25
  "rope_theta": 500000.0,
26
  "tie_word_embeddings": false,
27
  "torch_dtype": "bfloat16",
28
  "transformers_version": "5.3.0",
29
  "use_cache": true,
30
  "vocab_size": 155136
31
- }
 
1
  {
2
+ "architectures": ["StableDiffcoderForCausalLM"],
 
 
3
  "auto_map": {
4
+ "AutoModel": "modeling_stable_diffcoder.StableDiffcoderForCausalLM",
5
  "AutoModelForCausalLM": "modeling_stable_diffcoder.StableDiffcoderForCausalLM"
6
  },
7
  "attention_bias": false,
 
20
  "num_hidden_layers": 32,
21
  "num_key_value_heads": 8,
22
  "resid_pdrop": 0.1,
23
+ "rms_norm_eps": 1e-6,
24
  "rope_theta": 500000.0,
25
  "tie_word_embeddings": false,
26
  "torch_dtype": "bfloat16",
27
  "transformers_version": "5.3.0",
28
  "use_cache": true,
29
  "vocab_size": 155136
30
+ }
modeling_stable_diffcoder.py CHANGED
@@ -137,8 +137,10 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
137
  prompt_length = input_ids.shape[1]
138
  gen_block_list = [block_length for _ in range(gen_blocks)]
139
 
140
- res_block = block_length - (prompt_length % block_length)
141
- if res_block > 0:
 
 
142
  gen_block_list = [res_block] + gen_block_list
143
  gen_block_list[-1] = block_length - res_block
144
  gen_blocks += 1
@@ -156,16 +158,23 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
156
  nfe = 0
157
  final_flag = False
158
  prefill_length = prompt_length // block_length * block_length
 
159
  if prefill_length > 0:
160
  cur_attn_mask = block_diffusion_attention_mask[
161
  ..., :prefill_length, :prefill_length
162
  ]
 
 
 
 
 
163
  self(
164
  x[:, :prefill_length],
165
  past_key_values=past_key_values,
166
  attention_mask=cur_attn_mask,
167
  use_cache=True,
168
- ).past_key_values
 
169
 
170
  for block_id, block_size in enumerate(gen_block_list):
171
  block_start = (
@@ -182,7 +191,7 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
182
  replace_position[:, block_start:block_end] = True
183
 
184
  for token_count in num_transfer_tokens:
185
- if token_count:
186
  nfe += 1
187
  mask_map = x[:, block_start:block_end] == mask_id
188
  attention_mask = block_diffusion_attention_mask[
@@ -205,22 +214,28 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
205
  remasking,
206
  mask_map,
207
  x[:, block_start:block_end],
208
- token_count if threshold is None else None,
209
  threshold,
210
- shift=False,
211
  )
212
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
213
 
214
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
 
 
 
 
215
  if (
216
  eos_id is not None
217
- and (x[:, block_start:block_end] == eos_id).sum() > 0
 
218
  ):
219
  final_flag = True
220
  x = x[:, :block_end]
221
- eos_pos = (x == eos_id).nonzero(as_tuple=True)[1][0].item()
222
  x[0, eos_pos:] = eos_id
223
  break
 
224
  nfe += 1
225
  self(
226
  x[:, block_start:block_end],
 
137
  prompt_length = input_ids.shape[1]
138
  gen_block_list = [block_length for _ in range(gen_blocks)]
139
 
140
+ # Fix 3: Only handle residual blocks if the prompt length is NOT cleanly divisible
141
+ remainder = prompt_length % block_length
142
+ if remainder != 0:
143
+ res_block = block_length - remainder
144
  gen_block_list = [res_block] + gen_block_list
145
  gen_block_list[-1] = block_length - res_block
146
  gen_blocks += 1
 
158
  nfe = 0
159
  final_flag = False
160
  prefill_length = prompt_length // block_length * block_length
161
+
162
  if prefill_length > 0:
163
  cur_attn_mask = block_diffusion_attention_mask[
164
  ..., :prefill_length, :prefill_length
165
  ]
166
+ # Fix 1: Explicitly pass cache_position for newer transformers prefill
167
+ # actually not necessary since transformers will automatically generate it for prefilling
168
+ # if unspecified, but the official `generate` method does pass it,
169
+ # so we follow that for consistency and to avoid potential issues in future transformers updates
170
+ cache_pos = torch.arange(prefill_length, device=x.device)
171
  self(
172
  x[:, :prefill_length],
173
  past_key_values=past_key_values,
174
  attention_mask=cur_attn_mask,
175
  use_cache=True,
176
+ cache_position=cache_pos,
177
+ )
178
 
179
  for block_id, block_size in enumerate(gen_block_list):
180
  block_start = (
 
191
  replace_position[:, block_start:block_end] = True
192
 
193
  for token_count in num_transfer_tokens:
194
+ if token_count > 0:
195
  nfe += 1
196
  mask_map = x[:, block_start:block_end] == mask_id
197
  attention_mask = block_diffusion_attention_mask[
 
214
  remasking,
215
  mask_map,
216
  x[:, block_start:block_end],
217
+ token_count.item() if threshold is None else None,
218
  threshold,
219
+ shift=shift,
220
  )
221
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
222
 
223
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
224
+
225
+ # Fix 2: Calculate where the generated tokens ACTUALLY start in this block
226
+ gen_start = max(block_start, prompt_length)
227
+
228
  if (
229
  eos_id is not None
230
+ and gen_start < block_end
231
+ and (x[:, gen_start:block_end] == eos_id).sum() > 0
232
  ):
233
  final_flag = True
234
  x = x[:, :block_end]
235
+ eos_pos = (x[:, gen_start:block_end] == eos_id).nonzero(as_tuple=True)[1][0].item() + gen_start
236
  x[0, eos_pos:] = eos_id
237
  break
238
+
239
  nfe += 1
240
  self(
241
  x[:, block_start:block_end],