Spaces:
Running
Running
| /* | |
| * ,; | |
| * \@@#\: :/. .:;;: | |
| * _@@@@@@#+\|/!;;!-@@@--; ,@@@@@; | |
| * .!_*@@@@@@@@@@@@@@@@@@@; |@@@@@\ | |
| * .:!|+@@@@@##@@@@@@@#! -@@@@@#, | |
| * .\@@@*;,\@@@@@@@@+,*@@@@@@+. | |
| * :*#@@@@@@@@@@@@@@-+@@@@@@@\@@@@-. | |
| * .#@@@@@#@@@@#*@@@+ /@@@@@@;\@@@@+. | |
| * ;\/:, -@@@@;|@@@\ ,+@@@@!.+@@@@*: | |
| * ,@@@@#*@@@@@#+__!. ,*@@@@@/ | |
| * \##+_@@@@@@@@, ,+@@@_: | |
| * ;;,,..,: !;. | |
| */ | |
| var __defProp = Object.defineProperty; | |
| var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); | |
| var __export = (target, all) => { | |
| for (var name in all) | |
| __defProp(target, name, { get: all[name], enumerable: true }); | |
| }; | |
| // src/config.js | |
| var QWEN25_3B = { | |
| hiddenSize: 2048, | |
| numLayers: 36, | |
| numHeads: 16, | |
| numKVHeads: 2, | |
| headDim: 128, | |
| intermediateSize: 11008, | |
| vocabSize: 151936, | |
| rmsNormEps: 1e-6, | |
| ropeTheta: 1e6, | |
| /* | |
| * TECHNIQUE: Tie word embeddings | |
| * input embedding == output head. | |
| * Simplifies loading (one tensor), schema, and final projection math. | |
| * Required by the current model_uploader + schema. | |
| */ | |
| tieWordEmbeddings: true, | |
| // QKV projections carry a bias in Qwen2.5; o_proj and the MLP do not. | |
| attentionBias: true | |
| }; | |
| // src/readers.js | |
| function urlReader(baseUrl, headers = {}) { | |
| const base = baseUrl.endsWith("/") ? baseUrl : baseUrl + "/"; | |
| return { | |
| async range(path, start, end) { | |
| const r = await fetch(base + path, { | |
| headers: { ...headers, Range: `bytes=${start}-${end - 1}` } | |
| }); | |
| if (!r.ok && r.status !== 206) { | |
| throw new Error(`range ${path} ${start}-${end}: ${r.status}`); | |
| } | |
| return await r.arrayBuffer(); | |
| }, | |
| async text(path) { | |
| const r = await fetch(base + path, { headers }); | |
| if (!r.ok) throw new Error(`fetch ${path}: ${r.status}`); | |
| return await r.text(); | |
| } | |
| }; | |
| } | |
| __name(urlReader, "urlReader"); | |
| function hfReader(repo, token = "", rev = "main") { | |
| return urlReader( | |
| `https://huggingface.co/${repo}/resolve/${rev}`, | |
| token ? { Authorization: `Bearer ${token}` } : {} | |
| ); | |
| } | |
| __name(hfReader, "hfReader"); | |
| function fileReader(fileMap) { | |
| const pick = /* @__PURE__ */ __name((path) => fileMap[path] || fileMap[path.split("/").pop()], "pick"); | |
| return { | |
| async range(path, start, end) { | |
| const f = pick(path); | |
| if (!f) throw new Error(`file not provided: ${path}`); | |
| return await f.slice(start, end).arrayBuffer(); | |
| }, | |
| async text(path) { | |
| const f = pick(path); | |
| if (!f) throw new Error(`file not provided: ${path}`); | |
| return await f.text(); | |
| } | |
| }; | |
| } | |
| __name(fileReader, "fileReader"); | |
| // src/services/adapter_registry.js | |
| var AdapterRegistry = class { | |
| static { | |
| __name(this, "AdapterRegistry"); | |
| } | |
| constructor() { | |
| this.adapters = { none: null }; | |
| } | |
| add(name, modules) { | |
| this.adapters[name] = { modules }; | |
| return this.adapters[name]; | |
| } | |
| get(name) { | |
| return this.adapters[name] || null; | |
| } | |
| /* | |
| * TECHNIQUE: Runtime adapter swapping via setLora | |
| * Registry holds pre-uploaded A/B buffers. applyToRuntime calls | |
| * rt.setLora which just swaps references — no weight reload. | |
| */ | |
| applyToRuntime(name, rt) { | |
| const adapter = this.get(name); | |
| if (adapter) rt.setLora(adapter); | |
| else rt.clearLora(); | |
| return adapter; | |
| } | |
| }; | |
| // src/qwgpu/kernels.js | |
| var GEMV = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| requires subgroup_id; | |
| struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> w: array<u32>; // [N][K/4] int8 | |
| @group(0) @binding(2) var<storage,read> scale: array<f32>; // [N] | |
| @group(0) @binding(3) var<storage,read> bias: array<f32>; // [N] or dummy | |
| @group(0) @binding(4) var<storage,read> loraD: array<f32>; // [rank] precomputed x@A (or dummy) | |
| @group(0) @binding(5) var<storage,read> loraB: array<f32>; // [rank][N] (or dummy) | |
| @group(0) @binding(6) var<storage,read_write> y: array<f32>; // [N] | |
| var<immediate> m: Meta; | |
| var<workgroup> part: array<f32,64>; // one slot per subgroup | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32, | |
| @builtin(subgroup_id) sgroup: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.N) { return; } // workgroup-uniform: whole group exits together | |
| let K4 = m.K/4u; let rb = n*K4; | |
| var acc = 0.0; | |
| for (var k = tid; k < K4; k = k + 64u) { | |
| let p = w[rb+k]; | |
| let v = unpack4xI8(p); // vec4<i32> | |
| let kk = k*4u; | |
| acc = acc + x[kk]*f32(v.x) + x[kk+1u]*f32(v.y) + x[kk+2u]*f32(v.z) + x[kk+3u]*f32(v.w); | |
| } | |
| let ssum = subgroupAdd(acc); // reduce within subgroup (no barrier) | |
| if (sgid == 0u) { part[tid / sgsz] = ssum; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (64u + sgsz - 1u) / sgsz; var red = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { red = red + part[i]; } | |
| var o = red * scale[n]; | |
| if (m.hasBias == 1u) { o = o + bias[n]; } | |
| if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } | |
| y[n] = o; | |
| } | |
| }`; | |
| var LORA_A = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; // [K] | |
| @group(0) @binding(1) var<storage,read> A: array<f32>; // [rank][K] (transposed) | |
| @group(0) @binding(2) var<storage,read_write> d: array<f32>; // [rank] | |
| var<immediate> m: vec2<u32>; // K, rank | |
| var<workgroup> part: array<f32,64>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let r = wid.x; let K = m.x; if (r >= m.y) { return; } | |
| let rb = r*K; var acc = 0.0; | |
| for (var k = lid.x; k < K; k = k + 64u) { acc = acc + x[k]*A[rb + k]; } | |
| let s = subgroupAdd(acc); | |
| if (sgid == 0u) { part[lid.x / sgsz] = s; } | |
| workgroupBarrier(); | |
| if (lid.x == 0u) { let nsg=(64u+sgsz-1u)/sgsz; var o=0.0; for(var i=0u;i<nsg;i=i+1u){o=o+part[i];} d[r]=o; } | |
| }`; | |
| var LORA_A_BATCH = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; // [T][K] | |
| @group(0) @binding(1) var<storage,read> A: array<f32>; // [rank][K] | |
| @group(0) @binding(2) var<storage,read_write> d: array<f32>; // [T][rank] | |
| var<immediate> m: vec4<u32>; // K, rank, T, _ | |
| var<workgroup> part: array<f32,64>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let r = wid.x; let t = wid.y; let K = m.x; let rank = m.y; if (r >= rank || t >= m.z) { return; } | |
| let xb = t*K; let ab = r*K; var acc = 0.0; | |
| for (var k = lid.x; k < K; k = k + 64u) { acc = acc + x[xb + k]*A[ab + k]; } | |
| let s = subgroupAdd(acc); | |
| if (sgid == 0u) { part[lid.x / sgsz] = s; } | |
| workgroupBarrier(); | |
| if (lid.x == 0u) { let nsg=(64u+sgsz-1u)/sgsz; var o=0.0; for(var i=0u;i<nsg;i=i+1u){o=o+part[i];} d[t*rank + r]=o; } | |
| }`; | |
| var LORA_B_ADD_T = ` | |
| requires immediate_address_space; | |
| struct Meta { T:u32, N:u32, rank:u32, gx:u32, scale:f32, p1:f32, p2:f32, p3:f32 }; | |
| @group(0) @binding(0) var<storage,read> d: array<f32>; // [T][rank] | |
| @group(0) @binding(1) var<storage,read> B: array<f32>; // [rank][N] | |
| @group(0) @binding(2) var<storage,read_write> Y: array<f32>; // [T][N] | |
| var<immediate> m: Meta; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let i = gid.y * (m.gx * 256u) + gid.x; | |
| if (i >= m.T * m.N) { return; } | |
| let t = i / m.N; let n = i % m.N; var acc = 0.0; | |
| for (var r = 0u; r < m.rank; r = r + 1u) { acc = acc + d[t*m.rank + r] * B[r*m.N + n]; } | |
| Y[i] = Y[i] + m.scale * acc; | |
| }`; | |
| var LORA_B_ADD = ` | |
| requires immediate_address_space; | |
| struct Meta { N:u32, rank:u32, p0:u32, p1:u32, scale:f32, f0:f32, f1:f32, f2:f32 }; | |
| @group(0) @binding(0) var<storage,read> d: array<f32>; // [rank] | |
| @group(0) @binding(1) var<storage,read> B: array<f32>; // [rank][N] | |
| @group(0) @binding(2) var<storage,read_write> y: array<f32>; // [N] | |
| var<immediate> m: Meta; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let n = gid.x; | |
| if (n >= m.N) { return; } | |
| var acc = 0.0; | |
| for (var r = 0u; r < m.rank; r = r + 1u) { acc = acc + d[r] * B[r*m.N + n]; } | |
| y[n] = y[n] + m.scale * acc; | |
| }`; | |
| var RMSNORM = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> g: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> y: array<f32>; | |
| var<immediate> m: vec2<f32>; // K, eps | |
| var<workgroup> part: array<f32,256>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; let K = u32(m.x); | |
| var s = 0.0; for (var k = tid; k < K; k = k + WG) { let v = x[k]; s = s + v*v; } | |
| part[tid] = s; workgroupBarrier(); | |
| for (var t = WG / 2u; t > 0u; t = t/2u) { if (tid < t) { part[tid] = part[tid] + part[tid+t]; } workgroupBarrier(); } | |
| let inv = inverseSqrt(part[0]/m.x + m.y); | |
| for (var k = tid; k < K; k = k + WG) { y[k] = x[k]*inv*g[k]; } | |
| }`; | |
| var RMSNORM_F16 = ` | |
| requires immediate_address_space; | |
| enable f16; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> g: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> y: array<f32>; | |
| var<immediate> m: vec2<f32>; // K, eps | |
| // Reduction accumulates in f32 even though the normalize is f16: summing v*v over | |
| // thousands of dims overflows f16 (>65504) at high-magnitude tokens (the attention | |
| // sink), which collapses inv to 0. Keeping the sum in f32 is the overflow-safe path. | |
| var<workgroup> part: array<f32,256>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; let K = u32(m.x); | |
| var s = 0.0; | |
| for (var k = tid; k < K; k = k + WG) { let v = f32(x[k]); s = s + v*v; } | |
| part[tid] = s; workgroupBarrier(); | |
| for (var t = WG / 2u; t > 0u; t = t/2u) { if (tid < t) { part[tid] = part[tid] + part[tid+t]; } workgroupBarrier(); } | |
| let inv = f16(inverseSqrt(part[0]/m.x + m.y)); | |
| for (var k = tid; k < K; k = k + WG) { y[k] = f32( f16(x[k]) * inv * f16(g[k]) ); } | |
| }`; | |
| var ROPE = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read_write> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> cosT: array<f32>; | |
| @group(0) @binding(2) var<storage,read> sinT: array<f32>; | |
| var<immediate> m: vec3<u32>; // nHeads, headDim, pos | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let H = m.x; let D = m.y; let pos = m.z; let half = D/2u; | |
| if (g >= H*half) { return; } | |
| let h = g / half; let j = g % half; | |
| let lo = h*D + j; let hi = lo + half; let off = pos*D + j; | |
| let c = cosT[off]; let s = sinT[off]; | |
| let xl = x[lo]; let xh = x[hi]; | |
| // EXACT rotate-half: separately-rounded products (fma(a,b,0)) prevent the | |
| // compiler from contracting x*c - x*s into a single fma, matching the PyTorch | |
| // reference rounding exactly. | |
| x[lo] = fma(xl, c, 0.0) + fma(-xh, s, 0.0); | |
| x[hi] = fma(xh, c, 0.0) + fma(xl, s, 0.0); | |
| }`; | |
| var ROPE_F16 = ` | |
| requires immediate_address_space; | |
| enable f16; | |
| @group(0) @binding(0) var<storage,read_write> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> cosT: array<f32>; | |
| @group(0) @binding(2) var<storage,read> sinT: array<f32>; | |
| var<immediate> m: vec3<u32>; // nHeads, headDim, pos | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let H = m.x; let D = m.y; let pos = m.z; let half = D/2u; | |
| if (g >= H*half) { return; } | |
| let h = g / half; let j = g % half; | |
| let lo = h*D + j; let hi = lo + half; let off = pos*D + j; | |
| let c = f16(cosT[off]); let s = f16(sinT[off]); | |
| let xl = f16(x[lo]); let xh = f16(x[hi]); | |
| x[lo] = f32( fma(xl, c, 0.0h) + fma(-xh, s, 0.0h) ); | |
| x[hi] = f32( fma(xh, c, 0.0h) + fma(xl, s, 0.0h) ); | |
| }`; | |
| var ROPE_QK = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read_write> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read_write> k: array<f32>; | |
| @group(0) @binding(2) var<storage,read> cosT: array<f32>; | |
| @group(0) @binding(3) var<storage,read> sinT: array<f32>; | |
| var<immediate> m: vec4<u32>; // qHeads, kvHeads, headDim, pos | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let qH = m.x; let kH = m.y; let D = m.z; let pos = m.w; let half = D/2u; | |
| let qPairs = qH * half; let kPairs = kH * half; let total = qPairs + kPairs; | |
| if (g >= total) { return; } | |
| let isK = g >= qPairs; | |
| var r = g; | |
| if (isK) { r = g - qPairs; } | |
| let h = r / half; let j = r % half; | |
| let lo = h*D + j; let hi = lo + half; let off = pos*D + j; | |
| let c = cosT[off]; let s = sinT[off]; | |
| if (isK) { | |
| let xl = k[lo]; let xh = k[hi]; | |
| k[lo] = fma(xl, c, 0.0) + fma(-xh, s, 0.0); k[hi] = fma(xh, c, 0.0) + fma(xl, s, 0.0); | |
| } else { | |
| let xl = q[lo]; let xh = q[hi]; | |
| q[lo] = fma(xl, c, 0.0) + fma(-xh, s, 0.0); q[hi] = fma(xh, c, 0.0) + fma(xl, s, 0.0); | |
| } | |
| }`; | |
| var ROPE_QK_F16 = ` | |
| requires immediate_address_space; | |
| enable f16; | |
| @group(0) @binding(0) var<storage,read_write> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read_write> k: array<f32>; | |
| @group(0) @binding(2) var<storage,read> cosT: array<f32>; | |
| @group(0) @binding(3) var<storage,read> sinT: array<f32>; | |
| var<immediate> m: vec4<u32>; // qHeads, kvHeads, headDim, pos | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let qH = m.x; let kH = m.y; let D = m.z; let pos = m.w; let half = D/2u; | |
| let qPairs = qH * half; let kPairs = kH * half; let total = qPairs + kPairs; | |
| if (g >= total) { return; } | |
| let isK = g >= qPairs; | |
| var r = g; | |
| if (isK) { r = g - qPairs; } | |
| let h = r / half; let j = r % half; | |
| let lo = h*D + j; let hi = lo + half; let off = pos*D + j; | |
| let c = f16(cosT[off]); let s = f16(sinT[off]); | |
| if (isK) { | |
| let xl = f16(k[lo]); let xh = f16(k[hi]); | |
| k[lo] = f32( fma(xl, c, 0.0h) + fma(-xh, s, 0.0h) ); k[hi] = f32( fma(xh, c, 0.0h) + fma(xl, s, 0.0h) ); | |
| } else { | |
| let xl = f16(q[lo]); let xh = f16(q[hi]); | |
| q[lo] = f32( fma(xl, c, 0.0h) + fma(-xh, s, 0.0h) ); q[hi] = f32( fma(xh, c, 0.0h) + fma(xl, s, 0.0h) ); | |
| } | |
| }`; | |
| var ATTN_PARTIAL = ` | |
| requires immediate_address_space; | |
| enable subgroups; | |
| override WG: u32 = 128u; | |
| struct AttnP { nHeads: u32, nKV: u32, ctx: u32, hd: u32, nsplit: u32, chunk: u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> pm: array<f32>; // [nHeads*nsplit] per-split max | |
| @group(0) @binding(4) var<storage,read_write> pz: array<f32>; // [nHeads*nsplit] per-split sum | |
| @group(0) @binding(5) var<storage,read_write> po: array<f32>; // [nHeads*nsplit*hd] unnorm weighted V | |
| var<immediate> m: AttnP; | |
| var<workgroup> sc: array<f32,128>; | |
| var<workgroup> red: array<f32,32>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let h = wid.x; let s = wid.y; let tid = lid.x; | |
| let nHeads = m.nHeads; let nKV = m.nKV; let ctx = m.ctx; let hd = m.hd; let nsplit = m.nsplit; let chunk = m.chunk; | |
| let kvh = h / (nHeads / nKV); | |
| let qbase = h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scale = 1.0/sqrt(f32(hd)); | |
| let nsg = (128u + sgsz - 1u) / sgsz; | |
| let t0 = s*chunk; var t1 = t0 + chunk; if (t1 > ctx) { t1 = ctx; } | |
| let t = t0 + tid; var sv = -1e30; | |
| if (t < t1) { var dot = 0.0; let kb = t*stride + hoff; for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qbase+d]*kc[kb+d]; } sv = dot*scale; } | |
| let sgm = subgroupMax(sv); if (sgid == 0u) { red[tid/sgsz] = sgm; } | |
| workgroupBarrier(); | |
| var M = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { M = max(M, red[i]); } | |
| workgroupBarrier(); | |
| var ev = 0.0; if (t < t1) { ev = exp(sv - M); } sc[tid] = ev; | |
| let sgs = subgroupAdd(ev); if (sgid == 0u) { red[tid/sgsz] = sgs; } | |
| workgroupBarrier(); | |
| var Z = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { Z = Z + red[i]; } | |
| workgroupBarrier(); | |
| let len = t1 - t0; let pbase = (h*nsplit + s)*hd; | |
| for (var d = tid; d < hd; d = d + 128u) { | |
| var acc = 0.0; for (var tt = 0u; tt < len; tt = tt + 1u) { acc = acc + sc[tt]*vc[(t0+tt)*stride + hoff + d]; } | |
| po[pbase + d] = acc; | |
| } | |
| if (tid == 0u) { pm[h*nsplit + s] = M; pz[h*nsplit + s] = Z; } | |
| }`; | |
| var ATTN_PARTIAL_F16 = ` | |
| requires immediate_address_space; | |
| enable subgroups; | |
| enable f16; | |
| override WG: u32 = 128u; | |
| struct AttnP { nHeads: u32, nKV: u32, ctx: u32, hd: u32, nsplit: u32, chunk: u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> pm: array<f32>; // [nHeads*nsplit] per-split max | |
| @group(0) @binding(4) var<storage,read_write> pz: array<f32>; // [nHeads*nsplit] per-split sum | |
| @group(0) @binding(5) var<storage,read_write> po: array<f32>; // [nHeads*nsplit*hd] unnorm weighted V | |
| var<immediate> m: AttnP; | |
| // f16 "staging" mode: Q/K/V values are read through f16 (so they carry f16 rounding, | |
| // modelling an f16 KV cache), but every REDUCTION \u2014 the QK dot, the softmax max/sum, | |
| // and the weighted-V accumulation \u2014 runs in f32. Accumulating scores in f16 overflows | |
| // at long context / high-magnitude tokens; f32 accumulation is the overflow-safe path | |
| // (matches the Gemma-4 "scores/PV accumulate in f32, only K/V carry f16 rounding"). | |
| var<workgroup> sc: array<f32,128>; | |
| var<workgroup> red: array<f32,32>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let h = wid.x; let s = wid.y; let tid = lid.x; | |
| let nHeads = m.nHeads; let nKV = m.nKV; let ctx = m.ctx; let hd = m.hd; let nsplit = m.nsplit; let chunk = m.chunk; | |
| let kvh = h / (nHeads / nKV); | |
| let qbase = h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scale = 1.0 / sqrt(f32(hd)); | |
| let nsg = (WG + sgsz - 1u) / sgsz; | |
| let t0 = s*chunk; var t1 = t0 + chunk; if (t1 > ctx) { t1 = ctx; } | |
| let t = t0 + tid; var sv = -1e30; | |
| if (t < t1) { var dot = 0.0; let kb = t*stride + hoff; for (var d = 0u; d < hd; d = d + 1u) { dot = dot + f32(f16(q[qbase+d])) * f32(f16(kc[kb+d])); } sv = dot*scale; } | |
| let sgm = subgroupMax(sv); if (sgid == 0u) { red[tid/sgsz] = sgm; } | |
| workgroupBarrier(); | |
| var M = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { M = max(M, red[i]); } | |
| workgroupBarrier(); | |
| var ev = 0.0; if (t < t1) { ev = exp(sv - M); } sc[tid] = ev; | |
| let sgs = subgroupAdd(ev); if (sgid == 0u) { red[tid/sgsz] = sgs; } | |
| workgroupBarrier(); | |
| var Z = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { Z = Z + red[i]; } | |
| workgroupBarrier(); | |
| let len = t1 - t0; let pbase = (h*nsplit + s)*hd; | |
| for (var d = tid; d < hd; d = d + WG) { | |
| var acc = 0.0; for (var tt = 0u; tt < len; tt = tt + 1u) { acc = acc + sc[tt] * f32(f16(vc[(t0+tt)*stride + hoff + d])); } | |
| po[pbase + d] = acc; | |
| } | |
| if (tid == 0u) { pm[h*nsplit + s] = M; pz[h*nsplit + s] = Z; } | |
| }`; | |
| var ATTN_COMBINE = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 128u; | |
| @group(0) @binding(0) var<storage,read> pm: array<f32>; | |
| @group(0) @binding(1) var<storage,read> pz: array<f32>; | |
| @group(0) @binding(2) var<storage,read> po: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> o: array<f32>; | |
| var<immediate> m: vec4<u32>; // nHeads, hd, nsplit, _ | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let h = wid.x; let tid = lid.x; let hd = m.y; let nsplit = m.z; let base = h*nsplit; | |
| var M = -1e30; for (var s = 0u; s < nsplit; s = s + 1u) { M = max(M, pm[base+s]); } | |
| var Z = 0.0; for (var s = 0u; s < nsplit; s = s + 1u) { Z = Z + pz[base+s]*exp(pm[base+s]-M); } | |
| let invZ = 1.0 / Z; | |
| for (var d = tid; d < hd; d = d + WG) { | |
| var acc = 0.0; | |
| for (var s = 0u; s < nsplit; s = s + 1u) { acc = acc + exp(pm[base+s]-M)*po[(base+s)*hd + d]; } | |
| o[h*hd + d] = acc * invZ; | |
| } | |
| }`; | |
| var ATTN_COMBINE_F16 = ` | |
| requires immediate_address_space; | |
| enable f16; | |
| override WG: u32 = 128u; | |
| @group(0) @binding(0) var<storage,read> pm: array<f32>; | |
| @group(0) @binding(1) var<storage,read> pz: array<f32>; | |
| @group(0) @binding(2) var<storage,read> po: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> o: array<f32>; | |
| var<immediate> m: vec4<u32>; // nHeads, hd, nsplit, _ | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let h = wid.x; let tid = lid.x; let hd = m.y; let nsplit = m.z; let base = h*nsplit; | |
| // Cross-split softmax merge accumulates max/sum in f32 (overflow-safe); only the | |
| // final per-element weighting carries f16 rounding. | |
| var M = -1e30; for (var s = 0u; s < nsplit; s = s + 1u) { M = max(M, pm[base+s]); } | |
| var Z = 0.0; for (var s = 0u; s < nsplit; s = s + 1u) { Z = Z + pz[base+s] * exp(pm[base+s] - M); } | |
| let invZ = 1.0 / Z; | |
| for (var d = tid; d < hd; d = d + WG) { | |
| var acc = 0.0; | |
| for (var s = 0u; s < nsplit; s = s + 1u) { acc = acc + exp(pm[base+s] - M) * f32(f16(po[(base+s)*hd + d])); } | |
| o[h*hd + d] = acc * invZ; | |
| } | |
| }`; | |
| var GEMM4 = ` | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, T:u32, gpr:u32, hasBias:u32, p0:u32, p1:u32, p2:u32 }; | |
| @group(0) @binding(0) var<storage,read> A: array<f32>; // [T][K] | |
| @group(0) @binding(1) var<storage,read> W: array<u32>; // [N][K/8] int4 | |
| @group(0) @binding(2) var<storage,read> scale: array<f32>; // [N][gpr] | |
| @group(0) @binding(3) var<storage,read> bias: array<f32>; // [N] or dummy | |
| @group(0) @binding(4) var<storage,read_write> Y: array<f32>; // [T][N] | |
| var<immediate> m: Meta; | |
| const BM = 16u; const BN = 64u; | |
| var<workgroup> As: array<f32, 128>; // BM*8 \u2014 A staged for one 8-wide K chunk | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tTile = wid.y * BM; let col = wid.x * BN + lid.x; let valid = col < m.N; | |
| let K8 = m.K/8u; let rb = col*K8; | |
| var acc: array<f32, 16>; | |
| for (var i = 0u; i < BM; i = i + 1u) { acc[i] = 0.0; } | |
| for (var c = 0u; c < K8; c = c + 1u) { | |
| for (var l = lid.x; l < BM*8u; l = l + 64u) { | |
| let tt = l / 8u; let trow = tTile + tt; | |
| As[l] = select(0.0, A[trow*m.K + c*8u + (l % 8u)], trow < m.T); | |
| } | |
| workgroupBarrier(); | |
| if (valid) { | |
| let word = W[rb + c]; let sc = scale[col*m.gpr + ((c*8u) >> 7u)]; | |
| let w0=f32(i32(word<<28u)>>28u)*sc; let w1=f32(i32(word<<24u)>>28u)*sc; | |
| let w2=f32(i32(word<<20u)>>28u)*sc; let w3=f32(i32(word<<16u)>>28u)*sc; | |
| let w4=f32(i32(word<<12u)>>28u)*sc; let w5=f32(i32(word<<8u)>>28u)*sc; | |
| let w6=f32(i32(word<<4u)>>28u)*sc; let w7=f32(i32(word)>>28u)*sc; | |
| for (var t = 0u; t < BM; t = t + 1u) { | |
| let b = t*8u; | |
| acc[t] = acc[t] + As[b]*w0+As[b+1u]*w1+As[b+2u]*w2+As[b+3u]*w3+As[b+4u]*w4+As[b+5u]*w5+As[b+6u]*w6+As[b+7u]*w7; | |
| } | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (valid) { | |
| let bv = select(0.0, bias[col], m.hasBias == 1u); | |
| for (var t = 0u; t < BM; t = t + 1u) { let trow = tTile + t; if (trow < m.T) { Y[trow*m.N + col] = acc[t] + bv; } } | |
| } | |
| }`; | |
| var GEMM4_ADD_T = ` | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, T:u32, gpr:u32, hasBias:u32, p0:u32, p1:u32, p2:u32 }; | |
| @group(0) @binding(0) var<storage,read> A: array<f32>; | |
| @group(0) @binding(1) var<storage,read> W: array<u32>; | |
| @group(0) @binding(2) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(3) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(4) var<storage,read_write> Y: array<f32>; | |
| var<immediate> m: Meta; | |
| const BM = 16u; const BN = 64u; | |
| var<workgroup> As: array<f32, 128>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tTile = wid.y * BM; let col = wid.x * BN + lid.x; let valid = col < m.N; | |
| let K8 = m.K/8u; let rb = col*K8; | |
| var acc: array<f32, 16>; | |
| for (var i = 0u; i < BM; i = i + 1u) { acc[i] = 0.0; } | |
| for (var c = 0u; c < K8; c = c + 1u) { | |
| for (var l = lid.x; l < BM*8u; l = l + 64u) { | |
| let tt = l / 8u; let trow = tTile + tt; | |
| As[l] = select(0.0, A[trow*m.K + c*8u + (l % 8u)], trow < m.T); | |
| } | |
| workgroupBarrier(); | |
| if (valid) { | |
| let word = W[rb + c]; let sc = scale[col*m.gpr + ((c*8u) >> 7u)]; | |
| let w0=f32(i32(word<<28u)>>28u)*sc; let w1=f32(i32(word<<24u)>>28u)*sc; | |
| let w2=f32(i32(word<<20u)>>28u)*sc; let w3=f32(i32(word<<16u)>>28u)*sc; | |
| let w4=f32(i32(word<<12u)>>28u)*sc; let w5=f32(i32(word<<8u)>>28u)*sc; | |
| let w6=f32(i32(word<<4u)>>28u)*sc; let w7=f32(i32(word)>>28u)*sc; | |
| for (var t = 0u; t < BM; t = t + 1u) { | |
| let b = t*8u; | |
| acc[t] = acc[t] + As[b]*w0+As[b+1u]*w1+As[b+2u]*w2+As[b+3u]*w3+As[b+4u]*w4+As[b+5u]*w5+As[b+6u]*w6+As[b+7u]*w7; | |
| } | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (valid) { | |
| let bv = select(0.0, bias[col], m.hasBias == 1u); | |
| for (var t = 0u; t < BM; t = t + 1u) { | |
| let trow = tTile + t; | |
| if (trow < m.T) { Y[trow*m.N + col] = Y[trow*m.N + col] + acc[t] + bv; } | |
| } | |
| } | |
| }`; | |
| var ADD = ` | |
| requires immediate_address_space; | |
| requires linear_indexing; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> a: array<f32>; | |
| @group(0) @binding(1) var<storage,read_write> y: array<f32>; | |
| var<immediate> n: u32; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(global_invocation_index) gid: u32, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let stride = nwg.x * WG; | |
| for (var i = gid; i < n; i = i + stride) { y[i] = y[i] + a[i]; } | |
| }`; | |
| var ADD_F16 = ` | |
| requires immediate_address_space; | |
| requires linear_indexing; | |
| enable f16; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> a: array<f32>; | |
| @group(0) @binding(1) var<storage,read_write> y: array<f32>; | |
| var<immediate> n: u32; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(global_invocation_index) gid: u32, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let stride = nwg.x * WG; | |
| for (var i = gid; i < n; i = i + stride) { y[i] = f32(f16(y[i]) + f16(a[i])); } | |
| }`; | |
| var SILUMUL_F16 = ` | |
| requires immediate_address_space; | |
| enable f16; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read_write> gate: array<f32>; | |
| @group(0) @binding(1) var<storage,read> up: array<f32>; | |
| var<immediate> n: u32; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(global_invocation_id) g: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let stride = nwg.x * WG; | |
| // Activation (silu) in f32 to avoid the f16 exp(-v) -> Inf intermediate for very | |
| // negative v; only the bandwidth-bound elementwise multiply carries f16 rounding. | |
| for (var i = g.x; i < n; i = i + stride) { let v = gate[i]; let sg = v / (1.0 + exp(-v)); gate[i] = f32( f16(sg) * f16(up[i]) ); } | |
| }`; | |
| var SILUMUL = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read_write> gate: array<f32>; | |
| @group(0) @binding(1) var<storage,read> up: array<f32>; | |
| var<immediate> n: u32; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(global_invocation_id) g: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let stride = nwg.x * WG; | |
| for (var i = g.x; i < n; i = i + stride) { let v = gate[i]; gate[i] = (v/(1.0+exp(-v)))*up[i]; } | |
| }`; | |
| var EMBED = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> w: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> out: array<f32>; | |
| var<immediate> m: vec2<u32>; // id, hidden | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) g: vec3<u32>) { | |
| let k = g.x; let id = m.x; let H = m.y; if (k >= H) { return; } | |
| let v = unpack4xI8(w[id*(H/4u) + (k>>2u)]); let lane = k & 3u; | |
| var b: i32; if (lane==0u){b=v.x;} else if (lane==1u){b=v.y;} else if (lane==2u){b=v.z;} else {b=v.w;} | |
| out[k] = f32(b) * scale[id]; | |
| }`; | |
| var EMBED_BUF = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> w: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> out: array<f32>; | |
| @group(0) @binding(3) var<storage,read> idbuf: array<u32>; // idbuf[0] = token id | |
| var<immediate> H: u32; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) g: vec3<u32>) { | |
| let k = g.x; let id = idbuf[0]; if (k >= H) { return; } | |
| let v = unpack4xI8(w[id*(H/4u) + (k>>2u)]); let lane = k & 3u; | |
| var b: i32; if (lane==0u){b=v.x;} else if (lane==1u){b=v.y;} else if (lane==2u){b=v.z;} else {b=v.w;} | |
| out[k] = f32(b) * scale[id]; | |
| }`; | |
| var RMSNORM_T = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> g: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> y: array<f32>; | |
| var<immediate> m: vec2<f32>; // K, eps | |
| var<workgroup> part: array<f32,256>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; let K = u32(m.x); let base = wid.x * K; | |
| var s = 0.0; for (var k = tid; k < K; k = k + WG) { let v = x[base+k]; s = s + v*v; } | |
| part[tid] = s; workgroupBarrier(); | |
| for (var t = WG / 2u; t > 0u; t = t/2u) { if (tid < t) { part[tid] = part[tid] + part[tid+t]; } workgroupBarrier(); } | |
| let inv = inverseSqrt(part[0]/m.x + m.y); | |
| for (var k = tid; k < K; k = k + WG) { y[base+k] = x[base+k]*inv*g[k]; } | |
| }`; | |
| var RMSNORM_T_F16 = ` | |
| requires immediate_address_space; | |
| enable f16; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> g: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> y: array<f32>; | |
| var<immediate> m: vec2<f32>; // K, eps | |
| // f32 reduction (see RMSNORM_F16): overflow-safe sum-of-squares, f16 normalize. | |
| var<workgroup> part: array<f32,256>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; let K = u32(m.x); let base = wid.x * K; | |
| var s = 0.0; | |
| for (var k = tid; k < K; k = k + WG) { let v = f32(x[base+k]); s = s + v*v; } | |
| part[tid] = s; workgroupBarrier(); | |
| for (var t = WG / 2u; t > 0u; t = t/2u) { if (tid < t) { part[tid] = part[tid] + part[tid+t]; } workgroupBarrier(); } | |
| let inv = f16(inverseSqrt(part[0]/m.x + m.y)); | |
| for (var k = tid; k < K; k = k + WG) { y[base+k] = f32( f16(x[base+k]) * inv * f16(g[k]) ); } | |
| }`; | |
| var ROPE_T = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read_write> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> cosT: array<f32>; | |
| @group(0) @binding(2) var<storage,read> sinT: array<f32>; | |
| var<immediate> m: vec4<u32>; // nHeads, headDim, T, pos0 | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let H = m.x; let D = m.y; let T = m.z; let pos0 = m.w; let half = D/2u; | |
| let perRow = H*half; if (g >= T*perRow) { return; } | |
| let row = g / perRow; let r = g % perRow; let h = r / half; let j = r % half; | |
| let rb = row*H*D; let lo = rb + h*D + j; let hi = lo + half; let off = (pos0+row)*D + j; | |
| let c = cosT[off]; let s = sinT[off]; let xl = x[lo]; let xh = x[hi]; | |
| x[lo] = fma(xl, c, 0.0) + fma(-xh, s, 0.0); x[hi] = fma(xh, c, 0.0) + fma(xl, s, 0.0); | |
| }`; | |
| var ROPE_T_F16 = ` | |
| requires immediate_address_space; | |
| enable f16; | |
| @group(0) @binding(0) var<storage,read_write> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> cosT: array<f32>; | |
| @group(0) @binding(2) var<storage,read> sinT: array<f32>; | |
| var<immediate> m: vec4<u32>; // nHeads, headDim, T, pos0 | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let H = m.x; let D = m.y; let T = m.z; let pos0 = m.w; let half = D/2u; | |
| let perRow = H*half; if (g >= T*perRow) { return; } | |
| let row = g / perRow; let r = g % perRow; let h = r / half; let j = r % half; | |
| let rb = row*H*D; let lo = rb + h*D + j; let hi = lo + half; let off = (pos0+row)*D + j; | |
| let c = f16(cosT[off]); let s = f16(sinT[off]); let xl = f16(x[lo]); let xh = f16(x[hi]); | |
| x[lo] = f32( fma(xl, c, 0.0h) + fma(-xh, s, 0.0h) ); x[hi] = f32( fma(xh, c, 0.0h) + fma(xl, s, 0.0h) ); | |
| }`; | |
| var EMBED_T = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> w: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> out: array<f32>; | |
| @group(0) @binding(3) var<storage,read> ids: array<u32>; | |
| var<immediate> m: vec4<u32>; // T, H, idOffset, _ | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let T = m.x; let H = m.y; let N = T*H; let stride = nwg.x * 256u; | |
| for (var i = gid.x; i < N; i = i + stride) { | |
| let t = i / H; let k = i % H; let id = ids[m.z + t]; | |
| let v = unpack4xI8(w[id*(H/4u) + (k>>2u)]); let lane = k & 3u; | |
| var b: i32; if (lane==0u){b=v.x;} else if (lane==1u){b=v.y;} else if (lane==2u){b=v.z;} else {b=v.w;} | |
| out[i] = f32(b) * scale[id]; | |
| } | |
| }`; | |
| var ATTN_PREFILL = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; // [T][nHeads*hd] | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; // [ctx][nKV*hd] | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> o: array<f32>; // [T][nHeads*hd] | |
| var<immediate> m: vec4<u32>; // nHeads, nKV, hd, T | |
| var<workgroup> ps: array<f32,256>; // exp-scores for the current key block | |
| var<workgroup> acc: array<f32,128>; // running weighted-V accumulator (hd<=128) | |
| var<workgroup> red: array<f32,64>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let h = wid.x; let t = wid.y; let tid = lid.x; let nHeads = m.x; let nKV = m.y; let hd = m.z; | |
| let ctx = t + 1u; let kvh = h / (nHeads / nKV); | |
| let qbase = t*nHeads*hd + h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scl = 1.0/sqrt(f32(hd)); | |
| let nsg = (256u + sgsz - 1u) / sgsz; | |
| for (var d = tid; d < hd; d = d + 256u) { acc[d] = 0.0; } | |
| var mrun = -1e30; var lrun = 0.0; | |
| let nblk = (ctx + 255u) / 256u; | |
| for (var blk = 0u; blk < nblk; blk = blk + 1u) { | |
| let kbase = blk*256u; let kk = kbase + tid; | |
| var s = -1e30; | |
| if (kk < ctx) { var dot = 0.0; let kb = kk*stride + hoff; for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qbase+d]*kc[kb+d]; } s = dot*scl; } | |
| let sgm = subgroupMax(s); if (sgid == 0u) { red[tid/sgsz] = sgm; } | |
| workgroupBarrier(); // A: block-max partials visible | |
| var bm = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { bm = max(bm, red[i]); } | |
| let mnew = max(mrun, bm); let corr = exp(mrun - mnew); | |
| var p = 0.0; if (kk < ctx) { p = exp(s - mnew); } | |
| ps[tid] = p; | |
| workgroupBarrier(); // B: bm reads done + ps visible | |
| let sgs = subgroupAdd(p); if (sgid == 0u) { red[tid/sgsz] = sgs; } | |
| workgroupBarrier(); // C: block-sum partials visible | |
| var bs = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { bs = bs + red[i]; } | |
| lrun = lrun*corr + bs; | |
| let bcount = min(256u, ctx - kbase); | |
| for (var d = tid; d < hd; d = d + 256u) { | |
| var aa = acc[d]*corr; | |
| for (var j = 0u; j < bcount; j = j + 1u) { aa = aa + ps[j]*vc[(kbase+j)*stride + hoff + d]; } | |
| acc[d] = aa; | |
| } | |
| mrun = mnew; | |
| workgroupBarrier(); // D: acc's ps reads done before next block | |
| } | |
| let invL = 1.0/lrun; | |
| for (var d = tid; d < hd; d = d + 256u) { o[qbase + d] = acc[d]*invL; } | |
| }`; | |
| var ATTN_PREFILL_BLOCK = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32, qStart:u32, ctx:u32, p0:u32, p1:u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> o: array<f32>; | |
| var<immediate> m: Meta; | |
| const BQ = 4u; const BK = 128u; | |
| var<workgroup> ps: array<f32, 512>; // BQ*BK | |
| var<workgroup> acc: array<f32, 512>; // BQ*hd (hd<=128) | |
| var<workgroup> red: array<f32, 128>; // BQ*subgroup-count | |
| @compute @workgroup_size(128) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let h = wid.x; let qBlock = wid.y; let tid = lid.x; let hd = m.hd; | |
| let kvh = h / (m.nHeads / m.nKV); let stride = m.nKV * hd; let hoff = kvh * hd; | |
| let nsg = (128u + sgsz - 1u) / sgsz; let scl = 1.0 / sqrt(f32(hd)); | |
| var mrun: array<f32, 4>; var lrun: array<f32, 4>; | |
| for (var r = 0u; r < BQ; r = r + 1u) { mrun[r] = -1e30; lrun[r] = 0.0; } | |
| for (var i = tid; i < BQ*hd; i = i + 128u) { acc[i] = 0.0; } | |
| workgroupBarrier(); | |
| let nblk = (m.ctx + BK - 1u) / BK; | |
| for (var blk = 0u; blk < nblk; blk = blk + 1u) { | |
| let kbase = blk * BK; let kk = kbase + tid; | |
| var score: array<f32, 4>; | |
| var validQ: array<bool, 4>; | |
| var dot: array<f32, 4>; | |
| var corrRun: array<f32, 4>; | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| let qt = qBlock * BQ + r; let absQ = m.qStart + qt; | |
| validQ[r] = qt < m.T && kk < m.ctx && kk <= absQ; | |
| dot[r] = 0.0; score[r] = -1e30; | |
| } | |
| if (kk < m.ctx) { | |
| let kb = kk*stride + hoff; | |
| for (var d = 0u; d < hd; d = d + 1u) { | |
| let kval = kc[kb+d]; | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| let qt = qBlock * BQ + r; | |
| if (validQ[r]) { dot[r] = dot[r] + q[qt*m.nHeads*hd + h*hd + d] * kval; } | |
| } | |
| } | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| if (validQ[r]) { score[r] = dot[r] * scl; } | |
| } | |
| } | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| let s = score[r]; | |
| let sgm = subgroupMax(s); | |
| if (sgid == 0u) { red[r*32u + tid/sgsz] = sgm; } | |
| workgroupBarrier(); | |
| var bm = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { bm = max(bm, red[r*32u+i]); } | |
| let mnew = max(mrun[r], bm); let corr = exp(mrun[r] - mnew); | |
| corrRun[r] = corr; | |
| var p = 0.0; if (validQ[r]) { p = exp(s - mnew); } | |
| ps[r*BK + tid] = p; | |
| workgroupBarrier(); | |
| let sgs = subgroupAdd(p); | |
| if (sgid == 0u) { red[r*32u + tid/sgsz] = sgs; } | |
| workgroupBarrier(); | |
| var bs = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { bs = bs + red[r*32u+i]; } | |
| lrun[r] = lrun[r] * corr + bs; | |
| mrun[r] = mnew; | |
| workgroupBarrier(); | |
| } | |
| let bcount = min(BK, m.ctx - kbase); | |
| for (var d = tid; d < hd; d = d + 128u) { | |
| var aa: array<f32, 4>; | |
| for (var r = 0u; r < BQ; r = r + 1u) { aa[r] = acc[r*hd+d] * corrRun[r]; } | |
| for (var j = 0u; j < bcount; j = j + 1u) { | |
| let vv = vc[(kbase+j)*stride + hoff + d]; | |
| for (var r = 0u; r < BQ; r = r + 1u) { aa[r] = aa[r] + ps[r*BK+j] * vv; } | |
| } | |
| for (var r = 0u; r < BQ; r = r + 1u) { acc[r*hd+d] = aa[r]; } | |
| } | |
| workgroupBarrier(); | |
| } | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| let qt = qBlock * BQ + r; | |
| if (qt < m.T) { | |
| let invL = 1.0 / lrun[r]; let ob = qt*m.nHeads*hd + h*hd; | |
| for (var d = tid; d < hd; d = d + 128u) { o[ob+d] = acc[r*hd+d] * invL; } | |
| } | |
| } | |
| }`; | |
| var ARGMAX = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> logits: array<f32>; | |
| @group(0) @binding(1) var<storage,read_write> out: array<u32>; | |
| var<immediate> n: u32; | |
| var<workgroup> bv: array<f32,256>; var<workgroup> bi: array<u32,256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; var v = -1e30; var idx = 0xffffffffu; | |
| for (var i = tid; i < n; i = i + 256u) { let x = logits[i]; if (x > v || (x == v && i < idx)) { v = x; idx = i; } } | |
| bv[tid] = v; bi[tid] = idx; workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s = s/2u) { if (tid < s) { let ov = bv[tid+s]; let oi = bi[tid+s]; if (ov > bv[tid] || (ov == bv[tid] && oi < bi[tid])) { bv[tid] = ov; bi[tid] = oi; } } workgroupBarrier(); } | |
| if (tid == 0u) { out[0] = bi[0]; } | |
| }`; | |
| var TOPK_SELECT = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> logits: array<f32>; | |
| @group(0) @binding(1) var<storage,read_write> ids: array<u32>; | |
| @group(0) @binding(2) var<storage,read_write> vals: array<f32>; | |
| var<immediate> m: vec2<u32>; // vocabSize, selectedCount | |
| var<workgroup> bv: array<f32,256>; var<workgroup> bi: array<u32,256>; | |
| fn alreadySelected(id: u32, n: u32) -> bool { | |
| for (var j = 0u; j < n; j = j + 1u) { if (ids[j] == id) { return true; } } | |
| return false; | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; let n = m.x; let selected = m.y; | |
| var v = -1e30; var idx = 0xffffffffu; | |
| for (var i = tid; i < n; i = i + 256u) { | |
| let x = logits[i]; | |
| if (!alreadySelected(i, selected) && (x > v || (x == v && i < idx))) { v = x; idx = i; } | |
| } | |
| bv[tid] = v; bi[tid] = idx; workgroupBarrier(); | |
| for (var s = 128u; s > 0u; s = s/2u) { | |
| if (tid < s) { | |
| let ov = bv[tid+s]; let oi = bi[tid+s]; | |
| if (ov > bv[tid] || (ov == bv[tid] && oi < bi[tid])) { bv[tid] = ov; bi[tid] = oi; } | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (tid == 0u) { ids[selected] = bi[0]; vals[selected] = bv[0]; } | |
| }`; | |
| var SAMPLE_TOPK = ` | |
| requires immediate_address_space; | |
| struct Meta { k:u32, pad:u32, temp:f32, r:f32 }; | |
| @group(0) @binding(0) var<storage,read> ids: array<u32>; | |
| @group(0) @binding(1) var<storage,read> vals: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> outId: array<u32>; // [1] the chosen token | |
| var<immediate> m: Meta; | |
| var<workgroup> s: array<f32, 64>; // working softmax probs / prefix sums (small k) | |
| var<workgroup> red: array<f32, 64>; // reduction scratch for the softmax denominator | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; | |
| let k = m.k; | |
| let temp = m.temp; | |
| let r = m.r; | |
| let t = select(temp, 1.0, temp <= 0.0); | |
| // Load + temperature scale into shared (one thread per slot) | |
| var v = -1e30; | |
| if (tid < k) { | |
| let lv = vals[tid]; | |
| v = lv; | |
| if (t != 1.0) { v = lv / t; } | |
| } | |
| let ev = select(0.0, exp(v), tid < k); | |
| s[tid] = ev; | |
| red[tid] = ev; | |
| workgroupBarrier(); | |
| // sum | |
| for (var stride = 32u; stride > 0u; stride = stride / 2u) { | |
| if (tid < stride && (tid + stride) < 64u) { red[tid] = red[tid] + red[tid + stride]; } | |
| workgroupBarrier(); | |
| } | |
| let sum = red[0]; | |
| let invSum = select(0.0, 1.0 / sum, sum > 0.0); | |
| // normalize + prefix sum for nucleus / categorical pick | |
| if (tid < k) { | |
| s[tid] = s[tid] * invSum; | |
| } else { | |
| s[tid] = 0.0; | |
| } | |
| workgroupBarrier(); | |
| // prefix sum (small k, simple scan) | |
| for (var stride = 1u; stride < 64u; stride = stride * 2u) { | |
| var add = 0.0; | |
| if (tid >= stride && tid < 64u) { | |
| add = s[tid - stride]; | |
| } | |
| workgroupBarrier(); | |
| if (tid >= stride && tid < 64u) { | |
| s[tid] = s[tid] + add; | |
| } | |
| workgroupBarrier(); | |
| } | |
| // find the smallest j such that prefix[j] >= r (or last if r>=1) | |
| if (tid == 0u) { | |
| var chosen = select(0u, k - 1u, k > 0u); | |
| if (sum > 0.0) { | |
| for (var j = 0u; j < k; j = j + 1u) { | |
| let pj = s[j]; | |
| if (r <= pj) { chosen = j; break; } | |
| } | |
| } | |
| outId[0] = select(0u, ids[chosen], k > 0u); | |
| } | |
| }`; | |
| var GEMV4 = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> w: array<u32>; | |
| @group(0) @binding(2) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(3) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(4) var<storage,read> loraD: array<f32>; | |
| @group(0) @binding(5) var<storage,read> loraB: array<f32>; | |
| @group(0) @binding(6) var<storage,read_write> y: array<f32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> part: array<f32,64>; // one slot per subgroup | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.N) { return; } // workgroup-uniform: whole group exits together | |
| let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; | |
| var acc = 0.0; | |
| for (var c = tid; c < K8; c = c + 64u) { | |
| let word = w[rb+c]; let bk = c*8u; let sc = scale[sbase + (bk >> 7u)]; | |
| var p = 0.0; | |
| p = p + x[bk] * f32(i32(word << 28u) >> 28u); | |
| p = p + x[bk+1u] * f32(i32(word << 24u) >> 28u); | |
| p = p + x[bk+2u] * f32(i32(word << 20u) >> 28u); | |
| p = p + x[bk+3u] * f32(i32(word << 16u) >> 28u); | |
| p = p + x[bk+4u] * f32(i32(word << 12u) >> 28u); | |
| p = p + x[bk+5u] * f32(i32(word << 8u) >> 28u); | |
| p = p + x[bk+6u] * f32(i32(word << 4u) >> 28u); | |
| p = p + x[bk+7u] * f32(i32(word) >> 28u); | |
| acc = acc + p * sc; | |
| } | |
| let ssum = subgroupAdd(acc); // reduce within subgroup (no barrier) | |
| if (sgid == 0u) { part[tid / sgsz] = ssum; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (64u + sgsz - 1u) / sgsz; var o = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } | |
| if (m.hasBias == 1u) { o = o + bias[n]; } | |
| if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } | |
| y[n] = o; | |
| } | |
| }`; | |
| var GEMV4_ADD = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> w: array<u32>; | |
| @group(0) @binding(2) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(3) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(4) var<storage,read> loraD: array<f32>; | |
| @group(0) @binding(5) var<storage,read> loraB: array<f32>; | |
| @group(0) @binding(6) var<storage,read_write> y: array<f32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> part: array<f32,64>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.N) { return; } | |
| let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; | |
| var acc = 0.0; | |
| for (var c = tid; c < K8; c = c + 64u) { | |
| let word = w[rb+c]; let bk = c*8u; let sc = scale[sbase + (bk >> 7u)]; | |
| var p = 0.0; | |
| p = p + x[bk] * f32(i32(word << 28u) >> 28u); | |
| p = p + x[bk+1u] * f32(i32(word << 24u) >> 28u); | |
| p = p + x[bk+2u] * f32(i32(word << 20u) >> 28u); | |
| p = p + x[bk+3u] * f32(i32(word << 16u) >> 28u); | |
| p = p + x[bk+4u] * f32(i32(word << 12u) >> 28u); | |
| p = p + x[bk+5u] * f32(i32(word << 8u) >> 28u); | |
| p = p + x[bk+6u] * f32(i32(word << 4u) >> 28u); | |
| p = p + x[bk+7u] * f32(i32(word) >> 28u); | |
| acc = acc + p * sc; | |
| } | |
| let ssum = subgroupAdd(acc); | |
| if (sgid == 0u) { part[tid / sgsz] = ssum; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (64u + sgsz - 1u) / sgsz; var o = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } | |
| if (m.hasBias == 1u) { o = o + bias[n]; } | |
| if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } | |
| y[n] = y[n] + o; | |
| } | |
| }`; | |
| var QKV_GEMV4 = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { K:u32, totalN:u32, qN:u32, kN:u32, vN:u32, gpr:u32, gridX:u32, p0:u32 }; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> w: array<u32>; | |
| @group(0) @binding(2) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(3) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(4) var<storage,read_write> qOut: array<f32>; | |
| @group(0) @binding(5) var<storage,read_write> kOut: array<f32>; | |
| @group(0) @binding(6) var<storage,read_write> vOut: array<f32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> part: array<f32,64>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.totalN) { return; } | |
| let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; | |
| var acc = 0.0; | |
| for (var c = tid; c < K8; c = c + 64u) { | |
| let word = w[rb+c]; let bk = c*8u; let sc = scale[sbase + (bk >> 7u)]; | |
| var p = 0.0; | |
| p = p + x[bk] * f32(i32(word << 28u) >> 28u); | |
| p = p + x[bk+1u] * f32(i32(word << 24u) >> 28u); | |
| p = p + x[bk+2u] * f32(i32(word << 20u) >> 28u); | |
| p = p + x[bk+3u] * f32(i32(word << 16u) >> 28u); | |
| p = p + x[bk+4u] * f32(i32(word << 12u) >> 28u); | |
| p = p + x[bk+5u] * f32(i32(word << 8u) >> 28u); | |
| p = p + x[bk+6u] * f32(i32(word << 4u) >> 28u); | |
| p = p + x[bk+7u] * f32(i32(word) >> 28u); | |
| acc = acc + p * sc; | |
| } | |
| let ssum = subgroupAdd(acc); | |
| if (sgid == 0u) { part[tid / sgsz] = ssum; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (64u + sgsz - 1u) / sgsz; var o = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } | |
| o = o + bias[n]; | |
| if (n < m.qN) { | |
| qOut[n] = o; | |
| } else if (n < m.qN + m.kN) { | |
| kOut[n - m.qN] = o; | |
| } else { | |
| vOut[n - m.qN - m.kN] = o; | |
| } | |
| } | |
| }`; | |
| var GATE_UP_SILU_GEMV4 = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, gpr:u32, gridX:u32, gateRank:u32, upRank:u32, hasGateLora:u32, hasUpLora:u32, gateScaleLo:f32, upScaleLo:f32, p0:f32, p1:f32 }; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read> w: array<u32>; | |
| @group(0) @binding(2) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> y: array<f32>; | |
| @group(0) @binding(4) var<storage,read> gateD: array<f32>; | |
| @group(0) @binding(5) var<storage,read> gateB: array<f32>; | |
| @group(0) @binding(6) var<storage,read> upD: array<f32>; | |
| @group(0) @binding(7) var<storage,read> upB: array<f32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> partG: array<f32,64>; | |
| var<workgroup> partU: array<f32,64>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.N) { return; } | |
| let K8 = m.K/8u; let rbG = n*K8; let rbU = (m.N + n)*K8; | |
| let sbG = n*m.gpr; let sbU = (m.N + n)*m.gpr; | |
| var accG = 0.0; var accU = 0.0; | |
| for (var c = tid; c < K8; c = c + 64u) { | |
| let bk = c*8u; let wg = w[rbG+c]; let wu = w[rbU+c]; | |
| let scG = scale[sbG + (bk >> 7u)]; let scU = scale[sbU + (bk >> 7u)]; | |
| let x0=x[bk]; let x1=x[bk+1u]; let x2=x[bk+2u]; let x3=x[bk+3u]; | |
| let x4=x[bk+4u]; let x5=x[bk+5u]; let x6=x[bk+6u]; let x7=x[bk+7u]; | |
| var pg = 0.0; var pu = 0.0; | |
| pg = pg + x0*f32(i32(wg<<28u)>>28u) + x1*f32(i32(wg<<24u)>>28u) + x2*f32(i32(wg<<20u)>>28u) + x3*f32(i32(wg<<16u)>>28u); | |
| pg = pg + x4*f32(i32(wg<<12u)>>28u) + x5*f32(i32(wg<<8u)>>28u) + x6*f32(i32(wg<<4u)>>28u) + x7*f32(i32(wg)>>28u); | |
| pu = pu + x0*f32(i32(wu<<28u)>>28u) + x1*f32(i32(wu<<24u)>>28u) + x2*f32(i32(wu<<20u)>>28u) + x3*f32(i32(wu<<16u)>>28u); | |
| pu = pu + x4*f32(i32(wu<<12u)>>28u) + x5*f32(i32(wu<<8u)>>28u) + x6*f32(i32(wu<<4u)>>28u) + x7*f32(i32(wu)>>28u); | |
| accG = accG + pg * scG; accU = accU + pu * scU; | |
| } | |
| let sg = subgroupAdd(accG); let su = subgroupAdd(accU); | |
| if (sgid == 0u) { partG[tid / sgsz] = sg; partU[tid / sgsz] = su; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (64u + sgsz - 1u) / sgsz; var gate = 0.0; var up = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { gate = gate + partG[i]; up = up + partU[i]; } | |
| if (m.hasGateLora == 1u) { | |
| var dl = 0.0; for (var r = 0u; r < m.gateRank; r = r + 1u) { dl = dl + gateD[r] * gateB[r*m.N + n]; } | |
| gate = gate + m.gateScaleLo * dl; | |
| } | |
| if (m.hasUpLora == 1u) { | |
| var dl = 0.0; for (var r = 0u; r < m.upRank; r = r + 1u) { dl = dl + upD[r] * upB[r*m.N + n]; } | |
| up = up + m.upScaleLo * dl; | |
| } | |
| y[n] = (gate / (1.0 + exp(-gate))) * up; | |
| } | |
| }`; | |
| var DYN_QUANT_X = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage, read> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> x_q: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> scale_x: array<f32>; | |
| var<immediate> K: u32; | |
| var<workgroup> sh_max: array<f32, 64>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let g = wid.x; let tid = lid.x; let base = g * 128u; | |
| var local_max = 0.0; | |
| let idx0 = base + tid; let idx1 = base + tid + 64u; | |
| if (idx0 < K) { local_max = max(local_max, abs(x[idx0])); } | |
| if (idx1 < K) { local_max = max(local_max, abs(x[idx1])); } | |
| sh_max[tid] = local_max; | |
| workgroupBarrier(); | |
| for (var s = 32u; s > 0u; s = s / 2u) { | |
| if (tid < s) { sh_max[tid] = max(sh_max[tid], sh_max[tid + s]); } | |
| workgroupBarrier(); | |
| } | |
| let gmax = sh_max[0]; let scale = select(gmax / 127.0, 1.0, gmax == 0.0); | |
| if (tid == 0u) { scale_x[g] = scale; } | |
| let pidx = base + tid * 4u; | |
| if (pidx < K) { | |
| let q0 = clamp(i32(round(x[pidx] / scale)), -128, 127) & 0xff; | |
| let q1 = clamp(i32(round(x[pidx + 1u] / scale)), -128, 127) & 0xff; | |
| let q2 = clamp(i32(round(x[pidx + 2u] / scale)), -128, 127) & 0xff; | |
| let q3 = clamp(i32(round(x[pidx + 3u] / scale)), -128, 127) & 0xff; | |
| x_q[g * 32u + tid] = u32(q0 | (q1 << 8u) | (q2 << 16u) | (q3 << 24u)); | |
| } | |
| }`; | |
| var DYN_QUANT_X_T = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage, read> x: array<f32>; | |
| @group(0) @binding(1) var<storage, read_write> x_q: array<u32>; | |
| @group(0) @binding(2) var<storage, read_write> scale_x: array<f32>; | |
| var<immediate> m: vec2<u32>; // K, T | |
| var<workgroup> sh_max: array<f32, 64>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let g = wid.x; let t = wid.y; let tid = lid.x; let K = m.x; let T = m.y; | |
| if (t >= T) { return; } | |
| let row_base = t * K; let base = row_base + g * 128u; | |
| var local_max = 0.0; | |
| let idx0 = base + tid; let idx1 = base + tid + 64u; | |
| if (g * 128u + tid < K) { local_max = max(local_max, abs(x[idx0])); } | |
| if (g * 128u + tid + 64u < K) { local_max = max(local_max, abs(x[idx1])); } | |
| sh_max[tid] = local_max; | |
| workgroupBarrier(); | |
| for (var s = 32u; s > 0u; s = s / 2u) { | |
| if (tid < s) { sh_max[tid] = max(sh_max[tid], sh_max[tid + s]); } | |
| workgroupBarrier(); | |
| } | |
| let gmax = sh_max[0]; let scale = select(gmax / 127.0, 1.0, gmax == 0.0); | |
| let groupsPerRow = K / 128u; | |
| if (tid == 0u) { scale_x[t * groupsPerRow + g] = scale; } | |
| let pidx = base + tid * 4u; | |
| if (g * 128u + tid * 4u < K) { | |
| let q0 = clamp(i32(round(x[pidx] / scale)), -128, 127) & 0xff; | |
| let q1 = clamp(i32(round(x[pidx + 1u] / scale)), -128, 127) & 0xff; | |
| let q2 = clamp(i32(round(x[pidx + 2u] / scale)), -128, 127) & 0xff; | |
| let q3 = clamp(i32(round(x[pidx + 3u] / scale)), -128, 127) & 0xff; | |
| x_q[t * (K / 4u) + g * 32u + tid] = u32(q0 | (q1 << 8u) | (q2 << 16u) | (q3 << 24u)); | |
| } | |
| }`; | |
| var GEMV4_W4A8 = /* @__PURE__ */ __name((hasDP4a, wgSize = 64) => ` | |
| enable subgroups; | |
| ${hasDP4a ? ` | |
| enable packed_4x8_integer_dot_product; | |
| ` : ""} | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; | |
| @group(0) @binding(0) var<storage,read> x_q: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale_x: array<f32>; | |
| @group(0) @binding(2) var<storage,read> w: array<u32>; | |
| @group(0) @binding(3) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(4) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(5) var<storage,read> loraD: array<f32>; | |
| @group(0) @binding(6) var<storage,read> loraB: array<f32>; | |
| @group(0) @binding(7) var<storage,read_write> y: array<f32>; | |
| var<immediate> m: Meta; | |
| ${hasDP4a ? "" : ` | |
| fn dot4I8Packed(a: u32, b: u32) -> i32 { | |
| let va = unpack4xI8(a); | |
| let vb = unpack4xI8(b); | |
| return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; | |
| } | |
| `} | |
| var<workgroup> part: array<f32, ${wgSize}>; | |
| @compute @workgroup_size(${wgSize}) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.N) { return; } | |
| let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; | |
| var acc = 0.0; | |
| for (var c = tid; c < K8; c = c + ${wgSize}u) { | |
| let word = w[rb+c]; let bk = c*8u; | |
| let sc_w = scale[sbase + (bk >> 7u)]; | |
| let sc_x = scale_x[bk >> 7u]; | |
| let w0 = (i32(word << 28u) >> 28u) & 0xff; | |
| let w1 = (i32(word << 24u) >> 28u) & 0xff; | |
| let w2 = (i32(word << 20u) >> 28u) & 0xff; | |
| let w3 = (i32(word << 16u) >> 28u) & 0xff; | |
| let w4 = (i32(word << 12u) >> 28u) & 0xff; | |
| let w5 = (i32(word << 8u) >> 28u) & 0xff; | |
| let w6 = (i32(word << 4u) >> 28u) & 0xff; | |
| let w7 = (i32(word) >> 28u) & 0xff; | |
| let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); | |
| let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); | |
| let px0 = x_q[c * 2u]; | |
| let px1 = x_q[c * 2u + 1u]; | |
| let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); | |
| acc = acc + f32(sum) * sc_w * sc_x; | |
| } | |
| let ssum = subgroupAdd(acc); | |
| if (sgid == 0u) { part[tid / sgsz] = ssum; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (${wgSize}u + sgsz - 1u) / sgsz; var o = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } | |
| if (m.hasBias == 1u) { o = o + bias[n]; } | |
| if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } | |
| y[n] = o; | |
| } | |
| } | |
| `, "GEMV4_W4A8"); | |
| var GEMV4_ADD_W4A8 = /* @__PURE__ */ __name((hasDP4a, wgSize = 64) => ` | |
| enable subgroups; | |
| ${hasDP4a ? ` | |
| enable packed_4x8_integer_dot_product; | |
| ` : ""} | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; | |
| @group(0) @binding(0) var<storage,read> x_q: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale_x: array<f32>; | |
| @group(0) @binding(2) var<storage,read> w: array<u32>; | |
| @group(0) @binding(3) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(4) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(5) var<storage,read> loraD: array<f32>; | |
| @group(0) @binding(6) var<storage,read> loraB: array<f32>; | |
| @group(0) @binding(7) var<storage,read_write> y: array<f32>; | |
| var<immediate> m: Meta; | |
| ${hasDP4a ? "" : ` | |
| fn dot4I8Packed(a: u32, b: u32) -> i32 { | |
| let va = unpack4xI8(a); | |
| let vb = unpack4xI8(b); | |
| return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; | |
| } | |
| `} | |
| var<workgroup> part: array<f32, ${wgSize}>; | |
| @compute @workgroup_size(${wgSize}) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.N) { return; } | |
| let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; | |
| var acc = 0.0; | |
| for (var c = tid; c < K8; c = c + ${wgSize}u) { | |
| let word = w[rb+c]; let bk = c*8u; | |
| let sc_w = scale[sbase + (bk >> 7u)]; | |
| let sc_x = scale_x[bk >> 7u]; | |
| let w0 = (i32(word << 28u) >> 28u) & 0xff; | |
| let w1 = (i32(word << 24u) >> 28u) & 0xff; | |
| let w2 = (i32(word << 20u) >> 28u) & 0xff; | |
| let w3 = (i32(word << 16u) >> 28u) & 0xff; | |
| let w4 = (i32(word << 12u) >> 28u) & 0xff; | |
| let w5 = (i32(word << 8u) >> 28u) & 0xff; | |
| let w6 = (i32(word << 4u) >> 28u) & 0xff; | |
| let w7 = (i32(word) >> 28u) & 0xff; | |
| let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); | |
| let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); | |
| let px0 = x_q[c * 2u]; | |
| let px1 = x_q[c * 2u + 1u]; | |
| let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); | |
| acc = acc + f32(sum) * sc_w * sc_x; | |
| } | |
| let ssum = subgroupAdd(acc); | |
| if (sgid == 0u) { part[tid / sgsz] = ssum; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (${wgSize}u + sgsz - 1u) / sgsz; var o = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } | |
| if (m.hasBias == 1u) { o = o + bias[n]; } | |
| if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } | |
| y[n] = y[n] + o; | |
| } | |
| } | |
| `, "GEMV4_ADD_W4A8"); | |
| var QKV_GEMV4_W4A8 = /* @__PURE__ */ __name((hasDP4a, wgSize = 64) => ` | |
| enable subgroups; | |
| ${hasDP4a ? ` | |
| enable packed_4x8_integer_dot_product; | |
| ` : ""} | |
| requires immediate_address_space; | |
| struct Meta { K:u32, totalN:u32, qN:u32, kN:u32, vN:u32, gpr:u32, gridX:u32, p0:u32 }; | |
| @group(0) @binding(0) var<storage,read> x_q: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale_x: array<f32>; | |
| @group(0) @binding(2) var<storage,read> w: array<u32>; | |
| @group(0) @binding(3) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(4) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(5) var<storage,read_write> qOut: array<f32>; | |
| @group(0) @binding(6) var<storage,read_write> kOut: array<f32>; | |
| @group(0) @binding(7) var<storage,read_write> vOut: array<f32>; | |
| var<immediate> m: Meta; | |
| ${hasDP4a ? "" : ` | |
| fn dot4I8Packed(a: u32, b: u32) -> i32 { | |
| let va = unpack4xI8(a); | |
| let vb = unpack4xI8(b); | |
| return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; | |
| } | |
| `} | |
| var<workgroup> part: array<f32, ${wgSize}>; | |
| @compute @workgroup_size(${wgSize}) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.totalN) { return; } | |
| let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; | |
| var acc = 0.0; | |
| for (var c = tid; c < K8; c = c + ${wgSize}u) { | |
| let word = w[rb+c]; let bk = c*8u; | |
| let sc_w = scale[sbase + (bk >> 7u)]; | |
| let sc_x = scale_x[bk >> 7u]; | |
| let w0 = (i32(word << 28u) >> 28u) & 0xff; | |
| let w1 = (i32(word << 24u) >> 28u) & 0xff; | |
| let w2 = (i32(word << 20u) >> 28u) & 0xff; | |
| let w3 = (i32(word << 16u) >> 28u) & 0xff; | |
| let w4 = (i32(word << 12u) >> 28u) & 0xff; | |
| let w5 = (i32(word << 8u) >> 28u) & 0xff; | |
| let w6 = (i32(word << 4u) >> 28u) & 0xff; | |
| let w7 = (i32(word) >> 28u) & 0xff; | |
| let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); | |
| let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); | |
| let px0 = x_q[c * 2u]; | |
| let px1 = x_q[c * 2u + 1u]; | |
| let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); | |
| acc = acc + f32(sum) * sc_w * sc_x; | |
| } | |
| let ssum = subgroupAdd(acc); | |
| if (sgid == 0u) { part[tid / sgsz] = ssum; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (${wgSize}u + sgsz - 1u) / sgsz; var o = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } | |
| o = o + bias[n]; | |
| if (n < m.qN) { | |
| qOut[n] = o; | |
| } else if (n < m.qN + m.kN) { | |
| kOut[n - m.qN] = o; | |
| } else { | |
| vOut[n - m.qN - m.kN] = o; | |
| } | |
| } | |
| } | |
| `, "QKV_GEMV4_W4A8"); | |
| var GATE_UP_SILU_GEMV4_W4A8 = /* @__PURE__ */ __name((hasDP4a, wgSize = 64) => ` | |
| enable subgroups; | |
| ${hasDP4a ? ` | |
| enable packed_4x8_integer_dot_product; | |
| ` : ""} | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, gpr:u32, gridX:u32, gateRank:u32, upRank:u32, hasGateLora:u32, hasUpLora:u32, gateScaleLo:f32, upScaleLo:f32, p0:f32, p1:f32 }; | |
| @group(0) @binding(0) var<storage,read> x_q: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale_x: array<f32>; | |
| @group(0) @binding(2) var<storage,read> w: array<u32>; | |
| @group(0) @binding(3) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(4) var<storage,read_write> y: array<f32>; | |
| @group(0) @binding(5) var<storage,read> gateD: array<f32>; | |
| @group(0) @binding(6) var<storage,read> gateB: array<f32>; | |
| @group(0) @binding(7) var<storage,read> upD: array<f32>; | |
| @group(0) @binding(8) var<storage,read> upB: array<f32>; | |
| var<immediate> m: Meta; | |
| ${hasDP4a ? "" : ` | |
| fn dot4I8Packed(a: u32, b: u32) -> i32 { | |
| let va = unpack4xI8(a); | |
| let vb = unpack4xI8(b); | |
| return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; | |
| } | |
| `} | |
| var<workgroup> partG: array<f32, ${wgSize}>; | |
| var<workgroup> partU: array<f32, ${wgSize}>; | |
| @compute @workgroup_size(${wgSize}) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let n = wid.x + wid.y * m.gridX; let tid = lid.x; | |
| if (n >= m.N) { return; } | |
| let K8 = m.K/8u; let rbG = n*K8; let rbU = (m.N + n)*K8; | |
| let sbG = n*m.gpr; let sbU = (m.N + n)*m.gpr; | |
| var accG = 0.0; var accU = 0.0; | |
| for (var c = tid; c < K8; c = c + ${wgSize}u) { | |
| let wg = w[rbG+c]; let wu = w[rbU+c]; | |
| let bk = c*8u; | |
| let scG = scale[sbG + (bk >> 7u)]; let scU = scale[sbU + (bk >> 7u)]; | |
| let sc_x = scale_x[bk >> 7u]; | |
| let wg0 = (i32(wg << 28u) >> 28u) & 0xff; | |
| let wg1 = (i32(wg << 24u) >> 28u) & 0xff; | |
| let wg2 = (i32(wg << 20u) >> 28u) & 0xff; | |
| let wg3 = (i32(wg << 16u) >> 28u) & 0xff; | |
| let wg4 = (i32(wg << 12u) >> 28u) & 0xff; | |
| let wg5 = (i32(wg << 8u) >> 28u) & 0xff; | |
| let wg6 = (i32(wg << 4u) >> 28u) & 0xff; | |
| let wg7 = (i32(wg) >> 28u) & 0xff; | |
| let pwg0 = u32(wg0 | (wg1 << 8u) | (wg2 << 16u) | (wg3 << 24u)); | |
| let pwg1 = u32(wg4 | (wg5 << 8u) | (wg6 << 16u) | (wg7 << 24u)); | |
| let wu0 = (i32(wu << 28u) >> 28u) & 0xff; | |
| let wu1 = (i32(wu << 24u) >> 28u) & 0xff; | |
| let wu2 = (i32(wu << 20u) >> 28u) & 0xff; | |
| let wu3 = (i32(wu << 16u) >> 28u) & 0xff; | |
| let wu4 = (i32(wu << 12u) >> 28u) & 0xff; | |
| let wu5 = (i32(wu << 8u) >> 28u) & 0xff; | |
| let wu6 = (i32(wu << 4u) >> 28u) & 0xff; | |
| let wu7 = (i32(wu) >> 28u) & 0xff; | |
| let pwu0 = u32(wu0 | (wu1 << 8u) | (wu2 << 16u) | (wu3 << 24u)); | |
| let pwu1 = u32(wu4 | (wu5 << 8u) | (wu6 << 16u) | (wu7 << 24u)); | |
| let px0 = x_q[c * 2u]; | |
| let px1 = x_q[c * 2u + 1u]; | |
| let sumG = dot4I8Packed(pwg0, px0) + dot4I8Packed(pwg1, px1); | |
| let sumU = dot4I8Packed(pwu0, px0) + dot4I8Packed(pwu1, px1); | |
| accG = accG + f32(sumG) * scG * sc_x; | |
| accU = accU + f32(sumU) * scU * sc_x; | |
| } | |
| let sg = subgroupAdd(accG); let su = subgroupAdd(accU); | |
| if (sgid == 0u) { partG[tid / sgsz] = sg; partU[tid / sgsz] = su; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (${wgSize}u + sgsz - 1u) / sgsz; var gate = 0.0; var up = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { gate = gate + partG[i]; up = up + partU[i]; } | |
| if (m.hasGateLora == 1u) { | |
| var dl = 0.0; for (var r = 0u; r < m.gateRank; r = r + 1u) { dl = dl + gateD[r] * gateB[r*m.N + n]; } | |
| gate = gate + m.gateScaleLo * dl; | |
| } | |
| if (m.hasUpLora == 1u) { | |
| var dl = 0.0; for (var r = 0u; r < m.upRank; r = r + 1u) { dl = dl + upD[r] * upB[r*m.N + n]; } | |
| up = up + m.upScaleLo * dl; | |
| } | |
| y[n] = (gate / (1.0 + exp(-gate))) * up; | |
| } | |
| } | |
| `, "GATE_UP_SILU_GEMV4_W4A8"); | |
| var GEMM4_W4A8 = /* @__PURE__ */ __name((hasDP4a) => ` | |
| enable subgroups; | |
| ${hasDP4a ? ` | |
| enable packed_4x8_integer_dot_product; | |
| ` : ""} | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, T:u32, gpr:u32, hasBias:u32, p0:u32, p1:u32, p2:u32 }; | |
| @group(0) @binding(0) var<storage,read> A_q: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale_x: array<f32>; | |
| @group(0) @binding(2) var<storage,read> W: array<u32>; | |
| @group(0) @binding(3) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(4) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(5) var<storage,read_write> Y: array<f32>; | |
| var<immediate> m: Meta; | |
| ${hasDP4a ? "" : ` | |
| fn dot4I8Packed(a: u32, b: u32) -> i32 { | |
| let va = unpack4xI8(a); | |
| let vb = unpack4xI8(b); | |
| return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; | |
| } | |
| `} | |
| const BM = 16u; const BN = 64u; | |
| var<workgroup> As_q: array<u32, 32>; | |
| var<workgroup> As_scale: array<f32, 16>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tTile = wid.y * BM; let col = wid.x * BN + lid.x; let valid = col < m.N; | |
| let K8 = m.K/8u; let rb = col*K8; | |
| var acc: array<f32, 16>; | |
| for (var i = 0u; i < BM; i = i + 1u) { acc[i] = 0.0; } | |
| let groupsPerRow = m.K / 128u; | |
| for (var c = 0u; c < K8; c = c + 1u) { | |
| if (lid.x < BM * 2u) { | |
| let tt = lid.x / 2u; let trow = tTile + tt; let wordIdx = lid.x % 2u; | |
| As_q[lid.x] = select(0u, A_q[trow * (m.K / 4u) + c * 2u + wordIdx], trow < m.T); | |
| } | |
| if (lid.x < BM) { | |
| let trow = tTile + lid.x; | |
| As_scale[lid.x] = select(0.0, scale_x[trow * groupsPerRow + ((c * 8u) >> 7u)], trow < m.T); | |
| } | |
| workgroupBarrier(); | |
| if (valid) { | |
| let word = W[rb + c]; let sc_w = scale[col*m.gpr + ((c*8u) >> 7u)]; | |
| let w0 = (i32(word << 28u) >> 28u) & 0xff; | |
| let w1 = (i32(word << 24u) >> 28u) & 0xff; | |
| let w2 = (i32(word << 20u) >> 28u) & 0xff; | |
| let w3 = (i32(word << 16u) >> 28u) & 0xff; | |
| let w4 = (i32(word << 12u) >> 28u) & 0xff; | |
| let w5 = (i32(word << 8u) >> 28u) & 0xff; | |
| let w6 = (i32(word << 4u) >> 28u) & 0xff; | |
| let w7 = (i32(word) >> 28u) & 0xff; | |
| let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); | |
| let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); | |
| for (var t = 0u; t < BM; t = t + 1u) { | |
| let px0 = As_q[t * 2u]; let px1 = As_q[t * 2u + 1u]; | |
| let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); | |
| acc[t] = acc[t] + f32(sum) * sc_w * As_scale[t]; | |
| } | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (valid) { | |
| let bv = select(0.0, bias[col], m.hasBias == 1u); | |
| for (var t = 0u; t < BM; t = t + 1u) { let trow = tTile + t; if (trow < m.T) { Y[trow*m.N + col] = acc[t] + bv; } } | |
| } | |
| } | |
| `, "GEMM4_W4A8"); | |
| var GEMM4_ADD_T_W4A8 = /* @__PURE__ */ __name((hasDP4a) => ` | |
| enable subgroups; | |
| ${hasDP4a ? ` | |
| enable packed_4x8_integer_dot_product; | |
| ` : ""} | |
| requires immediate_address_space; | |
| struct Meta { K:u32, N:u32, T:u32, gpr:u32, hasBias:u32, p0:u32, p1:u32, p2:u32 }; | |
| @group(0) @binding(0) var<storage,read> A_q: array<u32>; | |
| @group(0) @binding(1) var<storage,read> scale_x: array<f32>; | |
| @group(0) @binding(2) var<storage,read> W: array<u32>; | |
| @group(0) @binding(3) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(4) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(5) var<storage,read_write> Y: array<f32>; | |
| var<immediate> m: Meta; | |
| ${hasDP4a ? "" : ` | |
| fn dot4I8Packed(a: u32, b: u32) -> i32 { | |
| let va = unpack4xI8(a); | |
| let vb = unpack4xI8(b); | |
| return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; | |
| } | |
| `} | |
| const BM = 16u; const BN = 64u; | |
| var<workgroup> As_q: array<u32, 32>; | |
| var<workgroup> As_scale: array<f32, 16>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tTile = wid.y * BM; let col = wid.x * BN + lid.x; let valid = col < m.N; | |
| let K8 = m.K/8u; let rb = col*K8; | |
| var acc: array<f32, 16>; | |
| for (var i = 0u; i < BM; i = i + 1u) { acc[i] = 0.0; } | |
| let groupsPerRow = m.K / 128u; | |
| for (var c = 0u; c < K8; c = c + 1u) { | |
| if (lid.x < BM * 2u) { | |
| let tt = lid.x / 2u; let trow = tTile + tt; let wordIdx = lid.x % 2u; | |
| As_q[lid.x] = select(0u, A_q[trow * (m.K / 4u) + c * 2u + wordIdx], trow < m.T); | |
| } | |
| if (lid.x < BM) { | |
| let trow = tTile + lid.x; | |
| As_scale[lid.x] = select(0.0, scale_x[trow * groupsPerRow + ((c * 8u) >> 7u)], trow < m.T); | |
| } | |
| workgroupBarrier(); | |
| if (valid) { | |
| let word = W[rb + c]; let sc_w = scale[col*m.gpr + ((c*8u) >> 7u)]; | |
| let w0 = (i32(word << 28u) >> 28u) & 0xff; | |
| let w1 = (i32(word << 24u) >> 28u) & 0xff; | |
| let w2 = (i32(word << 20u) >> 28u) & 0xff; | |
| let w3 = (i32(word << 16u) >> 28u) & 0xff; | |
| let w4 = (i32(word << 12u) >> 28u) & 0xff; | |
| let w5 = (i32(word << 8u) >> 28u) & 0xff; | |
| let w6 = (i32(word << 4u) >> 28u) & 0xff; | |
| let w7 = (i32(word) >> 28u) & 0xff; | |
| let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); | |
| let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); | |
| for (var t = 0u; t < BM; t = t + 1u) { | |
| let px0 = As_q[t * 2u]; let px1 = As_q[t * 2u + 1u]; | |
| let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); | |
| acc[t] = acc[t] + f32(sum) * sc_w * As_scale[t]; | |
| } | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (valid) { | |
| let bv = select(0.0, bias[col], m.hasBias == 1u); | |
| for (var t = 0u; t < BM; t = t + 1u) { | |
| let trow = tTile + t; | |
| if (trow < m.T) { Y[trow*m.N + col] = Y[trow*m.N + col] + acc[t] + bv; } | |
| } | |
| } | |
| } | |
| `, "GEMM4_ADD_T_W4A8"); | |
| var WRITE_KV_PAGE = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read> k_src: array<f32>; | |
| @group(0) @binding(1) var<storage,read> v_src: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> kc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> vc: array<f32>; | |
| @group(0) @binding(4) var<storage,read> block_table: array<u32>; | |
| var<immediate> m: vec4<u32>; // pos, seq_id, max_blocks, kvd | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let idx = gid.x; let pos = m.x; let seq_id = m.y; let max_blocks = m.z; let kvd = m.w; | |
| if (idx >= kvd) { return; } | |
| let page_idx = block_table[seq_id * max_blocks + (pos / 16u)]; | |
| let page_offset = pos % 16u; | |
| let physical_pos = page_idx * 16u + page_offset; | |
| let dst_offset = physical_pos * kvd + idx; | |
| kc[dst_offset] = k_src[idx]; | |
| vc[dst_offset] = v_src[idx]; | |
| }`; | |
| var WRITE_KV_PAGE_BATCH = ` | |
| requires immediate_address_space; | |
| struct KVBatchMeta { T:u32, seq_id:u32, max_blocks:u32, kvd:u32, off:u32 }; | |
| @group(0) @binding(0) var<storage,read> k_src: array<f32>; | |
| @group(0) @binding(1) var<storage,read> v_src: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> kc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> vc: array<f32>; | |
| @group(0) @binding(4) var<storage,read> block_table: array<u32>; | |
| var<immediate> m: KVBatchMeta; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let idx = gid.x; let T = m.T; let seq_id = m.seq_id; let max_blocks = m.max_blocks; let kvd = m.kvd; let off = m.off; | |
| let total = T * kvd; if (idx >= total) { return; } | |
| let t = idx / kvd; let d = idx % kvd; | |
| let page_idx = block_table[seq_id * max_blocks + ((off + t) / 16u)]; | |
| let page_offset = (off + t) % 16u; | |
| let physical_pos = page_idx * 16u + page_offset; | |
| let dst_offset = physical_pos * kvd + d; | |
| kc[dst_offset] = k_src[idx]; | |
| vc[dst_offset] = v_src[idx]; | |
| }`; | |
| var ATTN_PARTIAL_PAGED = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { nHeads:u32, nKV:u32, ctx:u32, hd:u32, nsplit:u32, chunk:u32, seq_id:u32, max_blocks:u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> pm: array<f32>; | |
| @group(0) @binding(4) var<storage,read_write> pz: array<f32>; | |
| @group(0) @binding(5) var<storage,read_write> po: array<f32>; | |
| @group(0) @binding(6) var<storage,read> block_table: array<u32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> sc: array<f32,128>; | |
| var<workgroup> red: array<f32,32>; | |
| @compute @workgroup_size(128) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let h = wid.x; let s = wid.y; let tid = lid.x; | |
| let nHeads = m.nHeads; let nKV = m.nKV; let ctx = m.ctx; let hd = m.hd; | |
| let nsplit = m.nsplit; let chunk = m.chunk; let seq_id = m.seq_id; let max_blocks = m.max_blocks; | |
| let kvh = h / (nHeads / nKV); | |
| let qbase = h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scale = 1.0/sqrt(f32(hd)); | |
| let nsg = (128u + sgsz - 1u) / sgsz; | |
| let t0 = s*chunk; var t1 = t0 + chunk; if (t1 > ctx) { t1 = ctx; } | |
| let t = t0 + tid; var sv = -1e30; | |
| if (t < t1) { | |
| var dot = 0.0; | |
| let page_idx = block_table[seq_id * max_blocks + (t / 16u)]; | |
| let page_offset = t % 16u; | |
| let kb = (page_idx * 16u + page_offset) * stride + hoff; | |
| for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qbase+d]*kc[kb+d]; } | |
| sv = dot*scale; | |
| } | |
| let sgm = subgroupMax(sv); if (sgid == 0u) { red[tid/sgsz] = sgm; } | |
| workgroupBarrier(); | |
| var M = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { M = max(M, red[i]); } | |
| workgroupBarrier(); | |
| var ev = 0.0; if (t < t1) { ev = exp(sv - M); } sc[tid] = ev; | |
| let sgs = subgroupAdd(ev); if (sgid == 0u) { red[tid/sgsz] = sgs; } | |
| workgroupBarrier(); | |
| var Z = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { Z = Z + red[i]; } | |
| workgroupBarrier(); | |
| let len = t1 - t0; let pbase = (h*nsplit + s)*hd; | |
| for (var d = tid; d < hd; d = d + 128u) { | |
| var acc = 0.0; | |
| for (var tt = 0u; tt < len; tt = tt + 1u) { | |
| let t_curr = t0 + tt; | |
| let page_idx = block_table[seq_id * max_blocks + (t_curr / 16u)]; | |
| let page_offset = t_curr % 16u; | |
| let physical_t = page_idx * 16u + page_offset; | |
| acc = acc + sc[tt]*vc[physical_t*stride + hoff + d]; | |
| } | |
| po[pbase + d] = acc; | |
| } | |
| if (tid == 0u) { pm[h*nsplit + s] = M; pz[h*nsplit + s] = Z; } | |
| }`; | |
| var ATTN_PREFILL_PAGED = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32, seq_id:u32, max_blocks:u32, p0:u32, p1:u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> o: array<f32>; | |
| @group(0) @binding(4) var<storage,read> block_table: array<u32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> ps: array<f32,256>; | |
| var<workgroup> acc: array<f32,128>; | |
| var<workgroup> red: array<f32,64>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let h = wid.x; let t = wid.y; let tid = lid.x; let nHeads = m.nHeads; let nKV = m.nKV; let hd = m.hd; | |
| let ctx = t + 1u; let kvh = h / (nHeads / nKV); | |
| let qbase = t*nHeads*hd + h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scl = 1.0/sqrt(f32(hd)); | |
| let nsg = (256u + sgsz - 1u) / sgsz; | |
| let seq_id = m.seq_id; let max_blocks = m.max_blocks; | |
| for (var d = tid; d < hd; d = d + 256u) { acc[d] = 0.0; } | |
| var mrun = -1e30; var lrun = 0.0; | |
| let nblk = (ctx + 255u) / 256u; | |
| for (var blk = 0u; blk < nblk; blk = blk + 1u) { | |
| let kbase = blk*256u; let kk = kbase + tid; | |
| var s = -1e30; | |
| if (kk < ctx) { | |
| var dot = 0.0; | |
| let page_idx = block_table[seq_id * max_blocks + (kk / 16u)]; | |
| let page_offset = kk % 16u; | |
| let kb = (page_idx * 16u + page_offset)*stride + hoff; | |
| for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qbase+d]*kc[kb+d]; } | |
| s = dot*scl; | |
| } | |
| let sgm = subgroupMax(s); if (sgid == 0u) { red[tid/sgsz] = sgm; } | |
| workgroupBarrier(); | |
| var bm = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { bm = max(bm, red[i]); } | |
| let mnew = max(mrun, bm); let corr = exp(mrun - mnew); | |
| var p = 0.0; if (kk < ctx) { p = exp(s - mnew); } | |
| ps[tid] = p; | |
| workgroupBarrier(); | |
| let sgs = subgroupAdd(p); if (sgid == 0u) { red[tid/sgsz] = sgs; } | |
| workgroupBarrier(); | |
| var bs = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { bs = bs + red[i]; } | |
| lrun = lrun*corr + bs; | |
| let bcount = min(256u, ctx - kbase); | |
| for (var d = tid; d < hd; d = d + 256u) { | |
| var aa = acc[d]*corr; | |
| for (var j = 0u; j < bcount; j = j + 1u) { | |
| let t_curr = kbase + j; | |
| let page_idx = block_table[seq_id * max_blocks + (t_curr / 16u)]; | |
| let page_offset = t_curr % 16u; | |
| let physical_t = page_idx * 16u + page_offset; | |
| aa = aa + ps[j]*vc[physical_t*stride + hoff + d]; | |
| } | |
| acc[d] = aa; | |
| } | |
| mrun = mnew; | |
| workgroupBarrier(); | |
| } | |
| let invL = 1.0/lrun; | |
| for (var d = tid; d < hd; d = d + 256u) { o[qbase + d] = acc[d]*invL; } | |
| }`; | |
| var ATTN_PREFILL_BLOCK_PAGED = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32, qStart:u32, ctx:u32, seq_id:u32, max_blocks:u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> o: array<f32>; | |
| @group(0) @binding(4) var<storage,read> block_table: array<u32>; | |
| var<immediate> m: Meta; | |
| const BQ = 4u; const BK = 128u; | |
| var<workgroup> ps: array<f32, 512>; | |
| var<workgroup> acc: array<f32, 512>; | |
| var<workgroup> red: array<f32, 128>; | |
| @compute @workgroup_size(128) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let h = wid.x; let qBlock = wid.y; let tid = lid.x; let hd = m.hd; | |
| let kvh = h / (m.nHeads / m.nKV); let stride = m.nKV * hd; let hoff = kvh * hd; | |
| let nsg = (128u + sgsz - 1u) / sgsz; let scl = 1.0 / sqrt(f32(hd)); | |
| let seq_id = m.seq_id; let max_blocks = m.max_blocks; | |
| var mrun: array<f32, 4>; var lrun: array<f32, 4>; | |
| for (var r = 0u; r < BQ; r = r + 1u) { mrun[r] = -1e30; lrun[r] = 0.0; } | |
| for (var i = tid; i < BQ*hd; i = i + 128u) { acc[i] = 0.0; } | |
| workgroupBarrier(); | |
| let nblk = (m.ctx + BK - 1u) / BK; | |
| for (var blk = 0u; blk < nblk; blk = blk + 1u) { | |
| let kbase = blk * BK; let kk = kbase + tid; | |
| var score: array<f32, 4>; | |
| var validQ: array<bool, 4>; | |
| var dot: array<f32, 4>; | |
| var corrRun: array<f32, 4>; | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| let qt = qBlock * BQ + r; let absQ = m.qStart + qt; | |
| validQ[r] = qt < m.T && kk < m.ctx && kk <= absQ; | |
| dot[r] = 0.0; score[r] = -1e30; | |
| } | |
| if (kk < m.ctx) { | |
| let page_idx = block_table[seq_id * max_blocks + (kk / 16u)]; | |
| let page_offset = kk % 16u; | |
| let kb = (page_idx * 16u + page_offset)*stride + hoff; | |
| for (var d = 0u; d < hd; d = d + 1u) { | |
| let kval = kc[kb+d]; | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| let qt = qBlock * BQ + r; | |
| if (validQ[r]) { dot[r] = dot[r] + q[qt*m.nHeads*hd + h*hd + d] * kval; } | |
| } | |
| } | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| if (validQ[r]) { score[r] = dot[r] * scl; } | |
| } | |
| } | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| let s = score[r]; | |
| let sgm = subgroupMax(s); | |
| if (sgid == 0u) { red[r*32u + tid/sgsz] = sgm; } | |
| workgroupBarrier(); | |
| var bm = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { bm = max(bm, red[r*32u+i]); } | |
| let mnew = max(mrun[r], bm); let corr = exp(mrun[r] - mnew); | |
| corrRun[r] = corr; | |
| var p = 0.0; if (validQ[r]) { p = exp(s - mnew); } | |
| ps[r*BK + tid] = p; | |
| workgroupBarrier(); | |
| let sgs = subgroupAdd(p); | |
| if (sgid == 0u) { red[r*32u + tid/sgsz] = sgs; } | |
| workgroupBarrier(); | |
| var bs = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { bs = bs + red[r*32u+i]; } | |
| lrun[r] = lrun[r] * corr + bs; | |
| mrun[r] = mnew; | |
| workgroupBarrier(); | |
| } | |
| let bcount = min(BK, m.ctx - kbase); | |
| for (var d = tid; d < hd; d = d + 128u) { | |
| var aa: array<f32, 4>; | |
| for (var r = 0u; r < BQ; r = r + 1u) { aa[r] = acc[r*hd+d] * corrRun[r]; } | |
| for (var j = 0u; j < bcount; j = j + 1u) { | |
| let t_curr = kbase + j; | |
| let page_idx = block_table[seq_id * max_blocks + (t_curr / 16u)]; | |
| let page_offset = t_curr % 16u; | |
| let physical_t = page_idx * 16u + page_offset; | |
| let vv = vc[physical_t*stride + hoff + d]; | |
| for (var r = 0u; r < BQ; r = r + 1u) { aa[r] = aa[r] + ps[r*BK+j] * vv; } | |
| } | |
| for (var r = 0u; r < BQ; r = r + 1u) { acc[r*hd+d] = aa[r]; } | |
| } | |
| workgroupBarrier(); | |
| } | |
| for (var r = 0u; r < BQ; r = r + 1u) { | |
| let qt = qBlock * BQ + r; | |
| if (qt < m.T) { | |
| let invL = 1.0 / lrun[r]; let ob = qt*m.nHeads*hd + h*hd; | |
| for (var d = tid; d < hd; d = d + 128u) { o[ob+d] = acc[r*hd+d] * invL; } | |
| } | |
| } | |
| }`; | |
| var GEMV4_QKV_ROPE_RMS = ` | |
| enable subgroups; | |
| requires immediate_address_space; | |
| struct Meta { | |
| K: u32, totalPairs: u32, qPairs: u32, kPairs: u32, vPairs: u32, gpr: u32, gridX: u32, | |
| pos: u32, headDim: u32, eps: f32, | |
| qN: u32, kN: u32 | |
| }; | |
| @group(0) @binding(0) var<storage,read> hidden: array<f32>; | |
| @group(0) @binding(1) var<storage,read> rms_g: array<f32>; | |
| @group(0) @binding(2) var<storage,read> w: array<u32>; | |
| @group(0) @binding(3) var<storage,read> scale: array<f32>; | |
| @group(0) @binding(4) var<storage,read> bias: array<f32>; | |
| @group(0) @binding(5) var<storage,read> cosT: array<f32>; | |
| @group(0) @binding(6) var<storage,read> sinT: array<f32>; | |
| @group(0) @binding(7) var<storage,read_write> qOut: array<f32>; | |
| @group(0) @binding(8) var<storage,read_write> kOut: array<f32>; | |
| @group(0) @binding(9) var<storage,read_write> vOut: array<f32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> partSum: array<f32, 64>; | |
| @compute @workgroup_size(64) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, | |
| @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { | |
| let pair_idx = wid.x + wid.y * m.gridX; | |
| if (pair_idx >= m.totalPairs) { return; } | |
| let tid = lid.x; | |
| var s = 0.0; | |
| for (var k = tid; k < m.K; k = k + 64u) { let v = hidden[k]; s = s + v*v; } | |
| let ssum = subgroupAdd(s); | |
| if (sgid == 0u) { partSum[tid / sgsz] = ssum; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (64u + sgsz - 1u) / sgsz; var red = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { red = red + partSum[i]; } | |
| partSum[0] = inverseSqrt(red / f32(m.K) + m.eps); | |
| } | |
| workgroupBarrier(); | |
| let inv = partSum[0]; | |
| let half = m.headDim / 2u; | |
| var n0: u32; var n1: u32; | |
| var isQ = false; var isK = false; var isV = false; | |
| var out_idx0: u32; var out_idx1: u32; | |
| var rope_j: u32 = 0u; | |
| if (pair_idx < m.qPairs) { | |
| isQ = true; | |
| let h = pair_idx / half; let j = pair_idx % half; | |
| n0 = h * m.headDim + j; | |
| n1 = n0 + half; | |
| out_idx0 = n0; out_idx1 = n1; | |
| rope_j = j; | |
| } else if (pair_idx < m.qPairs + m.kPairs) { | |
| isK = true; | |
| let p = pair_idx - m.qPairs; | |
| let h = p / half; let j = p % half; | |
| n0 = m.qN + h * m.headDim + j; | |
| n1 = n0 + half; | |
| out_idx0 = h * m.headDim + j; out_idx1 = out_idx0 + half; | |
| rope_j = j; | |
| } else { | |
| isV = true; | |
| let p = pair_idx - m.qPairs - m.kPairs; | |
| n0 = m.qN + m.kN + p * 2u; | |
| n1 = n0 + 1u; | |
| out_idx0 = p * 2u; out_idx1 = out_idx0 + 1u; | |
| } | |
| let K8 = m.K / 8u; | |
| let rb0 = n0 * K8; let rb1 = n1 * K8; | |
| let sbase0 = n0 * m.gpr; let sbase1 = n1 * m.gpr; | |
| var acc0 = 0.0; var acc1 = 0.0; | |
| for (var c = tid; c < K8; c = c + 64u) { | |
| let w0 = w[rb0 + c]; let w1 = w[rb1 + c]; | |
| let bk = c * 8u; | |
| let sc0 = scale[sbase0 + (bk >> 7u)]; let sc1 = scale[sbase1 + (bk >> 7u)]; | |
| // We compute normalized X on the fly | |
| let x0 = hidden[bk] * inv * rms_g[bk]; | |
| let x1 = hidden[bk+1u] * inv * rms_g[bk+1u]; | |
| let x2 = hidden[bk+2u] * inv * rms_g[bk+2u]; | |
| let x3 = hidden[bk+3u] * inv * rms_g[bk+3u]; | |
| let x4 = hidden[bk+4u] * inv * rms_g[bk+4u]; | |
| let x5 = hidden[bk+5u] * inv * rms_g[bk+5u]; | |
| let x6 = hidden[bk+6u] * inv * rms_g[bk+6u]; | |
| let x7 = hidden[bk+7u] * inv * rms_g[bk+7u]; | |
| var p0 = 0.0; var p1 = 0.0; | |
| p0 = p0 + x0 * f32(i32(w0 << 28u) >> 28u); p1 = p1 + x0 * f32(i32(w1 << 28u) >> 28u); | |
| p0 = p0 + x1 * f32(i32(w0 << 24u) >> 28u); p1 = p1 + x1 * f32(i32(w1 << 24u) >> 28u); | |
| p0 = p0 + x2 * f32(i32(w0 << 20u) >> 28u); p1 = p1 + x2 * f32(i32(w1 << 20u) >> 28u); | |
| p0 = p0 + x3 * f32(i32(w0 << 16u) >> 28u); p1 = p1 + x3 * f32(i32(w1 << 16u) >> 28u); | |
| p0 = p0 + x4 * f32(i32(w0 << 12u) >> 28u); p1 = p1 + x4 * f32(i32(w1 << 12u) >> 28u); | |
| p0 = p0 + x5 * f32(i32(w0 << 8u) >> 28u); p1 = p1 + x5 * f32(i32(w1 << 8u) >> 28u); | |
| p0 = p0 + x6 * f32(i32(w0 << 4u) >> 28u); p1 = p1 + x6 * f32(i32(w1 << 4u) >> 28u); | |
| p0 = p0 + x7 * f32(i32(w0) >> 28u); p1 = p1 + x7 * f32(i32(w1) >> 28u); | |
| acc0 = acc0 + p0 * sc0; | |
| acc1 = acc1 + p1 * sc1; | |
| } | |
| let ssum0 = subgroupAdd(acc0); let ssum1 = subgroupAdd(acc1); | |
| if (sgid == 0u) { partSum[tid / sgsz] = ssum0; partSum[32u + tid / sgsz] = ssum1; } | |
| workgroupBarrier(); | |
| if (tid == 0u) { | |
| let nsg = (64u + sgsz - 1u) / sgsz; | |
| var o0 = 0.0; var o1 = 0.0; | |
| for (var i = 0u; i < nsg; i = i + 1u) { o0 = o0 + partSum[i]; o1 = o1 + partSum[32u + i]; } | |
| o0 = o0 + bias[n0]; | |
| o1 = o1 + bias[n1]; | |
| if (isQ || isK) { | |
| let off = m.pos * m.headDim + rope_j; | |
| let c = cosT[off]; let s = sinT[off]; | |
| let rl = fma(o0, c, 0.0) + fma(-o1, s, 0.0); | |
| let rh = fma(o1, c, 0.0) + fma(o0, s, 0.0); | |
| o0 = rl; o1 = rh; | |
| } | |
| if (isQ) { qOut[out_idx0] = o0; qOut[out_idx1] = o1; } | |
| else if (isK) { kOut[out_idx0] = o0; kOut[out_idx1] = o1; } | |
| else { vOut[out_idx0] = o0; vOut[out_idx1] = o1; } | |
| } | |
| }`; | |
| // src/qwgpu/model_schema.js | |
| var arrEq = /* @__PURE__ */ __name((a, b) => a.length === b.length && a.every((v, i) => v === b[i]), "arrEq"); | |
| function projDesc(layer, subpath, outDim, inDim, { bias = false } = {}) { | |
| const name = `model.layers.${layer}.${subpath}.weight`; | |
| const m = subpath.match(/^(self_attn|mlp)\.(.+)$/); | |
| const loraKey = `layers.${layer}.${m[1]}.${m[2]}`; | |
| return { | |
| name, | |
| role: "projection", | |
| quant: "int4", | |
| shape: [outDim, inDim], | |
| loraKey, | |
| biasName: bias ? name.replace(/\.weight$/, ".bias") : null | |
| }; | |
| } | |
| __name(projDesc, "projDesc"); | |
| function f32Desc(name, shape, role = "f32") { | |
| return { name, role, quant: "f32", shape }; | |
| } | |
| __name(f32Desc, "f32Desc"); | |
| function createQwenSchema(cfg) { | |
| if (!cfg.tieWordEmbeddings && cfg.tieWordEmbeddings !== void 0) { | |
| throw new Error("QwenWGPU currently requires tied input/output embeddings"); | |
| } | |
| const H = cfg.hiddenSize; | |
| const QD = cfg.numHeads * cfg.headDim; | |
| const KVD = cfg.numKVHeads * cfg.headDim; | |
| const I = cfg.intermediateSize; | |
| const tensors = []; | |
| const layers = []; | |
| const add = /* @__PURE__ */ __name((d) => { | |
| tensors.push(d); | |
| return d; | |
| }, "add"); | |
| const embed = add({ name: "model.embed_tokens.weight", role: "embedding", quant: "int8", shape: [cfg.vocabSize, H] }); | |
| const finalNorm = add(f32Desc("model.norm.weight", [H], "final_norm")); | |
| for (let i = 0; i < cfg.numLayers; i++) { | |
| const p = `model.layers.${i}`; | |
| const layer = { | |
| index: i, | |
| inputNorm: add(f32Desc(`${p}.input_layernorm.weight`, [H], "input_norm")), | |
| postAttentionNorm: add(f32Desc(`${p}.post_attention_layernorm.weight`, [H], "post_attention_norm")), | |
| projections: {}, | |
| biases: {} | |
| }; | |
| layer.projections.q = add(projDesc(i, "self_attn.q_proj", QD, H, { bias: !!cfg.attentionBias })); | |
| layer.projections.k = add(projDesc(i, "self_attn.k_proj", KVD, H, { bias: !!cfg.attentionBias })); | |
| layer.projections.v = add(projDesc(i, "self_attn.v_proj", KVD, H, { bias: !!cfg.attentionBias })); | |
| layer.projections.o = add(projDesc(i, "self_attn.o_proj", H, QD)); | |
| layer.projections.gate = add(projDesc(i, "mlp.gate_proj", I, H)); | |
| layer.projections.up = add(projDesc(i, "mlp.up_proj", I, H)); | |
| layer.projections.down = add(projDesc(i, "mlp.down_proj", H, I)); | |
| for (const key of ["q", "k", "v"]) { | |
| const proj = layer.projections[key]; | |
| if (proj.biasName) { | |
| const bias = add(f32Desc(proj.biasName, [proj.shape[0]], `${key}_bias`)); | |
| layer.biases[key] = bias; | |
| } | |
| } | |
| layers.push(layer); | |
| } | |
| const byName = new Map(tensors.map((t) => [t.name, t])); | |
| const expectedNames = new Set(byName.keys()); | |
| return { | |
| cfg, | |
| tensors, | |
| byName, | |
| expectedNames, | |
| layers, | |
| embed, | |
| finalNorm, | |
| projectionDescs: tensors.filter((t) => t.role === "projection"), | |
| validateTensor(name, shape) { | |
| const desc = byName.get(name); | |
| if (!desc) return null; | |
| if (!arrEq(shape, desc.shape)) { | |
| throw new Error(`shape mismatch for ${name}: got [${shape.join(",")}], expected [${desc.shape.join(",")}]`); | |
| } | |
| return desc; | |
| }, | |
| assertComplete(seen) { | |
| const missing = []; | |
| for (const name of expectedNames) if (!seen.has(name)) missing.push(name); | |
| if (missing.length) { | |
| const sample = missing.slice(0, 12).join(", "); | |
| throw new Error(`missing ${missing.length} required tensor(s): ${sample}${missing.length > 12 ? ", \u2026" : ""}`); | |
| } | |
| } | |
| }; | |
| } | |
| __name(createQwenSchema, "createQwenSchema"); | |
| function moduleKeyFromTensorName(name) { | |
| const m = name.match(/layers\.(\d+)\.(self_attn|mlp)\.([a-z_]+?)(_proj)?\.(lora_[ABab])/i); | |
| if (!m) return null; | |
| return `layers.${m[1]}.${m[2]}.${m[3].replace(/_proj$/, "")}_proj`; | |
| } | |
| __name(moduleKeyFromTensorName, "moduleKeyFromTensorName"); | |
| // src/qwgpu/dispatch_plan.js | |
| function createDispatchPlan(schema) { | |
| return { | |
| embed: schema.embed, | |
| finalNorm: schema.finalNorm, | |
| layers: schema.layers.map((layer) => ({ | |
| index: layer.index, | |
| inputNorm: layer.inputNorm.name, | |
| postAttentionNorm: layer.postAttentionNorm.name, | |
| q: { | |
| weight: layer.projections.q.name, | |
| bias: layer.biases.q?.name || null, | |
| loraKey: layer.projections.q.loraKey | |
| }, | |
| k: { | |
| weight: layer.projections.k.name, | |
| bias: layer.biases.k?.name || null, | |
| loraKey: layer.projections.k.loraKey | |
| }, | |
| v: { | |
| weight: layer.projections.v.name, | |
| bias: layer.biases.v?.name || null, | |
| loraKey: layer.projections.v.loraKey | |
| }, | |
| o: { | |
| weight: layer.projections.o.name, | |
| bias: null, | |
| loraKey: layer.projections.o.loraKey | |
| }, | |
| gate: { | |
| weight: layer.projections.gate.name, | |
| bias: null, | |
| loraKey: layer.projections.gate.loraKey | |
| }, | |
| up: { | |
| weight: layer.projections.up.name, | |
| bias: null, | |
| loraKey: layer.projections.up.loraKey | |
| }, | |
| down: { | |
| weight: layer.projections.down.name, | |
| bias: null, | |
| loraKey: layer.projections.down.loraKey | |
| } | |
| })) | |
| }; | |
| } | |
| __name(createDispatchPlan, "createDispatchPlan"); | |
| // src/qwgpu/safetensors_loader.js | |
| function decodeBf16ToF32(u8, numel) { | |
| const u16 = new Uint16Array(u8.buffer, u8.byteOffset, numel); | |
| const out = new Float32Array(numel); | |
| const o32 = new Uint32Array(out.buffer); | |
| for (let i = 0; i < numel; i++) o32[i] = u16[i] << 16; | |
| return out; | |
| } | |
| __name(decodeBf16ToF32, "decodeBf16ToF32"); | |
| function decodeF16ToF32(u8, numel) { | |
| const u16 = new Uint16Array(u8.buffer, u8.byteOffset, numel); | |
| const out = new Float32Array(numel); | |
| for (let i = 0; i < numel; i++) { | |
| const h = u16[i], s = (h & 32768) >> 15, e = (h & 31744) >> 10, f = h & 1023; | |
| if (e === 0) out[i] = (s ? -1 : 1) * Math.pow(2, -14) * (f / 1024); | |
| else if (e === 31) out[i] = f ? NaN : s ? -Infinity : Infinity; | |
| else out[i] = (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024); | |
| } | |
| return out; | |
| } | |
| __name(decodeF16ToF32, "decodeF16ToF32"); | |
| function decodeF32(u8, numel) { | |
| return new Float32Array(u8.buffer.slice(u8.byteOffset, u8.byteOffset + numel * 4)); | |
| } | |
| __name(decodeF32, "decodeF32"); | |
| var DECODERS = { | |
| BF16: decodeBf16ToF32, | |
| F16: decodeF16ToF32, | |
| FP16: decodeF16ToF32, | |
| F32: decodeF32, | |
| FP32: decodeF32 | |
| }; | |
| async function loadIndex(reader) { | |
| try { | |
| const idx = JSON.parse(await reader.text("model.safetensors.index.json")); | |
| return { weightMap: idx.weight_map || {}, shards: [...new Set(Object.values(idx.weight_map || {}))] }; | |
| } catch { | |
| return { weightMap: null, shards: ["model.safetensors"] }; | |
| } | |
| } | |
| __name(loadIndex, "loadIndex"); | |
| function shardPlan(shards, weightMap, names) { | |
| if (!weightMap || !names) return new Map(shards.map((shard) => [shard, null])); | |
| const plan = /* @__PURE__ */ new Map(); | |
| for (const name of names) { | |
| const shard = weightMap[name]; | |
| if (!shard) continue; | |
| if (!plan.has(shard)) plan.set(shard, /* @__PURE__ */ new Set()); | |
| plan.get(shard).add(name); | |
| } | |
| return plan; | |
| } | |
| __name(shardPlan, "shardPlan"); | |
| async function streamSafetensors(source, { names = null, onTensor, onProgress = /* @__PURE__ */ __name(() => { | |
| }, "onProgress") } = {}) { | |
| if (!onTensor) throw new Error("streamSafetensors requires onTensor"); | |
| const reader = typeof source === "string" ? urlReader(source) : source; | |
| const { weightMap, shards } = await loadIndex(reader); | |
| const plan = shardPlan(shards, weightMap, names); | |
| let visited = 0; | |
| const total = names?.size || 0; | |
| for (const [shard, wantedInShard] of plan) { | |
| const lenBuf = await reader.range(shard, 0, 8); | |
| const headerLen = Number(new DataView(lenBuf).getBigUint64(0, true)); | |
| const hdrBuf = await reader.range(shard, 8, 8 + headerLen); | |
| const header = JSON.parse(new TextDecoder().decode(new Uint8Array(hdrBuf))); | |
| const dataStart = 8 + headerLen; | |
| const allNames = Object.keys(header).filter((k) => k !== "__metadata__"); | |
| const tensorNames = wantedInShard ? allNames.filter((n) => wantedInShard.has(n)) : names ? allNames.filter((n) => names.has(n)) : allNames; | |
| for (const name of tensorNames) { | |
| const t = header[name]; | |
| if (!t) continue; | |
| const dtype = String(t.dtype || "").toUpperCase(); | |
| const dec = DECODERS[dtype]; | |
| if (!dec) throw new Error(`unsupported dtype ${dtype} for ${name}`); | |
| const numel = t.shape.reduce((a, b) => a * b, 1); | |
| const [s, e] = t.data_offsets; | |
| const buf = await reader.range(shard, dataStart + s, dataStart + e); | |
| const data = dec(new Uint8Array(buf), numel); | |
| await onTensor({ name, shape: t.shape, dtype, data, shard }); | |
| visited++; | |
| onProgress(name, total ? Math.min(0.95, visited / total) : 0.3); | |
| } | |
| } | |
| } | |
| __name(streamSafetensors, "streamSafetensors"); | |
| // src/qwgpu/quantize.js | |
| function quantizeInt8RowMajor(f322, outDim, inDim) { | |
| const scale = new Float32Array(outDim); | |
| const q = new Int8Array(outDim * inDim); | |
| for (let o = 0; o < outDim; o++) { | |
| const base = o * inDim; | |
| let amax = 0; | |
| for (let i = 0; i < inDim; i++) { | |
| const a = Math.abs(f322[base + i]); | |
| if (a > amax) amax = a; | |
| } | |
| const s = amax > 0 ? amax / 127 : 1; | |
| scale[o] = s; | |
| const inv = 1 / s; | |
| for (let i = 0; i < inDim; i++) { | |
| let v = Math.round(f322[base + i] * inv); | |
| if (v > 127) v = 127; | |
| else if (v < -128) v = -128; | |
| q[base + i] = v; | |
| } | |
| } | |
| const packed = new Uint32Array(outDim * inDim / 4); | |
| const u8 = new Uint8Array(q.buffer); | |
| for (let w = 0; w < packed.length; w++) { | |
| packed[w] = u8[w * 4] | u8[w * 4 + 1] << 8 | u8[w * 4 + 2] << 16 | u8[w * 4 + 3] << 24; | |
| } | |
| return { packed, scale, outDim, inDim }; | |
| } | |
| __name(quantizeInt8RowMajor, "quantizeInt8RowMajor"); | |
| function quantizeInt4Group(f322, outDim, inDim, group = 128) { | |
| const groupsPerRow = inDim / group; | |
| const scale = new Float32Array(outDim * groupsPerRow); | |
| const q = new Int8Array(outDim * inDim); | |
| for (let o = 0; o < outDim; o++) { | |
| for (let g = 0; g < groupsPerRow; g++) { | |
| const base = o * inDim + g * group; | |
| let amax = 0; | |
| for (let i = 0; i < group; i++) { | |
| const a = Math.abs(f322[base + i]); | |
| if (a > amax) amax = a; | |
| } | |
| const s = amax > 0 ? amax / 7 : 1; | |
| scale[o * groupsPerRow + g] = s; | |
| const inv = 1 / s; | |
| for (let i = 0; i < group; i++) { | |
| let v = Math.round(f322[base + i] * inv); | |
| if (v > 7) v = 7; | |
| else if (v < -8) v = -8; | |
| q[base + i] = v; | |
| } | |
| } | |
| } | |
| const packed = new Uint32Array(outDim * inDim / 8); | |
| for (let w = 0; w < packed.length; w++) { | |
| let acc = 0; | |
| for (let j = 0; j < 8; j++) acc |= (q[w * 8 + j] & 15) << j * 4; | |
| packed[w] = acc >>> 0; | |
| } | |
| return { packed, scale, groupsPerRow }; | |
| } | |
| __name(quantizeInt4Group, "quantizeInt4Group"); | |
| // src/qwgpu/model_uploader.js | |
| var ModelUploader = class { | |
| static { | |
| __name(this, "ModelUploader"); | |
| } | |
| constructor({ schema, q, q4, bufs, uploadF32, uploadU32, groupSize = 128 }) { | |
| this.schema = schema; | |
| this.q = q; | |
| this.q4 = q4; | |
| this.bufs = bufs; | |
| this.uploadF32 = uploadF32; | |
| this.uploadU32 = uploadU32; | |
| this.groupSize = groupSize; | |
| this.seen = /* @__PURE__ */ new Set(); | |
| } | |
| visit({ name, shape, data }) { | |
| const desc = this.schema.validateTensor(name, shape); | |
| if (!desc) return; | |
| if (this.seen.has(name)) throw new Error(`duplicate tensor ${name}`); | |
| if (desc.quant === "int8") { | |
| const { packed, scale } = quantizeInt8RowMajor(data, shape[0], shape[1]); | |
| this.q[name] = { w: this.uploadU32(packed), scale: this.uploadF32(scale), N: shape[0], K: shape[1] }; | |
| } else if (desc.quant === "int4") { | |
| const { packed, scale, groupsPerRow } = quantizeInt4Group(data, shape[0], shape[1], this.groupSize); | |
| this.q4[name] = { | |
| w: this.uploadU32(packed), | |
| scale: this.uploadF32(scale), | |
| N: shape[0], | |
| K: shape[1], | |
| gpr: groupsPerRow, | |
| desc | |
| }; | |
| } else if (desc.quant === "f32") { | |
| this.bufs[name] = this.uploadF32(data); | |
| } else { | |
| throw new Error(`unsupported quant mode ${desc.quant} for ${name}`); | |
| } | |
| this.seen.add(name); | |
| } | |
| finalize() { | |
| this.schema.assertComplete(this.seen); | |
| } | |
| }; | |
| // src/qwgpu/buffer_pool.js | |
| var GPUBufferPool = class { | |
| static { | |
| __name(this, "GPUBufferPool"); | |
| } | |
| constructor(device, { cacheBindGroups = true } = {}) { | |
| this.dev = device; | |
| this.cacheBindGroups = cacheBindGroups; | |
| this.uniformPool = []; | |
| this.uniformIdx = 0; | |
| this.staticUniforms = /* @__PURE__ */ new Map(); | |
| this.bindGroups = /* @__PURE__ */ new Map(); | |
| this.sensitiveBindGroups = /* @__PURE__ */ new Set(); | |
| this.bufferIds = /* @__PURE__ */ new WeakMap(); | |
| this.pipelineIds = /* @__PURE__ */ new WeakMap(); | |
| this.nextBufferId = 1; | |
| this.nextPipelineId = 1; | |
| this._stats = this._emptyStats(); | |
| } | |
| /* | |
| * TECHNIQUE: Bind group caching (opt-in per call site) | |
| * Frequently reused (pipeline + buffer set) combinations are stored in a Map. | |
| * Avoids repeated GPU bind group creation on the hot GEMV / attention paths. | |
| * Sensitive / one-shot groups are deliberately not cached. | |
| */ | |
| _emptyStats() { | |
| return { | |
| buffersCreated: 0, | |
| dynamicUniformWrites: 0, | |
| staticUniformHits: 0, | |
| staticUniformMisses: 0, | |
| bindGroupHits: 0, | |
| bindGroupMisses: 0, | |
| uncachedBindGroups: 0 | |
| }; | |
| } | |
| resetStats() { | |
| this._stats = this._emptyStats(); | |
| } | |
| stats() { | |
| return { | |
| ...this._stats, | |
| uniformPoolSize: this.uniformPool.length, | |
| staticUniforms: this.staticUniforms.size, | |
| bindGroups: this.bindGroups.size | |
| }; | |
| } | |
| buffer(size, usage) { | |
| this._stats.buffersCreated++; | |
| return this.dev.createBuffer({ size, usage }); | |
| } | |
| uploadF32(arr, usage) { | |
| const b = this.buffer(arr.byteLength, usage); | |
| this.dev.queue.writeBuffer(b, 0, arr); | |
| return b; | |
| } | |
| uploadU32(arr, usage) { | |
| const b = this.buffer(arr.byteLength, usage); | |
| this.dev.queue.writeBuffer(b, 0, arr); | |
| return b; | |
| } | |
| dynamicUniform(arr, usage) { | |
| let b = this.uniformPool[this.uniformIdx]; | |
| if (!b) { | |
| b = this.buffer(32, usage); | |
| this.uniformPool[this.uniformIdx] = b; | |
| } | |
| this.uniformIdx++; | |
| this._stats.dynamicUniformWrites++; | |
| this.dev.queue.writeBuffer(b, 0, arr.buffer, arr.byteOffset, arr.byteLength); | |
| return b; | |
| } | |
| resetUniforms() { | |
| this.uniformIdx = 0; | |
| } | |
| staticUniform(key, arr, usage) { | |
| let b = this.staticUniforms.get(key); | |
| if (!b) { | |
| this._stats.staticUniformMisses++; | |
| b = this.buffer(32, usage); | |
| this.dev.queue.writeBuffer(b, 0, arr.buffer, arr.byteOffset, arr.byteLength); | |
| this.staticUniforms.set(key, b); | |
| } else this._stats.staticUniformHits++; | |
| return b; | |
| } | |
| idForBuffer(buffer) { | |
| let id = this.bufferIds.get(buffer); | |
| if (!id) { | |
| id = this.nextBufferId++; | |
| this.bufferIds.set(buffer, id); | |
| } | |
| return id; | |
| } | |
| idForPipeline(pipe) { | |
| let id = this.pipelineIds.get(pipe); | |
| if (!id) { | |
| id = this.nextPipelineId++; | |
| this.pipelineIds.set(pipe, id); | |
| } | |
| return id; | |
| } | |
| uncachedBindGroup(pipe, buffers) { | |
| this._stats.uncachedBindGroups++; | |
| return this.dev.createBindGroup({ | |
| label: pipe.__name ? `${pipe.__name}:bg:${buffers.length}` : void 0, | |
| layout: pipe.getBindGroupLayout(0), | |
| entries: buffers.map((buffer, i) => ({ binding: i, resource: { buffer } })) | |
| }); | |
| } | |
| cachedBindGroup(pipe, buffers, key, { sensitive = false } = {}) { | |
| if (!this.cacheBindGroups || !key) return this.uncachedBindGroup(pipe, buffers); | |
| const fullKey = `${this.idForPipeline(pipe)}:${key}:${buffers.map((b) => this.idForBuffer(b)).join(",")}`; | |
| let bg = this.bindGroups.get(fullKey); | |
| if (!bg) { | |
| this._stats.bindGroupMisses++; | |
| bg = this.uncachedBindGroup(pipe, buffers); | |
| this.bindGroups.set(fullKey, bg); | |
| if (sensitive) this.sensitiveBindGroups.add(fullKey); | |
| } else this._stats.bindGroupHits++; | |
| return bg; | |
| } | |
| clearSensitiveBindGroups() { | |
| for (const key of this.sensitiveBindGroups) this.bindGroups.delete(key); | |
| this.sensitiveBindGroups.clear(); | |
| } | |
| }; | |
| // src/qwgpu/runtime.js | |
| var STORAGE = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; | |
| var UNIFORM = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST; | |
| var QwenWGPU = class { | |
| static { | |
| __name(this, "QwenWGPU"); | |
| } | |
| // opts: { maxCtx, maxPrefillT, decodeBatchSize, samplingTopK } — context | |
| // window + batched-prefill cap (default 8192 each; KV cache grows linearly). | |
| constructor(device, cfg, opts = {}) { | |
| this.dev = device; | |
| this.cfg = cfg; | |
| this.lora = null; | |
| this.bufs = {}; | |
| this.opts = opts; | |
| this.features = this._normalizeFeatures(opts); | |
| this.pool = new GPUBufferPool(device, { cacheBindGroups: opts.cacheBindGroups !== false }); | |
| this._loraEpoch = 0; | |
| this.lastDispatchCount = 0; | |
| this.packedBytes = 0; | |
| this.workgroupAutotunePromise = null; | |
| this._argmaxReadBusy = false; | |
| this._topKReadBusy = false; | |
| } | |
| _normalizeFeatures(opts = {}) { | |
| const prefillAttention = opts.prefillAttention || "block"; | |
| if (!["row", "block"].includes(prefillAttention)) | |
| throw new Error(`unsupported prefillAttention ${prefillAttention}`); | |
| return { | |
| // fuseRMSNormQKVRoPE: fused RMSNorm + int4 QKV GEMV + RoPE for no-LoRA decode | |
| // (one workgroup per (head,rot) pair; verified logitDiff 0 vs PyTorch ref). | |
| // fuseQKV selects the alternate qkvGemv4 path and stays OFF by default since | |
| // the fused-RMS path already covers the fast no-LoRA decode; LoRA layers are | |
| // routed to the unfused gemv4x3 + ropeQK path automatically (see step()). | |
| fuseQKV: opts.fuseQKV === true, | |
| fuseRoPE: opts.fuseRoPE !== false, | |
| fuseMLP: opts.fuseMLP !== false, | |
| fuseResidual: opts.fuseResidual !== false, | |
| prefillAttention, | |
| prefillChunkSize: Math.max(0, opts.prefillChunkSize || 0), | |
| actQuant: !!opts.actQuant, | |
| // Default OFF: the GEMV4_QKV_ROPE_RMS kernel still computes zero outputs even | |
| // with the corrected (totalPairs) dispatch — there is a deeper bug in the | |
| // fused kernel itself. The unfused gemv4x3 + ropeQK decode is verified | |
| // logitDiff 0 vs the PyTorch ref, so it stays the default until the fused | |
| // kernel is debugged. The wrapper dispatch is now correct for that work. | |
| fuseRMSNormQKVRoPE: opts.fuseRMSNormQKVRoPE === true, | |
| pagedAttention: !!opts.pagedAttention | |
| }; | |
| } | |
| setFeatureFlags(flags = {}) { | |
| this.features = this._normalizeFeatures({ ...this.features, ...flags }); | |
| this.pool.clearSensitiveBindGroups(); | |
| } | |
| featureFlags() { | |
| return { ...this.features }; | |
| } | |
| // Phase 3 (f16): when shader-f16 is available we can switch hot kernels to f16 | |
| // storage/compute for bandwidth wins. Stub for now; real kernel variants + selection | |
| // will be added. Evaluation: compare f16 vs f32 logits within tolerance + bench speedup. | |
| hasF16Compute() { | |
| return !!this.hasF16; | |
| } | |
| setUseF16(v) { | |
| this._useF16 = !!v && this.hasF16Compute(); | |
| } | |
| usingF16() { | |
| return !!this._useF16; | |
| } | |
| // Phase 4: allow caller / autotuner to override workgroup size after build if desired. | |
| // Note: affects *future* pipes / re-pipes; existing pipes keep their specialization. | |
| setWorkgroupSize(wg) { | |
| if (wg && wg > 0) this.workgroupSize = wg | 0; | |
| } | |
| // Basic load-time / on-demand workgroup autotuner (Phase 4). | |
| // Tries a few WG sizes for simple override-supporting kernels (add / rms for now). | |
| // Uses wall time + onSubmittedWorkDone for broad compatibility. | |
| // Returns a map of best sizes; optionally hot-swaps the pipe for 'add'. | |
| async autotuneWorkgroups(opts = {}) { | |
| const iters = opts.iters || 6; | |
| const cands = opts.candidates || [32, 64, 128, 256]; | |
| const results = {}; | |
| const useTS = this.hasTimestampQuery; | |
| const timeKernel = /* @__PURE__ */ __name(async (spec, pipe, label) => { | |
| const n = spec.n; | |
| const a = this._buf(n * 4); | |
| const g = this._buf(n * 4); | |
| const y = this._buf(n * 4); | |
| const buffers = spec.buffers(a, y, g); | |
| const imm = spec.imm(n); | |
| let gpuMs = 0; | |
| let usedGPU = false; | |
| if (useTS) { | |
| const qs = this.dev.createQuerySet({ type: "timestamp", count: 2 }); | |
| const resolveBuf = this._buf(16, GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC); | |
| const readBuf = this._buf(16, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); | |
| const tWall0 = typeof performance !== "undefined" ? performance.now() : Date.now(); | |
| for (let i = 0; i < iters; i++) { | |
| const enc = this.dev.createCommandEncoder(); | |
| const bg = this._bg(pipe, buffers); | |
| const p = enc.beginComputePass({ | |
| timestampWrites: { | |
| querySet: qs, | |
| beginningOfPassWriteIndex: 0, | |
| endOfPassWriteIndex: 1 | |
| } | |
| }); | |
| p.setPipeline(pipe); | |
| if (bg) p.setBindGroup(0, bg); | |
| if (imm) p.setImmediates(0, imm); | |
| p.dispatchWorkgroups(Math.ceil(n / (pipe.__wg || 256)), 1); | |
| p.end(); | |
| enc.resolveQuerySet(qs, 0, 2, resolveBuf, 0); | |
| enc.copyBufferToBuffer(resolveBuf, 0, readBuf, 0, 16); | |
| this.dev.queue.submit([enc.finish()]); | |
| if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); | |
| await readBuf.mapAsync(GPUMapMode.READ); | |
| const t = new BigInt64Array(readBuf.getMappedRange()); | |
| const us = Number(t[1] - t[0]) / 1e3; | |
| gpuMs += us; | |
| readBuf.unmap(); | |
| } | |
| const wallMs = (typeof performance !== "undefined" ? performance.now() : Date.now()) - tWall0; | |
| resolveBuf.destroy?.(); | |
| readBuf.destroy?.(); | |
| qs.destroy?.(); | |
| usedGPU = true; | |
| a.destroy?.(); | |
| g.destroy?.(); | |
| y.destroy?.(); | |
| return gpuMs / iters / 1e3; | |
| } | |
| const t0 = typeof performance !== "undefined" ? performance.now() : Date.now(); | |
| for (let i = 0; i < iters; i++) { | |
| const enc = this.dev.createCommandEncoder(); | |
| const bg = this._bg(pipe, buffers); | |
| this._dispatch(enc, pipe, bg, Math.ceil(n / (pipe.__wg || 256)), 1, label + ":bench", imm); | |
| this.dev.queue.submit([enc.finish()]); | |
| if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); | |
| } | |
| const ms = (typeof performance !== "undefined" ? performance.now() : Date.now()) - t0; | |
| a.destroy?.(); | |
| g.destroy?.(); | |
| y.destroy?.(); | |
| return ms / iters; | |
| }, "timeKernel"); | |
| const kernels = [ | |
| { name: "add", src: ADD, n: 8192, buffers: /* @__PURE__ */ __name((a, y) => [a, y], "buffers"), imm: /* @__PURE__ */ __name((n) => new Uint32Array([n]), "imm") }, | |
| { name: "rms", src: RMSNORM, n: 4096, buffers: /* @__PURE__ */ __name((a, y, g) => [a, g, y], "buffers"), imm: /* @__PURE__ */ __name((n) => new Float32Array([n, this.cfg.rmsNormEps]), "imm") }, | |
| { name: "silu", src: SILUMUL, n: 8192, buffers: /* @__PURE__ */ __name((a, y) => [a, y], "buffers"), imm: /* @__PURE__ */ __name((n) => new Uint32Array([n]), "imm") } | |
| ]; | |
| for (const k of kernels) { | |
| try { | |
| let best = { wg: 256, ms: Infinity }; | |
| for (const wg of cands) { | |
| const p = this._pipe(k.src, `${k.name}:autotune:${wg}`, { WG: wg }); | |
| p.__wg = wg; | |
| const ms = await timeKernel(k, p, `${k.name}${wg}`); | |
| results[`${k.name}:${wg}`] = ms; | |
| if (ms < best.ms) best = { wg, ms }; | |
| } | |
| results[`best${k.name[0].toUpperCase()}${k.name.slice(1)}`] = best; | |
| if (opts.apply && this.pipes[k.name]) { | |
| this.pipes[k.name] = this._pipe(k.src, k.name, { WG: best.wg }); | |
| this.pipes[k.name].__wg = best.wg; | |
| } | |
| } catch (e) { | |
| results[`${k.name}Error`] = String(e); | |
| } | |
| } | |
| this.bestWorkgroupSizes = { | |
| add: results.bestAdd?.wg, | |
| rms: results.bestRms?.wg, | |
| silu: results.bestSilu?.wg, | |
| source: useTS ? "gpu-ts" : "wall" | |
| }; | |
| console.log("[autotune] WG microbench results (ms/iter, source=" + (useTS ? "gpu-ts" : "wall") + "):", results); | |
| return results; | |
| } | |
| _buf(size, usage = STORAGE) { | |
| return this.pool.buffer(size, usage); | |
| } | |
| _f32(arr, usage = STORAGE) { | |
| return this.pool.uploadF32(arr, usage); | |
| } | |
| _u32(arr) { | |
| return this.pool.uploadU32(arr, STORAGE); | |
| } | |
| _uni(arr) { | |
| return this.pool.dynamicUniform(arr, UNIFORM); | |
| } | |
| _staticUni(key, arr) { | |
| return this.pool.staticUniform(key, arr, UNIFORM); | |
| } | |
| _resetUni() { | |
| this.pool.resetUniforms(); | |
| this.lastDispatchCount = 0; | |
| } | |
| _pipe(code, name, overrides = null) { | |
| const processedCode = typeof code === "string" ? code.replaceAll("WG_SIZE", this.workgroupSize || 64) : code; | |
| const m = this.dev.createShaderModule({ | |
| label: name || void 0, | |
| code: processedCode | |
| }); | |
| const comp = { module: m, entryPoint: "main" }; | |
| if (overrides && typeof overrides === "object") comp.constants = overrides; | |
| const pipe = this.dev.createComputePipeline({ | |
| label: name ? `${name}-pipeline` : void 0, | |
| layout: "auto", | |
| compute: comp | |
| }); | |
| if (overrides?.WG) pipe.__wg = overrides.WG; | |
| if (name) pipe.__name = name; | |
| return pipe; | |
| } | |
| /* | |
| * TECHNIQUE: Specialization via pipeline constants (overrides) | |
| * Workgroup size and other small values are passed as pipeline-overridable | |
| * constants instead of uniforms or JS branches. Allows the shader compiler | |
| * to specialize the binary (better than runtime if). | |
| */ | |
| // `source` is a base URL string OR a reader { range, text } (e.g. hfReader/fileReader). | |
| async build(source, onProgress = () => { | |
| }) { | |
| const shaderCompileStart = performance.now(); | |
| const dev = this.dev, c = this.cfg; | |
| this.CHUNK = 128; | |
| this._initRuntimeOptions(); | |
| this.maxCtx = this.opts.maxCtx || 8192; | |
| this.maxPrefillT = Math.min(this.opts.maxPrefillT || 8192, this.maxCtx); | |
| const isAppleSilicon = this.dev.limits.minStorageBufferOffsetAlignment === 4; | |
| const isIntelArc = this.dev.limits.minStorageBufferOffsetAlignment === 256; | |
| this.workgroupSize = isAppleSilicon || isIntelArc ? 32 : 64; | |
| onProgress && onProgress(`workgroup size chosen: ${this.workgroupSize} (apple/intel bias toward 32)`, 0); | |
| let hasDP4a = false; | |
| if (typeof navigator !== "undefined" && navigator.gpu?.wgslLanguageFeatures?.has?.("packed_4x8_integer_dot_product")) { | |
| dev.pushErrorScope("validation"); | |
| try { | |
| dev.createShaderModule({ | |
| code: `enable packed_4x8_integer_dot_product; @compute @workgroup_size(1) fn main() {}` | |
| }); | |
| const error = await dev.popErrorScope(); | |
| if (!error) { | |
| hasDP4a = true; | |
| } | |
| } catch (e) { | |
| await dev.popErrorScope(); | |
| } | |
| } | |
| this.hasDP4a = hasDP4a; | |
| const hasF16 = this.dev.features.has("shader-f16"); | |
| this.hasF16 = hasF16; | |
| this.hasTimestampQuery = this.dev.features.has("timestamp-query"); | |
| this.pam = new PagedAttentionManager(this.maxCtx); | |
| this.pipes = { | |
| gemv: this._pipe(GEMV, "gemv"), | |
| loraA: this._pipe(LORA_A, "loraA"), | |
| loraABatch: this._pipe(LORA_A_BATCH, "loraABatch"), | |
| loraBAdd: this._pipe(LORA_B_ADD, "loraBAdd"), | |
| loraBAddT: this._pipe(LORA_B_ADD_T, "loraBAddT"), | |
| rms: this._pipe(RMSNORM, "rms", { WG: this.workgroupSize || 256 }), | |
| rmsF16: hasF16 ? this._pipe(RMSNORM_F16, "rmsF16", { WG: this.workgroupSize || 256 }) : null, | |
| rope: this._pipe(ROPE, "rope"), | |
| ropeF16: hasF16 ? this._pipe(ROPE_F16, "ropeF16") : null, | |
| ropeQK: this._pipe(ROPE_QK, "ropeQK"), | |
| ropeQKF16: hasF16 ? this._pipe(ROPE_QK_F16, "ropeQKF16") : null, | |
| ropeT: this._pipe(ROPE_T, "ropeT"), | |
| ropeTF16: hasF16 ? this._pipe(ROPE_T_F16, "ropeTF16") : null, | |
| attnP: this._pipe(ATTN_PARTIAL, "attnP", { WG: 128 }), | |
| attnPF16: hasF16 ? this._pipe(ATTN_PARTIAL_F16, "attnPF16", { WG: 128 }) : null, | |
| attnC: this._pipe(ATTN_COMBINE, "attnC", { WG: 128 }), | |
| attnCF16: hasF16 ? this._pipe(ATTN_COMBINE_F16, "attnCF16", { WG: 128 }) : null, | |
| add: this._pipe(ADD, "add", { WG: this.workgroupSize || 256 }), | |
| silu: this._pipe(SILUMUL, "silu", { WG: this.workgroupSize || 256 }), | |
| addF16: hasF16 ? this._pipe(ADD_F16, "addF16", { WG: this.workgroupSize || 256 }) : null, | |
| siluF16: hasF16 ? this._pipe(SILUMUL_F16, "siluF16", { WG: this.workgroupSize || 256 }) : null, | |
| embed: this._pipe(EMBED, "embed"), | |
| embedBuf: this._pipe(EMBED_BUF, "embedBuf"), | |
| argmax: this._pipe(ARGMAX, "argmax"), | |
| gemv4: this._pipe(GEMV4, "gemv4"), | |
| gemv4Add: this._pipe(GEMV4_ADD, "gemv4Add"), | |
| qkvGemv4: this._pipe(QKV_GEMV4, "qkvGemv4"), | |
| gateUpSiluGemv4: this._pipe(GATE_UP_SILU_GEMV4, "gateUpSiluGemv4"), | |
| topkSelect: this._pipe(TOPK_SELECT, "topkSelect"), | |
| sampleTopK: this._pipe(SAMPLE_TOPK, "sampleTopK"), | |
| gemm4: this._pipe(GEMM4, "gemm4"), | |
| gemm4AddT: this._pipe(GEMM4_ADD_T, "gemm4AddT"), | |
| rmsT: this._pipe(RMSNORM_T, "rmsT", { WG: this.workgroupSize || 256 }), | |
| rmsTF16: hasF16 ? this._pipe(RMSNORM_T_F16, "rmsTF16", { WG: this.workgroupSize || 256 }) : null, | |
| embedT: this._pipe(EMBED_T, "embedT"), | |
| attnPrefill: this._pipe(ATTN_PREFILL, "attnPrefill"), | |
| attnPrefillBlock: this._pipe(ATTN_PREFILL_BLOCK, "attnPrefillBlock"), | |
| dynQuant: this._pipe(DYN_QUANT_X, "dynQuant"), | |
| dynQuantT: this._pipe(DYN_QUANT_X_T, "dynQuantT"), | |
| gemv4W4A8: this._pipe(GEMV4_W4A8(hasDP4a, this.workgroupSize), "gemv4W4A8"), | |
| gemv4AddW4A8: this._pipe(GEMV4_ADD_W4A8(hasDP4a, this.workgroupSize), "gemv4AddW4A8"), | |
| qkvGemv4W4A8: this._pipe(QKV_GEMV4_W4A8(hasDP4a, this.workgroupSize), "qkvGemv4W4A8"), | |
| gateUpSiluGemv4W4A8: this._pipe(GATE_UP_SILU_GEMV4_W4A8(hasDP4a, this.workgroupSize), "gateUpSiluGemv4W4A8"), | |
| gemm4W4A8: this._pipe(GEMM4_W4A8(hasDP4a), "gemm4W4A8"), | |
| gemm4AddTW4A8: this._pipe(GEMM4_ADD_T_W4A8(hasDP4a), "gemm4AddTW4A8"), | |
| rmsNormQkvRope: this._pipe(GEMV4_QKV_ROPE_RMS, "rmsNormQkvRope"), | |
| writeKvPage: this._pipe(WRITE_KV_PAGE, "writeKvPage"), | |
| writeKvPageBatch: this._pipe(WRITE_KV_PAGE_BATCH, "writeKvPageBatch"), | |
| attnPartialPaged: this._pipe(ATTN_PARTIAL_PAGED, "attnPartialPaged"), | |
| attnPrefillPaged: this._pipe(ATTN_PREFILL_PAGED, "attnPrefillPaged"), | |
| attnPrefillBlockPaged: this._pipe(ATTN_PREFILL_BLOCK_PAGED, "attnPrefillBlockPaged") | |
| }; | |
| this.shaderCompileMs = performance.now() - shaderCompileStart; | |
| if (hasF16) { | |
| this.setUseF16(true); | |
| onProgress("f16 compute enabled (add/silu/rms/rope/attn-partial/combine paths)", 0); | |
| } | |
| if (this.hasTimestampQuery) { | |
| onProgress("timestamp-query available (precise GPU timing + autotune)", 0); | |
| } | |
| onProgress("streaming + quantizing weights", 0); | |
| this.schema = createQwenSchema(c); | |
| this.plan = createDispatchPlan(this.schema); | |
| this.q = {}; | |
| this.q4 = {}; | |
| this.qkv = []; | |
| this.gateUp = []; | |
| const uploader = new ModelUploader({ | |
| schema: this.schema, | |
| q: this.q, | |
| q4: this.q4, | |
| bufs: this.bufs, | |
| uploadF32: /* @__PURE__ */ __name((arr) => this._f32(arr), "uploadF32"), | |
| uploadU32: /* @__PURE__ */ __name((arr) => this._u32(arr), "uploadU32") | |
| }); | |
| if (source === "mock") { | |
| for (const name of this.schema.expectedNames) { | |
| const desc = this.schema.tensors.find((t) => t.name === name); | |
| const shape = desc.shape; | |
| const numel = shape.reduce((a, b) => a * b, 1); | |
| const type = desc.quant === "int8" ? "I8" : "F32"; | |
| uploader.visit({ name, shape, data: new Uint8Array(numel * (type === "I8" ? 1 : 4)), type }); | |
| } | |
| } else { | |
| await streamSafetensors(source, { | |
| names: this.schema.expectedNames, | |
| onProgress, | |
| onTensor: /* @__PURE__ */ __name(async (tensor) => { | |
| uploader.visit(tensor); | |
| if (uploader.seen.size % 48 === 0) await new Promise((r) => setTimeout(r, 0)); | |
| }, "onTensor") | |
| }); | |
| } | |
| uploader.finalize(); | |
| await this._buildPackedProjectionBuffers(); | |
| this._buildRope(this.maxCtx); | |
| this.kc = [], this.vc = []; | |
| const kvSize = c.numKVHeads * this.maxCtx * c.headDim * 4; | |
| for (let i = 0; i < c.numLayers; i++) { | |
| this.kc.push(this._buf(kvSize)); | |
| this.vc.push(this._buf(kvSize)); | |
| } | |
| const H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize; | |
| const NSPLITMAX = Math.ceil(this.maxCtx / this.CHUNK); | |
| this.s = { | |
| hidden: this._buf(H * 4), | |
| normed: this._buf(H * 4), | |
| q: this._buf(qd * 4), | |
| k: this._buf(kvd * 4), | |
| v: this._buf(kvd * 4), | |
| attn: this._buf(qd * 4), | |
| tmp: this._buf(Math.max(qd, I) * 4), | |
| tmp2: this._buf(I * 4), | |
| logits: this._buf(c.vocabSize * 4), | |
| dummy: this._buf(64), | |
| loraD: this._buf(256 * 4), | |
| loraD2: this._buf(256 * 4), | |
| amax: this._buf(4), | |
| pm: this._buf(c.numHeads * NSPLITMAX * 4), | |
| pz: this._buf(c.numHeads * NSPLITMAX * 4), | |
| po: this._buf(c.numHeads * NSPLITMAX * c.headDim * 4), | |
| idsBuf: this._buf(this.decodeBatchCapacity * 4), | |
| sampleIds: this._buf(this.maxSamplingTopK * 4), | |
| sampleVals: this._buf(this.maxSamplingTopK * 4), | |
| sampled: this._buf(4), | |
| // single u32 chosen by GPU sampler (Phase 5) | |
| x_q: this._buf(Math.max(qd, I) * 4), | |
| scale_x: this._buf(256 * 4), | |
| blockTableBuf: this._buf(this.pam.maxBlocksPerSeq * 4, STORAGE | GPUBufferUsage.COPY_DST) | |
| }; | |
| this.idsRead = this._buf(this.decodeBatchCapacity * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); | |
| this.argmaxRead = this._buf(4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); | |
| this.sampleIdsRead = this._buf(this.maxSamplingTopK * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); | |
| this.sampleValsRead = this._buf(this.maxSamplingTopK * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); | |
| this.sampledRead = this._buf(4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); | |
| this.sT = null; | |
| this.sTcap = 0; | |
| this._initStaticUniforms(); | |
| if (this.decodeBatchMode === "auto") { | |
| onProgress("autotuning decode batch", 0.98); | |
| await this.autotuneDecodeBatch(); | |
| } | |
| onProgress("ready", 1); | |
| if (!this._didAutoWG) { | |
| this._didAutoWG = true; | |
| this.workgroupAutotunePromise = this.autotuneWorkgroups({ iters: 2, apply: true }).catch((e) => ({ | |
| error: String(e) | |
| })); | |
| } | |
| return this; | |
| } | |
| _initRuntimeOptions() { | |
| const opts = this.opts; | |
| this.decodeBatchMode = opts.decodeBatchSize === "auto" ? "auto" : "fixed"; | |
| this.decodeBatchCandidates = (opts.decodeBatchCandidates || [1, 2, 4, 8, 16, 32]).map((x) => Math.max(1, Math.floor(Number(x) || 0))).filter(Boolean); | |
| const requested = opts.decodeBatchSize === void 0 || opts.decodeBatchSize === "auto" ? 16 : Math.max(1, Math.floor(Number(opts.decodeBatchSize))); | |
| this.maxDecodeBatchSize = Math.max( | |
| 1, | |
| Math.floor(Number(opts.maxDecodeBatchSize || Math.max(requested, ...this.decodeBatchCandidates, 16))) | |
| ); | |
| this.decodeBatchCapacity = Math.min(this.maxDecodeBatchSize, Math.max(requested, ...this.decodeBatchCandidates)); | |
| this.MAXBATCH = Math.min(requested, this.decodeBatchCapacity); | |
| this.decodeBatchWarmupTokens = Math.max(0, Math.floor(Number(opts.decodeBatchWarmupTokens ?? 4))); | |
| this.decodeBatchWarmupSize = Math.min( | |
| this.decodeBatchCapacity, | |
| Math.max(1, Math.floor(Number(opts.decodeBatchWarmupSize ?? 4))) | |
| ); | |
| this.decodeBatchMaxLatencyMs = Number(opts.decodeBatchMaxLatencyMs ?? 250); | |
| this.samplingTopK = Math.max(1, Math.floor(Number(opts.samplingTopK ?? 40))); | |
| this.maxSamplingTopK = Math.max(this.samplingTopK, Math.floor(Number(opts.maxSamplingTopK ?? 64))); | |
| this.decodeBatchTuning = { | |
| selected: this.MAXBATCH, | |
| candidates: [], | |
| reason: this.decodeBatchMode === "auto" ? "pending" : "fixed" | |
| }; | |
| } | |
| _buildRope(maxSeq) { | |
| const { headDim, ropeTheta } = this.cfg; | |
| const half = headDim / 2; | |
| const cos = new Float32Array(maxSeq * headDim), sin = new Float32Array(maxSeq * headDim); | |
| for (let p = 0; p < maxSeq; p++) | |
| for (let i = 0; i < half; i++) { | |
| const a = p / Math.pow(ropeTheta, 2 * i / headDim); | |
| const cc = Math.cos(a), ss = Math.sin(a); | |
| cos[p * headDim + i] = cc; | |
| cos[p * headDim + half + i] = cc; | |
| sin[p * headDim + i] = ss; | |
| sin[p * headDim + half + i] = ss; | |
| } | |
| this.ropeCos = this._f32(cos); | |
| this.ropeSin = this._f32(sin); | |
| this._ropeRow = headDim * 4; | |
| } | |
| _initStaticUniforms() { | |
| const c = this.cfg; | |
| const rms = new ArrayBuffer(8); | |
| const rmsDv = new DataView(rms); | |
| rmsDv.setFloat32(0, c.hiddenSize, true); | |
| rmsDv.setFloat32(4, c.rmsNormEps, true); | |
| this.u = { | |
| rmsHidden: this._staticUni(`rms:${c.hiddenSize}:${c.rmsNormEps}`, new Uint8Array(rms)), | |
| addHidden: this._staticUni(`u32:${c.hiddenSize}`, new Uint32Array([c.hiddenSize])), | |
| siluIntermediate: this._staticUni(`u32:${c.intermediateSize}`, new Uint32Array([c.intermediateSize])), | |
| embedBuf: this._staticUni(`embedBuf:${c.hiddenSize}`, new Uint32Array([c.hiddenSize])), | |
| argmax: this._staticUni(`argmax:${c.vocabSize}`, new Uint32Array([c.vocabSize])) | |
| }; | |
| } | |
| async _buildPackedProjectionBuffers() { | |
| const enc = this.dev.createCommandEncoder(); | |
| const copy = /* @__PURE__ */ __name((src, dst, dstOffset, bytes) => enc.copyBufferToBuffer(src, 0, dst, dstOffset, bytes), "copy"); | |
| this.packedBytes = 0; | |
| for (const L of this.plan.layers) { | |
| const q = this.q4[L.q.weight], k = this.q4[L.k.weight], v = this.q4[L.v.weight]; | |
| if (q.K !== k.K || q.K !== v.K || q.gpr !== k.gpr || q.gpr !== v.gpr) | |
| throw new Error(`layer ${L.index} qkv packing requires matching K/gpr`); | |
| const totalN = q.N + k.N + v.N; | |
| const wBytes = totalN * (q.K / 8) * 4; | |
| const scaleBytes = totalN * q.gpr * 4; | |
| const biasBytes = totalN * 4; | |
| const w = this._buf(wBytes); | |
| const scale = this._buf(scaleBytes); | |
| const bias = this._buf(biasBytes); | |
| enc.clearBuffer(bias); | |
| let wOff = 0, sOff = 0, bOff = 0; | |
| for (const part of [L.q, L.k, L.v]) { | |
| const qq = this.q4[part.weight]; | |
| const rowsW = qq.N * (qq.K / 8) * 4; | |
| const rowsS = qq.N * qq.gpr * 4; | |
| copy(qq.w, w, wOff, rowsW); | |
| wOff += rowsW; | |
| copy(qq.scale, scale, sOff, rowsS); | |
| sOff += rowsS; | |
| if (part.bias) copy(this.bufs[part.bias], bias, bOff, qq.N * 4); | |
| bOff += qq.N * 4; | |
| } | |
| this.qkv[L.index] = { w, scale, bias, K: q.K, qN: q.N, kN: k.N, vN: v.N, totalN, gpr: q.gpr }; | |
| this.packedBytes += wBytes + scaleBytes + biasBytes; | |
| const gate = this.q4[L.gate.weight], up = this.q4[L.up.weight]; | |
| if (gate.K !== up.K || gate.N !== up.N || gate.gpr !== up.gpr) | |
| throw new Error(`layer ${L.index} gate/up packing requires matching shape`); | |
| const guWBytes = (gate.N + up.N) * (gate.K / 8) * 4; | |
| const guScaleBytes = (gate.N + up.N) * gate.gpr * 4; | |
| const guW = this._buf(guWBytes); | |
| const guScale = this._buf(guScaleBytes); | |
| copy(gate.w, guW, 0, gate.N * (gate.K / 8) * 4); | |
| copy(up.w, guW, gate.N * (gate.K / 8) * 4, up.N * (up.K / 8) * 4); | |
| copy(gate.scale, guScale, 0, gate.N * gate.gpr * 4); | |
| copy(up.scale, guScale, gate.N * gate.gpr * 4, up.N * up.gpr * 4); | |
| this.gateUp[L.index] = { w: guW, scale: guScale, K: gate.K, N: gate.N, gpr: gate.gpr }; | |
| this.packedBytes += guWBytes + guScaleBytes; | |
| } | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.dev.queue.onSubmittedWorkDone(); | |
| } | |
| memoryFootprintBytes() { | |
| const c = this.cfg; | |
| const kvBytes = c.numLayers * 2 * c.numKVHeads * this.maxCtx * c.headDim * 4; | |
| const decodeScratchBytes = c.hiddenSize * 2 * 4 + (c.numHeads * c.headDim + 2 * c.numKVHeads * c.headDim + c.numHeads * c.headDim) * 4 + (Math.max(c.numHeads * c.headDim, c.intermediateSize) + c.intermediateSize + c.vocabSize) * 4; | |
| const prefillScratchBytes = this.sTcap ? this.sTcap * (3 * c.hiddenSize + c.numHeads * c.headDim + 2 * c.numKVHeads * c.headDim + c.numHeads * c.headDim + 2 * c.intermediateSize) * 4 : 0; | |
| return { kvBytes, decodeScratchBytes, prefillScratchBytes, packedBytes: this.packedBytes }; | |
| } | |
| _gemvMeta(q, biasBuf, mod) { | |
| const gx = Math.min(q.N, 65535); | |
| const bytes = new Uint8Array(32); | |
| const dv = new DataView(bytes.buffer); | |
| dv.setUint32(0, q.K, true); | |
| dv.setUint32(4, q.N, true); | |
| dv.setUint32(8, mod ? mod.rank : 0, true); | |
| dv.setUint32(12, biasBuf ? 1 : 0, true); | |
| dv.setUint32(16, mod ? 1 : 0, true); | |
| dv.setUint32(20, gx, true); | |
| dv.setFloat32(24, mod ? mod.scale : 0, true); | |
| return { | |
| gx, | |
| gy: Math.ceil(q.N / gx), | |
| bytes | |
| }; | |
| } | |
| _gemv4Meta(q, biasBuf, mod) { | |
| const gx = Math.min(q.N, 65535); | |
| const bytes = new Uint8Array(32); | |
| const dv = new DataView(bytes.buffer); | |
| dv.setUint32(0, q.K, true); | |
| dv.setUint32(4, q.N, true); | |
| dv.setUint32(8, mod ? mod.rank : 0, true); | |
| dv.setUint32(12, biasBuf ? 1 : 0, true); | |
| dv.setUint32(16, mod ? 1 : 0, true); | |
| dv.setUint32(20, gx, true); | |
| dv.setFloat32(24, mod ? mod.scale : 0, true); | |
| dv.setUint32(28, q.gpr, true); | |
| return { | |
| gx, | |
| gy: Math.ceil(q.N / gx), | |
| bytes | |
| }; | |
| } | |
| setLora(adapter) { | |
| this.lora = adapter; | |
| this._loraEpoch++; | |
| this.pool.clearSensitiveBindGroups(); | |
| } | |
| // {modules: {key:{A,B,rank,scale}}} A:[K][rank], B:[rank][N] f32 GPUBuffers | |
| clearLora() { | |
| this.lora = null; | |
| this._loraEpoch++; | |
| this.pool.clearSensitiveBindGroups(); | |
| } | |
| // Called after an in-place mutation of the active adapter's A/B buffers (e.g. an | |
| // optimizer step during training). Bumps the LoRA epoch so cached bind groups that | |
| // referenced the old contents are dropped and inference re-binds the mutated buffers. | |
| invalidateLora() { | |
| this._loraEpoch++; | |
| this.pool.clearSensitiveBindGroups(); | |
| } | |
| _bg(pipe, buffers) { | |
| return this.pool.uncachedBindGroup(pipe, buffers); | |
| } | |
| _bgCached(pipe, buffers, key, opts) { | |
| return this.pool.cachedBindGroup(pipe, buffers, key, opts); | |
| } | |
| _dispatch(enc, pipe, bg, gx, gy = 1, cat, imm = null) { | |
| this.lastDispatchCount++; | |
| let ts; | |
| if (this.prof && this.prof.idx < this.prof.cap) { | |
| const i = this.prof.idx++; | |
| this.prof.cats.push(cat || "misc"); | |
| ts = { querySet: this.prof.qs, beginningOfPassWriteIndex: 2 * i, endOfPassWriteIndex: 2 * i + 1 }; | |
| } | |
| const p = enc.beginComputePass(ts ? { timestampWrites: ts } : void 0); | |
| p.setPipeline(pipe); | |
| if (bg) p.setBindGroup(0, bg); | |
| if (imm) { | |
| if (Array.isArray(imm)) { | |
| let off = 0; | |
| for (const part of imm) { | |
| p.setImmediates(off, part); | |
| off += part.byteLength || part.length * (part.BYTES_PER_ELEMENT || 4); | |
| } | |
| } else { | |
| p.setImmediates(0, imm); | |
| } | |
| } | |
| p.dispatchWorkgroups(gx, gy); | |
| p.end(); | |
| } | |
| enableProf(cap = 700) { | |
| this.prof = { | |
| qs: this.dev.createQuerySet({ type: "timestamp", count: cap * 2 }), | |
| cap, | |
| idx: 0, | |
| cats: [], | |
| resolve: this._buf(cap * 16, GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC), | |
| read: this._buf(cap * 16, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ) | |
| }; | |
| } | |
| async profToken(id, pos) { | |
| this._resetUni(); | |
| this.prof.idx = 0; | |
| this.prof.cats = []; | |
| const enc = this.dev.createCommandEncoder(); | |
| this.embedRow(enc, id); | |
| this.step(enc, id, pos); | |
| const n = this.prof.idx; | |
| enc.resolveQuerySet(this.prof.qs, 0, n * 2, this.prof.resolve, 0); | |
| enc.copyBufferToBuffer(this.prof.resolve, 0, this.prof.read, 0, n * 16); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.prof.read.mapAsync(GPUMapMode.READ); | |
| const t = new BigInt64Array(this.prof.read.getMappedRange()); | |
| const sums = {}; | |
| for (let i = 0; i < n; i++) { | |
| const us = Number(t[2 * i + 1] - t[2 * i]) / 1e3; | |
| const c = this.prof.cats[i]; | |
| sums[c] = (sums[c] || 0) + us; | |
| } | |
| this.prof.read.unmap(); | |
| return sums; | |
| } | |
| poolStats() { | |
| return this.pool.stats(); | |
| } | |
| // Phase 4 observability: best workgroup sizes chosen by autotune (or null if not run). | |
| getBestWorkgroupSizes() { | |
| return this.bestWorkgroupSizes ? { ...this.bestWorkgroupSizes } : null; | |
| } | |
| resetPoolStats() { | |
| this.pool.resetStats(); | |
| } | |
| estimateKvCacheBytes() { | |
| const c = this.cfg; | |
| return c.numLayers * 2 * c.numKVHeads * this.maxCtx * c.headDim * 4; | |
| } | |
| estimatePrefillScratchBytes(T, loraRank = this._activeMaxLoraRank()) { | |
| const c = this.cfg, H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize; | |
| return T * H * 4 * 2 + T * qd * 4 * 2 + T * kvd * 4 * 2 + T * I * 4 * 2 + T * 4 + Math.max(1, T * Math.max(1, loraRank)) * 4; | |
| } | |
| greedyBatchSizeFor({ emitted = 0, remaining = Infinity, pos = 0 } = {}) { | |
| const interactive = emitted < this.decodeBatchWarmupTokens ? this.decodeBatchWarmupSize : this.MAXBATCH; | |
| return Math.max(0, Math.min(interactive, remaining, this.maxCtx - pos, this.decodeBatchCapacity)); | |
| } | |
| async _resetAutotuneDecodeState(tokens, seedTokenId = 0) { | |
| const c = this.cfg, S = this.s, H = c.hiddenSize, hd = c.headDim, qd = c.numHeads * hd, kvd = c.numKVHeads * hd, I = c.intermediateSize; | |
| const nsplitMax = Math.ceil(this.maxCtx / this.CHUNK); | |
| const touchedTokens = Math.min(Math.max(0, Math.floor(tokens)), this.maxCtx); | |
| const enc = this.dev.createCommandEncoder(); | |
| const clear = /* @__PURE__ */ __name((buf, bytes) => { | |
| if (bytes > 0) enc.clearBuffer(buf, 0, bytes); | |
| }, "clear"); | |
| clear(S.hidden, H * 4); | |
| clear(S.normed, H * 4); | |
| clear(S.q, qd * 4); | |
| clear(S.k, kvd * 4); | |
| clear(S.v, kvd * 4); | |
| clear(S.attn, qd * 4); | |
| clear(S.tmp, Math.max(qd, I) * 4); | |
| clear(S.tmp2, I * 4); | |
| clear(S.logits, c.vocabSize * 4); | |
| clear(S.loraD, 256 * 4); | |
| clear(S.idsBuf, this.decodeBatchCapacity * 4); | |
| clear(S.pm, c.numHeads * nsplitMax * 4); | |
| clear(S.pz, c.numHeads * nsplitMax * 4); | |
| clear(S.po, c.numHeads * nsplitMax * hd * 4); | |
| const kvBytes = touchedTokens * kvd * 4; | |
| for (let i = 0; i < c.numLayers; i++) { | |
| clear(this.kc[i], kvBytes); | |
| clear(this.vc[i], kvBytes); | |
| } | |
| this.dev.queue.submit([enc.finish()]); | |
| this.dev.queue.writeBuffer(S.amax, 0, new Uint32Array([seedTokenId])); | |
| if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); | |
| } | |
| async autotuneDecodeBatch() { | |
| const candidates = [...new Set(this.decodeBatchCandidates)].filter((k) => k >= 1 && k <= this.decodeBatchCapacity && k <= this.maxCtx).sort((a, b) => a - b); | |
| const rows = []; | |
| const resetTokens = candidates.length ? Math.max(...candidates) : 0; | |
| let selected = candidates[0] ?? this.MAXBATCH, best = Infinity; | |
| try { | |
| for (const k of candidates) { | |
| await this._resetAutotuneDecodeState(resetTokens); | |
| const t0 = performance.now(); | |
| await this.decodeGreedyBatch(0, k); | |
| const ms = performance.now() - t0; | |
| const msPerToken = ms / k; | |
| rows.push({ k, ms, msPerToken }); | |
| const latencyOk = !Number.isFinite(this.decodeBatchMaxLatencyMs) || ms <= this.decodeBatchMaxLatencyMs; | |
| if (latencyOk && msPerToken < best) { | |
| best = msPerToken; | |
| selected = k; | |
| } | |
| } | |
| if (!rows.some((r) => r.k === selected) && rows.length) | |
| selected = rows.reduce((a, b) => a.msPerToken <= b.msPerToken ? a : b).k; | |
| this.MAXBATCH = selected; | |
| this.decodeBatchTuning = { | |
| selected, | |
| candidates: rows, | |
| reason: "auto wall-clock decodeGreedyBatch with reset state" | |
| }; | |
| } catch (e) { | |
| this.decodeBatchTuning = { selected: this.MAXBATCH, candidates: rows, reason: `auto failed: ${e.message}` }; | |
| } finally { | |
| if (resetTokens > 0) { | |
| try { | |
| await this._resetAutotuneDecodeState(resetTokens); | |
| } catch { | |
| } | |
| } | |
| } | |
| return this.decodeBatchTuning; | |
| } | |
| // y = int8-GEMV(x, q) [+bias] [+lora]. q={w,scale,N,K}. moduleKey for LoRA lookup. | |
| gemv(enc, xBuf, q, yBuf, biasBuf, moduleKey) { | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); | |
| const meta = this._gemvMeta(q, biasBuf, mod); | |
| const key = `gemv:${moduleKey || "base"}:${q.K}:${q.N}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; | |
| const bg = this._bgCached( | |
| this.pipes.gemv, | |
| [xBuf, q.w, q.scale, biasBuf || this.s.dummy, this.s.loraD, mod ? mod.B : this.s.dummy, yBuf], | |
| key, | |
| { sensitive: !!mod } | |
| ); | |
| this._dispatch(enc, this.pipes.gemv, bg, meta.gx, meta.gy, `gemv:${q.N}x${q.K}`, meta.bytes); | |
| } | |
| gemv4(enc, xBuf, q, yBuf, biasBuf, moduleKey) { | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (this.debugCapture) console.log("VWG gemv4: " + moduleKey + " mod=" + !!mod); | |
| if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); | |
| const meta = this._gemv4Meta(q, biasBuf, mod); | |
| const key = `gemv4:${moduleKey || "base"}:${q.K}:${q.N}:${q.gpr}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; | |
| const bg = this._bgCached( | |
| this.pipes.gemv4, | |
| [xBuf, q.w, q.scale, biasBuf || this.s.dummy, this.s.loraD, mod ? mod.B : this.s.dummy, yBuf], | |
| key, | |
| { sensitive: !!mod } | |
| ); | |
| this._dispatch(enc, this.pipes.gemv4, bg, meta.gx, meta.gy, `g4:${q.N}x${q.K}`, meta.bytes); | |
| if (mod) { | |
| if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj" && this.debugStep < this.debugT) { | |
| enc.copyBufferToBuffer(yBuf, 0, this.debugBufs.ySeq, this.debugStep * q.N * 4, q.N * 4); | |
| this.debugStep++; | |
| } | |
| } | |
| } | |
| _loraA(enc, xBuf, q, mod, dBuf, moduleKey, label = "loraA") { | |
| const imm = new Uint32Array([q.K, mod.rank]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.loraA, | |
| this._bgCached(this.pipes.loraA, [xBuf, mod.A, dBuf], `${label}:${moduleKey}:${this._loraEpoch}`, { | |
| sensitive: true | |
| }), | |
| mod.rank, | |
| 1, | |
| label, | |
| imm | |
| ); | |
| if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj" && this.debugStep < this.debugT) { | |
| enc.copyBufferToBuffer(xBuf, 0, this.debugBufs.xSeq, this.debugStep * q.K * 4, q.K * 4); | |
| enc.copyBufferToBuffer(dBuf, 0, this.debugBufs.dSeq, this.debugStep * mod.rank * 4, mod.rank * 4); | |
| } | |
| } | |
| _loraBAdd(enc, yBuf, q, mod, dBuf, moduleKey) { | |
| const meta = new ArrayBuffer(32); | |
| const dv = new DataView(meta); | |
| dv.setUint32(0, q.N, true); | |
| dv.setUint32(4, mod.rank, true); | |
| dv.setFloat32(16, mod.scale, true); | |
| const bg = this._bgCached( | |
| this.pipes.loraBAdd, | |
| [dBuf, mod.B, yBuf], | |
| `loraBAdd:${moduleKey}:${this._loraEpoch}`, | |
| { sensitive: true } | |
| ); | |
| this._dispatch(enc, this.pipes.loraBAdd, bg, Math.ceil(q.N / 256), 1, "loraB", new Uint8Array(meta)); | |
| if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj" && this.debugStep < this.debugT) { | |
| enc.copyBufferToBuffer(yBuf, 0, this.debugBufs.ySeq, this.debugStep * q.N * 4, q.N * 4); | |
| this.debugStep++; | |
| } | |
| } | |
| gemv4Add(enc, xBuf, q, yBuf, biasBuf, moduleKey) { | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); | |
| const meta = this._gemv4Meta(q, biasBuf, mod); | |
| const key = `gemv4add:${moduleKey || "base"}:${q.K}:${q.N}:${q.gpr}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; | |
| const bg = this._bgCached( | |
| this.pipes.gemv4Add, | |
| [xBuf, q.w, q.scale, biasBuf || this.s.dummy, this.s.loraD, mod ? mod.B : this.s.dummy, yBuf], | |
| key, | |
| { sensitive: !!mod } | |
| ); | |
| this._dispatch(enc, this.pipes.gemv4Add, bg, meta.gx, meta.gy, `g4add:${q.N}x${q.K}`, meta.bytes); | |
| } | |
| dynQuant(enc, xBuf, x_qBuf, scale_xBuf, K) { | |
| const numGroups = Math.ceil(K / 128); | |
| const imm = new Uint32Array([K]); | |
| const bg = this._bg(this.pipes.dynQuant, [xBuf, x_qBuf, scale_xBuf]); | |
| this._dispatch(enc, this.pipes.dynQuant, bg, numGroups, 1, "dynQuant", imm); | |
| } | |
| dynQuantT(enc, xBuf, x_qBuf, scale_xBuf, K, T) { | |
| const numGroups = Math.ceil(K / 128); | |
| const imm = new Uint32Array([K, T]); | |
| const bg = this._bg(this.pipes.dynQuantT, [xBuf, x_qBuf, scale_xBuf]); | |
| this._dispatch(enc, this.pipes.dynQuantT, bg, numGroups, T, "dynQuantT", imm); | |
| } | |
| gemv4W4A8(enc, xBuf, x_qBuf, scale_xBuf, q, yBuf, biasBuf, moduleKey) { | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); | |
| const meta = this._gemv4Meta(q, biasBuf, mod); | |
| const key = `gemv4_w4a8:${moduleKey || "base"}:${q.K}:${q.N}:${q.gpr}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; | |
| const bg = this._bgCached( | |
| this.pipes.gemv4W4A8, | |
| [ | |
| x_qBuf, | |
| scale_xBuf, | |
| q.w, | |
| q.scale, | |
| biasBuf || this.s.dummy, | |
| this.s.loraD, | |
| mod ? mod.B : this.s.dummy, | |
| yBuf | |
| ], | |
| key, | |
| { sensitive: !!mod } | |
| ); | |
| this._dispatch(enc, this.pipes.gemv4W4A8, bg, meta.gx, meta.gy, `g4w4a8:${q.N}x${q.K}`, meta.bytes); | |
| } | |
| gemv4AddW4A8(enc, xBuf, x_qBuf, scale_xBuf, q, yBuf, biasBuf, moduleKey) { | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); | |
| const meta = this._gemv4Meta(q, biasBuf, mod); | |
| const key = `gemv4add_w4a8:${moduleKey || "base"}:${q.K}:${q.N}:${q.gpr}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; | |
| const bg = this._bgCached( | |
| this.pipes.gemv4AddW4A8, | |
| [ | |
| x_qBuf, | |
| scale_xBuf, | |
| q.w, | |
| q.scale, | |
| biasBuf || this.s.dummy, | |
| this.s.loraD, | |
| mod ? mod.B : this.s.dummy, | |
| yBuf | |
| ], | |
| key, | |
| { sensitive: !!mod } | |
| ); | |
| this._dispatch(enc, this.pipes.gemv4AddW4A8, bg, meta.gx, meta.gy, `g4addw4a8:${q.N}x${q.K}`, meta.bytes); | |
| } | |
| qkvGemv4W4A8(enc, xBuf, x_qBuf, scale_xBuf, packed, qBuf, kBuf, vBuf, L) { | |
| const gx = Math.min(packed.totalN, 65535); | |
| const imm = new Uint32Array([packed.K, packed.totalN, packed.qN, packed.kN, packed.vN, packed.gpr, gx, 0]); | |
| const bg = this._bgCached( | |
| this.pipes.qkvGemv4W4A8, | |
| [x_qBuf, scale_xBuf, packed.w, packed.scale, packed.bias, qBuf, kBuf, vBuf], | |
| `qkv_w4a8:${L.index}`, | |
| { sensitive: false } | |
| ); | |
| this._dispatch( | |
| enc, | |
| this.pipes.qkvGemv4W4A8, | |
| bg, | |
| gx, | |
| Math.ceil(packed.totalN / gx), | |
| `qkvw4a8:${packed.totalN}x${packed.K}`, | |
| imm | |
| ); | |
| for (const [part, out] of [ | |
| [L.q, qBuf], | |
| [L.k, kBuf], | |
| [L.v, vBuf] | |
| ]) { | |
| const mod = this.lora?.modules?.[part.loraKey]; | |
| if (!mod) continue; | |
| const q = this.q4[part.weight]; | |
| this._loraA(enc, xBuf, q, mod, this.s.loraD, part.loraKey); | |
| this._loraBAdd(enc, out, q, mod, this.s.loraD, part.loraKey); | |
| } | |
| } | |
| _gateUpImmediate(packed, gx, gateMod, upMod) { | |
| const imm = new Uint32Array(12); | |
| imm.set([ | |
| packed.K, | |
| packed.N, | |
| packed.gpr, | |
| gx, | |
| gateMod ? gateMod.rank : 0, | |
| upMod ? upMod.rank : 0, | |
| gateMod ? 1 : 0, | |
| upMod ? 1 : 0 | |
| ]); | |
| const f322 = new Float32Array(imm.buffer); | |
| f322[8] = gateMod ? gateMod.scale : 0; | |
| f322[9] = upMod ? upMod.scale : 0; | |
| return imm; | |
| } | |
| gateUpSiluGemv4W4A8(enc, xBuf, x_qBuf, scale_xBuf, packed, yBuf, L) { | |
| const gate = this.q4[L.gate.weight], up = this.q4[L.up.weight]; | |
| const gateMod = this.lora?.modules?.[L.gate.loraKey]; | |
| const upMod = this.lora?.modules?.[L.up.loraKey]; | |
| if (gateMod) this._loraA(enc, xBuf, gate, gateMod, this.s.loraD, L.gate.loraKey, "loraA:gate"); | |
| if (upMod) this._loraA(enc, xBuf, up, upMod, this.s.loraD2, L.up.loraKey, "loraA:up"); | |
| const gx = Math.min(packed.N, 65535); | |
| const imm = this._gateUpImmediate(packed, gx, gateMod, upMod); | |
| const bg = this._bgCached( | |
| this.pipes.gateUpSiluGemv4W4A8, | |
| [ | |
| x_qBuf, | |
| scale_xBuf, | |
| packed.w, | |
| packed.scale, | |
| yBuf, | |
| this.s.loraD, | |
| gateMod ? gateMod.B : this.s.dummy, | |
| this.s.loraD2, | |
| upMod ? upMod.B : this.s.dummy | |
| ], | |
| `gu_w4a8:${L.index}:${this._loraEpoch}:${gateMod ? 1 : 0}:${upMod ? 1 : 0}`, | |
| { sensitive: !!(gateMod || upMod) } | |
| ); | |
| this._dispatch( | |
| enc, | |
| this.pipes.gateUpSiluGemv4W4A8, | |
| bg, | |
| gx, | |
| Math.ceil(packed.N / gx), | |
| `guw4a8:${packed.N}x${packed.K}`, | |
| imm | |
| ); | |
| } | |
| gemm4W4A8(enc, aBuf, a_qBuf, scale_xBuf, q, yBuf, T, biasBuf, moduleKey) { | |
| const imm = new Uint32Array([q.K, q.N, T, q.gpr, biasBuf ? 1 : 0, 0, 0, 0]); | |
| const bg = this._bg(this.pipes.gemm4W4A8, [a_qBuf, scale_xBuf, q.w, q.scale, biasBuf || this.s.dummy, yBuf]); | |
| this._dispatch(enc, this.pipes.gemm4W4A8, bg, Math.ceil(q.N / 64), Math.ceil(T / 16), "gemm4W4A8", imm); | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (mod) this.loraBatchDelta(enc, aBuf, yBuf, q, T, mod, moduleKey); | |
| } | |
| gemm4AddTW4A8(enc, aBuf, a_qBuf, scale_xBuf, q, yBuf, T, biasBuf, moduleKey) { | |
| const imm = new Uint32Array([q.K, q.N, T, q.gpr, biasBuf ? 1 : 0, 0, 0, 0]); | |
| const bg = this._bg(this.pipes.gemm4AddTW4A8, [ | |
| a_qBuf, | |
| scale_xBuf, | |
| q.w, | |
| q.scale, | |
| biasBuf || this.s.dummy, | |
| yBuf | |
| ]); | |
| this._dispatch(enc, this.pipes.gemm4AddTW4A8, bg, Math.ceil(q.N / 64), Math.ceil(T / 16), "gemm4AddTW4A8", imm); | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (mod) this.loraBatchDelta(enc, aBuf, yBuf, q, T, mod, moduleKey); | |
| } | |
| // Fused decode: RMSNorm + int4 QKV GEMV + RoPE in one dispatch. The kernel | |
| // assigns ONE workgroup per (head, rotation) pair, so it must be launched with | |
| // totalPairs = (qN+kN+vN)/2 workgroups and the matching grid width — the prior | |
| // `20`-workgroup launch (+ element-count meta) left most Q/K/V outputs unwritten | |
| // and produced garbage tokens. The kernel normalizes x on the fly and has no | |
| // `normed` output, so this path is for the NO-LoRA case only; callers must route | |
| // LoRA-bearing layers to the unfused gemv4x3 path (which can add the adapter). | |
| rmsNormQkvRope(enc, xBuf, layerIndex, pos) { | |
| const c = this.cfg, L = this.plan.layers[layerIndex]; | |
| const packed = this.qkv[L.index]; | |
| const qPairs = packed.qN / 2, kPairs = packed.kN / 2, vPairs = packed.vN / 2; | |
| const totalPairs = qPairs + kPairs + vPairs; | |
| const gx = Math.min(totalPairs, 65535); | |
| const meta = new Uint32Array([ | |
| packed.K, | |
| totalPairs, | |
| qPairs, | |
| kPairs, | |
| vPairs, | |
| packed.gpr, | |
| gx, | |
| pos, | |
| c.headDim, | |
| ...new Uint32Array(new Float32Array([c.rmsNormEps, packed.qN, packed.kN]).buffer) | |
| ]); | |
| const bg = this._bg( | |
| this.pipes.rmsNormQkvRope, | |
| [ | |
| xBuf, | |
| this.bufs[L.inputNorm], | |
| packed.w, | |
| packed.scale, | |
| packed.bias, | |
| this.ropeCos, | |
| this.ropeSin, | |
| this.s.q, | |
| this.s.k, | |
| this.s.v | |
| ] | |
| ); | |
| this._dispatch(enc, this.pipes.rmsNormQkvRope, bg, gx, Math.ceil(totalPairs / gx), "rmsNormQkvRope", meta); | |
| } | |
| writeKvPage(enc, kBuf, vBuf, kcBuf, vcBuf, pos, layerIndex) { | |
| const c = this.cfg; | |
| const kvd = c.numKVHeads * c.headDim; | |
| this.pam.ensureBlocks(0, pos + 1); | |
| const btArr = this.pam.getBlockTableArray(0); | |
| this.dev.queue.writeBuffer(this.s.blockTableBuf, 0, btArr); | |
| const meta = new Uint32Array([pos, 0, this.pam.maxBlocksPerSeq, kvd]); | |
| const bg = this._bg(this.pipes.writeKvPage, [kBuf, vBuf, kcBuf, vcBuf, this.s.blockTableBuf]); | |
| this._dispatch(enc, this.pipes.writeKvPage, bg, Math.ceil(kvd / 256), 1, "writeKvPage", meta); | |
| } | |
| writeKvPageBatch(enc, kBuf, vBuf, kcBuf, vcBuf, T, off, layerIndex) { | |
| const c = this.cfg; | |
| const kvd = c.numKVHeads * c.headDim; | |
| this.pam.ensureBlocks(0, off + T); | |
| const btArr = this.pam.getBlockTableArray(0); | |
| this.dev.queue.writeBuffer(this.s.blockTableBuf, 0, btArr); | |
| const meta = new Uint32Array([T, 0, this.pam.maxBlocksPerSeq, kvd, off]); | |
| const bg = this._bg(this.pipes.writeKvPageBatch, [kBuf, vBuf, kcBuf, vcBuf, this.s.blockTableBuf]); | |
| this._dispatch(enc, this.pipes.writeKvPageBatch, bg, Math.ceil(T * kvd / 256), 1, "writeKvPageBatch", meta); | |
| } | |
| attnPaged(enc, qBuf, kc, vc, oBuf, ctx) { | |
| const c = this.cfg, S = this.s; | |
| const nsplit = Math.ceil(ctx / this.CHUNK); | |
| const bgP = this._bg(this.pipes.attnPartialPaged, [ | |
| qBuf, | |
| kc, | |
| vc, | |
| S.pm, | |
| S.pz, | |
| S.po, | |
| S.blockTableBuf | |
| ]); | |
| const immP = new Uint32Array([c.numHeads, c.numKVHeads, ctx, c.headDim, nsplit, this.CHUNK, 0, this.pam.maxBlocksPerSeq]); | |
| this._dispatch(enc, this.pipes.attnPartialPaged, bgP, c.numHeads, nsplit, "attnP_paged", immP); | |
| const useF16C = this.usingF16() && this.pipes.attnCF16; | |
| const pipeC = useF16C ? this.pipes.attnCF16 : this.pipes.attnC; | |
| const bgC = this._bg(pipeC, [ | |
| S.pm, | |
| S.pz, | |
| S.po, | |
| oBuf | |
| ]); | |
| const immC = new Uint32Array([c.numHeads, c.headDim, nsplit, 0]); | |
| this._dispatch(enc, pipeC, bgC, c.numHeads, 1, useF16C ? "attnCF16" : "attnC", immC); | |
| } | |
| attnPrefillPaged(enc, qBuf, kc, vc, oBuf, T, qStart = 0, ctx = T) { | |
| const c = this.cfg; | |
| if (this.features.prefillAttention === "block" || qStart !== 0 || ctx !== T) { | |
| const imm = new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, T, qStart, ctx, 0, this.pam.maxBlocksPerSeq]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.attnPrefillBlockPaged, | |
| this._bg(this.pipes.attnPrefillBlockPaged, [qBuf, kc, vc, oBuf, this.s.blockTableBuf]), | |
| c.numHeads, | |
| Math.ceil(T / 4), | |
| "attnPrefillBlockPaged", | |
| imm | |
| ); | |
| } else { | |
| const imm = new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, T, 0, this.pam.maxBlocksPerSeq, 0, 0]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.attnPrefillPaged, | |
| this._bg(this.pipes.attnPrefillPaged, [ | |
| qBuf, | |
| kc, | |
| vc, | |
| oBuf, | |
| this.s.blockTableBuf | |
| ]), | |
| c.numHeads, | |
| T, | |
| "attnPrefillPaged", | |
| imm | |
| ); | |
| } | |
| } | |
| qkvGemv4(enc, xBuf, packed, qBuf, kBuf, vBuf, L) { | |
| const gx = Math.min(packed.totalN, 65535); | |
| const imm = new Uint32Array([packed.K, packed.totalN, packed.qN, packed.kN, packed.vN, packed.gpr, gx, 0]); | |
| const bg = this._bgCached( | |
| this.pipes.qkvGemv4, | |
| [xBuf, packed.w, packed.scale, packed.bias, qBuf, kBuf, vBuf], | |
| `qkv:${L.index}`, | |
| { sensitive: false } | |
| ); | |
| this._dispatch(enc, this.pipes.qkvGemv4, bg, gx, Math.ceil(packed.totalN / gx), `qkv:${packed.totalN}x${packed.K}`, imm); | |
| for (const [part, out] of [ | |
| [L.q, qBuf], | |
| [L.k, kBuf], | |
| [L.v, vBuf] | |
| ]) { | |
| const mod = this.lora?.modules?.[part.loraKey]; | |
| if (!mod) continue; | |
| const q = this.q4[part.weight]; | |
| this._loraA(enc, xBuf, q, mod, this.s.loraD, part.loraKey); | |
| this._loraBAdd(enc, out, q, mod, this.s.loraD, part.loraKey); | |
| } | |
| } | |
| fusedRmsQkvRope(enc, hiddenBuf, inputNormBuf, packed, qBuf, kBuf, vBuf, pos, L) { | |
| const qPairs = packed.qN / 2; | |
| const kPairs = packed.kN / 2; | |
| const vPairs = packed.vN / 2; | |
| const totalPairs = qPairs + kPairs + vPairs; | |
| const gx = Math.min(totalPairs, 65535); | |
| const meta = new Uint32Array([ | |
| packed.K, | |
| totalPairs, | |
| qPairs, | |
| kPairs, | |
| vPairs, | |
| packed.gpr, | |
| gx, | |
| pos, | |
| this.cfg.headDim, | |
| ...new Uint32Array(new Float32Array([this.cfg.rmsNormEps, packed.qN, packed.kN]).buffer) | |
| ]); | |
| const bg = this._bg( | |
| this.pipes.rmsNormQkvRope, | |
| [ | |
| hiddenBuf, | |
| inputNormBuf, | |
| packed.w, | |
| packed.scale, | |
| packed.bias, | |
| this.ropeCos, | |
| this.ropeSin, | |
| qBuf, | |
| kBuf, | |
| vBuf | |
| ] | |
| ); | |
| this._dispatch( | |
| enc, | |
| this.pipes.rmsNormQkvRope, | |
| bg, | |
| gx, | |
| Math.ceil(totalPairs / gx), | |
| `fusedQkvRope:${totalPairs}x${packed.K}`, | |
| meta | |
| ); | |
| } | |
| gateUpSiluGemv4(enc, xBuf, packed, yBuf, L) { | |
| const gate = this.q4[L.gate.weight], up = this.q4[L.up.weight]; | |
| const gateMod = this.lora?.modules?.[L.gate.loraKey]; | |
| const upMod = this.lora?.modules?.[L.up.loraKey]; | |
| if (gateMod) this._loraA(enc, xBuf, gate, gateMod, this.s.loraD, L.gate.loraKey, "loraA:gate"); | |
| if (upMod) this._loraA(enc, xBuf, up, upMod, this.s.loraD2, L.up.loraKey, "loraA:up"); | |
| const gx = Math.min(packed.N, 65535); | |
| const imm = this._gateUpImmediate(packed, gx, gateMod, upMod); | |
| const bg = this._bgCached( | |
| this.pipes.gateUpSiluGemv4, | |
| [ | |
| xBuf, | |
| packed.w, | |
| packed.scale, | |
| yBuf, | |
| this.s.loraD, | |
| gateMod ? gateMod.B : this.s.dummy, | |
| this.s.loraD2, | |
| upMod ? upMod.B : this.s.dummy | |
| ], | |
| `gu:${L.index}:${this._loraEpoch}:${gateMod ? 1 : 0}:${upMod ? 1 : 0}`, | |
| { sensitive: !!(gateMod || upMod) } | |
| ); | |
| this._dispatch(enc, this.pipes.gateUpSiluGemv4, bg, gx, Math.ceil(packed.N / gx), `gu:${packed.N}x${packed.K}`, imm); | |
| } | |
| rms(enc, xBuf, gBuf, yBuf, K) { | |
| const imm = new Float32Array([K, this.cfg.rmsNormEps]); | |
| const useF16 = this.usingF16() && this.pipes.rmsF16; | |
| const pipe = useF16 ? this.pipes.rmsF16 : this.pipes.rms; | |
| const key = `rms:${K}${useF16 ? ":f16" : ""}`; | |
| this._dispatch(enc, pipe, this._bgCached(pipe, [xBuf, gBuf, yBuf], key), 1, 1, useF16 ? "rmsF16" : "rms", imm); | |
| } | |
| rope(enc, xBuf, pos, nHeads) { | |
| const useF16 = this.usingF16() && this.pipes.ropeF16; | |
| const pipe = useF16 ? this.pipes.ropeF16 : this.pipes.rope; | |
| this._dispatch( | |
| enc, | |
| pipe, | |
| this._bg(pipe, [ | |
| xBuf, | |
| this.ropeCos, | |
| this.ropeSin | |
| ]), | |
| Math.ceil(nHeads * (this.cfg.headDim / 2) / 256), | |
| 1, | |
| useF16 ? "ropeF16" : "rope", | |
| new Uint32Array([nHeads, this.cfg.headDim, pos]) | |
| ); | |
| } | |
| ropeQK(enc, qBuf, kBuf, pos) { | |
| const c = this.cfg; | |
| const pairs = (c.numHeads + c.numKVHeads) * (c.headDim / 2); | |
| const useF16 = this.usingF16() && this.pipes.ropeQKF16; | |
| const pipe = useF16 ? this.pipes.ropeQKF16 : this.pipes.ropeQK; | |
| this._dispatch( | |
| enc, | |
| pipe, | |
| this._bg(pipe, [ | |
| qBuf, | |
| kBuf, | |
| this.ropeCos, | |
| this.ropeSin | |
| ]), | |
| Math.ceil(pairs / 256), | |
| 1, | |
| useF16 ? "ropeQKF16" : "ropeQK", | |
| new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, pos]) | |
| ); | |
| } | |
| attn(enc, qBuf, kc, vc, oBuf, ctx) { | |
| const c = this.cfg, S = this.s; | |
| const nsplit = Math.ceil(ctx / this.CHUNK); | |
| const useF16P = this.usingF16() && this.pipes.attnPF16; | |
| const pipeP = useF16P ? this.pipes.attnPF16 : this.pipes.attnP; | |
| const bgP = this._bg(pipeP, [ | |
| qBuf, | |
| kc, | |
| vc, | |
| S.pm, | |
| S.pz, | |
| S.po | |
| ]); | |
| const immP = new Uint32Array([c.numHeads, c.numKVHeads, ctx, c.headDim, nsplit, this.CHUNK]); | |
| this._dispatch(enc, pipeP, bgP, c.numHeads, nsplit, useF16P ? "attnPF16" : "attnP", immP); | |
| const useF16C = this.usingF16() && this.pipes.attnCF16; | |
| const pipeC = useF16C ? this.pipes.attnCF16 : this.pipes.attnC; | |
| const bgC = this._bg(pipeC, [ | |
| S.pm, | |
| S.pz, | |
| S.po, | |
| oBuf | |
| ]); | |
| const immC = new Uint32Array([c.numHeads, c.headDim, nsplit, 0]); | |
| this._dispatch(enc, pipeC, bgC, c.numHeads, 1, useF16C ? "attnCF16" : "attnC", immC); | |
| } | |
| // Decode one token at absolute position `pos`. Writes logits to s.logits. Returns nothing. | |
| step(enc, tokenId, pos) { | |
| const c = this.cfg, S = this.s, hd = c.headDim, kvd = c.numKVHeads * hd; | |
| for (let i = 0; i < c.numLayers; i++) { | |
| const L = this.plan.layers[i]; | |
| const hasQkvLora = this.lora && (this.lora.modules[L.q.loraKey] || this.lora.modules[L.k.loraKey] || this.lora.modules[L.v.loraKey]); | |
| if (this.features.fuseRMSNormQKVRoPE && !hasQkvLora && !this.features.actQuant) { | |
| this.rmsNormQkvRope(enc, S.hidden, i, pos); | |
| } else { | |
| this.rms(enc, S.hidden, this.bufs[L.inputNorm], S.normed, c.hiddenSize); | |
| if (this.features.actQuant) { | |
| this.dynQuant(enc, S.normed, S.x_q, S.scale_x, c.hiddenSize); | |
| this.qkvGemv4W4A8(enc, S.normed, S.x_q, S.scale_x, this.qkv[L.index], S.q, S.k, S.v, L); | |
| } else { | |
| if (!hasQkvLora && this.features.fuseQKV) { | |
| this.fusedRmsQkvRope(enc, S.hidden, this.bufs[L.inputNorm], this.qkv[L.index], S.q, S.k, S.v, pos, L); | |
| } else if (this.features.fuseQKV) { | |
| this.qkvGemv4(enc, S.normed, this.qkv[L.index], S.q, S.k, S.v, L); | |
| if (this.features.fuseRoPE) this.ropeQK(enc, S.q, S.k, pos); | |
| else { | |
| this.rope(enc, S.q, pos, c.numHeads); | |
| this.rope(enc, S.k, pos, c.numKVHeads); | |
| } | |
| } else { | |
| this.gemv4(enc, S.normed, this.q4[L.q.weight], S.q, this.bufs[L.q.bias], L.q.loraKey); | |
| this.gemv4(enc, S.normed, this.q4[L.k.weight], S.k, this.bufs[L.k.bias], L.k.loraKey); | |
| this.gemv4(enc, S.normed, this.q4[L.v.weight], S.v, this.bufs[L.v.bias], L.v.loraKey); | |
| if (this.features.fuseRoPE) this.ropeQK(enc, S.q, S.k, pos); | |
| else { | |
| this.rope(enc, S.q, pos, c.numHeads); | |
| this.rope(enc, S.k, pos, c.numKVHeads); | |
| } | |
| } | |
| } | |
| } | |
| if (this.features.pagedAttention) { | |
| this.writeKvPage(enc, S.k, S.v, this.kc[i], this.vc[i], pos, i); | |
| } else { | |
| enc.copyBufferToBuffer(S.k, 0, this.kc[i], pos * kvd * 4, kvd * 4); | |
| enc.copyBufferToBuffer(S.v, 0, this.vc[i], pos * kvd * 4, kvd * 4); | |
| } | |
| if (this.features.pagedAttention) { | |
| this.attnPaged(enc, S.q, this.kc[i], this.vc[i], S.attn, pos + 1); | |
| } else { | |
| this.attn(enc, S.q, this.kc[i], this.vc[i], S.attn, pos + 1); | |
| } | |
| if (this.features.actQuant) { | |
| this.dynQuant(enc, S.attn, S.x_q, S.scale_x, c.hiddenSize); | |
| if (this.features.fuseResidual) { | |
| this.gemv4AddW4A8(enc, S.attn, S.x_q, S.scale_x, this.q4[L.o.weight], S.hidden, null, L.o.loraKey); | |
| } else { | |
| this.gemv4W4A8(enc, S.attn, S.x_q, S.scale_x, this.q4[L.o.weight], S.tmp, null, L.o.loraKey); | |
| this._addInto(enc, S.hidden, S.tmp, c.hiddenSize); | |
| } | |
| } else { | |
| if (this.features.fuseResidual) this.gemv4Add(enc, S.attn, this.q4[L.o.weight], S.hidden, null, L.o.loraKey); | |
| else { | |
| this.gemv4(enc, S.attn, this.q4[L.o.weight], S.tmp, null, L.o.loraKey); | |
| this._addInto(enc, S.hidden, S.tmp, c.hiddenSize); | |
| } | |
| } | |
| this.rms(enc, S.hidden, this.bufs[L.postAttentionNorm], S.normed, c.hiddenSize); | |
| if (this.features.actQuant) { | |
| this.dynQuant(enc, S.normed, S.x_q, S.scale_x, c.hiddenSize); | |
| this.gateUpSiluGemv4W4A8(enc, S.normed, S.x_q, S.scale_x, this.gateUp[L.index], S.tmp, L); | |
| } else { | |
| if (this.features.fuseMLP) { | |
| this.gateUpSiluGemv4(enc, S.normed, this.gateUp[L.index], S.tmp, L); | |
| } else { | |
| this.gemv4(enc, S.normed, this.q4[L.gate.weight], S.tmp, null, L.gate.loraKey); | |
| this.gemv4(enc, S.normed, this.q4[L.up.weight], S.tmp2, null, L.up.loraKey); | |
| this._siluMul(enc, S.tmp, S.tmp2, c.intermediateSize); | |
| } | |
| } | |
| if (this.features.actQuant) { | |
| this.dynQuant(enc, S.tmp, S.x_q, S.scale_x, c.intermediateSize); | |
| if (this.features.fuseResidual) { | |
| this.gemv4AddW4A8(enc, S.tmp, S.x_q, S.scale_x, this.q4[L.down.weight], S.hidden, null, L.down.loraKey); | |
| } else { | |
| this.gemv4W4A8(enc, S.tmp, S.x_q, S.scale_x, this.q4[L.down.weight], S.normed, null, L.down.loraKey); | |
| this._addInto(enc, S.hidden, S.normed, c.hiddenSize); | |
| } | |
| } else { | |
| if (this.features.fuseResidual) | |
| this.gemv4Add(enc, S.tmp, this.q4[L.down.weight], S.hidden, null, L.down.loraKey); | |
| else { | |
| this.gemv4(enc, S.tmp, this.q4[L.down.weight], S.normed, null, L.down.loraKey); | |
| this._addInto(enc, S.hidden, S.normed, c.hiddenSize); | |
| } | |
| } | |
| } | |
| this.rms(enc, S.hidden, this.bufs[this.plan.finalNorm.name], S.normed, c.hiddenSize); | |
| this.gemv(enc, S.normed, this.q[this.plan.embed.name], S.logits, null, null); | |
| } | |
| _addInto(enc, yBuf, aBuf, n) { | |
| const imm = new Uint32Array([n]); | |
| const useF16 = this.usingF16() && this.pipes.addF16; | |
| const pipe = useF16 ? this.pipes.addF16 : this.pipes.add; | |
| const bg = this._bgCached(pipe, [aBuf, yBuf], `add:${n}${useF16 ? ":f16" : ""}`); | |
| const wg = pipe.__wg || 256; | |
| this._dispatch(enc, pipe, bg, Math.min(Math.ceil(n / wg), 65535), 1, useF16 ? "addF16" : "add", imm); | |
| } | |
| _siluMul(enc, gateBuf, upBuf, n) { | |
| const imm = new Uint32Array([n]); | |
| const useF16 = this.usingF16() && this.pipes.siluF16; | |
| const pipe = useF16 ? this.pipes.siluF16 : this.pipes.silu; | |
| const bg = this._bgCached(pipe, [gateBuf, upBuf], `silu:${n}${useF16 ? ":f16" : ""}`); | |
| const wg = pipe.__wg || 256; | |
| this._dispatch(enc, pipe, bg, Math.min(Math.ceil(n / wg), 65535), 1, useF16 ? "siluF16" : "silu", imm); | |
| } | |
| embedRow(enc, id) { | |
| const e = this.q[this.plan.embed.name]; | |
| const imm = new Uint32Array([id, this.cfg.hiddenSize]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.embed, | |
| this._bg(this.pipes.embed, [e.w, e.scale, this.s.hidden]), | |
| Math.ceil(this.cfg.hiddenSize / 256), | |
| 1, | |
| "embed", | |
| imm | |
| ); | |
| } | |
| async argmaxLogits() { | |
| if (this._argmaxReadBusy) | |
| throw new Error("argmaxLogits() is already in flight; concurrent generation is not supported"); | |
| this._argmaxReadBusy = true; | |
| const enc = this.dev.createCommandEncoder(); | |
| const n = this.cfg.vocabSize || 0; | |
| this._dispatch( | |
| enc, | |
| this.pipes.argmax, | |
| this._bgCached(this.pipes.argmax, [this.s.logits, this.s.amax], "argmax"), | |
| 1, | |
| 1, | |
| "argmax", | |
| new Uint32Array([n]) | |
| ); | |
| enc.copyBufferToBuffer(this.s.amax, 0, this.argmaxRead, 0, 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); | |
| try { | |
| await this.argmaxRead.mapAsync(GPUMapMode.READ); | |
| const id = new Uint32Array(this.argmaxRead.getMappedRange())[0]; | |
| this.argmaxRead.unmap(); | |
| return id; | |
| } finally { | |
| this._argmaxReadBusy = false; | |
| } | |
| } | |
| // Convenience for numeric comparison harnesses (Phase 3 f16 eval etc.). | |
| // Returns a fresh Float32Array copy of the current final logits buffer. | |
| async readLogits() { | |
| const n = this.cfg.vocabSize; | |
| if (!this._logitsRead) { | |
| this._logitsRead = this._buf(n * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); | |
| } | |
| const enc = this.dev.createCommandEncoder(); | |
| enc.copyBufferToBuffer(this.s.logits, 0, this._logitsRead, 0, n * 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); | |
| await this._logitsRead.mapAsync(GPUMapMode.READ); | |
| const out = new Float32Array(this._logitsRead.getMappedRange()).slice(); | |
| this._logitsRead.unmap(); | |
| return out; | |
| } | |
| async topKLogits(k = this.samplingTopK) { | |
| if (this._topKReadBusy) throw new Error("topKLogits() is already in flight; concurrent sampling is not supported"); | |
| this._topKReadBusy = true; | |
| try { | |
| k = Math.min(Math.max(1, Math.floor(k)), this.maxSamplingTopK, this.cfg.vocabSize); | |
| const enc = this.dev.createCommandEncoder(); | |
| for (let i = 0; i < k; i++) { | |
| const imm = new Uint32Array([this.cfg.vocabSize, i]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.topkSelect, | |
| this._bgCached(this.pipes.topkSelect, [this.s.logits, this.s.sampleIds, this.s.sampleVals], `topk:${i}`), | |
| 1, | |
| 1, | |
| "topk", | |
| imm | |
| ); | |
| } | |
| enc.copyBufferToBuffer(this.s.sampleIds, 0, this.sampleIdsRead, 0, k * 4); | |
| enc.copyBufferToBuffer(this.s.sampleVals, 0, this.sampleValsRead, 0, k * 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await Promise.all([this.sampleIdsRead.mapAsync(GPUMapMode.READ), this.sampleValsRead.mapAsync(GPUMapMode.READ)]); | |
| const ids = Array.from(new Uint32Array(this.sampleIdsRead.getMappedRange(), 0, k)); | |
| const vals = Array.from(new Float32Array(this.sampleValsRead.getMappedRange(), 0, k)); | |
| return ids.map((id, i) => ({ id, logit: vals[i] })); | |
| } finally { | |
| if (this.sampleIdsRead.mapState !== "unmapped") this.sampleIdsRead.unmap(); | |
| if (this.sampleValsRead.mapState !== "unmapped") this.sampleValsRead.unmap(); | |
| this._topKReadBusy = false; | |
| } | |
| } | |
| // Phase 5: GPU-resident sampling (pure-GPU top-k + sample chaining). | |
| // Runs the iterative top-k selection dispatches directly into the GPU sampleIds/sampleVals | |
| // buffers, then immediately chains the SAMPLE_TOPK kernel in the same submission. | |
| // Only a single u32 (the chosen token) is ever read back from the GPU. | |
| // This eliminates the previous k-value readbacks for the sampling path. | |
| async sampleToken(temp = 1, r = typeof Math !== "undefined" ? Math.random() : 0.5) { | |
| if (this._topKReadBusy) throw new Error("sampleToken: top-k selection already in flight"); | |
| this._topKReadBusy = true; | |
| const k = Math.min(this.samplingTopK, this.maxSamplingTopK, this.cfg.vocabSize); | |
| try { | |
| const enc = this.dev.createCommandEncoder(); | |
| for (let i = 0; i < k; i++) { | |
| const imm2 = new Uint32Array([this.cfg.vocabSize, i]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.topkSelect, | |
| this._bgCached(this.pipes.topkSelect, [this.s.logits, this.s.sampleIds, this.s.sampleVals], `topk:${i}`), | |
| 1, | |
| 1, | |
| "topk", | |
| imm2 | |
| ); | |
| } | |
| const bg = this._bg(this.pipes.sampleTopK, [ | |
| this.s.sampleIds, | |
| this.s.sampleVals, | |
| this.s.sampled | |
| ]); | |
| const imm = new Uint32Array(4); | |
| imm[0] = k; | |
| const f322 = new Float32Array(imm.buffer); | |
| f322[2] = temp > 0 ? temp : 1; | |
| f322[3] = Math.max(0, Math.min(1, r)); | |
| this._dispatch(enc, this.pipes.sampleTopK, bg, 1, 1, "sampleTopK", imm); | |
| enc.copyBufferToBuffer(this.s.sampled, 0, this.sampledRead, 0, 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); | |
| await this.sampledRead.mapAsync(GPUMapMode.READ); | |
| const id = new Uint32Array(this.sampledRead.getMappedRange())[0]; | |
| this.sampledRead.unmap(); | |
| return id; | |
| } finally { | |
| this._topKReadBusy = false; | |
| } | |
| } | |
| // Run one token end-to-end (embed + step) and submit. | |
| token(id, pos) { | |
| this._resetUni(); | |
| const enc = this.dev.createCommandEncoder(); | |
| this.embedRow(enc, id); | |
| this.step(enc, id, pos); | |
| this.dev.queue.submit([enc.finish()]); | |
| } | |
| // embed the token id held in s.amax (GPU-resident, from a prior argmax) | |
| embedFromBuf(enc) { | |
| const e = this.q[this.plan.embed.name]; | |
| const imm = new Uint32Array([this.cfg.hiddenSize]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.embedBuf, | |
| this._bgCached(this.pipes.embedBuf, [e.w, e.scale, this.s.hidden, this.s.amax], "embedBuf"), | |
| Math.ceil(this.cfg.hiddenSize / 256), | |
| 1, | |
| "embed", | |
| imm | |
| ); | |
| } | |
| // argmax(logits) -> s.amax, within the given encoder (no submit/readback) | |
| argmaxInto(enc) { | |
| const n = this.cfg.vocabSize || 0; | |
| this._dispatch( | |
| enc, | |
| this.pipes.argmax, | |
| this._bgCached(this.pipes.argmax, [this.s.logits, this.s.amax], "argmax"), | |
| 1, | |
| 1, | |
| "argmax", | |
| new Uint32Array([n]) | |
| ); | |
| } | |
| // GPU-resident batched GREEDY decode only: chains embed->step->argmax for K | |
| // tokens in ONE submit, reads back K ids once, and checks stop tokens only | |
| // after readback. It assumes s.amax already holds the current token id to | |
| // embed. Do not use for sampled decoding; sampled tokens must be written by | |
| // the CPU/GPU sampler one step at a time. | |
| async decodeBatch(startPos, K) { | |
| K = Math.min(K, this.decodeBatchCapacity, this.maxCtx - startPos); | |
| if (K <= 0) return []; | |
| this._resetUni(); | |
| const enc = this.dev.createCommandEncoder(); | |
| for (let k = 0; k < K; k++) { | |
| this.embedFromBuf(enc); | |
| this.step(enc, 0, startPos + k); | |
| this.argmaxInto(enc); | |
| enc.copyBufferToBuffer(this.s.amax, 0, this.s.idsBuf, k * 4, 4); | |
| } | |
| enc.copyBufferToBuffer(this.s.idsBuf, 0, this.idsRead, 0, K * 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.idsRead.mapAsync(GPUMapMode.READ); | |
| const ids = Array.from(new Uint32Array(this.idsRead.getMappedRange(), 0, K)); | |
| this.idsRead.unmap(); | |
| return ids; | |
| } | |
| async decodeGreedyBatch(startPos, K) { | |
| return this.decodeBatch(startPos, K); | |
| } | |
| // ---- PREFILL (T>1): process the whole prompt at once via tiled GEMM. If a LoRA | |
| // adapter has the projection module, add its batched delta immediately after base GEMM. | |
| gemm4(enc, aBuf, q, yBuf, T, biasBuf, moduleKey) { | |
| const imm = new Uint32Array([q.K, q.N, T, q.gpr, biasBuf ? 1 : 0, 0, 0, 0]); | |
| const bg = this._bg(this.pipes.gemm4, [aBuf, q.w, q.scale, biasBuf || this.s.dummy, yBuf]); | |
| this._dispatch(enc, this.pipes.gemm4, bg, Math.ceil(q.N / 64), Math.ceil(T / 16), "gemm4", imm); | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (mod) this.loraBatchDelta(enc, aBuf, yBuf, q, T, mod, moduleKey); | |
| } | |
| gemm4AddT(enc, aBuf, q, yBuf, T, biasBuf, moduleKey) { | |
| const imm = new Uint32Array([q.K, q.N, T, q.gpr, biasBuf ? 1 : 0, 0, 0, 0]); | |
| const bg = this._bg(this.pipes.gemm4AddT, [aBuf, q.w, q.scale, biasBuf || this.s.dummy, yBuf]); | |
| this._dispatch(enc, this.pipes.gemm4AddT, bg, Math.ceil(q.N / 64), Math.ceil(T / 16), "gemm4AddT", imm); | |
| const mod = this.lora?.modules?.[moduleKey]; | |
| if (mod) this.loraBatchDelta(enc, aBuf, yBuf, q, T, mod, moduleKey); | |
| } | |
| loraBatchDelta(enc, xBuf, yBuf, q, T, mod, moduleKey) { | |
| if (this.debugCapture) console.log("VWG loraBatchDelta: " + moduleKey + " mod=" + !!mod); | |
| const imm = new Uint32Array([q.K, mod.rank, T, 0]); | |
| const bgA = this._bg(this.pipes.loraABatch, [xBuf, mod.A, this.sT.loraD]); | |
| this._dispatch(enc, this.pipes.loraABatch, bgA, mod.rank, T, "loraA:T", imm); | |
| if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj") { | |
| enc.copyBufferToBuffer(xBuf, 0, this.debugBufs.xBat, 0, T * q.K * 4); | |
| enc.copyBufferToBuffer(this.sT.loraD, 0, this.debugBufs.dBat, 0, T * mod.rank * 4); | |
| } | |
| const totalGroups = Math.ceil(T * q.N / 256); | |
| let gx = totalGroups; | |
| let gy = 1; | |
| if (gx > 65535) { | |
| gx = 256; | |
| gy = Math.ceil(totalGroups / 256); | |
| } | |
| const meta = new ArrayBuffer(32); | |
| const dv = new DataView(meta); | |
| dv.setUint32(0, T, true); | |
| dv.setUint32(4, q.N, true); | |
| dv.setUint32(8, mod.rank, true); | |
| dv.setUint32(12, gx, true); | |
| dv.setFloat32(16, mod.scale, true); | |
| const bgB = this._bg(this.pipes.loraBAddT, [this.sT.loraD, mod.B, yBuf]); | |
| this._dispatch(enc, this.pipes.loraBAddT, bgB, gx, gy, "loraB:T", new Uint8Array(meta)); | |
| if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj") { | |
| enc.copyBufferToBuffer(yBuf, 0, this.debugBufs.yBat, 0, T * q.N * 4); | |
| this.debugCaptured = true; | |
| } | |
| } | |
| rmsT(enc, xBuf, gBuf, yBuf, T, K) { | |
| const imm = new Float32Array([K, this.cfg.rmsNormEps]); | |
| const useF16 = this.usingF16() && this.pipes.rmsTF16; | |
| const pipe = useF16 ? this.pipes.rmsTF16 : this.pipes.rmsT; | |
| this._dispatch(enc, pipe, this._bg(pipe, [xBuf, gBuf, yBuf]), T, 1, useF16 ? "rmsTF16" : "rmsT", imm); | |
| } | |
| ropeT(enc, xBuf, T, nHeads, pos0 = 0) { | |
| const hd = this.cfg.headDim; | |
| const imm = new Uint32Array([nHeads, hd, T, pos0]); | |
| const useF16 = this.usingF16() && this.pipes.ropeTF16; | |
| const pipe = useF16 ? this.pipes.ropeTF16 : this.pipes.ropeT; | |
| this._dispatch( | |
| enc, | |
| pipe, | |
| this._bg(pipe, [xBuf, this.ropeCos, this.ropeSin]), | |
| Math.ceil(T * nHeads * (hd / 2) / 256), | |
| 1, | |
| useF16 ? "ropeTF16" : "ropeT", | |
| imm | |
| ); | |
| } | |
| attnPrefill(enc, qBuf, kc, vc, oBuf, T, qStart = 0, ctx = T) { | |
| const c = this.cfg; | |
| if (this.features.prefillAttention === "block" || qStart !== 0 || ctx !== T) { | |
| const imm = new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, T, qStart, ctx, 0, 0]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.attnPrefillBlock, | |
| this._bg(this.pipes.attnPrefillBlock, [qBuf, kc, vc, oBuf]), | |
| c.numHeads, | |
| Math.ceil(T / 4), | |
| "attnPrefillBlock", | |
| imm | |
| ); | |
| } else { | |
| const imm = new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, T]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.attnPrefill, | |
| this._bg(this.pipes.attnPrefill, [qBuf, kc, vc, oBuf]), | |
| c.numHeads, | |
| Math.ceil(T / 4), | |
| "attnPrefill", | |
| imm | |
| ); | |
| } | |
| } | |
| // (re)allocate prefill scratch sized to T (grows as needed; only paid when prefilling). | |
| _ensurePrefillScratch(T, loraRank = 0, idsCap = T) { | |
| if (this.sTcap >= T && (this.sTLoraRank || 0) >= loraRank && (this.sTidsCap || 0) >= idsCap) return; | |
| const need = this.estimatePrefillScratchBytes(T, loraRank); | |
| if (this.opts.maxPrefillScratchBytes && need > this.opts.maxPrefillScratchBytes) { | |
| throw new Error( | |
| `prefill scratch ${Math.ceil(need / 1048576)}MiB exceeds maxPrefillScratchBytes; lower maxPrefillT or use shorter prompt chunks` | |
| ); | |
| } | |
| if (this.sT) for (const k in this.sT) this.sT[k].destroy(); | |
| const c = this.cfg, H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize; | |
| this.sT = { | |
| hidden: this._buf(T * H * 4), | |
| normed: this._buf(T * H * 4), | |
| q: this._buf(T * qd * 4), | |
| k: this._buf(T * kvd * 4), | |
| v: this._buf(T * kvd * 4), | |
| attn: this._buf(T * qd * 4), | |
| tmp: this._buf(T * I * 4), | |
| tmp2: this._buf(T * I * 4), | |
| ids: this._buf(idsCap * 4), | |
| loraD: this._buf(Math.max(1, T * Math.max(1, loraRank)) * 4), | |
| x_q: this._buf(T * Math.max(H, I) * 4), | |
| scale_x: this._buf(T * Math.max(H, I) / 128 * 4) | |
| }; | |
| this.sTcap = T; | |
| this.sTLoraRank = loraRank; | |
| this.sTidsCap = idsCap; | |
| } | |
| _activeMaxLoraRank() { | |
| let rank = 0; | |
| const mods = this.lora?.modules; | |
| if (!mods) return 0; | |
| for (const key of Object.keys(mods)) rank = Math.max(rank, mods[key].rank || 0); | |
| return rank; | |
| } | |
| // Prefill the prompt (positions 0..T-1). Leaves last-row logits in s.logits and the | |
| // KV cache populated, so decode continues from pos=T. T must be <= maxPrefillT. | |
| prefillBatch(ids) { | |
| const T = ids.length; | |
| if (T > this.maxPrefillT) throw new Error(`prompt ${T} > maxPrefillT ${this.maxPrefillT}`); | |
| if (T > this.maxCtx) throw new Error(`prompt ${T} > maxCtx ${this.maxCtx}`); | |
| const chunk = this.features.prefillChunkSize; | |
| if (chunk > 0 && T > chunk) return this._prefillChunked(ids, chunk); | |
| return this._prefillFull(ids); | |
| } | |
| _prefillFull(ids) { | |
| const c = this.cfg, S = this.s, T = ids.length, hd = c.headDim, kvd = c.numKVHeads * hd, H = c.hiddenSize; | |
| this._ensurePrefillScratch(T, this._activeMaxLoraRank()); | |
| const ST = this.sT; | |
| this._resetUni(); | |
| this.dev.queue.writeBuffer(ST.ids, 0, new Uint32Array(ids)); | |
| const enc = this.dev.createCommandEncoder(); | |
| const e = this.q[this.plan.embed.name]; | |
| const imm = new Uint32Array([T, H, 0, 0]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.embedT, | |
| this._bg(this.pipes.embedT, [e.w, e.scale, ST.hidden, ST.ids]), | |
| Math.min(Math.ceil(T * H / 256), 65535), | |
| 1, | |
| "embedT", | |
| imm | |
| ); | |
| for (let i = 0; i < c.numLayers; i++) { | |
| const L = this.plan.layers[i]; | |
| this.rmsT(enc, ST.hidden, this.bufs[L.inputNorm], ST.normed, T, H); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, T); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.q.weight], | |
| ST.q, | |
| T, | |
| this.bufs[L.q.bias], | |
| L.q.loraKey | |
| ); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.k.weight], | |
| ST.k, | |
| T, | |
| this.bufs[L.k.bias], | |
| L.k.loraKey | |
| ); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.v.weight], | |
| ST.v, | |
| T, | |
| this.bufs[L.v.bias], | |
| L.v.loraKey | |
| ); | |
| } else { | |
| this.gemm4(enc, ST.normed, this.q4[L.q.weight], ST.q, T, this.bufs[L.q.bias], L.q.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.k.weight], ST.k, T, this.bufs[L.k.bias], L.k.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.v.weight], ST.v, T, this.bufs[L.v.bias], L.v.loraKey); | |
| } | |
| this.ropeT(enc, ST.q, T, c.numHeads); | |
| this.ropeT(enc, ST.k, T, c.numKVHeads); | |
| if (this.features.pagedAttention) { | |
| this.writeKvPageBatch(enc, ST.k, ST.v, this.kc[i], this.vc[i], T, 0, i); | |
| } else { | |
| enc.copyBufferToBuffer(ST.k, 0, this.kc[i], 0, T * kvd * 4); | |
| enc.copyBufferToBuffer(ST.v, 0, this.vc[i], 0, T * kvd * 4); | |
| } | |
| if (this.features.pagedAttention) { | |
| this.attnPrefillPaged(enc, ST.q, this.kc[i], this.vc[i], ST.attn, T, 0, T); | |
| } else { | |
| this.attnPrefill(enc, ST.q, this.kc[i], this.vc[i], ST.attn, T, 0, T); | |
| } | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.attn, ST.x_q, ST.scale_x, H, T); | |
| if (this.features.fuseResidual) { | |
| this.gemm4AddTW4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.hidden, T, null, L.o.loraKey); | |
| } else { | |
| this.gemm4W4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.tmp, T, null, L.o.loraKey); | |
| this._addInto(enc, ST.hidden, ST.tmp, T * H); | |
| } | |
| } else { | |
| if (this.features.fuseResidual) | |
| this.gemm4AddT(enc, ST.attn, this.q4[L.o.weight], ST.hidden, T, null, L.o.loraKey); | |
| else { | |
| this.gemm4(enc, ST.attn, this.q4[L.o.weight], ST.tmp, T, null, L.o.loraKey); | |
| this._addInto(enc, ST.hidden, ST.tmp, T * H); | |
| } | |
| } | |
| this.rmsT(enc, ST.hidden, this.bufs[L.postAttentionNorm], ST.normed, T, H); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, T); | |
| this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.gate.weight], ST.tmp, T, null, L.gate.loraKey); | |
| this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.up.weight], ST.tmp2, T, null, L.up.loraKey); | |
| } else { | |
| this.gemm4(enc, ST.normed, this.q4[L.gate.weight], ST.tmp, T, null, L.gate.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.up.weight], ST.tmp2, T, null, L.up.loraKey); | |
| } | |
| this._siluMul(enc, ST.tmp, ST.tmp2, T * c.intermediateSize); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.tmp, ST.x_q, ST.scale_x, c.intermediateSize, T); | |
| if (this.features.fuseResidual) { | |
| this.gemm4AddTW4A8( | |
| enc, | |
| ST.tmp, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.down.weight], | |
| ST.hidden, | |
| T, | |
| null, | |
| L.down.loraKey | |
| ); | |
| } else { | |
| this.gemm4W4A8(enc, ST.tmp, ST.x_q, ST.scale_x, this.q4[L.down.weight], ST.normed, T, null, L.down.loraKey); | |
| this._addInto(enc, ST.hidden, ST.normed, T * H); | |
| } | |
| } else { | |
| if (this.features.fuseResidual) | |
| this.gemm4AddT(enc, ST.tmp, this.q4[L.down.weight], ST.hidden, T, null, L.down.loraKey); | |
| else { | |
| this.gemm4(enc, ST.tmp, this.q4[L.down.weight], ST.normed, T, null, L.down.loraKey); | |
| this._addInto(enc, ST.hidden, ST.normed, T * H); | |
| } | |
| } | |
| } | |
| enc.copyBufferToBuffer(ST.hidden, (T - 1) * H * 4, S.hidden, 0, H * 4); | |
| this.rms(enc, S.hidden, this.bufs[this.plan.finalNorm.name], S.normed, H); | |
| this.gemv(enc, S.normed, this.q[this.plan.embed.name], S.logits, null, null); | |
| this.dev.queue.submit([enc.finish()]); | |
| } | |
| _prefillChunked(ids, chunkSize) { | |
| const c = this.cfg, S = this.s, H = c.hiddenSize, hd = c.headDim, kvd = c.numKVHeads * hd; | |
| const T = ids.length; | |
| this._ensurePrefillScratch(Math.min(chunkSize, T), this._activeMaxLoraRank(), T); | |
| const ST = this.sT; | |
| this._resetUni(); | |
| this.dev.queue.writeBuffer(ST.ids, 0, new Uint32Array(ids)); | |
| const enc = this.dev.createCommandEncoder(); | |
| const e = this.q[this.plan.embed.name]; | |
| for (let off = 0; off < T; off += chunkSize) { | |
| const end = Math.min(T, off + chunkSize); | |
| const CT = end - off; | |
| this._dispatch( | |
| enc, | |
| this.pipes.embedT, | |
| this._bg(this.pipes.embedT, [e.w, e.scale, ST.hidden, ST.ids]), | |
| Math.min(Math.ceil(CT * H / 256), 65535), | |
| 1, | |
| "embedT", | |
| new Uint32Array([CT, H, off, 0]) | |
| ); | |
| for (let i = 0; i < c.numLayers; i++) { | |
| const L = this.plan.layers[i]; | |
| this.rmsT(enc, ST.hidden, this.bufs[L.inputNorm], ST.normed, CT, H); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, CT); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.q.weight], | |
| ST.q, | |
| CT, | |
| this.bufs[L.q.bias], | |
| L.q.loraKey | |
| ); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.k.weight], | |
| ST.k, | |
| CT, | |
| this.bufs[L.k.bias], | |
| L.k.loraKey | |
| ); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.v.weight], | |
| ST.v, | |
| CT, | |
| this.bufs[L.v.bias], | |
| L.v.loraKey | |
| ); | |
| } else { | |
| this.gemm4(enc, ST.normed, this.q4[L.q.weight], ST.q, CT, this.bufs[L.q.bias], L.q.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.k.weight], ST.k, CT, this.bufs[L.k.bias], L.k.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.v.weight], ST.v, CT, this.bufs[L.v.bias], L.v.loraKey); | |
| } | |
| this.ropeT(enc, ST.q, CT, c.numHeads, off); | |
| this.ropeT(enc, ST.k, CT, c.numKVHeads, off); | |
| if (this.features.pagedAttention) { | |
| this.writeKvPageBatch(enc, ST.k, ST.v, this.kc[i], this.vc[i], CT, off, i); | |
| } else { | |
| enc.copyBufferToBuffer(ST.k, 0, this.kc[i], off * kvd * 4, CT * kvd * 4); | |
| enc.copyBufferToBuffer(ST.v, 0, this.vc[i], off * kvd * 4, CT * kvd * 4); | |
| } | |
| if (this.features.pagedAttention) { | |
| this.attnPrefillPaged(enc, ST.q, this.kc[i], this.vc[i], ST.attn, CT, off, end); | |
| } else { | |
| this.attnPrefill(enc, ST.q, this.kc[i], this.vc[i], ST.attn, CT, off, end); | |
| } | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.attn, ST.x_q, ST.scale_x, H, CT); | |
| if (this.features.fuseResidual) { | |
| this.gemm4AddTW4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.hidden, CT, null, L.o.loraKey); | |
| } else { | |
| this.gemm4W4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.tmp, CT, null, L.o.loraKey); | |
| this._addInto(enc, ST.hidden, ST.tmp, CT * H); | |
| } | |
| } else { | |
| if (this.features.fuseResidual) | |
| this.gemm4AddT(enc, ST.attn, this.q4[L.o.weight], ST.hidden, CT, null, L.o.loraKey); | |
| else { | |
| this.gemm4(enc, ST.attn, this.q4[L.o.weight], ST.tmp, CT, null, L.o.loraKey); | |
| this._addInto(enc, ST.hidden, ST.tmp, CT * H); | |
| } | |
| } | |
| this.rmsT(enc, ST.hidden, this.bufs[L.postAttentionNorm], ST.normed, CT, H); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, CT); | |
| this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.gate.weight], ST.tmp, CT, null, L.gate.loraKey); | |
| this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.up.weight], ST.tmp2, CT, null, L.up.loraKey); | |
| } else { | |
| this.gemm4(enc, ST.normed, this.q4[L.gate.weight], ST.tmp, CT, null, L.gate.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.up.weight], ST.tmp2, CT, null, L.up.loraKey); | |
| } | |
| this._siluMul(enc, ST.tmp, ST.tmp2, CT * c.intermediateSize); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.tmp, ST.x_q, ST.scale_x, c.intermediateSize, CT); | |
| if (this.features.fuseResidual) { | |
| this.gemm4AddTW4A8( | |
| enc, | |
| ST.tmp, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.down.weight], | |
| ST.hidden, | |
| CT, | |
| null, | |
| L.down.loraKey | |
| ); | |
| } else { | |
| this.gemm4W4A8( | |
| enc, | |
| ST.tmp, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.down.weight], | |
| ST.normed, | |
| CT, | |
| null, | |
| L.down.loraKey | |
| ); | |
| this._addInto(enc, ST.hidden, ST.normed, CT * H); | |
| } | |
| } else { | |
| if (this.features.fuseResidual) | |
| this.gemm4AddT(enc, ST.tmp, this.q4[L.down.weight], ST.hidden, CT, null, L.down.loraKey); | |
| else { | |
| this.gemm4(enc, ST.tmp, this.q4[L.down.weight], ST.normed, CT, null, L.down.loraKey); | |
| this._addInto(enc, ST.hidden, ST.normed, CT * H); | |
| } | |
| } | |
| } | |
| if (end === T) { | |
| enc.copyBufferToBuffer(ST.hidden, (CT - 1) * H * 4, S.hidden, 0, H * 4); | |
| } | |
| } | |
| this.rms(enc, S.hidden, this.bufs[this.plan.finalNorm.name], S.normed, H); | |
| this.gemv(enc, S.normed, this.q[this.plan.embed.name], S.logits, null, null); | |
| this.dev.queue.submit([enc.finish()]); | |
| } | |
| async speculativeDecode(draftModel, promptIds, maxNewTokens, onToken) { | |
| await this.prefillBatch(promptIds); | |
| await draftModel.prefillBatch(promptIds); | |
| let currentPos = promptIds.length; | |
| const generatedIds = []; | |
| let nextToken = await this.argmaxLogits(); | |
| generatedIds.push(nextToken); | |
| if (onToken) onToken(nextToken); | |
| draftModel.dev.queue.writeBuffer(draftModel.s.amax, 0, new Uint32Array([nextToken])); | |
| this.dev.queue.writeBuffer(this.s.amax, 0, new Uint32Array([nextToken])); | |
| const gamma = 4; | |
| while (generatedIds.length < maxNewTokens) { | |
| const draftCandidates = await draftModel.decodeBatch(currentPos, gamma); | |
| if (draftCandidates.length === 0) break; | |
| const T = draftCandidates.length; | |
| this._resetUni(); | |
| this._ensurePrefillScratch(T, this._activeMaxLoraRank()); | |
| const ST = this.sT; | |
| const c = this.cfg, H = c.hiddenSize, kvd = c.numKVHeads * c.headDim; | |
| this.dev.queue.writeBuffer(ST.ids, 0, new Uint32Array(draftCandidates)); | |
| const enc = this.dev.createCommandEncoder(); | |
| const e = this.q[this.plan.embed.name]; | |
| const embedUni = new Uint32Array([T, H, 0, 0]); | |
| this._dispatch( | |
| enc, | |
| this.pipes.embedT, | |
| this._bg(this.pipes.embedT, [e.w, e.scale, ST.hidden, ST.ids]), | |
| Math.min(Math.ceil(T * H / 256), 65535), | |
| 1, | |
| "embedT", | |
| embedUni | |
| ); | |
| for (let i = 0; i < c.numLayers; i++) { | |
| const L = this.plan.layers[i]; | |
| this.rmsT(enc, ST.hidden, this.bufs[L.inputNorm], ST.normed, T, H); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, T); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.q.weight], | |
| ST.q, | |
| T, | |
| this.bufs[L.q.bias], | |
| L.q.loraKey | |
| ); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.k.weight], | |
| ST.k, | |
| T, | |
| this.bufs[L.k.bias], | |
| L.k.loraKey | |
| ); | |
| this.gemm4W4A8( | |
| enc, | |
| ST.normed, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.v.weight], | |
| ST.v, | |
| T, | |
| this.bufs[L.v.bias], | |
| L.v.loraKey | |
| ); | |
| } else { | |
| this.gemm4(enc, ST.normed, this.q4[L.q.weight], ST.q, T, this.bufs[L.q.bias], L.q.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.k.weight], ST.k, T, this.bufs[L.k.bias], L.k.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.v.weight], ST.v, T, this.bufs[L.v.bias], L.v.loraKey); | |
| } | |
| this.ropeT(enc, ST.q, T, c.numHeads, currentPos); | |
| this.ropeT(enc, ST.k, T, c.numKVHeads, currentPos); | |
| if (this.features.pagedAttention) { | |
| this.writeKvPageBatch(enc, ST.k, ST.v, this.kc[i], this.vc[i], T, currentPos, i); | |
| } else { | |
| enc.copyBufferToBuffer(ST.k, 0, this.kc[i], currentPos * kvd * 4, T * kvd * 4); | |
| enc.copyBufferToBuffer(ST.v, 0, this.vc[i], currentPos * kvd * 4, T * kvd * 4); | |
| } | |
| if (this.features.pagedAttention) { | |
| this.attnPrefillPaged(enc, ST.q, this.kc[i], this.vc[i], ST.attn, T, currentPos, currentPos + T); | |
| } else { | |
| this.attnPrefill(enc, ST.q, this.kc[i], this.vc[i], ST.attn, T, currentPos, currentPos + T); | |
| } | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.attn, ST.x_q, ST.scale_x, H, T); | |
| if (this.features.fuseResidual) { | |
| this.gemm4AddTW4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.hidden, T, null, L.o.loraKey); | |
| } else { | |
| this.gemm4W4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.tmp, T, null, L.o.loraKey); | |
| this._addInto(enc, ST.hidden, ST.tmp, T * H); | |
| } | |
| } else { | |
| if (this.features.fuseResidual) | |
| this.gemm4AddT(enc, ST.attn, this.q4[L.o.weight], ST.hidden, T, null, L.o.loraKey); | |
| else { | |
| this.gemm4(enc, ST.attn, this.q4[L.o.weight], ST.tmp, T, null, L.o.loraKey); | |
| this._addInto(enc, ST.hidden, ST.tmp, T * H); | |
| } | |
| } | |
| this.rmsT(enc, ST.hidden, this.bufs[L.postAttentionNorm], ST.normed, T, H); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, T); | |
| this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.gate.weight], ST.tmp, T, null, L.gate.loraKey); | |
| this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.up.weight], ST.tmp2, T, null, L.up.loraKey); | |
| } else { | |
| this.gemm4(enc, ST.normed, this.q4[L.gate.weight], ST.tmp, T, null, L.gate.loraKey); | |
| this.gemm4(enc, ST.normed, this.q4[L.up.weight], ST.tmp2, T, null, L.up.loraKey); | |
| } | |
| this._siluMul(enc, ST.tmp, ST.tmp2, T * c.intermediateSize); | |
| if (this.features.actQuant) { | |
| this.dynQuantT(enc, ST.tmp, ST.x_q, ST.scale_x, c.intermediateSize, T); | |
| if (this.features.fuseResidual) { | |
| this.gemm4AddTW4A8( | |
| enc, | |
| ST.tmp, | |
| ST.x_q, | |
| ST.scale_x, | |
| this.q4[L.down.weight], | |
| ST.hidden, | |
| T, | |
| null, | |
| L.down.loraKey | |
| ); | |
| } else { | |
| this.gemm4W4A8(enc, ST.tmp, ST.x_q, ST.scale_x, this.q4[L.down.weight], ST.normed, T, null, L.down.loraKey); | |
| this._addInto(enc, ST.hidden, ST.normed, T * H); | |
| } | |
| } else { | |
| if (this.features.fuseResidual) | |
| this.gemm4AddT(enc, ST.tmp, this.q4[L.down.weight], ST.hidden, T, null, L.down.loraKey); | |
| else { | |
| this.gemm4(enc, ST.tmp, this.q4[L.down.weight], ST.normed, T, null, L.down.loraKey); | |
| this._addInto(enc, ST.hidden, ST.normed, T * H); | |
| } | |
| } | |
| } | |
| if (!this.s.logitsT || this.sTcap < T) { | |
| if (this.s.logitsT) this.s.logitsT.destroy(); | |
| this.s.logitsT = this._buf(T * c.vocabSize * 4); | |
| if (this.logitsTRead) this.logitsTRead.destroy(); | |
| this.logitsTRead = this._buf(T * c.vocabSize * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); | |
| } | |
| for (let t = 0; t < T; t++) { | |
| enc.copyBufferToBuffer(ST.hidden, t * H * 4, this.s.hidden, 0, H * 4); | |
| this.rms(enc, this.s.hidden, this.bufs[this.plan.finalNorm.name], this.s.normed, H); | |
| this.gemv(enc, this.s.normed, this.q[this.plan.embed.name], this.s.logits, null, null); | |
| enc.copyBufferToBuffer(this.s.logits, 0, this.s.logitsT, t * c.vocabSize * 4, c.vocabSize * 4); | |
| } | |
| enc.copyBufferToBuffer(this.s.logitsT, 0, this.logitsTRead, 0, T * c.vocabSize * 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.logitsTRead.mapAsync(GPUMapMode.READ); | |
| const logitsArray = new Float32Array(this.logitsTRead.getMappedRange()); | |
| let acceptedCount = 0; | |
| let targetToken = 0; | |
| for (let t = 0; t < T; t++) { | |
| let maxVal = -1e30; | |
| let argmaxId = 0; | |
| const offset = t * c.vocabSize; | |
| for (let v = 0; v < c.vocabSize; v++) { | |
| const l = logitsArray[offset + v]; | |
| if (l > maxVal) { | |
| maxVal = l; | |
| argmaxId = v; | |
| } | |
| } | |
| targetToken = argmaxId; | |
| if (t < T) { | |
| if (draftCandidates[t] === targetToken) { | |
| acceptedCount++; | |
| } else { | |
| break; | |
| } | |
| } | |
| } | |
| this.logitsTRead.unmap(); | |
| for (let a = 0; a < acceptedCount; a++) { | |
| generatedIds.push(draftCandidates[a]); | |
| if (onToken) onToken(draftCandidates[a]); | |
| } | |
| generatedIds.push(targetToken); | |
| if (onToken) onToken(targetToken); | |
| const nextPos = currentPos + acceptedCount + 1; | |
| this.dev.queue.writeBuffer(this.s.amax, 0, new Uint32Array([targetToken])); | |
| draftModel.dev.queue.writeBuffer(draftModel.s.amax, 0, new Uint32Array([targetToken])); | |
| if (this.features.pagedAttention) { | |
| this.pam.ensureBlocks(0, nextPos); | |
| } | |
| currentPos = nextPos; | |
| } | |
| return generatedIds; | |
| } | |
| // Simple high-level generation helper (Phase 5 wiring). | |
| // If opts.sample === true, uses the GPU sampler (sampleToken) with given temp; | |
| // otherwise falls back to argmax (greedy). | |
| // This makes sampleToken part of the real generation path. | |
| async generate(promptIds, maxNewTokens = 32, opts = {}) { | |
| const doSample = !!opts.sample; | |
| const temp = opts.temp != null && opts.temp > 0 ? opts.temp : 1; | |
| await this.prefillBatch(promptIds); | |
| const generatedIds = []; | |
| let pos = promptIds.length; | |
| let next = doSample ? await this.sampleToken(temp) : await this.argmaxLogits(); | |
| generatedIds.push(next); | |
| if (opts.onToken) opts.onToken(next); | |
| this.dev.queue.writeBuffer(this.s.amax, 0, new Uint32Array([next])); | |
| while (generatedIds.length < maxNewTokens) { | |
| this._resetUni(); | |
| const enc = this.dev.createCommandEncoder(); | |
| this.embedFromBuf(enc); | |
| this.step(enc, 0, pos); | |
| this.dev.queue.submit([enc.finish()]); | |
| next = doSample ? await this.sampleToken(temp) : await this.argmaxLogits(); | |
| generatedIds.push(next); | |
| if (opts.onToken) opts.onToken(next); | |
| this.dev.queue.writeBuffer(this.s.amax, 0, new Uint32Array([next])); | |
| pos += 1; | |
| } | |
| return generatedIds; | |
| } | |
| setupDebugCapture(T, K, rank, N) { | |
| this.debugCapture = true; | |
| this.debugT = T; | |
| this.debugK = K; | |
| this.debugRank = rank; | |
| this.debugN = N; | |
| this.debugStep = 0; | |
| this.debugCaptured = false; | |
| this.debugBufs = { | |
| xSeq: this._buf(T * K * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), | |
| dSeq: this._buf(T * rank * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), | |
| ySeq: this._buf(T * N * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), | |
| xBat: this._buf(T * K * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), | |
| dBat: this._buf(T * rank * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), | |
| yBat: this._buf(T * N * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ) | |
| }; | |
| } | |
| async readDebugCapture() { | |
| this.debugCapture = false; | |
| const bufs = this.debugBufs; | |
| if (!bufs) return null; | |
| await Promise.all([ | |
| bufs.xSeq.mapAsync(GPUMapMode.READ), | |
| bufs.dSeq.mapAsync(GPUMapMode.READ), | |
| bufs.ySeq.mapAsync(GPUMapMode.READ), | |
| bufs.xBat.mapAsync(GPUMapMode.READ), | |
| bufs.dBat.mapAsync(GPUMapMode.READ), | |
| bufs.yBat.mapAsync(GPUMapMode.READ) | |
| ]); | |
| const res = { | |
| xSeq: new Float32Array(bufs.xSeq.getMappedRange()).slice(), | |
| dSeq: new Float32Array(bufs.dSeq.getMappedRange()).slice(), | |
| ySeq: new Float32Array(bufs.ySeq.getMappedRange()).slice(), | |
| xBat: new Float32Array(bufs.xBat.getMappedRange()).slice(), | |
| dBat: new Float32Array(bufs.dBat.getMappedRange()).slice(), | |
| yBat: new Float32Array(bufs.yBat.getMappedRange()).slice() | |
| }; | |
| bufs.xSeq.unmap(); | |
| bufs.xSeq.destroy(); | |
| bufs.dSeq.unmap(); | |
| bufs.dSeq.destroy(); | |
| bufs.ySeq.unmap(); | |
| bufs.ySeq.destroy(); | |
| bufs.xBat.unmap(); | |
| bufs.xBat.destroy(); | |
| bufs.dBat.unmap(); | |
| bufs.dBat.destroy(); | |
| bufs.yBat.unmap(); | |
| bufs.yBat.destroy(); | |
| this.debugBufs = null; | |
| return res; | |
| } | |
| }; | |
| var PagedAttentionManager = class { | |
| static { | |
| __name(this, "PagedAttentionManager"); | |
| } | |
| constructor(maxCtx, pageSize = 16) { | |
| this.pageSize = pageSize; | |
| this.maxCtx = maxCtx; | |
| this.maxBlocksPerSeq = Math.ceil(maxCtx / pageSize); | |
| this.freeBlocks = []; | |
| this.seqBlocks = /* @__PURE__ */ new Map(); | |
| const totalBlocks = this.maxBlocksPerSeq * 4; | |
| for (let i = 0; i < totalBlocks; i++) { | |
| this.freeBlocks.push(i); | |
| } | |
| } | |
| allocateSeq(seqId) { | |
| this.seqBlocks.set(seqId, []); | |
| } | |
| freeSeq(seqId) { | |
| const blocks = this.seqBlocks.get(seqId) || []; | |
| this.freeBlocks.push(...blocks); | |
| this.seqBlocks.delete(seqId); | |
| } | |
| ensureBlocks(seqId, numTokens) { | |
| const neededBlocks = Math.ceil(numTokens / this.pageSize); | |
| const blocks = this.seqBlocks.get(seqId); | |
| if (!blocks) throw new Error(`Sequence ${seqId} not allocated`); | |
| while (blocks.length < neededBlocks) { | |
| if (this.freeBlocks.length === 0) { | |
| const newBlock = blocks.length + 1e3; | |
| this.freeBlocks.push(newBlock); | |
| } | |
| blocks.push(this.freeBlocks.pop()); | |
| } | |
| return blocks; | |
| } | |
| getBlockTableArray(seqId) { | |
| const blocks = this.seqBlocks.get(seqId) || []; | |
| const arr = new Uint32Array(this.maxBlocksPerSeq); | |
| arr.set(blocks); | |
| return arr; | |
| } | |
| }; | |
| // src/services/device_service.js | |
| async function initWebGPUDevice({ log: log2 = /* @__PURE__ */ __name(() => { | |
| }, "log") } = {}) { | |
| log2("requesting WebGPU device\u2026"); | |
| const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" }); | |
| if (!adapter) throw new Error("no WebGPU adapter (use a WebGPU-capable browser)"); | |
| if (!navigator.gpu.wgslLanguageFeatures?.has("immediate_address_space")) | |
| throw new Error("WGSL immediate_address_space is not available (upgrade to Chrome 149+)"); | |
| if (!adapter.features.has("subgroups")) | |
| throw new Error( | |
| 'GPU lacks the required "subgroups" feature. The current fast WGSL kernels require subgroups and no fallback kernel set is bundled.' | |
| ); | |
| const hasSubgroupId = !!navigator.gpu.wgslLanguageFeatures?.has("subgroup_id"); | |
| const hasLinearIndexing = !!navigator.gpu.wgslLanguageFeatures?.has("linear_indexing"); | |
| const hasF16 = adapter.features.has("shader-f16"); | |
| const hasTimestamp = adapter.features.has("timestamp-query"); | |
| const reqFeatures = ["subgroups"]; | |
| if (adapter.features.has("shader-f16")) reqFeatures.push("shader-f16"); | |
| if (hasTimestamp) reqFeatures.push("timestamp-query"); | |
| const dev = await adapter.requestDevice({ | |
| requiredFeatures: reqFeatures, | |
| requiredLimits: { | |
| maxBufferSize: adapter.limits.maxBufferSize, | |
| maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize, | |
| maxStorageBuffersPerShaderStage: adapter.limits.maxStorageBuffersPerShaderStage | |
| } | |
| }); | |
| dev.addEventListener?.("uncapturederror", (e) => console.error("GPUERR", e.error.message)); | |
| log2(`WebGPU ready. maxBuffer=${(Number(adapter.limits.maxBufferSize) / 1e9).toFixed(2)}GB subgroupId=${hasSubgroupId} linearIdx=${hasLinearIndexing} f16=${hasF16} tsQuery=${hasTimestamp}`); | |
| return dev; | |
| } | |
| __name(initWebGPUDevice, "initWebGPUDevice"); | |
| // src/services/prompt_formatter.js | |
| function chatML(messages) { | |
| let s = messages[0]?.role === "system" ? "" : "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"; | |
| for (const m of messages) s += `<|im_start|>${m.role} | |
| ${m.content}<|im_end|> | |
| `; | |
| return s + "<|im_start|>assistant\n"; | |
| } | |
| __name(chatML, "chatML"); | |
| function formatMessages(tokenizer, messages) { | |
| try { | |
| return tokenizer.apply_chat_template(messages, { tokenize: false, add_generation_prompt: true }); | |
| } catch { | |
| return chatML(messages); | |
| } | |
| } | |
| __name(formatMessages, "formatMessages"); | |
| // src/services/model_session.js | |
| async function buildTokenizer(reader) { | |
| const tj = JSON.parse(await reader.text("tokenizer.json")); | |
| const tc = JSON.parse(await reader.text("tokenizer_config.json")); | |
| const { PreTrainedTokenizer } = await import("@huggingface/transformers"); | |
| return new PreTrainedTokenizer(tj, tc); | |
| } | |
| __name(buildTokenizer, "buildTokenizer"); | |
| function randomUnit() { | |
| if (globalThis.crypto?.getRandomValues) { | |
| const u = new Uint32Array(1); | |
| globalThis.crypto.getRandomValues(u); | |
| return u[0] / 4294967296; | |
| } | |
| return Math.random(); | |
| } | |
| __name(randomUnit, "randomUnit"); | |
| function sampleTopK(candidates, { temperature, topP = 1 }) { | |
| if (!temperature || temperature <= 0) return candidates[0]?.id ?? 0; | |
| const best = candidates[0]?.logit ?? 0; | |
| const weighted = candidates.map((c2) => ({ id: c2.id, w: Math.exp((c2.logit - best) / temperature) })); | |
| let sum = weighted.reduce((a, c2) => a + c2.w, 0); | |
| if (topP > 0 && topP < 1 && weighted.length > 1 && sum > 0) { | |
| let csum = 0, keep = 0; | |
| for (; keep < weighted.length; keep++) { | |
| csum += weighted[keep].w / sum; | |
| if (csum >= topP) { | |
| keep++; | |
| break; | |
| } | |
| } | |
| weighted.length = Math.max(1, keep); | |
| sum = weighted.reduce((a, c2) => a + c2.w, 0); | |
| } | |
| let r = randomUnit() * sum, c = 0; | |
| for (const item of weighted) { | |
| c += item.w; | |
| if (r <= c) return item.id; | |
| } | |
| return weighted[weighted.length - 1]?.id ?? candidates[0]?.id ?? 0; | |
| } | |
| __name(sampleTopK, "sampleTopK"); | |
| var ModelSession = class { | |
| static { | |
| __name(this, "ModelSession"); | |
| } | |
| constructor({ cfg = QWEN25_3B, log: log2 = /* @__PURE__ */ __name(() => { | |
| }, "log"), runtimeOptions = {} } = {}) { | |
| this.cfg = cfg; | |
| this.log = log2; | |
| this.runtimeOptions = { decodeBatchSize: "auto", samplingTopK: 40, ...runtimeOptions }; | |
| this.dev = null; | |
| this.rt = null; | |
| this.tokenizer = null; | |
| } | |
| async loadWith(reader, label) { | |
| this.dev = await initWebGPUDevice({ log: this.log }); | |
| this.log(`loading tokenizer from ${label}\u2026`); | |
| this.tokenizer = await buildTokenizer(reader); | |
| this.log(`tokenizer loaded. streaming + quantizing weights (int4) from ${label}\u2026`); | |
| const t0 = performance.now(); | |
| this.rt = new QwenWGPU(this.dev, this.cfg, this.runtimeOptions); | |
| await this.rt.build(reader, (msg, frac) => this.log(`weights: ${msg} ${(frac * 100).toFixed(0)}%`)); | |
| window.__rt = this.rt; | |
| window.__tokenizer = this.tokenizer; | |
| const tuning = this.rt.decodeBatchTuning; | |
| const tuned = tuning ? ` decodeBatch=${tuning.selected} (${tuning.reason})` : ""; | |
| this.log( | |
| `READY in ${((performance.now() - t0) / 1e3).toFixed(1)}s \u2014 base loaded once; adapters hot-swap live.${tuned}` | |
| ); | |
| return this; | |
| } | |
| async readLogits() { | |
| const n = this.cfg.vocabSize; | |
| const rb = this.dev.createBuffer({ size: n * 4, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }); | |
| const enc = this.dev.createCommandEncoder(); | |
| enc.copyBufferToBuffer(this.rt.s.logits, 0, rb, 0, n * 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await rb.mapAsync(GPUMapMode.READ); | |
| const a = new Float32Array(rb.getMappedRange()).slice(); | |
| rb.unmap(); | |
| rb.destroy(); | |
| return a; | |
| } | |
| async sampleNextToken({ temperature, topK = this.rt.samplingTopK, topP = 1 } = {}) { | |
| return sampleTopK(await this.rt.topKLogits(topK), { temperature, topP }); | |
| } | |
| async *generate(messages, { maxTokens = 1024, temperature = 0, topK, topP = 1, stopIds = [151645, 151643] } = {}) { | |
| const rt = this.rt, tokenizer = this.tokenizer; | |
| const ids = tokenizer.encode(formatMessages(tokenizer, messages)); | |
| if (ids.length <= rt.maxPrefillT) rt.prefillBatch(ids); | |
| else for (let p = 0; p < ids.length; p++) rt.token(ids[p], p); | |
| let pos = ids.length; | |
| const emit = /* @__PURE__ */ __name((id) => tokenizer.decode([id], { skip_special_tokens: true }), "emit"); | |
| if (temperature > 0) { | |
| let next = await this.sampleNextToken({ temperature, topK, topP }); | |
| for (let step = 0; step < maxTokens; step++) { | |
| if (stopIds.includes(next)) break; | |
| const d = emit(next); | |
| if (d) yield d; | |
| rt.token(next, pos); | |
| pos++; | |
| next = await this.sampleNextToken({ temperature, topK, topP }); | |
| } | |
| return; | |
| } | |
| const first = await rt.argmaxLogits(); | |
| if (stopIds.includes(first)) return; | |
| { | |
| const d = emit(first); | |
| if (d) yield d; | |
| } | |
| let emitted = 1; | |
| while (emitted < maxTokens && pos < rt.maxCtx) { | |
| const K = rt.greedyBatchSizeFor({ emitted, remaining: maxTokens - emitted, pos }); | |
| const batch = await rt.decodeGreedyBatch(pos, K); | |
| pos += batch.length; | |
| let stop = false; | |
| for (const id of batch) { | |
| if (stopIds.includes(id)) { | |
| stop = true; | |
| break; | |
| } | |
| const d = emit(id); | |
| if (d) yield d; | |
| emitted++; | |
| if (emitted >= maxTokens) { | |
| stop = true; | |
| break; | |
| } | |
| } | |
| if (stop) break; | |
| } | |
| } | |
| }; | |
| // src/qwgpu/backward_kernels.js | |
| var GEMM_DX_INT4 = ` | |
| requires immediate_address_space; | |
| struct Meta { T:u32, N:u32, K:u32, gpr:u32 }; | |
| @group(0) @binding(0) var<storage,read> dY: array<f32>; // [T][N] | |
| @group(0) @binding(1) var<storage,read> W: array<u32>; // [N][K/8] int4 | |
| @group(0) @binding(2) var<storage,read> scaleW: array<f32>; // [N][gpr] | |
| @group(0) @binding(3) var<storage,read_write> dX: array<f32>; // [T][K] | |
| var<immediate> m: Meta; | |
| fn deq4(n: u32, k: u32, K8: u32) -> f32 { | |
| let word = W[n*K8 + (k >> 3u)]; | |
| let shift = (k & 7u) * 4u; | |
| let nib = i32(word << (28u - shift)) >> 28u; | |
| return f32(nib) * scaleW[n*m.gpr + (k >> 7u)]; | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let total = m.T * m.K; let stride = nwg.x * 256u; let K8 = m.K / 8u; | |
| for (var i = gid.x; i < total; i = i + stride) { | |
| let t = i / m.K; let k = i % m.K; | |
| var acc = 0.0; | |
| let yb = t * m.N; | |
| for (var n = 0u; n < m.N; n = n + 1u) { acc = acc + dY[yb + n] * deq4(n, k, K8); } | |
| dX[i] = dX[i] + acc; | |
| } | |
| }`; | |
| var LORA_DD = ` | |
| requires immediate_address_space; | |
| struct Meta { T:u32, N:u32, rank:u32, p:u32, scale:f32, f0:f32, f1:f32, f2:f32 }; | |
| @group(0) @binding(0) var<storage,read> dY: array<f32>; // [T][N] | |
| @group(0) @binding(1) var<storage,read> B: array<f32>; // [rank][N] | |
| @group(0) @binding(2) var<storage,read_write> dD: array<f32>; // [T][rank] | |
| var<immediate> m: Meta; | |
| var<workgroup> part: array<f32, 256>; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let idx = wid.x; let t = idx / m.rank; let r = idx % m.rank; let tid = lid.x; | |
| if (t >= m.T) { return; } | |
| var s = 0.0; let yb = t*m.N; let bb = r*m.N; | |
| for (var n = tid; n < m.N; n = n + 256u) { s = s + dY[yb + n] * B[bb + n]; } | |
| part[tid] = s; workgroupBarrier(); | |
| for (var st = 128u; st > 0u; st = st/2u) { if (tid < st) { part[tid] = part[tid] + part[tid+st]; } workgroupBarrier(); } | |
| if (tid == 0u) { dD[t*m.rank + r] = m.scale * part[0]; } | |
| }`; | |
| var LORA_GRAD_A = ` | |
| requires immediate_address_space; | |
| struct Meta { T:u32, K:u32, rank:u32, p:u32 }; | |
| @group(0) @binding(0) var<storage,read> dD: array<f32>; // [T][rank] | |
| @group(0) @binding(1) var<storage,read> X: array<f32>; // [T][K] | |
| @group(0) @binding(2) var<storage,read_write> dA: array<f32>; // [rank][K] | |
| var<immediate> m: Meta; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let total = m.rank * m.K; let stride = nwg.x * 256u; | |
| for (var i = gid.x; i < total; i = i + stride) { | |
| let r = i / m.K; let k = i % m.K; | |
| var acc = 0.0; | |
| for (var t = 0u; t < m.T; t = t + 1u) { acc = acc + dD[t*m.rank + r] * X[t*m.K + k]; } | |
| dA[i] = dA[i] + acc; | |
| } | |
| }`; | |
| var LORA_GRAD_B = ` | |
| requires immediate_address_space; | |
| struct Meta { T:u32, N:u32, rank:u32, p:u32, scale:f32, f0:f32, f1:f32, f2:f32 }; | |
| @group(0) @binding(0) var<storage,read> D: array<f32>; // [T][rank] | |
| @group(0) @binding(1) var<storage,read> dY: array<f32>; // [T][N] | |
| @group(0) @binding(2) var<storage,read_write> dB: array<f32>; // [rank][N] | |
| var<immediate> m: Meta; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let total = m.rank * m.N; let stride = nwg.x * 256u; | |
| for (var i = gid.x; i < total; i = i + stride) { | |
| let r = i / m.N; let n = i % m.N; | |
| var acc = 0.0; | |
| for (var t = 0u; t < m.T; t = t + 1u) { acc = acc + D[t*m.rank + r] * dY[t*m.N + n]; } | |
| dB[i] = dB[i] + m.scale * acc; | |
| } | |
| }`; | |
| var LORA_DX_ADD = ` | |
| requires immediate_address_space; | |
| struct Meta { T:u32, K:u32, rank:u32, p:u32 }; | |
| @group(0) @binding(0) var<storage,read> dD: array<f32>; // [T][rank] | |
| @group(0) @binding(1) var<storage,read> A: array<f32>; // [rank][K] | |
| @group(0) @binding(2) var<storage,read_write> dX: array<f32>; // [T][K] | |
| var<immediate> m: Meta; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let total = m.T * m.K; let stride = nwg.x * 256u; | |
| for (var i = gid.x; i < total; i = i + stride) { | |
| let t = i / m.K; let k = i % m.K; | |
| var acc = 0.0; | |
| for (var r = 0u; r < m.rank; r = r + 1u) { acc = acc + dD[t*m.rank + r] * A[r*m.K + k]; } | |
| dX[i] = dX[i] + acc; | |
| } | |
| }`; | |
| var RMSNORM_BWD_T = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; // [T][K] | |
| @group(0) @binding(1) var<storage,read> g: array<f32>; // [K] | |
| @group(0) @binding(2) var<storage,read> dy: array<f32>; // [T][K] | |
| @group(0) @binding(3) var<storage,read_write> dx: array<f32>; // [T][K] | |
| var<immediate> m: vec2<f32>; // K, eps | |
| var<workgroup> red: array<f32, 256>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; let K = u32(m.x); let base = wid.x * K; | |
| // sum of squares for inv | |
| var ss = 0.0; | |
| for (var k = tid; k < K; k = k + WG) { let v = x[base+k]; ss = ss + v*v; } | |
| red[tid] = ss; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } | |
| let ms = red[0] / m.x; | |
| let inv = inverseSqrt(ms + m.y); | |
| workgroupBarrier(); | |
| // c = sum dy*g*x | |
| var cc = 0.0; | |
| for (var k = tid; k < K; k = k + WG) { cc = cc + dy[base+k]*g[k]*x[base+k]; } | |
| red[tid] = cc; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } | |
| let c = red[0]; | |
| let inv3overK = inv*inv*inv / m.x; | |
| for (var k = tid; k < K; k = k + WG) { | |
| dx[base+k] = inv*g[k]*dy[base+k] - inv3overK * x[base+k] * c; | |
| } | |
| }`; | |
| var SWIGLU_BWD = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> gate: array<f32>; | |
| @group(0) @binding(1) var<storage,read> up: array<f32>; | |
| @group(0) @binding(2) var<storage,read> dOut: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> dGate: array<f32>; | |
| @group(0) @binding(4) var<storage,read_write> dUp: array<f32>; | |
| var<immediate> n: u32; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let stride = nwg.x * WG; | |
| for (var i = gid.x; i < n; i = i + stride) { | |
| let z = gate[i]; let sig = 1.0/(1.0+exp(-z)); let sl = z*sig; | |
| let d = dOut[i]; | |
| dUp[i] = d * sl; | |
| dGate[i] = d * up[i] * (sig * (1.0 + z*(1.0 - sig))); | |
| } | |
| }`; | |
| var ROPE_BWD_T = ` | |
| requires immediate_address_space; | |
| @group(0) @binding(0) var<storage,read_write> dx: array<f32>; // [T][nHeads*headDim] gradient | |
| @group(0) @binding(1) var<storage,read> cosT: array<f32>; | |
| @group(0) @binding(2) var<storage,read> sinT: array<f32>; | |
| var<immediate> m: vec4<u32>; // nHeads, headDim, T, pos0 | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let H = m.x; let D = m.y; let T = m.z; let pos0 = m.w; let half = D/2u; | |
| let perRow = H*half; if (g >= T*perRow) { return; } | |
| let row = g / perRow; let r = g % perRow; let h = r / half; let j = r % half; | |
| let rb = row*H*D; let lo = rb + h*D + j; let hi = lo + half; let off = (pos0+row)*D + j; | |
| let c = cosT[off]; let s = sinT[off]; | |
| let dl = dx[lo]; let dh = dx[hi]; | |
| dx[lo] = c*dl + s*dh; | |
| dx[hi] = -s*dl + c*dh; | |
| }`; | |
| var ATTN_BWD_STATS = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 128u; | |
| struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; // [T][nHeads*hd] | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; // [T][nKV*hd] | |
| @group(0) @binding(2) var<storage,read> o: array<f32>; // [T][nHeads*hd] attn output | |
| @group(0) @binding(3) var<storage,read> doo: array<f32>; // [T][nHeads*hd] grad of attn output | |
| @group(0) @binding(4) var<storage,read_write> lse: array<f32>; // [nHeads*T] | |
| @group(0) @binding(5) var<storage,read_write> delta: array<f32>; // [nHeads*T] | |
| var<immediate> m: Meta; | |
| var<workgroup> red: array<f32, 128>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let h = wid.x; let t = wid.y; let tid = lid.x; | |
| let hd = m.hd; let nKV = m.nKV; let kvh = h / (m.nHeads / nKV); | |
| let qb = t*m.nHeads*hd + h*hd; let kvstride = nKV*hd; let hoff = kvh*hd; | |
| let scl = 1.0 / sqrt(f32(hd)); | |
| // running max | |
| var lmax = -1e30; | |
| for (var j = tid; j <= t; j = j + WG) { | |
| var dot = 0.0; let kb = j*kvstride + hoff; | |
| for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qb+d]*kc[kb+d]; } | |
| lmax = max(lmax, dot*scl); | |
| } | |
| red[tid] = lmax; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = max(red[tid], red[tid+s]); } workgroupBarrier(); } | |
| let M = red[0]; | |
| workgroupBarrier(); | |
| var lsum = 0.0; | |
| for (var j = tid; j <= t; j = j + WG) { | |
| var dot = 0.0; let kb = j*kvstride + hoff; | |
| for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qb+d]*kc[kb+d]; } | |
| lsum = lsum + exp(dot*scl - M); | |
| } | |
| red[tid] = lsum; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } | |
| // delta | |
| var dl = 0.0; | |
| for (var d = tid; d < hd; d = d + WG) { dl = dl + doo[qb+d]*o[qb+d]; } | |
| // reuse red after sum captured | |
| let Z = red[0]; | |
| workgroupBarrier(); | |
| red[tid] = dl; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } | |
| if (tid == 0u) { lse[h*m.T + t] = M + log(Z); delta[h*m.T + t] = red[0]; } | |
| }`; | |
| var ATTN_BWD_DQ = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 128u; | |
| struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read> doo: array<f32>; | |
| @group(0) @binding(4) var<storage,read> lse: array<f32>; | |
| @group(0) @binding(5) var<storage,read> delta: array<f32>; | |
| @group(0) @binding(6) var<storage,read_write> dq: array<f32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> red: array<f32, 128>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let h = wid.x; let t = wid.y; let d = lid.x; | |
| let hd = m.hd; let nKV = m.nKV; let kvh = h / (m.nHeads / nKV); | |
| let qb = t*m.nHeads*hd + h*hd; let kvstride = nKV*hd; let hoff = kvh*hd; | |
| let scl = 1.0 / sqrt(f32(hd)); | |
| let lse_t = lse[h*m.T + t]; let delta_t = delta[h*m.T + t]; | |
| // Guard every storage read behind (d < hd): WGSL select() is eager and would | |
| // still evaluate the buffer load for inactive lanes (OOB when hd < WG). Barriers | |
| // stay at uniform control flow so the reductions remain valid. | |
| let inHd = d < hd; | |
| var acc = 0.0; | |
| for (var j = 0u; j <= t; j = j + 1u) { | |
| let kb = j*kvstride + hoff; | |
| // s = scl * dot(q, k_j) | |
| var sv = 0.0; if (inHd) { sv = q[qb+d] * kc[kb+d]; } | |
| red[d] = sv; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (d < s) { red[d] = red[d] + red[d+s]; } workgroupBarrier(); } | |
| let sval = red[0] * scl; | |
| workgroupBarrier(); | |
| // dp = dot(do, v_j) | |
| var dpv = 0.0; if (inHd) { dpv = doo[qb+d] * vc[kb+d]; } | |
| red[d] = dpv; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (d < s) { red[d] = red[d] + red[d+s]; } workgroupBarrier(); } | |
| let dp = red[0]; | |
| workgroupBarrier(); | |
| let p = exp(sval - lse_t); | |
| let ds = p * (dp - delta_t); | |
| if (inHd) { acc = acc + ds * kc[kb+d]; } | |
| } | |
| if (inHd) { dq[qb+d] = dq[qb+d] + scl * acc; } | |
| }`; | |
| var ATTN_BWD_DKV = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 128u; | |
| struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32 }; | |
| @group(0) @binding(0) var<storage,read> q: array<f32>; | |
| @group(0) @binding(1) var<storage,read> kc: array<f32>; | |
| @group(0) @binding(2) var<storage,read> vc: array<f32>; | |
| @group(0) @binding(3) var<storage,read> doo: array<f32>; | |
| @group(0) @binding(4) var<storage,read> lse: array<f32>; | |
| @group(0) @binding(5) var<storage,read> delta: array<f32>; | |
| @group(0) @binding(6) var<storage,read_write> dk: array<f32>; | |
| @group(0) @binding(7) var<storage,read_write> dv: array<f32>; | |
| var<immediate> m: Meta; | |
| var<workgroup> red: array<f32, 128>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let kvh = wid.x; let j = wid.y; let d = lid.x; | |
| let hd = m.hd; let nKV = m.nKV; let group = m.nHeads / nKV; | |
| let kvstride = nKV*hd; let hoff = kvh*hd; let kb = j*kvstride + hoff; | |
| let scl = 1.0 / sqrt(f32(hd)); | |
| // Guard storage reads behind (d < hd) \u2014 see ATTN_BWD_DQ note on eager select(). | |
| let inHd = d < hd; | |
| var dkacc = 0.0; var dvacc = 0.0; | |
| for (var hi = 0u; hi < group; hi = hi + 1u) { | |
| let h = kvh*group + hi; | |
| for (var t = j; t < m.T; t = t + 1u) { | |
| let qb = t*m.nHeads*hd + h*hd; | |
| var sv = 0.0; if (inHd) { sv = q[qb+d] * kc[kb+d]; } | |
| red[d] = sv; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (d < s) { red[d] = red[d] + red[d+s]; } workgroupBarrier(); } | |
| let sval = red[0] * scl; | |
| workgroupBarrier(); | |
| var dpv = 0.0; if (inHd) { dpv = doo[qb+d] * vc[kb+d]; } | |
| red[d] = dpv; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (d < s) { red[d] = red[d] + red[d+s]; } workgroupBarrier(); } | |
| let dp = red[0]; | |
| workgroupBarrier(); | |
| let p = exp(sval - lse[h*m.T + t]); | |
| let ds = p * (dp - delta[h*m.T + t]); | |
| if (inHd) { | |
| dkacc = dkacc + scl * ds * q[qb+d]; | |
| dvacc = dvacc + p * doo[qb+d]; | |
| } | |
| } | |
| } | |
| if (inHd) { dk[kb+d] = dk[kb+d] + dkacc; dv[kb+d] = dv[kb+d] + dvacc; } | |
| }`; | |
| var LOGITS_GEMM_I8 = ` | |
| requires immediate_address_space; | |
| struct Meta { T:u32, vocab:u32, K:u32, tOff:u32 }; | |
| @group(0) @binding(0) var<storage,read> normed: array<f32>; // [T][K] (full-seq buffer, offset by tOff) | |
| @group(0) @binding(1) var<storage,read> E: array<u32>; // [vocab][K/4] int8 | |
| @group(0) @binding(2) var<storage,read> scaleE: array<f32>; // [vocab] | |
| @group(0) @binding(3) var<storage,read_write> logits: array<f32>; // [Tblock][vocab] | |
| var<immediate> m: Meta; | |
| fn sx8(v: u32) -> i32 { | |
| return i32(v << 24u) >> 24u; | |
| } | |
| fn unpack4xI8(x: u32) -> vec4<i32> { | |
| return vec4<i32>( | |
| sx8(x & 0xffu), | |
| sx8((x >> 8u) & 0xffu), | |
| sx8((x >> 16u) & 0xffu), | |
| sx8((x >> 24u) & 0xffu) | |
| ); | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let total = m.T * m.vocab; let stride = nwg.x * 256u; let K4 = m.K / 4u; | |
| for (var i = gid.x; i < total; i = i + stride) { | |
| let t = i / m.vocab; let v = i % m.vocab; | |
| let nb = (m.tOff + t) * m.K; let eb = v * K4; | |
| var acc = 0.0; | |
| for (var c = 0u; c < K4; c = c + 1u) { | |
| let p = unpack4xI8(E[eb + c]); let kk = c*4u; | |
| acc = acc + normed[nb+kk]*f32(p.x) + normed[nb+kk+1u]*f32(p.y) | |
| + normed[nb+kk+2u]*f32(p.z) + normed[nb+kk+3u]*f32(p.w); | |
| } | |
| logits[i] = acc * scaleE[v]; | |
| } | |
| }`; | |
| var CE_SOFTMAX_GRAD = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 256u; | |
| struct Meta { vocab:u32, tOff:u32, lossScale:f32, p:u32 }; | |
| @group(0) @binding(0) var<storage,read_write> logits: array<f32>; // [bt][vocab] -> dLogits | |
| @group(0) @binding(1) var<storage,read> labels: array<u32>; // [T] token id (global) | |
| @group(0) @binding(2) var<storage,read> mask: array<f32>; // [T] 1 train / 0 skip (global) | |
| @group(0) @binding(3) var<storage,read_write> lossOut: array<f32>;// [T] (global) | |
| var<immediate> m: Meta; | |
| var<workgroup> red: array<f32, 256>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { | |
| let lt = wid.x; let tid = lid.x; let base = lt*m.vocab; | |
| let gt = m.tOff + lt; // global token index for target/mask/loss | |
| let mk = mask[gt]; | |
| // max | |
| var mx = -1e30; | |
| for (var v = tid; v < m.vocab; v = v + WG) { mx = max(mx, logits[base+v]); } | |
| red[tid] = mx; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = max(red[tid], red[tid+s]); } workgroupBarrier(); } | |
| let M = red[0]; workgroupBarrier(); | |
| // sum exp | |
| var sm = 0.0; | |
| for (var v = tid; v < m.vocab; v = v + WG) { sm = sm + exp(logits[base+v] - M); } | |
| red[tid] = sm; workgroupBarrier(); | |
| for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } | |
| let Z = red[0]; | |
| let tgt = labels[gt]; | |
| if (tid == 0u) { | |
| let ltgt = logits[base + tgt]; | |
| lossOut[gt] = mk * (log(Z) - (ltgt - M)); | |
| } | |
| // dLogits = mask*lossScale*(p - onehot) | |
| let invZ = 1.0 / Z; let g = mk * m.lossScale; | |
| for (var v = tid; v < m.vocab; v = v + WG) { | |
| var p = exp(logits[base+v] - M) * invZ; | |
| if (v == tgt) { p = p - 1.0; } | |
| logits[base+v] = g * p; | |
| } | |
| }`; | |
| var DHIDDEN_FROM_DLOGITS_I8 = ` | |
| requires immediate_address_space; | |
| struct Meta { T:u32, vocab:u32, K:u32, tOff:u32 }; | |
| @group(0) @binding(0) var<storage,read> dLogits: array<f32>; // [Tblock][vocab] | |
| @group(0) @binding(1) var<storage,read> E: array<u32>; // [vocab][K/4] int8 | |
| @group(0) @binding(2) var<storage,read> scaleE: array<f32>; // [vocab] | |
| @group(0) @binding(3) var<storage,read_write> dHidden: array<f32>; // [T][K] (offset tOff) | |
| var<immediate> m: Meta; | |
| fn sx8(v: u32) -> i32 { | |
| return i32(v << 24u) >> 24u; | |
| } | |
| fn unpack4xI8(x: u32) -> vec4<i32> { | |
| return vec4<i32>( | |
| sx8(x & 0xffu), | |
| sx8((x >> 8u) & 0xffu), | |
| sx8((x >> 16u) & 0xffu), | |
| sx8((x >> 24u) & 0xffu) | |
| ); | |
| } | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let total = m.T * m.K; let stride = nwg.x * 256u; let K4 = m.K / 4u; | |
| for (var i = gid.x; i < total; i = i + stride) { | |
| let t = i / m.K; let k = i % m.K; | |
| let lb = t * m.vocab; | |
| var acc = 0.0; | |
| let word_idx = k >> 2u; let lane = k & 3u; | |
| for (var v = 0u; v < m.vocab; v = v + 1u) { | |
| let p = unpack4xI8(E[v*K4 + word_idx]); | |
| var b: i32; if (lane==0u){b=p.x;} else if (lane==1u){b=p.y;} else if (lane==2u){b=p.z;} else {b=p.w;} | |
| acc = acc + dLogits[lb + v] * scaleE[v] * f32(b); | |
| } | |
| dHidden[(m.tOff + t)*m.K + k] = dHidden[(m.tOff + t)*m.K + k] + acc; | |
| } | |
| }`; | |
| var ADAMW_STEP = ` | |
| requires immediate_address_space; | |
| struct Meta { n:u32, p:u32, lr:f32, beta1:f32, beta2:f32, eps:f32, wd:f32, gScale:f32, b1c:f32, b2c:f32, f0:f32, f1:f32 }; | |
| @group(0) @binding(0) var<storage,read_write> param: array<f32>; | |
| @group(0) @binding(1) var<storage,read> grad: array<f32>; | |
| @group(0) @binding(2) var<storage,read_write> mBuf: array<f32>; | |
| @group(0) @binding(3) var<storage,read_write> vBuf: array<f32>; | |
| var<immediate> m: Meta; | |
| @compute @workgroup_size(256) | |
| fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { | |
| let stride = nwg.x * 256u; | |
| for (var i = gid.x; i < m.n; i = i + stride) { | |
| let gr = grad[i] * m.gScale; | |
| let mm = m.beta1 * mBuf[i] + (1.0 - m.beta1) * gr; | |
| let vv = m.beta2 * vBuf[i] + (1.0 - m.beta2) * gr * gr; | |
| mBuf[i] = mm; vBuf[i] = vv; | |
| let mhat = mm / m.b1c; let vhat = vv / m.b2c; | |
| param[i] = param[i] - m.lr * (mhat / (sqrt(vhat) + m.eps) + m.wd * param[i]); | |
| } | |
| }`; | |
| var SUMSQ = ` | |
| requires immediate_address_space; | |
| override WG: u32 = 256u; | |
| @group(0) @binding(0) var<storage,read> x: array<f32>; | |
| @group(0) @binding(1) var<storage,read_write> out: array<f32>; // [1] | |
| var<immediate> n: u32; | |
| var<workgroup> red: array<f32, 256>; | |
| @compute @workgroup_size(WG) | |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>) { | |
| let tid = lid.x; var s = 0.0; | |
| for (var i = tid; i < n; i = i + WG) { let v = x[i]; s = s + v*v; } | |
| red[tid] = s; workgroupBarrier(); | |
| for (var st = WG/2u; st > 0u; st = st/2u) { if (tid < st) { red[tid] = red[tid] + red[tid+st]; } workgroupBarrier(); } | |
| if (tid == 0u) { out[0] = out[0] + red[0]; } | |
| }`; | |
| // src/qwgpu/trainer.js | |
| var STORAGE2 = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; | |
| var READBACK = GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ; | |
| var nowMs = /* @__PURE__ */ __name(() => globalThis.performance?.now?.() ?? Date.now(), "nowMs"); | |
| var ALL_PROJ = ["q", "k", "v", "o", "gate", "up", "down"]; | |
| function createTrainableAdapter(rt, opts = {}) { | |
| const rank = Math.max(1, Math.floor(opts.rank ?? 16)); | |
| const alpha = opts.alpha ?? rank * 2; | |
| const scale = opts.scale ?? alpha / rank; | |
| const targets = opts.targetModules ?? ALL_PROJ; | |
| const stddev = opts.stddev ?? 1 / Math.sqrt(rank); | |
| const usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; | |
| const gauss = /* @__PURE__ */ __name(() => { | |
| let u = 0, v = 0; | |
| while (u === 0) u = Math.random(); | |
| while (v === 0) v = Math.random(); | |
| return Math.sqrt(-2 * Math.log(u)) * Math.cos(2 * Math.PI * v); | |
| }, "gauss"); | |
| const modules = {}; | |
| for (const L of rt.plan.layers) { | |
| for (const name of ALL_PROJ) { | |
| if (!targets.includes(name)) continue; | |
| const part = L[name]; | |
| const q4 = rt.q4[part.weight]; | |
| const K = q4.K, N = q4.N; | |
| const Aarr = new Float32Array(rank * K); | |
| for (let i = 0; i < Aarr.length; i++) Aarr[i] = gauss() * stddev; | |
| const Barr = new Float32Array(rank * N); | |
| const A = rt.dev.createBuffer({ size: Aarr.byteLength, usage }); | |
| const B = rt.dev.createBuffer({ size: Barr.byteLength, usage }); | |
| rt.dev.queue.writeBuffer(A, 0, Aarr); | |
| rt.dev.queue.writeBuffer(B, 0, Barr); | |
| modules[part.loraKey] = { A, B, rank, scale, inDim: K, outDim: N }; | |
| } | |
| } | |
| return { name: opts.name || "trainable", modules }; | |
| } | |
| __name(createTrainableAdapter, "createTrainableAdapter"); | |
| var QwenLoraTrainer = class { | |
| static { | |
| __name(this, "QwenLoraTrainer"); | |
| } | |
| // rt: a built QwenWGPU. opts: see _normalizeOpts. | |
| constructor(rt, opts = {}) { | |
| this.rt = rt; | |
| this.dev = rt.dev; | |
| this.cfg = rt.cfg; | |
| this.opts = this._normalizeOpts(opts); | |
| this.step = 0; | |
| this._microInWindow = 0; | |
| this.scratchT = 0; | |
| this._buildPipes(); | |
| } | |
| _normalizeOpts(o) { | |
| return { | |
| lr: o.lr ?? 1e-4, | |
| beta1: o.beta1 ?? 0.9, | |
| beta2: o.beta2 ?? 0.999, | |
| eps: o.eps ?? 1e-8, | |
| weightDecay: o.weightDecay ?? 0, | |
| maxGradNorm: o.maxGradNorm ?? 1, | |
| gradAccumSteps: Math.max(1, Math.floor(o.gradAccumSteps ?? 1)), | |
| lmHeadBlock: Math.max(1, Math.floor(o.lmHeadBlock ?? 128)), | |
| maxTrainSeq: Math.max(1, Math.floor(o.maxTrainSeq ?? 512)), | |
| warmupSteps: Math.max(0, Math.floor(o.warmupSteps ?? 0)), | |
| totalSteps: o.totalSteps ?? 0, | |
| // for cosine decay; 0 disables decay | |
| minLrRatio: o.minLrRatio ?? 0.1, | |
| targetModules: o.targetModules ?? ALL_PROJ | |
| }; | |
| } | |
| _buildPipes() { | |
| const rt = this.rt; | |
| this.p = { | |
| dx4: rt._pipe(GEMM_DX_INT4, "bwd_dx4"), | |
| dd: rt._pipe(LORA_DD, "bwd_lora_dd"), | |
| gradA: rt._pipe(LORA_GRAD_A, "bwd_lora_dA"), | |
| gradB: rt._pipe(LORA_GRAD_B, "bwd_lora_dB"), | |
| dxAdd: rt._pipe(LORA_DX_ADD, "bwd_lora_dx"), | |
| rmsBwd: rt._pipe(RMSNORM_BWD_T, "bwd_rms"), | |
| swiglu: rt._pipe(SWIGLU_BWD, "bwd_swiglu"), | |
| ropeBwd: rt._pipe(ROPE_BWD_T, "bwd_rope"), | |
| attnStats: rt._pipe(ATTN_BWD_STATS, "bwd_attn_stats"), | |
| attnDq: rt._pipe(ATTN_BWD_DQ, "bwd_attn_dq"), | |
| attnDkv: rt._pipe(ATTN_BWD_DKV, "bwd_attn_dkv"), | |
| logits: rt._pipe(LOGITS_GEMM_I8, "bwd_logits"), | |
| ceGrad: rt._pipe(CE_SOFTMAX_GRAD, "bwd_ce"), | |
| dHidden: rt._pipe(DHIDDEN_FROM_DLOGITS_I8, "bwd_dhidden"), | |
| adamw: rt._pipe(ADAMW_STEP, "adamw"), | |
| sumsq: rt._pipe(SUMSQ, "sumsq") | |
| }; | |
| } | |
| // ---- adapter attach: build per-module grad + Adam moment state ---- | |
| // The adapter must already be uploaded (loadLoraAdapterGPU) and set on rt. | |
| attach(adapter) { | |
| if (!adapter || !adapter.modules) throw new Error("trainer.attach: adapter with modules required"); | |
| this.adapter = adapter; | |
| this.rt.setLora(adapter); | |
| const rt = this.rt; | |
| const byKey = /* @__PURE__ */ new Map(); | |
| for (const L of rt.plan.layers) { | |
| for (const name of ALL_PROJ) { | |
| const part = L[name]; | |
| byKey.set(part.loraKey, { part, kind: name, q4: rt.q4[part.weight] }); | |
| } | |
| } | |
| this.state = {}; | |
| let maxRank = 1; | |
| for (const key of Object.keys(adapter.modules)) { | |
| const mod = adapter.modules[key]; | |
| const info = byKey.get(key); | |
| if (!info) continue; | |
| const kind = info.kind.replace(/_proj$/, ""); | |
| if (!this.opts.targetModules.includes(kind)) continue; | |
| const K = info.q4.K, N = info.q4.N, rank = mod.rank; | |
| maxRank = Math.max(maxRank, rank); | |
| this.state[key] = { | |
| mod, | |
| q4: info.q4, | |
| K, | |
| N, | |
| rank, | |
| scale: mod.scale, | |
| dA: rt._buf(rank * K * 4), | |
| dB: rt._buf(rank * N * 4), | |
| mA: rt._buf(rank * K * 4), | |
| vA: rt._buf(rank * K * 4), | |
| mB: rt._buf(rank * N * 4), | |
| vB: rt._buf(rank * N * 4) | |
| }; | |
| } | |
| this.maxRank = maxRank; | |
| this.trainedKeys = Object.keys(this.state); | |
| if (!this.trainedKeys.length) throw new Error("trainer.attach: no trainable modules matched targetModules"); | |
| this._zeroAdamMoments(); | |
| this.zeroGrads(); | |
| return this; | |
| } | |
| _zeroAdamMoments() { | |
| const enc = this.dev.createCommandEncoder(); | |
| for (const k of this.trainedKeys) { | |
| const st = this.state[k]; | |
| enc.clearBuffer(st.mA); | |
| enc.clearBuffer(st.vA); | |
| enc.clearBuffer(st.mB); | |
| enc.clearBuffer(st.vB); | |
| } | |
| this.dev.queue.submit([enc.finish()]); | |
| } | |
| zeroGrads() { | |
| const enc = this.dev.createCommandEncoder(); | |
| for (const k of this.trainedKeys) { | |
| enc.clearBuffer(this.state[k].dA); | |
| enc.clearBuffer(this.state[k].dB); | |
| } | |
| this.dev.queue.submit([enc.finish()]); | |
| this._microInWindow = 0; | |
| } | |
| // ---- activation/gradient scratch sized to the sequence ---- | |
| _ensureScratch(T) { | |
| if (this.scratchT >= T && this.s) return; | |
| if (this.s) for (const k in this.s) this.s[k].destroy?.(); | |
| if (this.ckpt) for (const c2 of this.ckpt) c2.destroy?.(); | |
| this.lossRead?.destroy?.(); | |
| this.normRead?.destroy?.(); | |
| const c = this.cfg; | |
| const H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize, nH = c.numHeads, R = this.maxRank, lmB = this.opts.lmHeadBlock, V = c.vocabSize; | |
| const b = /* @__PURE__ */ __name((n) => this.rt._buf(n * 4), "b"); | |
| this.ckpt = []; | |
| for (let i = 0; i <= c.numLayers; i++) this.ckpt.push(b(T * H)); | |
| this.s = { | |
| hid: b(T * H), | |
| normed1: b(T * H), | |
| normed2: b(T * H), | |
| normedF: b(T * H), | |
| q: b(T * qd), | |
| k: b(T * kvd), | |
| v: b(T * kvd), | |
| attn: b(T * qd), | |
| hmid: b(T * H), | |
| gate: b(T * I), | |
| up: b(T * I), | |
| swig: b(T * I), | |
| dHidden: b(T * H), | |
| dnorm: b(T * H), | |
| dtmp: b(T * H), | |
| dhmid: b(T * H), | |
| dq: b(T * qd), | |
| dk: b(T * kvd), | |
| dv: b(T * kvd), | |
| dob: b(T * qd), | |
| dgate: b(T * I), | |
| dup: b(T * I), | |
| dswig: b(T * I), | |
| dD: b(T * R), | |
| Dmat: b(T * R), | |
| lse: b(nH * T), | |
| delta: b(nH * T), | |
| logits: b(lmB * V), | |
| loss: b(T), | |
| targets: this.rt._buf(T * 4), | |
| mask: b(T), | |
| normBuf: b(1) | |
| }; | |
| this.lossRead = this.rt._buf(T * 4, READBACK); | |
| this.normRead = this.rt._buf(4, READBACK); | |
| this.scratchT = T; | |
| } | |
| // ---- small dispatch helpers ---- | |
| _grid1d(n) { | |
| return Math.min(Math.ceil(n / 256), 65535); | |
| } | |
| _disp(enc, pipe, buffers, gx, gy, imm, cat) { | |
| const bg = this.rt._bg(pipe, buffers); | |
| this.rt._dispatch(enc, pipe, bg, gx, gy, cat || "train", imm); | |
| } | |
| _u32(arr) { | |
| return new Uint32Array(arr); | |
| } | |
| _meta(u32parts, f32parts = {}) { | |
| const buf = new ArrayBuffer(48); | |
| const dv = new DataView(buf); | |
| for (const [i, v] of u32parts) dv.setUint32(i * 4, v >>> 0, true); | |
| for (const [i, v] of Object.entries(f32parts)) dv.setFloat32(Number(i) * 4, v, true); | |
| return new Uint8Array(buf); | |
| } | |
| // ---- forward with checkpoints (LoRA-modified, f32) ---- | |
| _layerForward(enc, L, hid, T) { | |
| const rt = this.rt, c = this.cfg, s = this.s; | |
| const H = c.hiddenSize; | |
| rt.rmsT(enc, hid, rt.bufs[L.inputNorm], s.normed1, T, H); | |
| rt.gemm4(enc, s.normed1, rt.q4[L.q.weight], s.q, T, rt.bufs[L.q.bias], L.q.loraKey); | |
| rt.gemm4(enc, s.normed1, rt.q4[L.k.weight], s.k, T, rt.bufs[L.k.bias], L.k.loraKey); | |
| rt.gemm4(enc, s.normed1, rt.q4[L.v.weight], s.v, T, rt.bufs[L.v.bias], L.v.loraKey); | |
| rt.ropeT(enc, s.q, T, c.numHeads); | |
| rt.ropeT(enc, s.k, T, c.numKVHeads); | |
| rt.attnPrefill(enc, s.q, s.k, s.v, s.attn, T, 0, T); | |
| rt.gemm4AddT(enc, s.attn, rt.q4[L.o.weight], hid, T, null, L.o.loraKey); | |
| rt.rmsT(enc, hid, rt.bufs[L.postAttentionNorm], s.normed2, T, H); | |
| rt.gemm4(enc, s.normed2, rt.q4[L.gate.weight], s.gate, T, null, L.gate.loraKey); | |
| rt.gemm4(enc, s.normed2, rt.q4[L.up.weight], s.up, T, null, L.up.loraKey); | |
| enc.copyBufferToBuffer(s.gate, 0, s.swig, 0, T * c.intermediateSize * 4); | |
| rt._siluMul(enc, s.swig, s.up, T * c.intermediateSize); | |
| rt.gemm4AddT(enc, s.swig, rt.q4[L.down.weight], hid, T, null, L.down.loraKey); | |
| } | |
| _forward(enc, ids, T) { | |
| const rt = this.rt, c = this.cfg, s = this.s, H = c.hiddenSize; | |
| rt._ensurePrefillScratch(T, this.maxRank); | |
| rt._resetUni(); | |
| const e = rt.q[rt.plan.embed.name]; | |
| this.dev.queue.writeBuffer(rt.sT.ids, 0, new Uint32Array(ids)); | |
| rt._dispatch( | |
| enc, | |
| rt.pipes.embedT, | |
| rt._bg(rt.pipes.embedT, [e.w, e.scale, this.ckpt[0], rt.sT.ids]), | |
| Math.min(Math.ceil(T * H / 256), 65535), | |
| 1, | |
| "embedT", | |
| this._u32([T, H, 0, 0]) | |
| ); | |
| enc.copyBufferToBuffer(this.ckpt[0], 0, s.hid, 0, T * H * 4); | |
| for (let i = 0; i < c.numLayers; i++) { | |
| this._layerForward(enc, rt.plan.layers[i], s.hid, T); | |
| enc.copyBufferToBuffer(s.hid, 0, this.ckpt[i + 1], 0, T * H * 4); | |
| } | |
| } | |
| // recompute one layer's forward internals (from its checkpoint) into scratch, also | |
| // producing hmid (= ckpt + attnProj) which the backward needs as the post-attn input. | |
| _recomputeLayer(enc, L, T) { | |
| const rt = this.rt, c = this.cfg, s = this.s, H = c.hiddenSize, idx = L.index; | |
| rt.rmsT(enc, this.ckpt[idx], rt.bufs[L.inputNorm], s.normed1, T, H); | |
| rt.gemm4(enc, s.normed1, rt.q4[L.q.weight], s.q, T, rt.bufs[L.q.bias], L.q.loraKey); | |
| rt.gemm4(enc, s.normed1, rt.q4[L.k.weight], s.k, T, rt.bufs[L.k.bias], L.k.loraKey); | |
| rt.gemm4(enc, s.normed1, rt.q4[L.v.weight], s.v, T, rt.bufs[L.v.bias], L.v.loraKey); | |
| rt.ropeT(enc, s.q, T, c.numHeads); | |
| rt.ropeT(enc, s.k, T, c.numKVHeads); | |
| rt.attnPrefill(enc, s.q, s.k, s.v, s.attn, T, 0, T); | |
| enc.copyBufferToBuffer(this.ckpt[idx], 0, s.hmid, 0, T * H * 4); | |
| rt.gemm4AddT(enc, s.attn, rt.q4[L.o.weight], s.hmid, T, null, L.o.loraKey); | |
| rt.rmsT(enc, s.hmid, rt.bufs[L.postAttentionNorm], s.normed2, T, H); | |
| rt.gemm4(enc, s.normed2, rt.q4[L.gate.weight], s.gate, T, null, L.gate.loraKey); | |
| rt.gemm4(enc, s.normed2, rt.q4[L.up.weight], s.up, T, null, L.up.loraKey); | |
| enc.copyBufferToBuffer(s.gate, 0, s.swig, 0, T * c.intermediateSize * 4); | |
| rt._siluMul(enc, s.swig, s.up, T * c.intermediateSize); | |
| } | |
| // ---- LoRA + base projection backward ---- | |
| // dY [T][N] -> accumulate into dXbuf [T][K] (base + LoRA), plus dA/dB grads. | |
| _projBackward(enc, key, Xbuf, dYbuf, dXbuf, T) { | |
| const st = this.state[key]; | |
| if (!st) { | |
| this._dispatch_dx4(enc, dYbuf, st, dXbuf, T, key); | |
| return; | |
| } | |
| const { K, N, rank, scale, q4, dA, dB } = st; | |
| const s = this.s; | |
| this._disp( | |
| enc, | |
| this.p.dx4, | |
| [dYbuf, q4.w, q4.scale, dXbuf], | |
| this._grid1d(T * K), | |
| 1, | |
| this._meta([[0, T], [1, N], [2, K], [3, q4.gpr]]), | |
| "dx4" | |
| ); | |
| this._disp( | |
| enc, | |
| this.p.dd, | |
| [dYbuf, st.mod.B, s.dD], | |
| T * rank, | |
| 1, | |
| this._meta([[0, T], [1, N], [2, rank]], { 4: scale }), | |
| "dd" | |
| ); | |
| this._disp( | |
| enc, | |
| this.p.gradA, | |
| [s.dD, Xbuf, dA], | |
| this._grid1d(rank * K), | |
| 1, | |
| this._meta([[0, T], [1, K], [2, rank]]), | |
| "gradA" | |
| ); | |
| this._disp( | |
| enc, | |
| this.rt.pipes.loraABatch, | |
| [Xbuf, st.mod.A, s.Dmat], | |
| rank, | |
| T, | |
| this._u32([K, rank, T, 0]), | |
| "loraABatch" | |
| ); | |
| this._disp( | |
| enc, | |
| this.p.gradB, | |
| [s.Dmat, dYbuf, dB], | |
| this._grid1d(rank * N), | |
| 1, | |
| this._meta([[0, T], [1, N], [2, rank]], { 4: scale }), | |
| "gradB" | |
| ); | |
| this._disp( | |
| enc, | |
| this.p.dxAdd, | |
| [s.dD, st.mod.A, dXbuf], | |
| this._grid1d(T * K), | |
| 1, | |
| this._meta([[0, T], [1, K], [2, rank]]), | |
| "dxAdd" | |
| ); | |
| } | |
| _dispatch_dx4(enc, dYbuf, st, dXbuf, T, key) { | |
| const info = this._infoForKey(key); | |
| const q4 = info.q4; | |
| this._disp( | |
| enc, | |
| this.p.dx4, | |
| [dYbuf, q4.w, q4.scale, dXbuf], | |
| this._grid1d(T * q4.K), | |
| 1, | |
| this._meta([[0, T], [1, q4.N], [2, q4.K], [3, q4.gpr]]), | |
| "dx4" | |
| ); | |
| } | |
| _infoForKey(key) { | |
| for (const L of this.rt.plan.layers) | |
| for (const name of ALL_PROJ) if (L[name].loraKey === key) return { q4: this.rt.q4[L[name].weight] }; | |
| throw new Error(`unknown loraKey ${key}`); | |
| } | |
| _rmsBwd(enc, xBuf, gBuf, dyBuf, dxBuf, T) { | |
| const c = this.cfg; | |
| this._disp( | |
| enc, | |
| this.p.rmsBwd, | |
| [xBuf, gBuf, dyBuf, dxBuf], | |
| T, | |
| 1, | |
| new Float32Array([c.hiddenSize, c.rmsNormEps]), | |
| "rmsBwd" | |
| ); | |
| } | |
| // ---- full backward for one micro-batch; accumulates grads, returns nothing ---- | |
| _backward(enc, T, numActive) { | |
| const rt = this.rt, c = this.cfg, s = this.s, H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize, V = c.vocabSize; | |
| rt.rmsT(enc, this.ckpt[c.numLayers], rt.bufs[rt.plan.finalNorm.name], s.normedF, T, H); | |
| enc.clearBuffer(s.dnorm); | |
| const e = rt.q[rt.plan.embed.name]; | |
| const lossScale = 1 / Math.max(1, numActive); | |
| const lmB = this.opts.lmHeadBlock; | |
| for (let off = 0; off < T; off += lmB) { | |
| const bt = Math.min(lmB, T - off); | |
| this._disp( | |
| enc, | |
| this.p.logits, | |
| [s.normedF, e.w, e.scale, s.logits], | |
| this._grid1d(bt * V), | |
| 1, | |
| this._meta([[0, bt], [1, V], [2, H], [3, off]]), | |
| "logits" | |
| ); | |
| this._disp( | |
| enc, | |
| this.p.ceGrad, | |
| [s.logits, s.targets, s.mask, s.loss], | |
| bt, | |
| 1, | |
| this._meta([[0, V], [1, off]], { 2: lossScale }), | |
| "ce" | |
| ); | |
| this._disp( | |
| enc, | |
| this.p.dHidden, | |
| [s.logits, e.w, e.scale, s.dnorm], | |
| this._grid1d(bt * H), | |
| 1, | |
| this._meta([[0, bt], [1, V], [2, H], [3, off]]), | |
| "dHidden" | |
| ); | |
| } | |
| this._rmsBwd(enc, this.ckpt[c.numLayers], rt.bufs[rt.plan.finalNorm.name], s.dnorm, s.dHidden, T); | |
| for (let i = c.numLayers - 1; i >= 0; i--) { | |
| const L = rt.plan.layers[i]; | |
| this._recomputeLayer(enc, L, T); | |
| enc.clearBuffer(s.dswig); | |
| this._projBackward(enc, L.down.loraKey, s.swig, s.dHidden, s.dswig, T); | |
| this._disp( | |
| enc, | |
| this.p.swiglu, | |
| [s.gate, s.up, s.dswig, s.dgate, s.dup], | |
| this._grid1d(T * I), | |
| 1, | |
| this._u32([T * I]), | |
| "swiglu" | |
| ); | |
| enc.clearBuffer(s.dnorm); | |
| this._projBackward(enc, L.gate.loraKey, s.normed2, s.dgate, s.dnorm, T); | |
| this._projBackward(enc, L.up.loraKey, s.normed2, s.dup, s.dnorm, T); | |
| this._rmsBwd(enc, s.hmid, rt.bufs[L.postAttentionNorm], s.dnorm, s.dtmp, T); | |
| enc.copyBufferToBuffer(s.dHidden, 0, s.dhmid, 0, T * H * 4); | |
| rt._addInto(enc, s.dhmid, s.dtmp, T * H); | |
| enc.clearBuffer(s.dob); | |
| this._projBackward(enc, L.o.loraKey, s.attn, s.dhmid, s.dob, T); | |
| const am = this._u32([c.numHeads, c.numKVHeads, c.headDim, T]); | |
| this._disp(enc, this.p.attnStats, [s.q, s.k, s.attn, s.dob, s.lse, s.delta], c.numHeads, T, am, "attnStats"); | |
| enc.clearBuffer(s.dq); | |
| enc.clearBuffer(s.dk); | |
| enc.clearBuffer(s.dv); | |
| this._disp(enc, this.p.attnDq, [s.q, s.k, s.v, s.dob, s.lse, s.delta, s.dq], c.numHeads, T, am, "attnDq"); | |
| this._disp( | |
| enc, | |
| this.p.attnDkv, | |
| [s.q, s.k, s.v, s.dob, s.lse, s.delta, s.dk, s.dv], | |
| c.numKVHeads, | |
| T, | |
| am, | |
| "attnDkv" | |
| ); | |
| this._disp( | |
| enc, | |
| this.p.ropeBwd, | |
| [s.dq, rt.ropeCos, rt.ropeSin], | |
| Math.ceil(T * c.numHeads * (c.headDim / 2) / 256), | |
| 1, | |
| this._u32([c.numHeads, c.headDim, T, 0]), | |
| "ropeBwd" | |
| ); | |
| this._disp( | |
| enc, | |
| this.p.ropeBwd, | |
| [s.dk, rt.ropeCos, rt.ropeSin], | |
| Math.ceil(T * c.numKVHeads * (c.headDim / 2) / 256), | |
| 1, | |
| this._u32([c.numKVHeads, c.headDim, T, 0]), | |
| "ropeBwd" | |
| ); | |
| enc.clearBuffer(s.dnorm); | |
| this._projBackward(enc, L.q.loraKey, s.normed1, s.dq, s.dnorm, T); | |
| this._projBackward(enc, L.k.loraKey, s.normed1, s.dk, s.dnorm, T); | |
| this._projBackward(enc, L.v.loraKey, s.normed1, s.dv, s.dnorm, T); | |
| this._rmsBwd(enc, this.ckpt[i], rt.bufs[L.inputNorm], s.dnorm, s.dtmp, T); | |
| enc.copyBufferToBuffer(s.dhmid, 0, s.dHidden, 0, T * H * 4); | |
| rt._addInto(enc, s.dHidden, s.dtmp, T * H); | |
| } | |
| } | |
| // shifted-label targets + mask into the scratch buffers; returns numActive. | |
| _writeTargets(tokens, lossMask, T) { | |
| const targets = new Uint32Array(T); | |
| const mask = new Float32Array(T); | |
| let numActive = 0; | |
| for (let t = 0; t < T - 1; t++) { | |
| targets[t] = tokens[t + 1] >>> 0; | |
| const mk = lossMask ? lossMask[t] ? 1 : 0 : 1; | |
| mask[t] = mk; | |
| numActive += mk; | |
| } | |
| targets[T - 1] = 0; | |
| mask[T - 1] = 0; | |
| this.dev.queue.writeBuffer(this.s.targets, 0, targets); | |
| this.dev.queue.writeBuffer(this.s.mask, 0, mask); | |
| return numActive; | |
| } | |
| // loss head only (final norm + streamed logits + CE), no backward sweep. Used by | |
| // evalLoss(). CE overwrites s.logits with dLogits but we ignore that here. | |
| _lossOnly(enc, T, numActive) { | |
| const rt = this.rt, c = this.cfg, s = this.s, H = c.hiddenSize, V = c.vocabSize; | |
| rt.rmsT(enc, this.ckpt[c.numLayers], rt.bufs[rt.plan.finalNorm.name], s.normedF, T, H); | |
| const e = rt.q[rt.plan.embed.name]; | |
| const lossScale = 1 / Math.max(1, numActive); | |
| const lmB = this.opts.lmHeadBlock; | |
| for (let off = 0; off < T; off += lmB) { | |
| const bt = Math.min(lmB, T - off); | |
| this._disp(enc, this.p.logits, [s.normedF, e.w, e.scale, s.logits], this._grid1d(bt * V), 1, this._meta([[0, bt], [1, V], [2, H], [3, off]]), "logits"); | |
| this._disp(enc, this.p.ceGrad, [s.logits, s.targets, s.mask, s.loss], bt, 1, this._meta([[0, V], [1, off]], { 2: lossScale }), "ce"); | |
| } | |
| } | |
| // ---- public: forward-only mean cross-entropy (no grads). For held-out eval. ---- | |
| async evalLoss(tokens, lossMask) { | |
| const T = tokens.length; | |
| if (T > this.opts.maxTrainSeq) throw new Error(`seq ${T} > maxTrainSeq ${this.opts.maxTrainSeq}`); | |
| this._ensureScratch(T); | |
| const wasF16 = this.rt.usingF16?.(); | |
| this.rt.setUseF16?.(false); | |
| try { | |
| const numActive = this._writeTargets(tokens, lossMask, T); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._forward(enc, tokens, T); | |
| this._lossOnly(enc, T, numActive); | |
| enc.copyBufferToBuffer(this.s.loss, 0, this.lossRead, 0, T * 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.lossRead.mapAsync(GPUMapMode.READ); | |
| const arr = new Float32Array(this.lossRead.getMappedRange().slice(0)); | |
| this.lossRead.unmap(); | |
| let sum = 0; | |
| for (let t = 0; t < T; t++) sum += arr[t]; | |
| return { loss: sum / Math.max(1, numActive), numActive }; | |
| } finally { | |
| if (wasF16) this.rt.setUseF16?.(true); | |
| } | |
| } | |
| // ---- public: accumulate one micro-batch. tokens: Int array, lossMask: 0/1 array. ---- | |
| // lossMask[t]==1 means "train the prediction of tokens[t+1] from position t". | |
| async microStep(tokens, lossMask) { | |
| const c = this.cfg; | |
| const T = tokens.length; | |
| const t0 = nowMs(); | |
| if (T > this.opts.maxTrainSeq) throw new Error(`seq ${T} > maxTrainSeq ${this.opts.maxTrainSeq}`); | |
| this._ensureScratch(T); | |
| const wasF16 = this.rt.usingF16?.(); | |
| this.rt.setUseF16?.(false); | |
| try { | |
| const numActive = this._writeTargets(tokens, lossMask, T); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._forward(enc, tokens, T); | |
| this._backward(enc, T, numActive); | |
| enc.copyBufferToBuffer(this.s.loss, 0, this.lossRead, 0, T * 4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.lossRead.mapAsync(GPUMapMode.READ); | |
| const lossArr = new Float32Array(this.lossRead.getMappedRange().slice(0)); | |
| this.lossRead.unmap(); | |
| let lossSum = 0; | |
| for (let t = 0; t < T; t++) lossSum += lossArr[t]; | |
| this._microInWindow++; | |
| const microStepMs = nowMs() - t0; | |
| return { | |
| loss: lossSum / Math.max(1, numActive), | |
| numActive, | |
| tokens: T, | |
| microStepMs, | |
| trainTokPerSec: T / Math.max(1e-6, microStepMs / 1e3) | |
| }; | |
| } finally { | |
| if (wasF16) this.rt.setUseF16?.(true); | |
| } | |
| } | |
| // ---- public: apply accumulated grads with AdamW + global-norm clip ---- | |
| async optimizerStep() { | |
| const t0 = nowMs(); | |
| const o = this.opts; | |
| const accum = this._microInWindow || 1; | |
| const encN = this.dev.createCommandEncoder(); | |
| encN.clearBuffer(this.s.normBuf); | |
| for (const k of this.trainedKeys) { | |
| const st = this.state[k]; | |
| this._disp(encN, this.p.sumsq, [st.dA, this.s.normBuf], 1, 1, this._u32([st.rank * st.K]), "sumsq"); | |
| this._disp(encN, this.p.sumsq, [st.dB, this.s.normBuf], 1, 1, this._u32([st.rank * st.N]), "sumsq"); | |
| } | |
| encN.copyBufferToBuffer(this.s.normBuf, 0, this.normRead, 0, 4); | |
| this.dev.queue.submit([encN.finish()]); | |
| await this.normRead.mapAsync(GPUMapMode.READ); | |
| const sumsq = new Float32Array(this.normRead.getMappedRange().slice(0))[0]; | |
| this.normRead.unmap(); | |
| const gradScale = 1 / accum; | |
| const gnorm = Math.sqrt(sumsq) * gradScale; | |
| const clip2 = o.maxGradNorm > 0 && gnorm > o.maxGradNorm ? o.maxGradNorm / (gnorm + 1e-6) : 1; | |
| const gScale = gradScale * clip2; | |
| this.step++; | |
| const lr = this._lrAt(this.step); | |
| const b1c = 1 - Math.pow(o.beta1, this.step); | |
| const b2c = 1 - Math.pow(o.beta2, this.step); | |
| const enc = this.dev.createCommandEncoder(); | |
| for (const k of this.trainedKeys) { | |
| const st = this.state[k]; | |
| const metaA = this._adamMeta(st.rank * st.K, lr, gScale, b1c, b2c); | |
| this._disp(enc, this.p.adamw, [st.mod.A, st.dA, st.mA, st.vA], this._grid1d(st.rank * st.K), 1, metaA, "adamw"); | |
| const metaB = this._adamMeta(st.rank * st.N, lr, gScale, b1c, b2c); | |
| this._disp(enc, this.p.adamw, [st.mod.B, st.dB, st.mB, st.vB], this._grid1d(st.rank * st.N), 1, metaB, "adamw"); | |
| } | |
| this.dev.queue.submit([enc.finish()]); | |
| this.rt.invalidateLora(); | |
| this.zeroGrads(); | |
| return { lr, gradNorm: gnorm, clip: clip2, optimizerStepMs: nowMs() - t0 }; | |
| } | |
| _lrAt(step) { | |
| const o = this.opts; | |
| if (o.warmupSteps > 0 && step <= o.warmupSteps) return o.lr * (step / o.warmupSteps); | |
| if (o.totalSteps > 0 && step > o.warmupSteps) { | |
| const prog = (step - o.warmupSteps) / Math.max(1, o.totalSteps - o.warmupSteps); | |
| const cos = 0.5 * (1 + Math.cos(Math.PI * Math.min(1, prog))); | |
| return o.lr * (o.minLrRatio + (1 - o.minLrRatio) * cos); | |
| } | |
| return o.lr; | |
| } | |
| _adamMeta(n, lr, gScale, b1c, b2c) { | |
| const o = this.opts; | |
| const buf = new ArrayBuffer(48); | |
| const dv = new DataView(buf); | |
| dv.setUint32(0, n >>> 0, true); | |
| dv.setFloat32(8, lr, true); | |
| dv.setFloat32(12, o.beta1, true); | |
| dv.setFloat32(16, o.beta2, true); | |
| dv.setFloat32(20, o.eps, true); | |
| dv.setFloat32(24, o.weightDecay, true); | |
| dv.setFloat32(28, gScale, true); | |
| dv.setFloat32(32, b1c, true); | |
| dv.setFloat32(36, b2c, true); | |
| return new Uint8Array(buf); | |
| } | |
| // ---- convenience: one full optimization step over a list of micro-batches ---- | |
| async trainStep(batches) { | |
| const list = Array.isArray(batches) ? batches : [batches]; | |
| let lossSum = 0, n = 0, numActive = 0, tokens = 0, microStepMs = 0; | |
| for (const b of list) { | |
| const r = await this.microStep(b.tokens, b.lossMask); | |
| lossSum += r.loss; | |
| numActive += r.numActive || 0; | |
| tokens += r.tokens || b.tokens?.length || 0; | |
| microStepMs += r.microStepMs || 0; | |
| n++; | |
| } | |
| const opt = await this.optimizerStep(); | |
| const totalStepMs = microStepMs + (opt.optimizerStepMs || 0); | |
| return { | |
| loss: lossSum / Math.max(1, n), | |
| microBatches: n, | |
| numActive, | |
| tokens, | |
| microStepMs, | |
| totalStepMs, | |
| trainTokPerSec: tokens / Math.max(1e-6, totalStepMs / 1e3), | |
| ...opt | |
| }; | |
| } | |
| }; | |
| // src/services/training_controller.js | |
| var IM_END = 151645; | |
| var TrainingController = class { | |
| static { | |
| __name(this, "TrainingController"); | |
| } | |
| // session: a loaded ModelSession (rt + tokenizer). adapters: AdapterRegistry. | |
| constructor({ session: session2, adapters: adapters2, log: log2 = /* @__PURE__ */ __name(() => { | |
| }, "log"), trainerOptions = {} } = {}) { | |
| this.session = session2; | |
| this.adapters = adapters2; | |
| this.log = log2; | |
| this.trainerOptions = trainerOptions; | |
| this.trainer = null; | |
| this.adapter = null; | |
| } | |
| get rt() { | |
| return this.session.rt; | |
| } | |
| get tokenizer() { | |
| return this.session.tokenizer; | |
| } | |
| // Create + register a fresh trainable adapter and attach the trainer to it. | |
| initAdapter(name = "trainable", { rank = 16, alpha = 32, targetModules } = {}) { | |
| const adapter = createTrainableAdapter(this.rt, { name, rank, alpha, targetModules }); | |
| this.adapters.adapters[name] = adapter; | |
| this.adapter = adapter; | |
| this.trainer = new QwenLoraTrainer(this.rt, this.trainerOptions); | |
| this.trainer.attach(adapter); | |
| this.log(`init adapter "${name}" rank=${rank} alpha=${alpha} modules=${Object.keys(adapter.modules).length}`); | |
| return adapter; | |
| } | |
| // Attach to an already-registered adapter (e.g. continue training a loaded one). | |
| attachAdapter(name) { | |
| const adapter = this.adapters.get(name); | |
| if (!adapter) throw new Error(`adapter "${name}" not found`); | |
| this.adapter = adapter; | |
| this.trainer = new QwenLoraTrainer(this.rt, this.trainerOptions); | |
| this.trainer.attach(adapter); | |
| return adapter; | |
| } | |
| /* | |
| * TECHNIQUE: Completion-only loss masking with shifted labels | |
| * Tokenize prompt (with assistant generation prompt) and completion separately. | |
| * mask[t]=1 trains the prediction of tokens[t+1] from position t — so we mask | |
| * positions whose NEXT token is part of the completion (incl. the final EOS). | |
| * Prompt tokens get mask=0, so the model is only graded on what it should write. | |
| */ | |
| prepareExample({ messages, prompt, completion, trainPromptToo = false }) { | |
| const tk = this.tokenizer; | |
| let promptIds; | |
| if (messages) { | |
| promptIds = tk.encode(formatMessages(tk, messages)); | |
| } else { | |
| promptIds = tk.encode(prompt); | |
| } | |
| const compIds = tk.encode(completion, { add_special_tokens: false }); | |
| const tokens = [...promptIds, ...compIds, IM_END]; | |
| const T = tokens.length; | |
| const lossMask = new Array(T).fill(0); | |
| const firstTrainPos = trainPromptToo ? 0 : Math.max(0, promptIds.length - 1); | |
| for (let t = firstTrainPos; t < T - 1; t++) lossMask[t] = 1; | |
| return { | |
| tokens, | |
| lossMask, | |
| promptLength: promptIds.length, | |
| completionLength: compIds.length, | |
| firstTrainPos | |
| }; | |
| } | |
| inspectExample(example) { | |
| const prepared = this.prepareExample(example); | |
| const { tokens, lossMask, promptLength, completionLength, firstTrainPos } = prepared; | |
| const rows = tokens.map((id, index) => { | |
| const targetId = index + 1 < tokens.length ? tokens[index + 1] : null; | |
| const segment = index < promptLength ? "prompt" : index < promptLength + completionLength ? "completion" : "eos"; | |
| return { | |
| index, | |
| id, | |
| text: decodeToken(this.tokenizer, id), | |
| segment, | |
| trainsNext: !!lossMask[index], | |
| targetId, | |
| targetText: targetId == null ? "" : decodeToken(this.tokenizer, targetId) | |
| }; | |
| }); | |
| return { | |
| ...prepared, | |
| trainPositions: lossMask.reduce((n, v) => n + (v ? 1 : 0), 0), | |
| firstTrainPos, | |
| rows | |
| }; | |
| } | |
| prepareBatch(examples) { | |
| return examples.map((e) => this.prepareExample(e)); | |
| } | |
| // One optimizer step over `microBatches` (array of {tokens, lossMask}); grads | |
| // accumulate across them, then a single AdamW update is applied. | |
| async step(microBatches) { | |
| if (!this.trainer) throw new Error("call initAdapter()/attachAdapter() first"); | |
| return this.trainer.trainStep(microBatches); | |
| } | |
| // Full training run over a dataset of examples. Honors gradAccumSteps by grouping | |
| // examples into accumulation windows. Calls onStep({step, loss, lr, gradNorm}). | |
| async train(examples, { epochs = 1, onStep = /* @__PURE__ */ __name(() => { | |
| }, "onStep"), maxTrainSeq } = {}) { | |
| if (!this.trainer) this.initAdapter(); | |
| const accum = this.trainer.opts.gradAccumSteps; | |
| const cap = maxTrainSeq ?? this.trainer.opts.maxTrainSeq; | |
| let globalStep = 0; | |
| for (let ep = 0; ep < epochs; ep++) { | |
| const order = shuffle([...Array(examples.length).keys()]); | |
| let window2 = []; | |
| for (const idx of order) { | |
| let mb = this.prepareExample(examples[idx]); | |
| if (mb.tokens.length > cap) mb = truncate(mb, cap); | |
| window2.push(mb); | |
| if (window2.length === accum) { | |
| const r = await this.step(window2); | |
| globalStep++; | |
| this.log(`step ${globalStep} epoch ${ep} loss=${r.loss.toFixed(4)} lr=${r.lr.toExponential(2)} |g|=${r.gradNorm.toFixed(3)}`); | |
| onStep({ step: globalStep, epoch: ep, ...r }); | |
| window2 = []; | |
| } | |
| } | |
| if (window2.length) { | |
| const r = await this.step(window2); | |
| globalStep++; | |
| onStep({ step: globalStep, epoch: ep, ...r }); | |
| } | |
| } | |
| this.adapters.applyToRuntime(this.adapter.name, this.rt); | |
| return { steps: globalStep, adapter: this.adapter }; | |
| } | |
| }; | |
| function truncate(mb, cap) { | |
| return { | |
| ...mb, | |
| tokens: mb.tokens.slice(0, cap), | |
| lossMask: mb.lossMask.slice(0, cap) | |
| }; | |
| } | |
| __name(truncate, "truncate"); | |
| function decodeToken(tokenizer, id) { | |
| try { | |
| if (tokenizer?.decode) return tokenizer.decode([id], { skip_special_tokens: false }); | |
| } catch { | |
| } | |
| return String(id); | |
| } | |
| __name(decodeToken, "decodeToken"); | |
| function shuffle(a) { | |
| for (let i = a.length - 1; i > 0; i--) { | |
| const j = Math.floor(Math.random() * (i + 1)); | |
| [a[i], a[j]] = [a[j], a[i]]; | |
| } | |
| return a; | |
| } | |
| __name(shuffle, "shuffle"); | |
| // src/lora_export.js | |
| async function readBufferF32(dev, src, byteLen) { | |
| const rb = dev.createBuffer({ size: byteLen, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }); | |
| const enc = dev.createCommandEncoder(); | |
| enc.copyBufferToBuffer(src, 0, rb, 0, byteLen); | |
| dev.queue.submit([enc.finish()]); | |
| await rb.mapAsync(GPUMapMode.READ); | |
| const out = new Float32Array(rb.getMappedRange().slice(0)); | |
| rb.unmap(); | |
| rb.destroy(); | |
| return out; | |
| } | |
| __name(readBufferF32, "readBufferF32"); | |
| function transpose2d(arr, rows, cols) { | |
| const o = new Float32Array(arr.length); | |
| for (let r = 0; r < rows; r++) for (let c = 0; c < cols; c++) o[c * rows + r] = arr[r * cols + c]; | |
| return o; | |
| } | |
| __name(transpose2d, "transpose2d"); | |
| function buildSafetensors(tensors, metadata = { format: "pt" }) { | |
| let offset = 0; | |
| const header = {}; | |
| if (metadata) header.__metadata__ = metadata; | |
| for (const t of tensors) { | |
| const bytes = t.data.byteLength; | |
| header[t.name] = { dtype: "F32", shape: t.shape, data_offsets: [offset, offset + bytes] }; | |
| offset += bytes; | |
| } | |
| let headerStr = JSON.stringify(header); | |
| const enc = new TextEncoder(); | |
| let headerBytes = enc.encode(headerStr); | |
| const pad = (8 - headerBytes.length % 8) % 8; | |
| if (pad) { | |
| headerStr += " ".repeat(pad); | |
| headerBytes = enc.encode(headerStr); | |
| } | |
| const total = 8 + headerBytes.length + offset; | |
| const buf = new ArrayBuffer(total); | |
| const dv = new DataView(buf); | |
| dv.setBigUint64(0, BigInt(headerBytes.length), true); | |
| new Uint8Array(buf, 8, headerBytes.length).set(headerBytes); | |
| let p = 8 + headerBytes.length; | |
| for (const t of tensors) { | |
| new Uint8Array(buf, p, t.data.byteLength).set(new Uint8Array(t.data.buffer, t.data.byteOffset, t.data.byteLength)); | |
| p += t.data.byteLength; | |
| } | |
| return new Uint8Array(buf); | |
| } | |
| __name(buildSafetensors, "buildSafetensors"); | |
| async function exportLoraAdapter(trainer, opts = {}) { | |
| const rt = trainer.rt; | |
| const dev = rt.dev; | |
| const tensors = []; | |
| const targets = /* @__PURE__ */ new Set(); | |
| const rankByKey = {}; | |
| const alphaByKey = {}; | |
| for (const key of trainer.trainedKeys) { | |
| const st = trainer.state[key]; | |
| const A = await readBufferF32(dev, st.mod.A, st.rank * st.K * 4); | |
| const B = await readBufferF32(dev, st.mod.B, st.rank * st.N * 4); | |
| const Bt = transpose2d(B, st.rank, st.N); | |
| const base = `base_model.model.model.${key}`; | |
| tensors.push({ name: `${base}.lora_A.weight`, shape: [st.rank, st.K], data: A }); | |
| tensors.push({ name: `${base}.lora_B.weight`, shape: [st.N, st.rank], data: Bt }); | |
| rankByKey[key] = st.rank; | |
| alphaByKey[key] = st.scale * st.rank; | |
| targets.add(key.split(".").pop()); | |
| } | |
| const safetensors = buildSafetensors(tensors); | |
| const ranks = Object.values(rankByKey); | |
| const alphas = Object.values(alphaByKey); | |
| const r = opts.rank ?? mode(ranks) ?? 0; | |
| const alpha = opts.alpha ?? mode(alphas) ?? 0; | |
| const rankPattern = {}; | |
| const alphaPattern = {}; | |
| for (const key of Object.keys(rankByKey)) { | |
| if (rankByKey[key] !== r) rankPattern[key] = rankByKey[key]; | |
| if (alphaByKey[key] !== alpha) alphaPattern[key] = alphaByKey[key]; | |
| } | |
| const config = { | |
| peft_type: "LORA", | |
| auto_mapping: null, | |
| base_model_name_or_path: opts.baseModel || "WeiboAI/VibeThinker-3B", | |
| r, | |
| lora_alpha: alpha, | |
| target_modules: [...targets], | |
| lora_dropout: 0, | |
| bias: "none", | |
| fan_in_fan_out: false, | |
| inference_mode: true, | |
| task_type: "CAUSAL_LM", | |
| ...Object.keys(rankPattern).length ? { rank_pattern: rankPattern } : {}, | |
| ...Object.keys(alphaPattern).length ? { alpha_pattern: alphaPattern } : {} | |
| }; | |
| const configJson = JSON.stringify(config, null, 2); | |
| return { safetensors, config, configJson }; | |
| } | |
| __name(exportLoraAdapter, "exportLoraAdapter"); | |
| function mode(arr) { | |
| if (!arr.length) return void 0; | |
| const counts = /* @__PURE__ */ new Map(); | |
| let best = arr[0], bestN = 0; | |
| for (const v of arr) { | |
| const n = (counts.get(v) || 0) + 1; | |
| counts.set(v, n); | |
| if (n > bestN) { | |
| bestN = n; | |
| best = v; | |
| } | |
| } | |
| return best; | |
| } | |
| __name(mode, "mode"); | |
| async function downloadLoraAdapter(trainer, opts = {}) { | |
| const { safetensors, configJson } = await exportLoraAdapter(trainer, opts); | |
| const stem = opts.name || trainer.adapter?.name || "adapter"; | |
| triggerDownload(new Blob([safetensors], { type: "application/octet-stream" }), `${stem}.safetensors`); | |
| triggerDownload(new Blob([configJson], { type: "application/json" }), "adapter_config.json"); | |
| } | |
| __name(downloadLoraAdapter, "downloadLoraAdapter"); | |
| function triggerDownload(blob, filename) { | |
| if (typeof document === "undefined") return; | |
| const url = URL.createObjectURL(blob); | |
| const a = document.createElement("a"); | |
| a.href = url; | |
| a.download = filename; | |
| document.body.appendChild(a); | |
| a.click(); | |
| a.remove(); | |
| setTimeout(() => URL.revokeObjectURL(url), 1e3); | |
| } | |
| __name(triggerDownload, "triggerDownload"); | |
| // src/lora_gpu.js | |
| function parseSt(buf) { | |
| const dv = new DataView(buf); | |
| const hl = Number(dv.getBigUint64(0, true)); | |
| const header = JSON.parse(new TextDecoder().decode(new Uint8Array(buf, 8, hl))); | |
| return { header, dataStart: 8 + hl, u8: new Uint8Array(buf) }; | |
| } | |
| __name(parseSt, "parseSt"); | |
| function bf16f32(u8, off, n) { | |
| const u16 = new Uint16Array(u8.buffer, u8.byteOffset + off, n); | |
| const o = new Float32Array(n); | |
| const o32 = new Uint32Array(o.buffer); | |
| for (let i = 0; i < n; i++) o32[i] = u16[i] << 16; | |
| return o; | |
| } | |
| __name(bf16f32, "bf16f32"); | |
| function f32(u8, off, n) { | |
| return new Float32Array(u8.buffer.slice(u8.byteOffset + off, u8.byteOffset + off + n * 4)); | |
| } | |
| __name(f32, "f32"); | |
| function readTensor(st, name) { | |
| const t = st.header[name]; | |
| const n = t.shape.reduce((a, b) => a * b, 1); | |
| const dt = t.dtype.toUpperCase(); | |
| const arr = dt === "BF16" ? bf16f32(st.u8, st.dataStart + t.data_offsets[0], n) : f32(st.u8, st.dataStart + t.data_offsets[0], n); | |
| return { arr, shape: t.shape }; | |
| } | |
| __name(readTensor, "readTensor"); | |
| var isA = /* @__PURE__ */ __name((name) => /lora_a/i.test(name), "isA"); | |
| function transpose2d2(arr, rows, cols) { | |
| const o = new Float32Array(arr.length); | |
| for (let r = 0; r < rows; r++) for (let c = 0; c < cols; c++) o[c * rows + r] = arr[r * cols + c]; | |
| return o; | |
| } | |
| __name(transpose2d2, "transpose2d"); | |
| async function loadLoraAdapterGPU(dev, files, cfg) { | |
| const stFile = files.find((f) => f.name.endsWith(".safetensors")); | |
| if (!stFile) throw new Error("no .safetensors in adapter files"); | |
| const cfgFile = files.find((f) => /adapter_config\.json|config\.json/.test(f.name)); | |
| let rankCfg = 16, scaleCfg = null; | |
| if (cfgFile) { | |
| const c = JSON.parse(await cfgFile.text()); | |
| const lp = c.lora_parameters || {}; | |
| rankCfg = c.r ?? c.rank ?? c.lora_rank ?? lp.rank ?? rankCfg; | |
| if (lp.scale != null) | |
| scaleCfg = lp.scale; | |
| else if (c.lora_alpha != null) | |
| scaleCfg = c.lora_alpha / rankCfg; | |
| else if (c.alpha != null) scaleCfg = c.alpha / rankCfg; | |
| } | |
| const st = parseSt(await stFile.arrayBuffer()); | |
| const names = Object.keys(st.header).filter((k) => k !== "__metadata__" && /lora_[abAB]/.test(k)); | |
| const groups = {}; | |
| for (const nm of names) { | |
| const key = moduleKeyFromTensorName(nm); | |
| if (!key) continue; | |
| (groups[key] ||= {})[isA(nm) ? "A" : "B"] = readTensor(st, nm); | |
| } | |
| const S = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST; | |
| const mk = /* @__PURE__ */ __name((arr) => { | |
| const b = dev.createBuffer({ size: arr.byteLength, usage: S }); | |
| dev.queue.writeBuffer(b, 0, arr); | |
| return b; | |
| }, "mk"); | |
| const modules = {}; | |
| for (const key of Object.keys(groups)) { | |
| const g = groups[key]; | |
| if (!g.A || !g.B) continue; | |
| const r = Math.min(...g.A.shape, ...g.B.shape); | |
| let Aarr = g.A.arr; | |
| if (g.A.shape[0] !== r) Aarr = transpose2d2(g.A.arr, g.A.shape[0], g.A.shape[1]); | |
| let Barr = g.B.arr; | |
| if (g.B.shape[0] !== r) Barr = transpose2d2(g.B.arr, g.B.shape[0], g.B.shape[1]); | |
| const scale = scaleCfg != null ? scaleCfg : 2; | |
| modules[key] = { A: mk(Aarr), B: mk(Barr), rawA: Aarr, rawB: Barr, rank: r, scale }; | |
| } | |
| if (!Object.keys(modules).length) throw new Error("no LoRA modules matched layers.*.{self_attn,mlp}.*_proj"); | |
| const name = stFile.name.replace(/\.safetensors$/, ""); | |
| return { name, modules }; | |
| } | |
| __name(loadLoraAdapterGPU, "loadLoraAdapterGPU"); | |
| // src/services/store.js | |
| var store_exports = {}; | |
| __export(store_exports, { | |
| connectDirectory: () => connectDirectory, | |
| deleteRun: () => deleteRun, | |
| ensurePermission: () => ensurePermission, | |
| forgetDirectory: () => forgetDirectory, | |
| fsSupported: () => fsSupported, | |
| getRun: () => getRun, | |
| getRunBlobs: () => getRunBlobs, | |
| listRuns: () => listRuns, | |
| loadRunFiles: () => loadRunFiles, | |
| newId: () => newId, | |
| readDirText: () => readDirText, | |
| saveRun: () => saveRun, | |
| savedDirectory: () => savedDirectory, | |
| writeFileToDir: () => writeFileToDir | |
| }); | |
| var LS_KEY = "emberglass.history.v2"; | |
| var DB_NAME = "emberglass"; | |
| var DB_VERSION = 1; | |
| var BLOB_STORE = "adapters"; | |
| var HANDLE_STORE = "handles"; | |
| var _dbp = null; | |
| function db() { | |
| if (_dbp) return _dbp; | |
| _dbp = new Promise((resolve, reject) => { | |
| const r = indexedDB.open(DB_NAME, DB_VERSION); | |
| r.onupgradeneeded = () => { | |
| const d = r.result; | |
| if (!d.objectStoreNames.contains(BLOB_STORE)) d.createObjectStore(BLOB_STORE); | |
| if (!d.objectStoreNames.contains(HANDLE_STORE)) d.createObjectStore(HANDLE_STORE); | |
| }; | |
| r.onsuccess = () => resolve(r.result); | |
| r.onerror = () => reject(r.error); | |
| }); | |
| return _dbp; | |
| } | |
| __name(db, "db"); | |
| async function idbPut(store, key, val) { | |
| const d = await db(); | |
| return new Promise((res, rej) => { | |
| const tx = d.transaction(store, "readwrite"); | |
| tx.objectStore(store).put(val, key); | |
| tx.oncomplete = () => res(); | |
| tx.onerror = () => rej(tx.error); | |
| }); | |
| } | |
| __name(idbPut, "idbPut"); | |
| async function idbGet(store, key) { | |
| const d = await db(); | |
| return new Promise((res, rej) => { | |
| const tx = d.transaction(store, "readonly"); | |
| const rq = tx.objectStore(store).get(key); | |
| rq.onsuccess = () => res(rq.result); | |
| rq.onerror = () => rej(rq.error); | |
| }); | |
| } | |
| __name(idbGet, "idbGet"); | |
| async function idbDel(store, key) { | |
| const d = await db(); | |
| return new Promise((res, rej) => { | |
| const tx = d.transaction(store, "readwrite"); | |
| tx.objectStore(store).delete(key); | |
| tx.oncomplete = () => res(); | |
| tx.onerror = () => rej(tx.error); | |
| }); | |
| } | |
| __name(idbDel, "idbDel"); | |
| function listRuns() { | |
| try { | |
| const a = JSON.parse(localStorage.getItem(LS_KEY) || "[]"); | |
| return Array.isArray(a) ? a : []; | |
| } catch { | |
| return []; | |
| } | |
| } | |
| __name(listRuns, "listRuns"); | |
| function writeIndex(arr) { | |
| try { | |
| localStorage.setItem(LS_KEY, JSON.stringify(arr)); | |
| } catch (e) { | |
| console.warn("[store] localStorage write failed", e); | |
| } | |
| } | |
| __name(writeIndex, "writeIndex"); | |
| function getRun(id) { | |
| return listRuns().find((r) => r.id === id) || null; | |
| } | |
| __name(getRun, "getRun"); | |
| function newId() { | |
| return "run_" + Date.now().toString(36) + "_" + Math.random().toString(36).slice(2, 7); | |
| } | |
| __name(newId, "newId"); | |
| async function saveRun(meta, files) { | |
| const stBytes = files.safetensors instanceof Uint8Array ? files.safetensors : new Uint8Array(files.safetensors); | |
| await idbPut(BLOB_STORE, meta.id, { | |
| safetensors: new Blob([stBytes], { type: "application/octet-stream" }), | |
| configJson: files.configJson || "{}" | |
| }); | |
| const idx = listRuns().filter((r) => r.id !== meta.id); | |
| idx.unshift(meta); | |
| writeIndex(idx); | |
| return meta; | |
| } | |
| __name(saveRun, "saveRun"); | |
| async function deleteRun(id) { | |
| writeIndex(listRuns().filter((r) => r.id !== id)); | |
| try { | |
| await idbDel(BLOB_STORE, id); | |
| } catch { | |
| } | |
| } | |
| __name(deleteRun, "deleteRun"); | |
| async function loadRunFiles(id) { | |
| const rec = await idbGet(BLOB_STORE, id); | |
| if (!rec) throw new Error("adapter blob missing for " + id); | |
| const meta = getRun(id); | |
| const stem = (meta?.name || id).replace(/[^\w.-]+/g, "_"); | |
| return [ | |
| new File([rec.safetensors], `${stem}.safetensors`, { type: "application/octet-stream" }), | |
| new File([rec.configJson], "adapter_config.json", { type: "application/json" }) | |
| ]; | |
| } | |
| __name(loadRunFiles, "loadRunFiles"); | |
| async function getRunBlobs(id) { | |
| const rec = await idbGet(BLOB_STORE, id); | |
| if (!rec) throw new Error("adapter blob missing for " + id); | |
| return { safetensors: rec.safetensors, configJson: rec.configJson }; | |
| } | |
| __name(getRunBlobs, "getRunBlobs"); | |
| var fsSupported = typeof window !== "undefined" && "showDirectoryPicker" in window; | |
| async function connectDirectory() { | |
| if (!fsSupported) throw new Error("File System Access API not available in this browser"); | |
| const handle = await window.showDirectoryPicker({ id: "emberglass", mode: "readwrite" }); | |
| await idbPut(HANDLE_STORE, "dir", handle); | |
| return handle; | |
| } | |
| __name(connectDirectory, "connectDirectory"); | |
| async function savedDirectory() { | |
| if (!fsSupported) return null; | |
| try { | |
| return await idbGet(HANDLE_STORE, "dir") || null; | |
| } catch { | |
| return null; | |
| } | |
| } | |
| __name(savedDirectory, "savedDirectory"); | |
| async function forgetDirectory() { | |
| try { | |
| await idbDel(HANDLE_STORE, "dir"); | |
| } catch { | |
| } | |
| } | |
| __name(forgetDirectory, "forgetDirectory"); | |
| async function ensurePermission(handle, mode2 = "readwrite") { | |
| if (!handle) return false; | |
| const opts = { mode: mode2 }; | |
| if (await handle.queryPermission(opts) === "granted") return true; | |
| return await handle.requestPermission(opts) === "granted"; | |
| } | |
| __name(ensurePermission, "ensurePermission"); | |
| async function readDirText(handle, { exts = ["txt", "md", "json", "csv"], maxChars = 2e5 } = {}) { | |
| let out = ""; | |
| const names = []; | |
| for await (const [name, h] of handle.entries()) { | |
| if (h.kind !== "file") continue; | |
| const ext = name.split(".").pop().toLowerCase(); | |
| if (!exts.includes(ext)) continue; | |
| try { | |
| const f = await h.getFile(); | |
| out += ` | |
| # ${name} | |
| ` + await f.text(); | |
| names.push(name); | |
| if (out.length > maxChars) break; | |
| } catch { | |
| } | |
| } | |
| return { text: out.slice(0, maxChars), names }; | |
| } | |
| __name(readDirText, "readDirText"); | |
| async function writeFileToDir(handle, name, data) { | |
| const fh = await handle.getFileHandle(name, { create: true }); | |
| const w = await fh.createWritable(); | |
| await w.write(data); | |
| await w.close(); | |
| } | |
| __name(writeFileToDir, "writeFileToDir"); | |
| // src/skills.js | |
| function specSig(spec) { | |
| return spec.ops.map((o) => `${o.name}(${(o.params || []).join(", ")})${o.ret ? " -> " + o.ret : ""}`).join("; "); | |
| } | |
| __name(specSig, "specSig"); | |
| function skillSystem(domain, spec) { | |
| return `You are ${domain}. Convert the request into a macro using ONLY these operations: | |
| ` + specSig(spec) + `. | |
| Output ONLY the macro, one call per line, no prose. If the request is outside ${spec.scope}, output exactly: OUT_OF_SCOPE.`; | |
| } | |
| __name(skillSystem, "skillSystem"); | |
| function parseMacroCalls(text) { | |
| const out = []; | |
| for (const raw of String(text).split("\n")) { | |
| const line = raw.trim(); | |
| if (!line || line === "OUT_OF_SCOPE") continue; | |
| const m = line.match(/^(?:[A-Za-z_]\w*\s*=\s*)?([A-Za-z_]\w*)\s*\((.*)\)\s*;?\s*$/); | |
| if (!m) continue; | |
| const keys = [...m[2].matchAll(/(?:^|,)\s*([A-Za-z_]\w*)\s*=/g)].map((k) => k[1]); | |
| out.push({ op: m[1], keys }); | |
| } | |
| return out; | |
| } | |
| __name(parseMacroCalls, "parseMacroCalls"); | |
| function verifyMacro(text, spec) { | |
| const t = String(text); | |
| const calls = parseMacroCalls(t); | |
| const bounced = /(^|\n)\s*OUT_OF_SCOPE\s*($|\n)/.test(t) && calls.length === 0; | |
| if (bounced) return { status: "oos", calls: [], issues: [], n: 0 }; | |
| if (!calls.length) return { status: "empty", calls: [], issues: [], n: 0 }; | |
| const byName = new Map(spec.ops.map((o) => [o.name, o])); | |
| const issues = []; | |
| const detail = []; | |
| for (const c of calls) { | |
| const op = byName.get(c.op); | |
| if (!op) { | |
| issues.push(`unknown op: ${c.op}`); | |
| detail.push({ op: c.op, ok: false }); | |
| continue; | |
| } | |
| const allowed = new Set(op.params || []); | |
| const bad = c.keys.filter((k) => !allowed.has(k)); | |
| if (bad.length) { | |
| issues.push(`${c.op}: unexpected arg ${bad.join(", ")}`); | |
| detail.push({ op: c.op, ok: false }); | |
| } else detail.push({ op: c.op, ok: true }); | |
| } | |
| return { status: issues.length ? "bad" : "ok", calls: detail, issues, n: calls.length }; | |
| } | |
| __name(verifyMacro, "verifyMacro"); | |
| function hashStr(s) { | |
| let h = 2166136261; | |
| for (let i = 0; i < s.length; i++) { | |
| h ^= s.charCodeAt(i); | |
| h = Math.imul(h, 16777619); | |
| } | |
| return h >>> 0; | |
| } | |
| __name(hashStr, "hashStr"); | |
| function mulberry32(a) { | |
| return function() { | |
| a |= 0; | |
| a = a + 1831565813 | 0; | |
| let t = Math.imul(a ^ a >>> 15, 1 | a); | |
| t = t + Math.imul(t ^ t >>> 7, 61 | t) ^ t; | |
| return ((t ^ t >>> 14) >>> 0) / 4294967296; | |
| }; | |
| } | |
| __name(mulberry32, "mulberry32"); | |
| function fill(tpl, choice) { | |
| return tpl.replace(/\{(\w+)\}/g, (_, k) => k in choice ? choice[k] : "{" + k + "}"); | |
| } | |
| __name(fill, "fill"); | |
| function expand(def, perTemplate) { | |
| const rnd = mulberry32(hashStr(def.key)); | |
| const out = []; | |
| const seen = /* @__PURE__ */ new Set(); | |
| for (const t of def.templates || []) { | |
| const slots = [...new Set([...t.req.matchAll(/\{(\w+)\}/g)].map((m) => m[1]))]; | |
| let made = 0, tries = 0; | |
| const cap = perTemplate * 8; | |
| while (made < perTemplate && tries < cap) { | |
| tries++; | |
| const choice = {}; | |
| for (const s of slots) { | |
| const arr = def.vocab[s] || ["x"]; | |
| choice[s] = arr[Math.floor(rnd() * arr.length)]; | |
| } | |
| const req = fill(t.req, choice); | |
| if (seen.has(req)) continue; | |
| seen.add(req); | |
| out.push([req, fill(t.macro, choice)]); | |
| made++; | |
| } | |
| } | |
| return out; | |
| } | |
| __name(expand, "expand"); | |
| function buildSkill(def, perTemplate = 6) { | |
| const spec = { scope: def.scope, ops: def.ops }; | |
| const examples = [ | |
| ...def.fixed || [], | |
| ...expand(def, perTemplate), | |
| ...(def.oos || []).map((r) => [r, "OUT_OF_SCOPE"]) | |
| ]; | |
| return { | |
| key: def.key, | |
| label: def.label, | |
| icon: def.icon, | |
| desc: def.desc, | |
| domain: def.domain, | |
| spec, | |
| system: skillSystem(def.domain, spec), | |
| suggest: def.suggest, | |
| examples | |
| }; | |
| } | |
| __name(buildSkill, "buildSkill"); | |
| var PEOPLE = ["mom", "Sarah", "Alex", "the design team", "my manager", "Priya", "John", "the landlord", "accounting", "Dana"]; | |
| var TOPICS = ["the Q3 roadmap", "the launch", "the budget", "onboarding", "the API redesign", "the offsite", "the bug report", "the contract"]; | |
| var WHENS = ["today 17:00", "tomorrow 09:00", "Friday 14:00", "next Monday 10:00", "Thursday 16:30", "tonight 19:00"]; | |
| var DEFS = [ | |
| { | |
| key: "inbox-calendar", | |
| label: "Inbox & Calendar", | |
| icon: "\u2709", | |
| domain: "an Inbox & Calendar operator", | |
| scope: "inbox or calendar", | |
| desc: "Compiles requests like \u201Cemail my mom and book a reminder to respond\u201D into a verifiable macro over a fixed set of inbox/calendar actions; bounces anything else.", | |
| suggest: "Email the design team this week's notes, then put a 30-minute review on my calendar for Monday morning.", | |
| ops: [ | |
| { name: "find_email", params: ["query"], ret: "thread" }, | |
| { name: "compose_email", params: ["to", "subject", "body"] }, | |
| { name: "reply_email", params: ["thread", "body"] }, | |
| { name: "forward_email", params: ["thread", "to", "note"] }, | |
| { name: "archive_email", params: ["thread"] }, | |
| { name: "label_email", params: ["thread", "label"] }, | |
| { name: "schedule_send", params: ["to", "subject", "body", "when"] }, | |
| { name: "create_event", params: ["title", "start", "end", "remind_min"] }, | |
| { name: "set_reminder", params: ["text", "when"] }, | |
| { name: "find_slot", params: ["duration_min", "after", "before"], ret: "slot" }, | |
| { name: "rsvp", params: ["event", "response"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "email my mom and book a calendar event to remind me to respond", | |
| 'compose_email(to="mom", subject="Hi mom", body="Just checking in \u2014 talk soon!")\ncreate_event(title="Respond to mom", start="tomorrow 09:00", end="tomorrow 09:15", remind_min=10)' | |
| ], | |
| [ | |
| "schedule a 30 minute focus block tomorrow afternoon", | |
| 's = find_slot(duration_min=30, after="tomorrow 13:00", before="tomorrow 18:00")\ncreate_event(title="Focus block", start=s.start, end=s.end, remind_min=5)' | |
| ], | |
| [ | |
| "reply yes to the standup invite and add it to my calendar", | |
| 't = find_email(query="standup invite")\nrsvp(event=t, response="yes")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "email {person} about {topic}", macro: 'compose_email(to="{person}", subject="{topic}", body="Quick note about {topic}.")' }, | |
| { req: "remind me to follow up on {topic} {when}", macro: 'set_reminder(text="Follow up on {topic}", when="{when}")' }, | |
| { req: "find the email from {person} and reply that I will review it by {when}", macro: 't = find_email(query="from:{person}")\nreply_email(thread=t, body="Thanks \u2014 I will review this by {when}.")' }, | |
| { req: "archive the emails about {topic}", macro: 't = find_email(query="{topic}")\narchive_email(thread=t)' }, | |
| { req: "forward the {topic} email to {person}", macro: 't = find_email(query="{topic}")\nforward_email(thread=t, to="{person}", note="FYI \u2014 for your records.")' }, | |
| { req: "label the email from {person} as {label}", macro: 't = find_email(query="from:{person}")\nlabel_email(thread=t, label="{label}")' }, | |
| { req: "send {person} a note {when} saying thanks for {topic}", macro: 'schedule_send(to="{person}", subject="Thank you", body="Thanks for {topic}.", when="{when}")' }, | |
| { req: "set up a meeting about {topic} with {person} {when} for 30 minutes", macro: 'create_event(title="{topic} with {person}", start="{when}", end="{when}", remind_min=10)' }, | |
| { req: "find a 45 minute slot {when} and book {topic}", macro: 's = find_slot(duration_min=45, after="{when}", before="{when}")\ncreate_event(title="{topic}", start=s.start, end=s.end, remind_min=10)' } | |
| ], | |
| vocab: { person: PEOPLE, topic: TOPICS, when: WHENS, label: ["housing", "urgent", "finance", "travel", "follow-up", "receipts"] }, | |
| oos: ["order me a pizza", "what is the capital of France?", "play some jazz"] | |
| }, | |
| { | |
| key: "music", | |
| label: "Music", | |
| icon: "\u266A", | |
| domain: "a music player operator", | |
| scope: "music playback", | |
| desc: "Turns \u201Cplay some lo-fi and turn it down\u201D into a macro over a music action space \u2014 find/play/queue/volume/playlist \u2014 and bounces non-music asks.", | |
| suggest: "Play something upbeat for cooking and add it to a new playlist called Dinner.", | |
| ops: [ | |
| { name: "find_track", params: ["query"], ret: "track" }, | |
| { name: "play_track", params: ["track"] }, | |
| { name: "queue_track", params: ["track"] }, | |
| { name: "pause", params: [] }, | |
| { name: "skip", params: [] }, | |
| { name: "previous", params: [] }, | |
| { name: "set_volume", params: ["level"] }, | |
| { name: "create_playlist", params: ["name"] }, | |
| { name: "add_to_playlist", params: ["playlist", "track"] }, | |
| { name: "shuffle", params: ["on"] }, | |
| { name: "repeat", params: ["mode"] } | |
| ], | |
| fixed: [ | |
| ["skip this song", "skip()"], | |
| ["pause the music", "pause()"], | |
| ["go back to the previous song", "previous()"] | |
| ], | |
| templates: [ | |
| { req: "play some {genre}", macro: 't = find_track(query="{genre}")\nplay_track(track=t)' }, | |
| { req: "queue up {artist} after this", macro: 't = find_track(query="{artist}")\nqueue_track(track=t)' }, | |
| { req: "set the volume to {vol}", macro: "set_volume(level={vol})" }, | |
| { req: "make a playlist called {name}", macro: 'create_playlist(name="{name}")' }, | |
| { req: "add {artist} to my {name} playlist", macro: 't = find_track(query="{artist}")\nadd_to_playlist(playlist="{name}", track=t)' }, | |
| { req: "shuffle my {name} playlist", macro: 'shuffle(on=true)\nt = find_track(query="playlist:{name}")\nplay_track(track=t)' }, | |
| { req: "put on {artist} and turn it up", macro: 't = find_track(query="{artist}")\nplay_track(track=t)\nset_volume(level=80)' }, | |
| { req: "repeat this {mode}", macro: 'repeat(mode="{mode}")' } | |
| ], | |
| vocab: { | |
| genre: ["lo-fi beats", "deep house", "classic jazz", "pop hits", "ambient", "classical", "90s hip hop", "indie rock"], | |
| artist: ["Taylor Swift", "The Beatles", "Daft Punk", "Miles Davis", "Radiohead", "Bad Bunny", "Fleetwood Mac"], | |
| name: ["Focus", "Workout", "Dinner", "Chill", "Road Trip", "Sleep"], | |
| vol: ["10", "25", "40", "60", "75", "90"], | |
| mode: ["one", "all"] | |
| }, | |
| oos: ["email my boss", "what is the weather today?", "open an issue on the repo"] | |
| }, | |
| { | |
| key: "github", | |
| label: "GitHub", | |
| icon: "\u{1F419}", | |
| domain: "a GitHub operator", | |
| scope: "GitHub repositories, issues, and pull requests", | |
| desc: "Compiles dev requests into a macro over issues, pull requests, and repos; bounces anything that isn\u2019t GitHub.", | |
| suggest: 'Open an issue on the api repo titled "fix login redirect", then assign it to Dana.', | |
| ops: [ | |
| { name: "find_issue", params: ["query"], ret: "issue" }, | |
| { name: "create_issue", params: ["repo", "title", "body"] }, | |
| { name: "comment_issue", params: ["issue", "body"] }, | |
| { name: "close_issue", params: ["issue"] }, | |
| { name: "assign_issue", params: ["issue", "assignee"] }, | |
| { name: "label_issue", params: ["issue", "label"] }, | |
| { name: "find_pr", params: ["query"], ret: "pr" }, | |
| { name: "open_pr", params: ["repo", "title", "branch"] }, | |
| { name: "review_pr", params: ["pr", "verdict"] }, | |
| { name: "merge_pr", params: ["pr"] }, | |
| { name: "create_repo", params: ["name", "visibility"] }, | |
| { name: "star_repo", params: ["repo"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "open an issue on the api repo titled fix login redirect and assign it to Dana", | |
| 'i = create_issue(repo="api", title="fix login redirect", body="The login flow redirects to the wrong page.")\nassign_issue(issue=i, assignee="Dana")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "open an issue on {repo} titled {title}", macro: 'create_issue(repo="{repo}", title="{title}", body="{title}.")' }, | |
| { req: "close the {topic} issue", macro: 'i = find_issue(query="{topic}")\nclose_issue(issue=i)' }, | |
| { req: "comment {comment} on the {topic} issue", macro: 'i = find_issue(query="{topic}")\ncomment_issue(issue=i, body="{comment}")' }, | |
| { req: "assign the {topic} issue to {user}", macro: 'i = find_issue(query="{topic}")\nassign_issue(issue=i, assignee="{user}")' }, | |
| { req: "label the {topic} issue as {label}", macro: 'i = find_issue(query="{topic}")\nlabel_issue(issue=i, label="{label}")' }, | |
| { req: "open a pull request on {repo} from {branch} titled {title}", macro: 'open_pr(repo="{repo}", title="{title}", branch="{branch}")' }, | |
| { req: "approve the {topic} pull request", macro: 'p = find_pr(query="{topic}")\nreview_pr(pr=p, verdict="approve")' }, | |
| { req: "merge the {topic} PR", macro: 'p = find_pr(query="{topic}")\nmerge_pr(pr=p)' }, | |
| { req: "create a private repo called {repo}", macro: 'create_repo(name="{repo}", visibility="private")' }, | |
| { req: "star the {repo} repo", macro: 'star_repo(repo="{repo}")' } | |
| ], | |
| vocab: { | |
| repo: ["api", "frontend", "docs", "infra", "mobile-app", "design-system"], | |
| title: ["fix login redirect", "add dark mode", "update README", "flaky test fix", "bump dependencies", "improve error logs"], | |
| topic: ["login", "dark mode", "flaky test", "memory leak", "rate limiting", "docs typo"], | |
| comment: ["looks good to me", "can you add a test?", "I will pick this up", "reproduced on main", "duplicate of #42"], | |
| user: ["Dana", "Alex", "Priya", "the on-call", "Sam"], | |
| label: ["bug", "enhancement", "good first issue", "p1", "docs", "wontfix"], | |
| branch: ["feature/auth", "fix/cache", "chore/deps", "feat/ui", "hotfix/crash"] | |
| }, | |
| oos: ["play some music", "email my mom", "what is 2 + 2?"] | |
| }, | |
| { | |
| key: "slack", | |
| label: "Slack", | |
| icon: "\u{1F4AC}", | |
| domain: "a Slack operator", | |
| scope: "Slack messaging", | |
| desc: "Compiles team-chat requests into a macro over channels, DMs, threads, and reminders; bounces non-Slack asks.", | |
| suggest: "Post the release notes in #launch and DM Dana to review them.", | |
| ops: [ | |
| { name: "find_message", params: ["query"], ret: "message" }, | |
| { name: "send_message", params: ["channel", "text"] }, | |
| { name: "dm", params: ["user", "text"] }, | |
| { name: "reply_thread", params: ["message", "text"] }, | |
| { name: "react", params: ["message", "emoji"] }, | |
| { name: "set_status", params: ["text", "emoji"] }, | |
| { name: "create_channel", params: ["name"] }, | |
| { name: "invite", params: ["user", "channel"] }, | |
| { name: "remind", params: ["text", "when"] }, | |
| { name: "pin", params: ["message"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "post the release notes in #launch and dm Dana to review them", | |
| 'send_message(channel="launch", text="Release notes are up \u2014 please review.")\ndm(user="Dana", text="Can you review the release notes I posted in #launch?")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "post {text} in #{channel}", macro: 'send_message(channel="{channel}", text="{text}")' }, | |
| { req: "dm {user} {text}", macro: 'dm(user="{user}", text="{text}")' }, | |
| { req: "reply {text} to the {topic} thread", macro: 'm = find_message(query="{topic}")\nreply_thread(message=m, text="{text}")' }, | |
| { req: "react {emoji} to the {topic} message", macro: 'm = find_message(query="{topic}")\nreact(message=m, emoji="{emoji}")' }, | |
| { req: "set my status to {text}", macro: 'set_status(text="{text}", emoji="{emoji}")' }, | |
| { req: "create a channel called {channel}", macro: 'create_channel(name="{channel}")' }, | |
| { req: "invite {user} to #{channel}", macro: 'invite(user="{user}", channel="{channel}")' }, | |
| { req: "remind the team to {task} {when}", macro: 'remind(text="{task}", when="{when}")' }, | |
| { req: "pin the {topic} message", macro: 'm = find_message(query="{topic}")\npin(message=m)' } | |
| ], | |
| vocab: { | |
| channel: ["launch", "general", "engineering", "design", "random", "incidents"], | |
| user: ["Dana", "Alex", "Priya", "Sam", "the team lead"], | |
| text: ["standup in 5", "PR is ready for review", "deploy is green", "lunch at noon?", "great work today"], | |
| topic: ["deploy", "incident", "roadmap", "lunch", "release"], | |
| emoji: [":eyes:", ":white_check_mark:", ":tada:", ":fire:", ":+1:"], | |
| task: ["submit timesheets", "join the retro", "review the doc", "update the board"], | |
| when: WHENS | |
| }, | |
| oos: ["play a song", "order groceries", "what time is it in Tokyo?"] | |
| }, | |
| { | |
| key: "notion", | |
| label: "Notion", | |
| icon: "\u{1F4DD}", | |
| domain: "a Notion operator", | |
| scope: "Notion pages, notes, and tasks", | |
| desc: "Compiles note-taking requests into a macro over pages, blocks, tasks, and databases; bounces anything else.", | |
| suggest: 'Create a page titled "Trip plan" and add a task to book flights due Friday.', | |
| ops: [ | |
| { name: "find_page", params: ["query"], ret: "page" }, | |
| { name: "create_page", params: ["title", "body"] }, | |
| { name: "append_block", params: ["page", "text"] }, | |
| { name: "create_task", params: ["title", "due"] }, | |
| { name: "complete_task", params: ["task"] }, | |
| { name: "find_task", params: ["query"], ret: "task" }, | |
| { name: "add_to_database", params: ["database", "name"] }, | |
| { name: "set_property", params: ["page", "key", "value"] }, | |
| { name: "create_database", params: ["name"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "create a page titled Trip plan and add a task to book flights due Friday", | |
| 'create_page(title="Trip plan", body="Planning notes.")\ncreate_task(title="Book flights", due="Friday")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "create a page titled {title}", macro: 'create_page(title="{title}", body="{title} \u2014 notes.")' }, | |
| { req: "add a note {text} to the {topic} page", macro: 'p = find_page(query="{topic}")\nappend_block(page=p, text="{text}")' }, | |
| { req: "add a task to {task} due {when}", macro: 'create_task(title="{task}", due="{when}")' }, | |
| { req: "mark the {task} task done", macro: 't = find_task(query="{task}")\ncomplete_task(task=t)' }, | |
| { req: "add {name} to my {database} database", macro: 'add_to_database(database="{database}", name="{name}")' }, | |
| { req: "set the status of the {topic} page to {value}", macro: 'p = find_page(query="{topic}")\nset_property(page=p, key="status", value="{value}")' }, | |
| { req: "create a database called {database}", macro: 'create_database(name="{database}")' } | |
| ], | |
| vocab: { | |
| title: ["Trip plan", "Q3 goals", "Reading list", "Meeting notes", "Project brief", "Recipes"], | |
| text: ["remember to confirm the budget", "add the agenda", "link the spec", "note the blockers"], | |
| topic: ["trip", "goals", "project", "meeting", "reading"], | |
| task: ["book flights", "draft the brief", "email the vendor", "review the PR", "pay the invoice"], | |
| when: ["today", "tomorrow", "Friday", "next week", "end of month"], | |
| name: ["Acme Co", "Q3 launch", "Vendor X", "Idea: dark mode"], | |
| database: ["Projects", "CRM", "Tasks", "Reading", "Inventory"], | |
| value: ["in progress", "done", "blocked", "todo", "review"] | |
| }, | |
| oos: ["play music", "navigate home", "send a tweet"] | |
| }, | |
| { | |
| key: "x", | |
| label: "X", | |
| icon: "\u{1D54F}", | |
| domain: "an X (Twitter) operator", | |
| scope: "posting and engagement on X", | |
| desc: "Compiles social requests into a macro over posts, replies, reposts, follows, and DMs; bounces anything off-platform.", | |
| suggest: 'Post "shipping something fun today \u{1F680}" and schedule a follow-up for 5pm.', | |
| ops: [ | |
| { name: "find_post", params: ["query"], ret: "post" }, | |
| { name: "post", params: ["text"] }, | |
| { name: "reply", params: ["post", "text"] }, | |
| { name: "repost", params: ["post"] }, | |
| { name: "like", params: ["post"] }, | |
| { name: "follow", params: ["user"] }, | |
| { name: "dm", params: ["user", "text"] }, | |
| { name: "schedule_post", params: ["text", "when"] }, | |
| { name: "bookmark", params: ["post"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "post shipping something fun today and schedule a follow up for 5pm", | |
| 'post(text="shipping something fun today \u{1F680}")\nschedule_post(text="more details soon \u2014 stay tuned", when="today 17:00")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "post {text}", macro: 'post(text="{text}")' }, | |
| { req: "reply {text} to the {topic} post", macro: 'p = find_post(query="{topic}")\nreply(post=p, text="{text}")' }, | |
| { req: "repost the {topic} tweet", macro: 'p = find_post(query="{topic}")\nrepost(post=p)' }, | |
| { req: "like the {topic} post", macro: 'p = find_post(query="{topic}")\nlike(post=p)' }, | |
| { req: "follow {user}", macro: 'follow(user="{user}")' }, | |
| { req: "dm {user} {text}", macro: 'dm(user="{user}", text="{text}")' }, | |
| { req: "schedule a post {when} saying {text}", macro: 'schedule_post(text="{text}", when="{when}")' }, | |
| { req: "bookmark the {topic} thread", macro: 'p = find_post(query="{topic}")\nbookmark(post=p)' } | |
| ], | |
| vocab: { | |
| text: ["gm", "big news coming", "loved this talk", "hot take: tabs > spaces", "thanks for 10k followers"], | |
| topic: ["the launch", "the keynote", "the meme", "the thread on AI", "the announcement"], | |
| user: ["@levelsio", "@naval", "@swyx", "@dhh", "@karpathy"], | |
| when: WHENS | |
| }, | |
| oos: ["archive my inbox", "play a playlist", "open a GitHub issue"] | |
| }, | |
| { | |
| key: "instagram", | |
| label: "Instagram", | |
| icon: "\u{1F4F7}", | |
| domain: "an Instagram operator", | |
| scope: "Instagram posts, stories, and DMs", | |
| desc: "Compiles requests into a macro over photo posts, stories, comments, and DMs; bounces anything off-platform.", | |
| suggest: 'Post a photo with caption "sunset run \u{1F305}" and share it to my story.', | |
| ops: [ | |
| { name: "find_post", params: ["query"], ret: "post" }, | |
| { name: "post_photo", params: ["caption", "media"] }, | |
| { name: "post_story", params: ["media"] }, | |
| { name: "reply_dm", params: ["user", "text"] }, | |
| { name: "like_post", params: ["post"] }, | |
| { name: "comment", params: ["post", "text"] }, | |
| { name: "follow", params: ["user"] }, | |
| { name: "save_post", params: ["post"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "post a photo with caption sunset run and share it to my story", | |
| 'post_photo(caption="sunset run \u{1F305}", media="latest")\npost_story(media="latest")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "post a photo with caption {caption}", macro: 'post_photo(caption="{caption}", media="latest")' }, | |
| { req: "share {media} to my story", macro: 'post_story(media="{media}")' }, | |
| { req: "comment {text} on the {topic} post", macro: 'p = find_post(query="{topic}")\ncomment(post=p, text="{text}")' }, | |
| { req: "like the {topic} post", macro: 'p = find_post(query="{topic}")\nlike_post(post=p)' }, | |
| { req: "reply {text} to {user} in DMs", macro: 'reply_dm(user="{user}", text="{text}")' }, | |
| { req: "follow {user}", macro: 'follow(user="{user}")' }, | |
| { req: "save the {topic} post", macro: 'p = find_post(query="{topic}")\nsave_post(post=p)' } | |
| ], | |
| vocab: { | |
| caption: ["sunset run \u{1F305}", "weekend vibes", "new kicks \u{1F45F}", "homemade pasta \u{1F35D}", "trail day"], | |
| media: ["latest", "the beach photo", "the reel", "the carousel"], | |
| text: ["love this!", "where is this?", "so good \u{1F525}", "congrats!", "need the recipe"], | |
| topic: ["the travel", "the food", "the fit check", "the puppy", "the launch"], | |
| user: ["@natgeo", "@nike", "@a_friend", "@the_chef"] | |
| }, | |
| oos: ["merge the pull request", "set a reminder", "navigate to work"] | |
| }, | |
| { | |
| key: "youtube", | |
| label: "YouTube", | |
| icon: "\u25B6", | |
| domain: "a YouTube operator", | |
| scope: "YouTube playback and library", | |
| desc: "Compiles requests into a macro over search, playback, playlists, and subscriptions; bounces anything else.", | |
| suggest: "Play a 10-minute beginner yoga video and add it to my Morning playlist.", | |
| ops: [ | |
| { name: "find_video", params: ["query"], ret: "video" }, | |
| { name: "play_video", params: ["video"] }, | |
| { name: "queue_video", params: ["video"] }, | |
| { name: "subscribe", params: ["channel"] }, | |
| { name: "like_video", params: ["video"] }, | |
| { name: "add_to_playlist", params: ["playlist", "video"] }, | |
| { name: "create_playlist", params: ["name"] }, | |
| { name: "comment", params: ["video", "text"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "play a beginner yoga video and add it to my Morning playlist", | |
| 'v = find_video(query="beginner yoga 10 minutes")\nplay_video(video=v)\nadd_to_playlist(playlist="Morning", video=v)' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "play a video about {query}", macro: 'v = find_video(query="{query}")\nplay_video(video=v)' }, | |
| { req: "queue a video about {query}", macro: 'v = find_video(query="{query}")\nqueue_video(video=v)' }, | |
| { req: "subscribe to {channel}", macro: 'subscribe(channel="{channel}")' }, | |
| { req: "like the {query} video", macro: 'v = find_video(query="{query}")\nlike_video(video=v)' }, | |
| { req: "add a {query} video to my {name} playlist", macro: 'v = find_video(query="{query}")\nadd_to_playlist(playlist="{name}", video=v)' }, | |
| { req: "make a playlist called {name}", macro: 'create_playlist(name="{name}")' }, | |
| { req: "comment {text} on the {query} video", macro: 'v = find_video(query="{query}")\ncomment(video=v, text="{text}")' } | |
| ], | |
| vocab: { | |
| query: ["lo-fi study mix", "rust tutorial", "marathon training", "pasta recipe", "guitar lesson", "space documentary"], | |
| channel: ["Veritasium", "Fireship", "MKBHD", "Kurzgesagt", "NileRed"], | |
| name: ["Morning", "Watch Later", "Cooking", "Workouts", "Learning"], | |
| text: ["great explanation!", "first", "this helped a lot", "please do a part 2"] | |
| }, | |
| oos: ["email the team", "open a PR", "set my Slack status"] | |
| }, | |
| { | |
| key: "maps", | |
| label: "Maps", | |
| icon: "\u{1F4CD}", | |
| domain: "a Maps operator", | |
| scope: "navigation and places", | |
| desc: "Compiles requests into a macro over places, directions, and navigation; bounces anything off-map.", | |
| suggest: "Find the nearest coffee shop and start navigation, then share my ETA with Alex.", | |
| ops: [ | |
| { name: "search_place", params: ["query"], ret: "place" }, | |
| { name: "find_nearby", params: ["category"], ret: "place" }, | |
| { name: "directions", params: ["to", "mode"] }, | |
| { name: "start_navigation", params: ["place"] }, | |
| { name: "save_place", params: ["place", "list"] }, | |
| { name: "share_eta", params: ["place", "contact"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "find the nearest coffee shop and start navigation then share my eta with Alex", | |
| 'p = find_nearby(category="coffee shop")\nstart_navigation(place=p)\nshare_eta(place=p, contact="Alex")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "navigate to {place}", macro: 'p = search_place(query="{place}")\nstart_navigation(place=p)' }, | |
| { req: "directions to {place} by {mode}", macro: 'directions(to="{place}", mode="{mode}")' }, | |
| { req: "find a {category} near me", macro: 'find_nearby(category="{category}")' }, | |
| { req: "find the nearest {category} and navigate there", macro: 'p = find_nearby(category="{category}")\nstart_navigation(place=p)' }, | |
| { req: "save {place} to my {list} list", macro: 'p = search_place(query="{place}")\nsave_place(place=p, list="{list}")' }, | |
| { req: "share my ETA to {place} with {contact}", macro: 'p = search_place(query="{place}")\nshare_eta(place=p, contact="{contact}")' } | |
| ], | |
| vocab: { | |
| place: ["the airport", "downtown", "the office", "Central Park", "the train station", "the stadium"], | |
| mode: ["driving", "walking", "transit", "cycling"], | |
| category: ["coffee shop", "gas station", "pharmacy", "grocery store", "ATM", "parking"], | |
| list: ["Favorites", "Want to go", "Trip", "Restaurants"], | |
| contact: ["Alex", "mom", "Dana", "the group"] | |
| }, | |
| oos: ["post a tweet", "play a song", "create a GitHub repo"] | |
| }, | |
| { | |
| key: "amazon", | |
| label: "Shopping", | |
| icon: "\u{1F6D2}", | |
| domain: "a shopping operator", | |
| scope: "shopping cart and orders", | |
| desc: "Compiles requests into a macro over product search, cart, orders, and lists; bounces anything that isn\u2019t shopping.", | |
| suggest: "Add two packs of AA batteries to my cart and track my last order.", | |
| ops: [ | |
| { name: "search_product", params: ["query"], ret: "product" }, | |
| { name: "add_to_cart", params: ["product", "qty"] }, | |
| { name: "buy_now", params: ["product"] }, | |
| { name: "find_order", params: ["query"], ret: "order" }, | |
| { name: "track_order", params: ["order"], ret: "status" }, | |
| { name: "reorder", params: ["query"] }, | |
| { name: "add_to_list", params: ["product", "list"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "add two packs of AA batteries to my cart and track my last order", | |
| 'p = search_product(query="AA batteries 2 pack")\nadd_to_cart(product=p, qty=2)\no = find_order(query="last order")\ntrack_order(order=o)' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "add {qty} {product} to my cart", macro: 'p = search_product(query="{product}")\nadd_to_cart(product=p, qty={qty})' }, | |
| { req: "buy {product} now", macro: 'p = search_product(query="{product}")\nbuy_now(product=p)' }, | |
| { req: "reorder {product}", macro: 'reorder(query="{product}")' }, | |
| { req: "track my {product} order", macro: 'o = find_order(query="{product}")\ntrack_order(order=o)' }, | |
| { req: "add {product} to my {list} list", macro: 'p = search_product(query="{product}")\nadd_to_list(product=p, list="{list}")' }, | |
| { req: "search for {product}", macro: 'search_product(query="{product}")' } | |
| ], | |
| vocab: { | |
| product: ["AA batteries", "USB-C cable", "olive oil", "running shoes", "paper towels", "a coffee grinder", "phone case"], | |
| qty: ["1", "2", "3", "4"], | |
| list: ["Wishlist", "Subscribe & Save", "Home", "Gifts"] | |
| }, | |
| oos: ["send an email", "play a video", "navigate to the office"] | |
| }, | |
| { | |
| key: "reddit", | |
| label: "Reddit", | |
| icon: "\u{1F47D}", | |
| domain: "a Reddit operator", | |
| scope: "Reddit posts and comments", | |
| desc: "Compiles requests into a macro over submissions, comments, votes, and subscriptions; bounces anything off-platform.", | |
| suggest: 'Post "What mechanical keyboard should I buy?" to r/keyboards and subscribe.', | |
| ops: [ | |
| { name: "find_post", params: ["query"], ret: "post" }, | |
| { name: "submit_post", params: ["subreddit", "title", "body"] }, | |
| { name: "comment", params: ["post", "text"] }, | |
| { name: "upvote", params: ["post"] }, | |
| { name: "reply_comment", params: ["comment", "text"] }, | |
| { name: "subscribe", params: ["subreddit"] }, | |
| { name: "save_post", params: ["post"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "post what mechanical keyboard should I buy to r/keyboards and subscribe", | |
| 'submit_post(subreddit="keyboards", title="What mechanical keyboard should I buy?", body="Budget is flexible \u2014 looking for recommendations.")\nsubscribe(subreddit="keyboards")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "post {title} to r/{subreddit}", macro: 'submit_post(subreddit="{subreddit}", title="{title}", body="{title}")' }, | |
| { req: "comment {text} on the {topic} post", macro: 'p = find_post(query="{topic}")\ncomment(post=p, text="{text}")' }, | |
| { req: "upvote the {topic} post", macro: 'p = find_post(query="{topic}")\nupvote(post=p)' }, | |
| { req: "subscribe to r/{subreddit}", macro: 'subscribe(subreddit="{subreddit}")' }, | |
| { req: "save the {topic} post", macro: 'p = find_post(query="{topic}")\nsave_post(post=p)' } | |
| ], | |
| vocab: { | |
| subreddit: ["keyboards", "programming", "AskReddit", "buildapc", "cooking", "fitness"], | |
| title: ["What keyboard should I buy?", "Best beginner setup?", "How do I start running?", "Favorite pasta recipe?"], | |
| text: ["this is the way", "underrated take", "source?", "thanks for sharing", "happy cake day"], | |
| topic: ["the keyboard", "the build", "the recipe", "the AMA", "the discussion"] | |
| }, | |
| oos: ["email my mom", "play a song", "navigate home"] | |
| }, | |
| { | |
| key: "linkedin", | |
| label: "LinkedIn", | |
| icon: "\u{1F4BC}", | |
| domain: "a LinkedIn operator", | |
| scope: "LinkedIn networking and posts", | |
| desc: "Compiles requests into a macro over posts, connections, messages, and endorsements; bounces anything off-platform.", | |
| suggest: "Connect with Priya with a note, then endorse her for product management.", | |
| ops: [ | |
| { name: "find_person", params: ["query"], ret: "person" }, | |
| { name: "post_update", params: ["text"] }, | |
| { name: "connect", params: ["user", "note"] }, | |
| { name: "message", params: ["user", "text"] }, | |
| { name: "endorse", params: ["person", "skill"] }, | |
| { name: "find_post", params: ["query"], ret: "post" }, | |
| { name: "comment", params: ["post", "text"] } | |
| ], | |
| fixed: [ | |
| [ | |
| "connect with Priya with a note then endorse her for product management", | |
| 'connect(user="Priya", note="Great working with you \u2014 let us stay in touch!")\np = find_person(query="Priya")\nendorse(person=p, skill="product management")' | |
| ] | |
| ], | |
| templates: [ | |
| { req: "post an update saying {text}", macro: 'post_update(text="{text}")' }, | |
| { req: "connect with {user} and add a note {note}", macro: 'connect(user="{user}", note="{note}")' }, | |
| { req: "message {user} {text}", macro: 'message(user="{user}", text="{text}")' }, | |
| { req: "endorse {user} for {skill}", macro: 'p = find_person(query="{user}")\nendorse(person=p, skill="{skill}")' }, | |
| { req: "comment {text} on the {topic} post", macro: 'p = find_post(query="{topic}")\ncomment(post=p, text="{text}")' } | |
| ], | |
| vocab: { | |
| text: ["excited to share I started a new role", "we are hiring engineers", "grateful for a great quarter", "thoughts on remote work"], | |
| user: ["Priya", "Alex", "a recruiter", "Dana", "my former manager"], | |
| note: ["Great working with you!", "Loved your talk", "Let us connect", "Fellow alum here"], | |
| skill: ["product management", "leadership", "TypeScript", "design", "data science"], | |
| topic: ["the hiring", "the milestone", "the article", "the announcement"] | |
| }, | |
| oos: ["play music", "open a github issue", "navigate to the airport"] | |
| } | |
| ]; | |
| var SKILLS = DEFS.map((d) => buildSkill(d, 6)); | |
| var POPULAR_2026 = [ | |
| { key: "inbox-calendar", name: "Inbox & Calendar", skill: "inbox-calendar", cat: "productivity", bg: "#2f72c4", glyph: "\u2709", fs: 22 }, | |
| { key: "music", name: "Music", skill: "music", cat: "media", bg: "#1db954", glyph: "\u266A", fs: 24 }, | |
| { key: "github", name: "GitHub", skill: "github", cat: "developer", bg: "#181717", glyph: "GH", fs: 15 }, | |
| { key: "youtube", name: "YouTube", skill: "youtube", cat: "media", bg: "#FF0000", glyph: "\u25B6", fs: 18 }, | |
| { key: "instagram", name: "Instagram", skill: "instagram", cat: "social", bg: "linear-gradient(135deg,#feda75,#d62976 48%,#4f5bd5)", glyph: "\u{1F4F7}", fs: 20 }, | |
| { key: "x", name: "X", skill: "x", cat: "social", bg: "#000000", glyph: "\u{1D54F}", fs: 23 }, | |
| { key: "slack", name: "Slack", skill: "slack", cat: "work", bg: "#4A154B", glyph: "S", fs: 24 }, | |
| { key: "notion", name: "Notion", skill: "notion", cat: "productivity", bg: "#0f0f0f", glyph: "N", fs: 24 }, | |
| { key: "maps", name: "Maps", skill: "maps", cat: "navigation", bg: "#34A853", glyph: "\u{1F4CD}", fs: 20 }, | |
| { key: "amazon", name: "Amazon", skill: "amazon", cat: "shopping", bg: "#FF9900", fg: "#232F3E", glyph: "a", fs: 27 }, | |
| { key: "reddit", name: "Reddit", skill: "reddit", cat: "social", bg: "#FF4500", glyph: "\u{1F47D}", fs: 20 }, | |
| { key: "linkedin", name: "LinkedIn", skill: "linkedin", cat: "work", bg: "#0A66C2", glyph: "in", fs: 17 }, | |
| // ── the broader armory (coming soon) ── | |
| { key: "google", name: "Google", cat: "productivity", bg: "#4285F4", glyph: "G", fs: 25 }, | |
| { key: "whatsapp", name: "WhatsApp", cat: "social", bg: "#25D366", glyph: "\u2706", fs: 22 }, | |
| { key: "tiktok", name: "TikTok", cat: "social", bg: "#010101", glyph: "\u266B", fs: 22 }, | |
| { key: "facebook", name: "Facebook", cat: "social", bg: "#1877F2", glyph: "f", fs: 27 }, | |
| { key: "snapchat", name: "Snapchat", cat: "social", bg: "#FFFC00", fg: "#111", glyph: "\u{1F47B}", fs: 22 }, | |
| { key: "messenger", name: "Messenger", cat: "social", bg: "#0084FF", glyph: "\u2726", fs: 22 }, | |
| { key: "discord", name: "Discord", cat: "social", bg: "#5865F2", glyph: "D", fs: 24 }, | |
| { key: "telegram", name: "Telegram", cat: "social", bg: "#229ED9", glyph: "\u2708", fs: 20 }, | |
| { key: "netflix", name: "Netflix", cat: "media", bg: "#E50914", glyph: "NF", fs: 15 }, | |
| { key: "twitch", name: "Twitch", cat: "media", bg: "#9146FF", glyph: "tw", fs: 16 }, | |
| { key: "spotify", name: "Spotify", cat: "media", bg: "#1DB954", glyph: "\u25C9", fs: 20 }, | |
| { key: "pinterest", name: "Pinterest", cat: "social", bg: "#E60023", glyph: "P", fs: 24 }, | |
| { key: "threads", name: "Threads", cat: "social", bg: "#000000", glyph: "@", fs: 24 }, | |
| { key: "uber", name: "Uber", cat: "travel", bg: "#000000", glyph: "U", fs: 24 }, | |
| { key: "doordash", name: "DoorDash", cat: "food", bg: "#FF3008", glyph: "DD", fs: 14 }, | |
| { key: "airbnb", name: "Airbnb", cat: "travel", bg: "#FF5A5F", glyph: "A", fs: 24 }, | |
| { key: "paypal", name: "PayPal", cat: "finance", bg: "#003087", glyph: "P", fs: 23 }, | |
| { key: "venmo", name: "Venmo", cat: "finance", bg: "#3D95CE", glyph: "V", fs: 24 }, | |
| { key: "chatgpt", name: "ChatGPT", cat: "ai", bg: "#10A37F", glyph: "\u2738", fs: 20 }, | |
| { key: "gemini", name: "Gemini", cat: "ai", bg: "#1C69FF", glyph: "\u2726", fs: 20 }, | |
| { key: "perplexity", name: "Perplexity", cat: "ai", bg: "#1FB8CD", glyph: "\u273A", fs: 20 }, | |
| { key: "cursor", name: "Cursor", cat: "developer", bg: "#0b0b0b", glyph: "\u25AE", fs: 18 } | |
| ]; | |
| // src/main.js | |
| var $ = /* @__PURE__ */ __name((id) => document.getElementById(id), "$"); | |
| var log = /* @__PURE__ */ __name((m) => { | |
| const s = $("railMsg"); | |
| if (s) s.textContent = m; | |
| console.log("[emberglass]", m); | |
| }, "log"); | |
| function steps(id) { | |
| const el = $(id), m = {}; | |
| el.querySelectorAll(".step").forEach((s) => m[s.dataset.s] = s); | |
| const all = /* @__PURE__ */ __name(() => Object.values(m), "all"); | |
| return { | |
| reset() { | |
| all().forEach((s) => s.classList.remove("active", "done", "loop")); | |
| }, | |
| active(k) { | |
| m[k]?.classList.add("active"); | |
| }, | |
| activeOnly(k) { | |
| all().forEach((s) => s.classList.remove("active")); | |
| m[k]?.classList.add("active"); | |
| }, | |
| done(k) { | |
| m[k]?.classList.remove("active", "loop"); | |
| m[k]?.classList.add("done"); | |
| }, | |
| loop(keys, on) { | |
| keys.forEach((k) => m[k]?.classList.toggle("loop", on)); | |
| } | |
| }; | |
| } | |
| __name(steps, "steps"); | |
| function startClock(id) { | |
| const el = $(id), t = el.querySelector(".t"), t0 = performance.now(); | |
| let run = true; | |
| el.classList.add("on"); | |
| (/* @__PURE__ */ __name((function f() { | |
| if (!run) return; | |
| t.textContent = ((performance.now() - t0) / 1e3).toFixed(1) + "s"; | |
| requestAnimationFrame(f); | |
| }), "f"))(); | |
| return () => { | |
| run = false; | |
| el.classList.remove("on"); | |
| }; | |
| } | |
| __name(startClock, "startClock"); | |
| var session = new ModelSession({ cfg: QWEN25_3B, log }); | |
| var adapters = new AdapterRegistry(); | |
| var state = { | |
| loaded: false, | |
| busy: false, | |
| err: null, | |
| tuned: null, | |
| // { name, kind:'guided'|'own', build(userText)->messages[], suggest } | |
| activeRunId: null, | |
| // history run currently applied | |
| dirHandle: null | |
| // File System Access workspace folder | |
| }; | |
| var GEN = { maxTokens: 2048, temperature: 0.6, topP: 0.95, topK: 64 }; | |
| var skillByKey = /* @__PURE__ */ __name((key) => SKILLS.find((s) => key && (key === s.key || String(key).startsWith(s.key + " "))), "skillByKey"); | |
| var selectedSkillKey = SKILLS[0].key; | |
| var trainLosses = []; | |
| function sampleExamples(all, n) { | |
| const oos = all.filter(([, a]) => a === "OUT_OF_SCOPE"); | |
| const inscope = all.filter(([, a]) => a !== "OUT_OF_SCOPE"); | |
| const keep = Math.max(0, n - oos.length); | |
| const stride = Math.max(1, Math.floor(inscope.length / Math.max(1, keep))); | |
| const picked = []; | |
| for (let i = 0; i < inscope.length && picked.length < keep; i += stride) picked.push(inscope[i]); | |
| return [...picked, ...oos]; | |
| } | |
| __name(sampleExamples, "sampleExamples"); | |
| function setBadge() { | |
| const rail = $("rail"), chip = $("railChip"); | |
| if (!rail || !chip) return; | |
| if (state.err) { | |
| rail.dataset.state = "err"; | |
| chip.textContent = "Load failed"; | |
| return; | |
| } | |
| if (state.busy === "load") { | |
| rail.dataset.state = "busy"; | |
| chip.textContent = "Loading\u2026"; | |
| return; | |
| } | |
| if (!state.loaded) { | |
| rail.dataset.state = "idle"; | |
| chip.textContent = "Model not loaded"; | |
| return; | |
| } | |
| const sel = $("adapterSel")?.value || "none"; | |
| if (sel === "none") { | |
| rail.dataset.state = "ok"; | |
| chip.textContent = "Live \xB7 base"; | |
| } else { | |
| rail.dataset.state = "tuned"; | |
| chip.textContent = "Live \xB7 tuned: " + sel; | |
| } | |
| } | |
| __name(setBadge, "setBadge"); | |
| function lockInference(on) { | |
| $("inferLock").style.display = on ? "flex" : "none"; | |
| $("run").disabled = on || !state.loaded || state.busy === "gen"; | |
| } | |
| __name(lockInference, "lockInference"); | |
| function gateButtons() { | |
| const ready = state.loaded && !state.busy; | |
| $("run").disabled = !ready; | |
| $("trainGuided").disabled = !ready; | |
| $("trainOwn").disabled = !ready || !ownExamples().length; | |
| for (const id of ["load", "loadHF"]) $(id).disabled = !!state.busy; | |
| const ask = $("askSection"); | |
| if (ask) ask.hidden = !state.loaded; | |
| } | |
| __name(gateButtons, "gateButtons"); | |
| async function loadWith(reader, label) { | |
| if (state.busy) return; | |
| state.busy = "load"; | |
| state.err = null; | |
| setBadge(); | |
| gateButtons(); | |
| try { | |
| await session.loadWith(reader, label); | |
| state.loaded = true; | |
| log("Model ready. Ask it anything below \u2014 or hit Train to teach it something new."); | |
| } catch (e) { | |
| state.err = e.message; | |
| log("Load error: " + e.message); | |
| console.error(e); | |
| } finally { | |
| state.busy = false; | |
| setBadge(); | |
| gateButtons(); | |
| } | |
| } | |
| __name(loadWith, "loadWith"); | |
| function buildMessages(userText) { | |
| const sel = $("adapterSel")?.value || "none"; | |
| if (sel !== "none" && state.tuned && state.tuned.name === sel) return state.tuned.build(userText); | |
| return [{ role: "user", content: userText }]; | |
| } | |
| __name(buildMessages, "buildMessages"); | |
| async function runInference() { | |
| if (!state.loaded || state.busy) return; | |
| const userText = $("prompt").value.trim(); | |
| if (!userText) { | |
| log("type something to ask first"); | |
| return; | |
| } | |
| state.busy = "gen"; | |
| gateButtons(); | |
| const sel = $("adapterSel")?.value || "none"; | |
| adapters.applyToRuntime(sel, session.rt); | |
| const out = $("out"); | |
| out.textContent = ""; | |
| const node = document.createTextNode(""); | |
| out.appendChild(node); | |
| const st = steps("inferSteps"); | |
| st.reset(); | |
| const cap = $("inferCap"); | |
| const stop = startClock("inferClock"); | |
| $("inferProc").classList.add("on"); | |
| setMacroCheck(null); | |
| st.active("tok"); | |
| cap.textContent = "Tokenizing your prompt with the VibeThinker tokenizer\u2026"; | |
| const t0 = performance.now(); | |
| let n = 0, first = true, acc = ""; | |
| try { | |
| const msgs = buildMessages(userText); | |
| st.done("tok"); | |
| st.active("prefill"); | |
| cap.textContent = "Reading the prompt into the KV cache (prefill)\u2026"; | |
| for await (const d of session.generate(msgs, { maxTokens: GEN.maxTokens, temperature: GEN.temperature, topP: GEN.topP, topK: GEN.topK })) { | |
| if (first) { | |
| first = false; | |
| st.done("prefill"); | |
| st.active("decode"); | |
| cap.textContent = "Generating the answer one token at a time\u2026"; | |
| } | |
| node.appendData(d); | |
| acc += d; | |
| n++; | |
| $("tokps").textContent = `${n} tok \xB7 ${(n / ((performance.now() - t0) / 1e3)).toFixed(1)} tok/s`; | |
| out.scrollTop = out.scrollHeight; | |
| } | |
| const dt = (performance.now() - t0) / 1e3; | |
| $("tokps").textContent = `${n} tok \xB7 ${(n / dt).toFixed(1)} tok/s \xB7 ${dt.toFixed(1)}s`; | |
| st.done("prefill"); | |
| st.done("decode"); | |
| st.done("done"); | |
| cap.textContent = `Done \u2014 ${sel === "none" ? "base model" : 'tuned adapter "' + sel + '"'}.`; | |
| const skill = sel !== "none" && state.tuned && state.tuned.name === sel ? skillByKey(state.tuned.base) : null; | |
| if (skill) { | |
| const res = verifyMacro(acc, skill.spec); | |
| setMacroCheck(res, skill, acc); | |
| if (res.status === "ok") stageMsg(`Action resolved \u2014 compiled a ${res.n}-step plan on ${skill.label}.`); | |
| else if (res.status === "oos") stageMsg(`That request is off the map for ${skill.label}. Try one of its actions.`); | |
| else stageMsg(`The plan didn't validate \u2014 adjust the request and try again.`); | |
| if (state.activeRunId) { | |
| bumpUses(state.activeRunId); | |
| renderDock(); | |
| } | |
| } | |
| log(`done (${sel === "none" ? "base model" : "tuned adapter"}).`); | |
| } catch (e) { | |
| out.appendData("\n\n[error] " + e.message); | |
| cap.textContent = "error: " + e.message; | |
| console.error(e); | |
| } finally { | |
| stop(); | |
| $("inferProc").classList.remove("on"); | |
| state.busy = false; | |
| gateButtons(); | |
| } | |
| } | |
| __name(runInference, "runInference"); | |
| async function runTraining({ examples, lr, epochs, accum, base, kind, system, build, suggest }) { | |
| if (!state.loaded) { | |
| log("load the model first (INFERENCE pane)."); | |
| switchTab("infer"); | |
| return; | |
| } | |
| if (state.busy) return; | |
| const name = uniqueName(base); | |
| const runId = newId(); | |
| state.busy = "train"; | |
| lockInference(true); | |
| gateButtons(); | |
| $("trainWidget").style.display = ""; | |
| resetTrainTelemetry(); | |
| const windows = Math.max(1, Math.ceil(examples.length / accum)); | |
| const total = windows * epochs; | |
| let lastLoss = null; | |
| const ctrl = new TrainingController({ | |
| session, | |
| adapters, | |
| log: /* @__PURE__ */ __name(() => { | |
| }, "log"), | |
| trainerOptions: { lr, maxTrainSeq: 384, lmHeadBlock: 128, maxGradNorm: 1, weightDecay: 0, warmupSteps: Math.min(4, total), totalSteps: total, gradAccumSteps: accum } | |
| }); | |
| const st = steps("trainSteps"); | |
| st.reset(); | |
| const cap = $("trainCap"); | |
| const stop = startClock("trainClock"); | |
| st.active("prep"); | |
| cap.textContent = "Building masked, shifted-label examples and tokenizing on the GPU\u2026"; | |
| renderMaskPreview(ctrl, examples[0]); | |
| ctrl.initAdapter(name, { rank: 16, alpha: 32 }); | |
| trainProgress(0, total, null, "warming up\u2026"); | |
| const t0 = performance.now(); | |
| try { | |
| st.done("prep"); | |
| st.loop(["fwd", "bwd", "opt"], true); | |
| cap.textContent = "Looping forward \u2192 backward \u2192 AdamW over your examples (full-network backprop)\u2026"; | |
| await ctrl.train(examples, { | |
| epochs, | |
| onStep: /* @__PURE__ */ __name((r) => { | |
| const { step, loss } = r; | |
| lastLoss = loss; | |
| updateTrainTelemetry(step, total, r); | |
| trainProgress(step, total, loss, `teaching \xB7 step ${step}/${total} \xB7 loss ${loss.toFixed(3)} \xB7 ${fmtNum(r.trainTokPerSec)} tok/s`); | |
| cap.textContent = `Step ${step}/${total} \u2014 forward ${fmtMs(r.microStepMs)} \u2192 backward \u2192 AdamW ${fmtMs(r.optimizerStepMs)} \xB7 loss ${loss.toFixed(3)}`; | |
| }, "onStep") | |
| }); | |
| const dt = ((performance.now() - t0) / 1e3).toFixed(1); | |
| st.loop(["fwd", "bwd", "opt"], false); | |
| st.done("fwd"); | |
| st.done("bwd"); | |
| st.done("opt"); | |
| st.active("swap"); | |
| state.tuned = { name, kind, base, build, suggest, ctrl }; | |
| state.activeRunId = runId; | |
| addAdapterOption(name); | |
| $("adapterSel").value = name; | |
| st.done("swap"); | |
| trainProgress(total, total, null, `done in ${dt}s \u2014 adapter "${name}" is live`); | |
| cap.textContent = `Adapter "${name}" hot-swapped into inference \u2014 live. Trained in ${dt}s.`; | |
| $("downloadAdapter").style.display = ""; | |
| showTryIt(suggest); | |
| try { | |
| const files = await exportLoraAdapter(ctrl.trainer, { name }); | |
| await saveRun( | |
| { | |
| id: runId, | |
| name, | |
| base, | |
| kind, | |
| system: system || null, | |
| suggest: suggest || "", | |
| createdAt: Date.now(), | |
| steps: total, | |
| epochs, | |
| durationSec: +dt, | |
| finalLoss: lastLoss, | |
| rank: 16, | |
| alpha: 32 | |
| }, | |
| { safetensors: files.safetensors, configJson: files.configJson } | |
| ); | |
| renderHistory(); | |
| } catch (e) { | |
| console.warn("[history] save failed", e); | |
| } | |
| log(`Trained "${name}" in ${dt}s. Saved to your fine-tunes; switch to Inference to try it.`); | |
| } catch (e) { | |
| st.loop(["fwd", "bwd", "opt"], false); | |
| trainProgress(0, total, null, "training error: " + e.message); | |
| cap.textContent = "error: " + e.message; | |
| console.error(e); | |
| } finally { | |
| stop(); | |
| state.busy = false; | |
| lockInference(false); | |
| gateButtons(); | |
| } | |
| } | |
| __name(runTraining, "runTraining"); | |
| var MAX_CHARS = 12e3; | |
| var MAX_CHUNKS = 24; | |
| var MIN_WORDS = 12; | |
| var HEAD_WORDS = 6; | |
| function chunkText(text) { | |
| text = (text || "").replace(/\r/g, "").slice(0, MAX_CHARS); | |
| const paras = text.split(/\n{2,}|\.(?=\s)/).map((s) => s.trim()).filter(Boolean); | |
| const out = []; | |
| for (const p of paras) { | |
| const words = p.split(/\s+/).filter(Boolean); | |
| if (words.length < MIN_WORDS) continue; | |
| const head = words.slice(0, HEAD_WORDS).join(" "); | |
| const rest = words.slice(HEAD_WORDS).join(" "); | |
| out.push({ head, rest, full: p }); | |
| if (out.length >= MAX_CHUNKS) break; | |
| } | |
| return out; | |
| } | |
| __name(chunkText, "chunkText"); | |
| var _ownChunks = []; | |
| function ownExamples() { | |
| return _ownChunks.map((c) => ({ messages: [{ role: "user", content: c.head }], completion: " " + c.rest })); | |
| } | |
| __name(ownExamples, "ownExamples"); | |
| function refreshOwn() { | |
| const text = $("ownText").value; | |
| _ownChunks = chunkText(text); | |
| const chars = Math.min(MAX_CHARS, (text || "").length); | |
| $("ownStats").textContent = _ownChunks.length ? `${_ownChunks.length} snippet(s) \xB7 ${chars} chars (cap ${MAX_CHARS}) \xB7 ready to teach` : `paste/drop at least one paragraph (~${MIN_WORDS}+ words). 100% local.`; | |
| gateButtons(); | |
| } | |
| __name(refreshOwn, "refreshOwn"); | |
| function switchTab(which) { | |
| const infer = which === "infer"; | |
| $("paneInfer").classList.toggle("active", infer); | |
| $("paneTrain").classList.toggle("active", !infer); | |
| $("tabInfer").classList.toggle("on", infer); | |
| $("tabTrain").classList.toggle("on", !infer); | |
| } | |
| __name(switchTab, "switchTab"); | |
| function addAdapterOption(name) { | |
| const sel = $("adapterSel"); | |
| if (![...sel.options].some((o) => o.value === name)) { | |
| const o = document.createElement("option"); | |
| o.value = name; | |
| o.textContent = name; | |
| sel.appendChild(o); | |
| } | |
| const wrap = $("adapterWrap"); | |
| if (wrap) wrap.hidden = false; | |
| } | |
| __name(addAdapterOption, "addAdapterOption"); | |
| function trainProgress(step, total, loss, label) { | |
| $("trainBar").style.width = (100 * step / Math.max(1, total)).toFixed(1) + "%"; | |
| $("trainLabel").textContent = label; | |
| } | |
| __name(trainProgress, "trainProgress"); | |
| function resetTrainTelemetry() { | |
| trainLosses = []; | |
| const box = $("trainMetrics"); | |
| if (box) box.hidden = false; | |
| for (const [id, v] of [["tmLoss", "\u2014"], ["tmTokps", "\u2014"], ["tmActive", "\u2014"], ["tmOpt", "\u2014"]]) { | |
| const el = $(id); | |
| if (el) el.textContent = v; | |
| } | |
| const line = $("lossLine"); | |
| if (line) line.setAttribute("points", ""); | |
| const preview = $("maskPreview"); | |
| if (preview) preview.hidden = true; | |
| } | |
| __name(resetTrainTelemetry, "resetTrainTelemetry"); | |
| function updateTrainTelemetry(step, total, r) { | |
| trainLosses.push(r.loss); | |
| $("tmLoss").textContent = r.loss.toFixed(4); | |
| $("tmTokps").textContent = `${fmtNum(r.trainTokPerSec)} tok/s`; | |
| $("tmActive").textContent = `${r.numActive || 0} / ${r.tokens || 0}`; | |
| $("tmOpt").textContent = fmtMs(r.optimizerStepMs); | |
| drawLossSpark(); | |
| } | |
| __name(updateTrainTelemetry, "updateTrainTelemetry"); | |
| function drawLossSpark() { | |
| const line = $("lossLine"); | |
| if (!line || trainLosses.length < 2) return; | |
| const min = Math.min(...trainLosses); | |
| const max = Math.max(...trainLosses); | |
| const span = Math.max(1e-6, max - min); | |
| const points = trainLosses.map((v, i) => { | |
| const x = i / Math.max(1, trainLosses.length - 1) * 300; | |
| const y = 36 - (v - min) / span * 32; | |
| return `${x.toFixed(1)},${y.toFixed(1)}`; | |
| }).join(" "); | |
| line.setAttribute("points", points); | |
| } | |
| __name(drawLossSpark, "drawLossSpark"); | |
| function renderMaskPreview(ctrl, example) { | |
| const box = $("maskPreview"); | |
| const rows = $("maskRows"); | |
| if (!box || !rows || !example) return; | |
| try { | |
| const preview = ctrl.inspectExample(example); | |
| $("maskSummary").textContent = `${preview.tokens.length} tokens \xB7 ${preview.trainPositions} trained next-token labels`; | |
| const shown = preview.rows.slice(0, 96); | |
| rows.innerHTML = '<div class="hdr">pos</div><div class="hdr">segment</div><div class="hdr">token</div><div class="hdr target">trained target</div>' + shown.map((r) => { | |
| const cls = `${r.trainsNext ? "train" : ""} ${r.segment}`; | |
| const target = r.trainsNext ? `${r.targetId} ${clip(r.targetText, 24)}` : ""; | |
| return `<div class="${cls}">${r.index}</div><div class="${cls}">${esc(r.segment)}</div><div class="${cls}">${r.id} ${esc(clip(r.text, 28))}</div><div class="${cls} target">${esc(target)}</div>`; | |
| }).join("") + (preview.rows.length > shown.length ? `<div class="prompt">\u2026</div><div class="prompt">truncated</div><div class="prompt">${preview.rows.length - shown.length} more rows</div><div class="prompt target"></div>` : ""); | |
| box.hidden = false; | |
| } catch (e) { | |
| rows.innerHTML = `<div class="prompt">preview</div><div class="prompt">error</div><div class="prompt">${esc(e.message)}</div><div class="prompt target"></div>`; | |
| box.hidden = false; | |
| } | |
| } | |
| __name(renderMaskPreview, "renderMaskPreview"); | |
| function showTryIt(suggest) { | |
| const t = $("tryIt"); | |
| t.style.display = "flex"; | |
| $("tryItBtn").onclick = () => { | |
| switchTab("infer"); | |
| $("adapterSel").value = state.tuned.name; | |
| setBadge(); | |
| $("prompt").value = suggest; | |
| runInference(); | |
| }; | |
| renderEquipPanel(); | |
| if (state.tuned?.name) stageMsg(`New skill learned: \u201C${state.tuned.name}\u201D \u2014 it dropped into your inventory. Equip it to act.`); | |
| } | |
| __name(showTryIt, "showTryIt"); | |
| function renderEquipPanel() { | |
| const bar = $("equipBar"); | |
| if (!bar) return; | |
| const skill = state.tuned ? skillByKey(state.tuned.base) : null; | |
| if (!skill || !skill.spec) { | |
| bar.hidden = true; | |
| return; | |
| } | |
| bar.hidden = false; | |
| const set = /* @__PURE__ */ __name((id, v) => { | |
| const e = $(id); | |
| if (e) e.textContent = v; | |
| }, "set"); | |
| set("equipIcon", skill.icon); | |
| set("equipName", `${skill.label} skill`); | |
| set("equipScope", `scope: ${skill.spec.scope}`); | |
| const ops = $("equipOps"); | |
| if (ops) { | |
| ops.innerHTML = ""; | |
| for (const op of skill.spec.ops) { | |
| const c = document.createElement("span"); | |
| c.className = "equip__op"; | |
| c.textContent = op.name; | |
| c.title = `${op.name}(${(op.params || []).join(", ")})`; | |
| ops.appendChild(c); | |
| } | |
| } | |
| const host = $("equipDrills"); | |
| if (host) { | |
| host.innerHTML = ""; | |
| const inscope = skill.examples.filter(([, a]) => a !== "OUT_OF_SCOPE"); | |
| const step = Math.max(1, Math.floor(inscope.length / 4)); | |
| const picks = []; | |
| for (let i = 0; i < inscope.length && picks.length < 4; i += step) picks.push(inscope[i][0]); | |
| for (const q of picks) { | |
| const b = document.createElement("button"); | |
| b.type = "button"; | |
| b.className = "drill"; | |
| b.textContent = q; | |
| b.title = "Fire this drill"; | |
| b.onclick = () => { | |
| $("prompt").value = q; | |
| runInference(); | |
| }; | |
| host.appendChild(b); | |
| } | |
| } | |
| } | |
| __name(renderEquipPanel, "renderEquipPanel"); | |
| function humanizePlan(text) { | |
| const out = []; | |
| for (const raw of String(text).split("\n")) { | |
| const line = raw.trim(); | |
| if (!line || line === "OUT_OF_SCOPE") continue; | |
| const m = line.match(/^(?:[A-Za-z_]\w*\s*=\s*)?([A-Za-z_]\w*)\s*\((.*)\)\s*;?\s*$/); | |
| if (!m) continue; | |
| const op = m[1].replace(/_/g, " "); | |
| const args = [...m[2].matchAll(/([A-Za-z_]\w*)\s*=\s*"([^"]*)"/g)].map((x) => x[2]).filter(Boolean); | |
| const summary = args.slice(0, 2).join(" \xB7 "); | |
| out.push(summary ? `${op} \u2014 ${summary}` : op); | |
| } | |
| return out; | |
| } | |
| __name(humanizePlan, "humanizePlan"); | |
| function uniqueName(base) { | |
| const taken = new Set(listRuns().map((r) => r.name)); | |
| if (!taken.has(base)) return base; | |
| let i = 2; | |
| while (taken.has(`${base} #${i}`)) i++; | |
| return `${base} #${i}`; | |
| } | |
| __name(uniqueName, "uniqueName"); | |
| function buildFromMeta(meta) { | |
| return meta.system ? (u) => [{ role: "system", content: meta.system }, { role: "user", content: u }] : (u) => [{ role: "user", content: u }]; | |
| } | |
| __name(buildFromMeta, "buildFromMeta"); | |
| function fmtRunMeta(m) { | |
| const parts = []; | |
| if (m.finalLoss != null) parts.push("loss " + Number(m.finalLoss).toFixed(3)); | |
| if (m.steps) parts.push(m.steps + " steps"); | |
| if (m.durationSec != null) parts.push(Math.round(m.durationSec) + "s"); | |
| try { | |
| parts.push(new Date(m.createdAt).toLocaleDateString(void 0, { month: "short", day: "numeric" })); | |
| } catch { | |
| } | |
| return parts.join(" \xB7 "); | |
| } | |
| __name(fmtRunMeta, "fmtRunMeta"); | |
| function renderHistory() { | |
| const runs = listRuns(); | |
| $("historyCount").textContent = String(runs.length); | |
| $("historyEmpty").style.display = runs.length ? "none" : ""; | |
| const ul = $("historyList"); | |
| ul.innerHTML = ""; | |
| for (const m of runs) { | |
| const { lv, xp } = skillLevel(m); | |
| const rar = rarityOf(lv); | |
| const active = m.id === state.activeRunId; | |
| const li = document.createElement("li"); | |
| li.className = "item" + (active ? " active" : ""); | |
| li.dataset.id = m.id; | |
| li.dataset.kind = m.kind || "own"; | |
| li.dataset.rarity = rar.key; | |
| li.title = `${m.name} \u2014 click to equip`; | |
| li.innerHTML = `<div class="item__frame"><span class="item__icon">${runIcon(m)}</span><span class="item__lv">L${lv}</span></div><div class="item__body"><div class="item__name">${esc(m.name)}</div><div class="item__rar">${rar.label} \xB7 ${esc(itemTypeLabel(m))}</div><div class="item__meta">${esc(fmtRunMeta(m))}</div><div class="item__xp"><i style="width:${xp}%"></i></div></div>` + (active ? `<div class="item__tag">EQUIPPED</div>` : "") + `<div class="item__acts"><button data-act="apply" class="tiny primary">${active ? "\u2713 Equipped" : "\u25B6 Equip"}</button><button data-act="export" class="tiny secondary" title="Export adapter">\u2B07</button><button data-act="del" class="tiny danger" title="Scrap">\u2715</button></div>`; | |
| li.querySelector("[data-act=apply]").onclick = (e) => { | |
| e.stopPropagation(); | |
| applyRun(m.id); | |
| }; | |
| li.querySelector("[data-act=export]").onclick = (e) => { | |
| e.stopPropagation(); | |
| exportRun(m.id); | |
| }; | |
| li.querySelector("[data-act=del]").onclick = (e) => { | |
| e.stopPropagation(); | |
| delRun(m.id); | |
| }; | |
| li.onclick = () => applyRun(m.id); | |
| ul.appendChild(li); | |
| } | |
| renderDock(); | |
| renderStage(); | |
| } | |
| __name(renderHistory, "renderHistory"); | |
| var SKILL_ICON = { guided: "\u2694", own: "\u{1F4DC}" }; | |
| var usesByRun = /* @__PURE__ */ new Map(); | |
| function bumpUses(id) { | |
| usesByRun.set(id, (usesByRun.get(id) || 0) + 1); | |
| } | |
| __name(bumpUses, "bumpUses"); | |
| function runIcon(m) { | |
| const sk = skillByKey(m.base); | |
| return sk ? sk.icon : SKILL_ICON[m.kind] || "\u{1F5E1}"; | |
| } | |
| __name(runIcon, "runIcon"); | |
| function skillLevel(m) { | |
| const lv = Math.max(1, Math.min(9, Math.round((m.steps || 12) / 12))); | |
| const loss = m.finalLoss == null ? 1.5 : Number(m.finalLoss); | |
| const xp = Math.max(6, Math.min(100, Math.round(100 * (3 - loss) / 3))); | |
| return { lv, xp }; | |
| } | |
| __name(skillLevel, "skillLevel"); | |
| function rarityOf(lv) { | |
| if (lv >= 9) return { key: "legendary", label: "Legendary" }; | |
| if (lv >= 7) return { key: "epic", label: "Epic" }; | |
| if (lv >= 5) return { key: "rare", label: "Rare" }; | |
| if (lv >= 3) return { key: "uncommon", label: "Uncommon" }; | |
| return { key: "common", label: "Common" }; | |
| } | |
| __name(rarityOf, "rarityOf"); | |
| function itemTypeLabel(m) { | |
| const sk = skillByKey(m.base); | |
| if (sk) return sk.label; | |
| return m.kind === "guided" ? "Skill" : "Custom note"; | |
| } | |
| __name(itemTypeLabel, "itemTypeLabel"); | |
| var BYOD_TILE = { bg: "#6b6256", fg: "#fff", glyph: "\u{1F4DC}", fs: 20 }; | |
| var SERVICES = POPULAR_2026; | |
| var dockRuns = []; | |
| function renderDock() { | |
| const tray = $("dockSlots"); | |
| if (!tray) return; | |
| const runs = listRuns(); | |
| tray.innerHTML = ""; | |
| dockRuns = []; | |
| const seen = /* @__PURE__ */ new Set(); | |
| const addTile = /* @__PURE__ */ __name((svc, opts) => { | |
| const el = document.createElement("div"); | |
| el.className = "dock__tile"; | |
| el.tabIndex = 0; | |
| el.setAttribute("role", "button"); | |
| el.dataset.state = opts.state; | |
| el.dataset.key = svc.key; | |
| if (opts.runid) el.dataset.runid = opts.runid; | |
| const g = document.createElement("span"); | |
| g.className = "dock__glyph"; | |
| g.style.background = svc.bg; | |
| g.style.color = svc.fg || "#fff"; | |
| g.style.fontSize = (svc.fs || 21) + "px"; | |
| g.textContent = svc.glyph; | |
| el.appendChild(g); | |
| if (opts.lv != null) { | |
| const b = document.createElement("span"); | |
| b.className = "dock__lv"; | |
| b.textContent = "L" + opts.lv; | |
| el.appendChild(b); | |
| } | |
| if (opts.keyN != null) { | |
| const k = document.createElement("span"); | |
| k.className = "dock__key"; | |
| k.textContent = opts.keyN; | |
| el.appendChild(k); | |
| } | |
| if (opts.forge) { | |
| const f = document.createElement("span"); | |
| f.className = "dock__forge"; | |
| f.textContent = "+"; | |
| el.appendChild(f); | |
| } | |
| if (opts.lock) { | |
| const l = document.createElement("span"); | |
| l.className = "dock__lock"; | |
| l.textContent = "\u{1F512}"; | |
| el.appendChild(l); | |
| } | |
| const t = document.createElement("span"); | |
| t.className = "dock__tip"; | |
| t.textContent = opts.tip; | |
| el.appendChild(t); | |
| el.setAttribute("aria-label", opts.tip); | |
| el.onclick = opts.onClick; | |
| el.onkeydown = (e) => { | |
| if (e.key === "Enter" || e.key === " ") { | |
| e.preventDefault(); | |
| opts.onClick(); | |
| } | |
| }; | |
| tray.appendChild(el); | |
| }, "addTile"); | |
| for (const svc of SERVICES) { | |
| if (svc.skill) { | |
| const run = runs.find((r) => skillByKey(r.base)?.key === svc.skill); | |
| if (run) { | |
| seen.add(run.id); | |
| const { lv } = skillLevel(run); | |
| const equipped = run.id === state.activeRunId; | |
| dockRuns.push(run.id); | |
| const keyN = dockRuns.length <= 9 ? dockRuns.length : null; | |
| const uses = usesByRun.get(run.id) || 0; | |
| addTile(svc, { | |
| state: equipped ? "equipped" : "owned", | |
| runid: run.id, | |
| lv, | |
| keyN, | |
| tip: `${svc.name} \xB7 Lv ${lv}${equipped ? " \xB7 equipped" : ""}${uses ? " \xB7 " + uses + "\xD7" : ""}${keyN ? " \xB7 [" + keyN + "]" : ""}`, | |
| onClick: /* @__PURE__ */ __name(() => applyRun(run.id), "onClick") | |
| }); | |
| } else { | |
| addTile(svc, { | |
| state: "forge", | |
| forge: true, | |
| tip: `${svc.name} \u2014 forge this skill`, | |
| onClick: /* @__PURE__ */ __name(() => { | |
| switchTab("train"); | |
| selectSkill(svc.skill); | |
| }, "onClick") | |
| }); | |
| } | |
| } else { | |
| addTile(svc, { | |
| state: "soon", | |
| lock: true, | |
| tip: `${svc.name} \u2014 coming soon`, | |
| onClick: /* @__PURE__ */ __name(() => { | |
| switchTab("train"); | |
| log(`\u201C${svc.name}\u201D skill \u2014 coming soon. The armory grows as we add action spaces.`); | |
| }, "onClick") | |
| }); | |
| } | |
| } | |
| const extra = runs.filter((r) => !seen.has(r.id)); | |
| if (extra.length) { | |
| const sep = document.createElement("div"); | |
| sep.className = "dock__sep"; | |
| tray.appendChild(sep); | |
| } | |
| for (const r of extra) { | |
| const { lv } = skillLevel(r); | |
| const equipped = r.id === state.activeRunId; | |
| dockRuns.push(r.id); | |
| const keyN = dockRuns.length <= 9 ? dockRuns.length : null; | |
| addTile({ key: "byod-" + r.id, name: r.name, ...BYOD_TILE }, { | |
| state: equipped ? "equipped" : "owned", | |
| runid: r.id, | |
| lv, | |
| keyN, | |
| tip: `${r.name} \xB7 Lv ${lv}${equipped ? " \xB7 equipped" : ""}${keyN ? " \xB7 [" + keyN + "]" : ""}`, | |
| onClick: /* @__PURE__ */ __name(() => applyRun(r.id), "onClick") | |
| }); | |
| } | |
| } | |
| __name(renderDock, "renderDock"); | |
| var lastEquipIntent = null; | |
| function equipByIndex(i) { | |
| if (i < 0 || i >= dockRuns.length) return; | |
| lastEquipIntent = dockRuns[i]; | |
| applyRun(dockRuns[i]); | |
| } | |
| __name(equipByIndex, "equipByIndex"); | |
| function setMacroCheck(res, skill, text) { | |
| const el = $("macroCheck"); | |
| if (!el) return; | |
| if (!res || res.status === "empty") { | |
| el.hidden = true; | |
| el.textContent = ""; | |
| el.removeAttribute("data-state"); | |
| return; | |
| } | |
| el.hidden = false; | |
| if (res.status === "ok") { | |
| el.dataset.state = "ok"; | |
| const ops = res.calls.map((c) => c.op).join(", "); | |
| const plan = text ? humanizePlan(text) : []; | |
| const planHtml = plan.length ? `<ol class="macrochk__plan">${plan.map((p) => `<li>${esc(p)}</li>`).join("")}</ol>` : ""; | |
| el.innerHTML = `<b>\u2713 valid macro</b> \xB7 ${res.n} call${res.n === 1 ? "" : "s"} on the ${esc(skill.label)} action space \xB7 <code>${esc(ops)}</code>${planHtml}`; | |
| } else if (res.status === "oos") { | |
| el.dataset.state = "oos"; | |
| el.innerHTML = `<b>\u26D4 OUT_OF_SCOPE</b> \xB7 the ${esc(skill.label)} skill correctly refused \u2014 that request is outside its actions`; | |
| } else { | |
| el.dataset.state = "bad"; | |
| el.innerHTML = `<b>\u2717 invalid macro</b> \xB7 ${esc(res.issues.slice(0, 2).join("; "))}`; | |
| } | |
| } | |
| __name(setMacroCheck, "setMacroCheck"); | |
| var RANKS = [[12, "Grandmaster"], [9, "Master"], [6, "Artisan"], [4, "Adept"], [2, "Journeyman"], [1, "Apprentice"], [0, "Initiate"]]; | |
| function firstColor(bg) { | |
| if (!bg) return null; | |
| const m = String(bg).match(/#[0-9a-f]{3,8}/i); | |
| return m ? m[0] : String(bg).startsWith("#") ? bg : null; | |
| } | |
| __name(firstColor, "firstColor"); | |
| function stageMsg(text) { | |
| const e = $("stageMsg"); | |
| if (e) e.textContent = "\xBB " + text; | |
| } | |
| __name(stageMsg, "stageMsg"); | |
| function renderStage() { | |
| const stage = $("stage"); | |
| if (!stage) return; | |
| const runs = listRuns(); | |
| const acquired = new Set(runs.map((r) => skillByKey(r.base)?.key).filter(Boolean)); | |
| let maxLv = 0, steps2 = 0; | |
| for (const r of runs) { | |
| maxLv = Math.max(maxLv, skillLevel(r).lv); | |
| steps2 += r.steps || 0; | |
| } | |
| const lvl = 1 + Math.floor(steps2 / 120); | |
| const xpPct = Math.round(steps2 % 120 / 120 * 100); | |
| const rank = (RANKS.find(([t]) => runs.length >= t) || [0, "Initiate"])[1]; | |
| const active = runs.find((r) => r.id === state.activeRunId); | |
| const skill = active ? skillByKey(active.base) : null; | |
| const d = skill ? dockOf(skill.key) : null; | |
| const set = /* @__PURE__ */ __name((id, v) => { | |
| const e = $(id); | |
| if (e) e.textContent = v; | |
| }, "set"); | |
| set("stageScore", `${acquired.size} / ${SKILLS.length}`); | |
| set("stageLv", String(lvl)); | |
| set("stageRank", rank); | |
| const xp = $("stageXp"); | |
| if (xp) xp.style.width = xpPct + "%"; | |
| const scene = $("stageScene"); | |
| const icon = $("stageSignIcon"); | |
| if (active) { | |
| set("stageSignName", active.name); | |
| if (icon) { | |
| icon.textContent = d?.glyph || skill?.icon || "\u25C6"; | |
| icon.style.background = d?.bg || "#6b6256"; | |
| icon.style.color = d?.fg || "#fff"; | |
| icon.style.fontSize = Math.round((d?.fs || 18) * 0.8) + "px"; | |
| } | |
| if (scene) scene.style.setProperty("--scene", firstColor(d?.bg) || "#1d6f6a"); | |
| stage.dataset.where = "in"; | |
| } else { | |
| set("stageSignName", "The open web"); | |
| if (icon) { | |
| icon.textContent = "\u{1F310}"; | |
| icon.style.background = "#13393f"; | |
| icon.style.color = "#cdeeea"; | |
| icon.style.fontSize = "17px"; | |
| } | |
| if (scene) scene.style.setProperty("--scene", "#1d6f6a"); | |
| stage.dataset.where = "out"; | |
| } | |
| } | |
| __name(renderStage, "renderStage"); | |
| var dockOf = /* @__PURE__ */ __name((key) => POPULAR_2026.find((s) => s.key === key) || {}, "dockOf"); | |
| function renderSkillPicker() { | |
| const host = $("skillPicker"); | |
| if (!host) return; | |
| const runs = listRuns(); | |
| host.innerHTML = ""; | |
| for (const sk of SKILLS) { | |
| const d = dockOf(sk.key); | |
| const run = runs.find((r) => skillByKey(r.base)?.key === sk.key); | |
| const lv = run ? skillLevel(run).lv : 0; | |
| const b = document.createElement("button"); | |
| b.type = "button"; | |
| b.className = "skillpick__btn" + (sk.key === selectedSkillKey ? " on" : "") + (lv ? " forged" : ""); | |
| b.dataset.key = sk.key; | |
| b.innerHTML = `<span class="skillpick__icon" style="background:${d.bg || "#6b6256"};color:${d.fg || "#fff"};font-size:${Math.round((d.fs || 18) * 0.78)}px">${d.glyph || sk.icon}</span><span class="skillpick__txt"><b>${esc(sk.label)}</b><i>${sk.spec.ops.length} actions \xB7 ${sk.examples.length} examples</i></span>` + (lv ? `<span class="skillpick__lv">L${lv}</span>` : ""); | |
| b.onclick = () => selectSkill(sk.key); | |
| host.appendChild(b); | |
| } | |
| } | |
| __name(renderSkillPicker, "renderSkillPicker"); | |
| function selectSkill(key) { | |
| const sk = skillByKey(key) || SKILLS[0]; | |
| selectedSkillKey = sk.key; | |
| document.querySelectorAll("#skillPicker .skillpick__btn").forEach((b) => b.classList.toggle("on", b.dataset.key === sk.key)); | |
| const title = $("skillTitle"); | |
| if (title) title.innerHTML = `${sk.icon} ${esc(sk.label)} skill`; | |
| const desc = $("skillDesc"); | |
| if (desc) desc.textContent = sk.desc; | |
| const list = $("guidedList"); | |
| if (list) { | |
| const inscope = sk.examples.filter(([, a]) => a !== "OUT_OF_SCOPE"); | |
| const oos = sk.examples.filter(([, a]) => a === "OUT_OF_SCOPE"); | |
| const sample = [...inscope.slice(0, 5), ...oos.slice(0, 1)]; | |
| const more = sk.examples.length - sample.length; | |
| list.innerHTML = sample.map(([q, a]) => `<li><span class="skill-req">${esc(q)}</span><pre class="skill-macro">${esc(a)}</pre></li>`).join("") + (more > 0 ? `<li class="skill-more">+ ${more} more spec-valid pairs forge with this skill</li>` : ""); | |
| } | |
| } | |
| __name(selectSkill, "selectSkill"); | |
| async function applyRun(id) { | |
| const meta = getRun(id); | |
| if (!meta) return; | |
| if (!state.loaded) { | |
| log("Load VibeThinker-3B first (Step 1), then tap a fine-tune to use it."); | |
| switchTab("infer"); | |
| return; | |
| } | |
| if (state.busy) return; | |
| state.busy = "apply"; | |
| gateButtons(); | |
| try { | |
| log(`Applying "${meta.name}"\u2026`); | |
| let adapter = adapters.get(meta.name); | |
| if (!adapter) { | |
| const files = await loadRunFiles(id); | |
| adapter = await loadLoraAdapterGPU(session.rt.dev, files, QWEN25_3B); | |
| adapter.name = meta.name; | |
| adapters.adapters[meta.name] = adapter; | |
| } | |
| addAdapterOption(meta.name); | |
| state.tuned = { name: meta.name, kind: meta.kind, base: meta.base, build: buildFromMeta(meta), suggest: meta.suggest }; | |
| state.activeRunId = id; | |
| $("adapterSel").value = meta.name; | |
| setMacroCheck(null); | |
| setBadge(); | |
| renderHistory(); | |
| renderEquipPanel(); | |
| switchTab("infer"); | |
| if (meta.suggest) $("prompt").value = meta.suggest; | |
| stageMsg(`You step into \u201C${meta.name}\u201D. Pick an action below and act.`); | |
| log(`Now serving fine-tune "${meta.name}". Ask away.`); | |
| } catch (e) { | |
| log("Could not apply: " + e.message); | |
| console.error(e); | |
| } finally { | |
| state.busy = false; | |
| gateButtons(); | |
| } | |
| } | |
| __name(applyRun, "applyRun"); | |
| async function exportRun(id) { | |
| const meta = getRun(id); | |
| if (!meta) return; | |
| try { | |
| const { safetensors, configJson } = await getRunBlobs(id); | |
| const stem = (meta.name || "adapter").replace(/[^\w.-]+/g, "_"); | |
| if (state.dirHandle && await ensurePermission(state.dirHandle)) { | |
| await writeFileToDir(state.dirHandle, stem + ".safetensors", safetensors); | |
| await writeFileToDir(state.dirHandle, stem + ".adapter_config.json", configJson); | |
| log(`Saved "${meta.name}" to your connected folder.`); | |
| } else { | |
| triggerBlob(safetensors, stem + ".safetensors"); | |
| triggerBlob(new Blob([configJson], { type: "application/json" }), stem + ".adapter_config.json"); | |
| log(`Exported "${meta.name}".`); | |
| } | |
| } catch (e) { | |
| log("Export failed: " + e.message); | |
| } | |
| } | |
| __name(exportRun, "exportRun"); | |
| async function delRun(id) { | |
| await deleteRun(id); | |
| if (state.activeRunId === id) state.activeRunId = null; | |
| renderHistory(); | |
| } | |
| __name(delRun, "delRun"); | |
| function triggerBlob(data, filename) { | |
| const blob = data instanceof Blob ? data : new Blob([data]); | |
| const url = URL.createObjectURL(blob); | |
| const a = document.createElement("a"); | |
| a.href = url; | |
| a.download = filename; | |
| document.body.appendChild(a); | |
| a.click(); | |
| a.remove(); | |
| setTimeout(() => URL.revokeObjectURL(url), 1e3); | |
| } | |
| __name(triggerBlob, "triggerBlob"); | |
| function fmtMs(ms) { | |
| return Number.isFinite(ms) ? `${ms.toFixed(ms >= 100 ? 0 : 1)}ms` : "\u2014"; | |
| } | |
| __name(fmtMs, "fmtMs"); | |
| function fmtNum(n) { | |
| return Number.isFinite(n) ? n >= 100 ? n.toFixed(0) : n.toFixed(1) : "\u2014"; | |
| } | |
| __name(fmtNum, "fmtNum"); | |
| function clip(s, n) { | |
| s = String(s ?? "").replace(/\s+/g, " "); | |
| return s.length > n ? s.slice(0, Math.max(0, n - 1)) + "\u2026" : s; | |
| } | |
| __name(clip, "clip"); | |
| function applyLayout() { | |
| const mq = /* @__PURE__ */ __name((q) => { | |
| try { | |
| return window.matchMedia(q).matches; | |
| } catch { | |
| return false; | |
| } | |
| }, "mq"); | |
| const fold = mq("(horizontal-viewport-segments: 2)") || mq("(spanning: single-fold-vertical)"); | |
| const mobile = mq("(max-width: 700px)"); | |
| document.body.dataset.layout = fold ? "foldable" : mobile ? "mobile" : "desktop"; | |
| } | |
| __name(applyLayout, "applyLayout"); | |
| async function initFs() { | |
| if (!fsSupported) { | |
| $("fsBlock").hidden = true; | |
| return; | |
| } | |
| $("fsBlock").hidden = false; | |
| const setDir = /* @__PURE__ */ __name((h) => { | |
| state.dirHandle = h; | |
| $("fsForget").hidden = false; | |
| $("ownImportDir").hidden = false; | |
| $("fsStatus").textContent = `connected: ${h.name || "folder"} \u2014 adapters can save here; import text below.`; | |
| }, "setDir"); | |
| try { | |
| const saved = await savedDirectory(); | |
| if (saved) setDir(saved); | |
| } catch { | |
| } | |
| $("fsConnect").onclick = async () => { | |
| try { | |
| setDir(await connectDirectory()); | |
| } catch (e) { | |
| if (e.name !== "AbortError") log("folder: " + e.message); | |
| } | |
| }; | |
| $("fsForget").onclick = async () => { | |
| await forgetDirectory(); | |
| state.dirHandle = null; | |
| $("fsForget").hidden = true; | |
| $("ownImportDir").hidden = true; | |
| $("fsStatus").textContent = "not connected \u2014 import training text & save adapters straight to a folder you pick."; | |
| }; | |
| $("ownImportDir").onclick = async () => { | |
| if (!state.dirHandle) return; | |
| if (!await ensurePermission(state.dirHandle, "read")) { | |
| log("permission denied for folder"); | |
| return; | |
| } | |
| try { | |
| const { text, names } = await readDirText(state.dirHandle); | |
| if (!text.trim()) { | |
| $("ownStats").textContent = "no .txt/.md/.json/.csv files found in that folder"; | |
| return; | |
| } | |
| $("ownText").value = (text + "\n" + $("ownText").value).slice(0, MAX_CHARS); | |
| refreshOwn(); | |
| $("ownStats").textContent = `imported ${names.length} file(s) \xB7 ` + $("ownStats").textContent; | |
| } catch (e) { | |
| log("import failed: " + e.message); | |
| } | |
| }; | |
| } | |
| __name(initFs, "initFs"); | |
| window.addEventListener("DOMContentLoaded", () => { | |
| renderSkillPicker(); | |
| selectSkill(selectedSkillKey); | |
| $("tabInfer").onclick = () => switchTab("infer"); | |
| $("tabTrain").onclick = () => switchTab("train"); | |
| $("gear").onclick = () => { | |
| const open = $("settings").hidden; | |
| $("settings").hidden = !open; | |
| $("gear").classList.toggle("on", open); | |
| }; | |
| $("adapterSel").onchange = setBadge; | |
| $("load").onclick = () => loadWith(urlReader($("modelUrl").value.trim()), $("modelUrl").value.trim()); | |
| $("loadHF").onclick = () => { | |
| const repo = $("hfRepo").value.trim(); | |
| const token = ($("hfToken")?.value || "").trim(); | |
| if (!repo) return log("enter a Hugging Face repo id, e.g. WeiboAI/VibeThinker-3B"); | |
| loadWith(hfReader(repo, token), "HF: " + repo); | |
| }; | |
| $("modelFiles").onchange = (ev) => { | |
| const files = [...ev.target.files]; | |
| if (!files.length) return; | |
| const map = {}; | |
| for (const f of files) map[f.name] = f; | |
| loadWith(fileReader(map), `${files.length} local files`); | |
| }; | |
| $("run").onclick = runInference; | |
| $("prompt").addEventListener("keydown", (e) => { | |
| if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) runInference(); | |
| }); | |
| document.addEventListener("keydown", (e) => { | |
| if (e.metaKey || e.ctrlKey || e.altKey) return; | |
| const tag = e.target && e.target.tagName || ""; | |
| if (tag === "INPUT" || tag === "TEXTAREA" || e.target && e.target.isContentEditable) return; | |
| if (e.key >= "1" && e.key <= "9") equipByIndex(+e.key - 1); | |
| }); | |
| $("trainGuided").onclick = () => { | |
| const sk = skillByKey(selectedSkillKey) || SKILLS[0]; | |
| const pool = sampleExamples(sk.examples, 32); | |
| const ex = pool.map(([q, a]) => ({ messages: [{ role: "system", content: sk.system }, { role: "user", content: q }], completion: " " + a })); | |
| const windows = Math.ceil(ex.length / 2); | |
| runTraining({ | |
| examples: ex, | |
| lr: 3e-4, | |
| epochs: Math.max(6, Math.min(14, Math.round(280 / windows))), | |
| accum: 2, | |
| base: sk.key, | |
| kind: "guided", | |
| system: sk.system, | |
| build: /* @__PURE__ */ __name((u) => [{ role: "system", content: sk.system }, { role: "user", content: u }], "build"), | |
| suggest: sk.suggest | |
| }); | |
| }; | |
| $("ownText").addEventListener("input", refreshOwn); | |
| $("ownFiles").onchange = async (ev) => { | |
| const files = [...ev.target.files].slice(0, 5); | |
| let txt = ""; | |
| for (const f of files) { | |
| try { | |
| txt += await f.text() + "\n\n"; | |
| } catch { | |
| } | |
| } | |
| $("ownText").value = (txt + "\n" + $("ownText").value).slice(0, MAX_CHARS); | |
| refreshOwn(); | |
| }; | |
| $("ownFetch").onclick = async () => { | |
| const url = $("ownUrl").value.trim(); | |
| if (!url) return; | |
| $("ownStats").textContent = "fetching readable text via reader proxy\u2026"; | |
| try { | |
| const r = await fetch("https://r.jina.ai/" + url); | |
| if (!r.ok) throw new Error("HTTP " + r.status); | |
| const t = await r.text(); | |
| $("ownText").value = t.slice(0, MAX_CHARS); | |
| refreshOwn(); | |
| } catch (e) { | |
| $("ownStats").textContent = "could not fetch (CORS/blocked) \u2014 paste the text instead. " + e.message; | |
| } | |
| }; | |
| $("trainOwn").onclick = () => { | |
| const ex = ownExamples(); | |
| if (!ex.length) return; | |
| const windows = Math.ceil(ex.length / 2); | |
| runTraining({ | |
| examples: ex, | |
| lr: 3e-4, | |
| accum: 2, | |
| epochs: Math.max(3, Math.min(8, Math.round(50 / windows))), | |
| base: "my-notes", | |
| kind: "own", | |
| system: null, | |
| build: /* @__PURE__ */ __name((u) => [{ role: "user", content: u }], "build"), | |
| suggest: _ownChunks[0]?.head || "" | |
| }); | |
| }; | |
| $("downloadAdapter").onclick = () => { | |
| if (state.tuned?.ctrl?.trainer) downloadLoraAdapter(state.tuned.ctrl.trainer, { name: state.tuned.name }); | |
| }; | |
| applyLayout(); | |
| for (const q of ["(max-width: 700px)", "(horizontal-viewport-segments: 2)", "(spanning: single-fold-vertical)"]) { | |
| try { | |
| window.matchMedia(q).addEventListener("change", applyLayout); | |
| } catch { | |
| } | |
| } | |
| window.__layout = (m) => { | |
| document.body.dataset.layout = m; | |
| }; | |
| window.__eg = { | |
| store: store_exports, | |
| renderHistory, | |
| renderDock, | |
| renderStage, | |
| stageMsg, | |
| renderEquipPanel, | |
| humanizePlan, | |
| applyRun, | |
| exportRun, | |
| delRun, | |
| state, | |
| // devtools/test surface | |
| SKILLS, | |
| POPULAR_2026, | |
| selectSkill, | |
| renderSkillPicker, | |
| verifyMacro, | |
| setMacroCheck, | |
| equipByIndex, | |
| skillByKey, | |
| sampleExamples, | |
| get selectedSkillKey() { | |
| return selectedSkillKey; | |
| }, | |
| get lastEquipIntent() { | |
| return lastEquipIntent; | |
| } | |
| }; | |
| initFs(); | |
| renderHistory(); | |
| switchTab("infer"); | |
| setBadge(); | |
| refreshOwn(); | |
| gateButtons(); | |
| }); | |
| function esc(s) { | |
| return String(s).replace(/[&<>]/g, (c) => ({ "&": "&", "<": "<", ">": ">" })[c]); | |
| } | |
| __name(esc, "esc"); | |