// ChatIPC := Chat Incremental Pattern Constructor #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef _OPENMP #include #else inline int omp_get_max_threads(){ return 1; } inline int omp_get_thread_num(){ return 0; } #endif extern unsigned char dictionary_json[]; // provide dictionary.cpp to embed dictionary JSON bytes extern unsigned int dictionary_json_len; // --------------------------- Short utility functions ---------------------- static inline bool is_space(char c){ return std::isspace(static_cast(c)) != 0; } static inline char to_low(char c){ return static_cast(std::tolower(static_cast(c))); } static inline void safe_flush(std::ostream &os){ os.flush(); } // NEW: dictionary model, normalization, and English-rule helpers. struct DictionaryEntry { std::string pos; std::string word; std::vector definitions; }; static std::vector global_dictionary_entries; static std::unordered_map> global_def_tokens_cache; static std::unordered_map> global_pos_cache; static inline bool is_word_char_for_key(char c){ unsigned char uc = static_cast(c); return std::isalnum(uc) != 0 || c == '\'' || c == '-'; } static std::string normalize_dictionary_key(const std::string &s){ size_t b = 0, e = s.size(); while (b < e && !is_word_char_for_key(s[b])) ++b; while (e > b && !is_word_char_for_key(s[e - 1])) --e; std::string out; out.reserve(e - b); for (size_t i = b; i < e; ++i) out.push_back(to_low(s[i])); return out; } static std::string normalize_pos_tag(const std::string &s){ std::string out; out.reserve(s.size()); for (char c : s){ unsigned char uc = static_cast(c); if (std::isalpha(uc) != 0) out.push_back(to_low(c)); } return out; } enum class PosClass { Unknown, Noun, Verb, Adj, Adv, Pron, Prep, Conj, Det, Num, Interj }; static PosClass pos_class_from_tag(const std::string &tag){ if (tag == "n" || tag == "noun") return PosClass::Noun; if (tag == "v" || tag == "verb" || tag == "part" || tag == "participle" || tag == "p") return PosClass::Verb; if (tag == "a" || tag == "adj" || tag == "adjective") return PosClass::Adj; if (tag == "adv" || tag == "adverb") return PosClass::Adv; if (tag == "pron" || tag == "pronoun") return PosClass::Pron; if (tag == "prep" || tag == "preposition") return PosClass::Prep; if (tag == "conj" || tag == "conjunction") return PosClass::Conj; if (tag == "art" || tag == "article" || tag == "det" || tag == "determiner") return PosClass::Det; if (tag == "num" || tag == "number") return PosClass::Num; if (tag == "interj" || tag == "interjection") return PosClass::Interj; return PosClass::Unknown; } static bool has_pos_class(const std::vector &tags, PosClass cls){ for (const auto &t : tags){ if (pos_class_from_tag(t) == cls) return true; } return false; } static const std::vector &dictionary_pos_for_token(const std::string &surface){ static const std::vector empty; auto key = normalize_dictionary_key(surface); if (key.empty()) return empty; auto it = global_pos_cache.find(key); return (it == global_pos_cache.end()) ? empty : it->second; } static bool first_alpha_is_upper(const std::string &s){ for (char c : s){ unsigned char uc = static_cast(c); if (std::isalpha(uc) != 0) return std::isupper(uc) != 0; } return false; } static bool first_alpha_is_lower(const std::string &s){ for (char c : s){ unsigned char uc = static_cast(c); if (std::isalpha(uc) != 0) return std::islower(uc) != 0; } return false; } static bool is_sentence_boundary_token(const std::string &s){ if (s.empty()) return false; char c = s.back(); return c == '.' || c == '!' || c == '?'; } static bool is_open_punct_token(const std::string &s){ return s == "(" || s == "[" || s == "{" || s == "\"" || s == "'"; } static bool is_punctuation_only_token(const std::string &s){ if (s.empty()) return false; for (char c : s){ unsigned char uc = static_cast(c); if (std::isalnum(uc) != 0) return false; } return true; } static bool is_common_determiner(const std::string &s){ return s == "a" || s == "an" || s == "the" || s == "this" || s == "that" || s == "these" || s == "those" || s == "my" || s == "your" || s == "his" || s == "her" || s == "its" || s == "our" || s == "their"; } static bool is_common_preposition(const std::string &s){ return s == "of" || s == "in" || s == "on" || s == "at" || s == "by" || s == "for" || s == "from" || s == "with" || s == "into" || s == "onto" || s == "about" || s == "over" || s == "under" || s == "after" || s == "before" || s == "between" || s == "through" || s == "during" || s == "without" || s == "within" || s == "under" || s == "across" || s == "against" || s == "among" || s == "around"; } static bool is_common_aux_or_modal(const std::string &s){ return s == "to" || s == "be" || s == "am" || s == "is" || s == "are" || s == "was" || s == "were" || s == "been" || s == "being" || s == "have" || s == "has" || s == "had" || s == "do" || s == "does" || s == "did" || s == "can" || s == "could" || s == "may" || s == "might" || s == "must" || s == "shall" || s == "should" || s == "will" || s == "would"; } static bool begins_with_vowel_sound(const std::string &s){ if (s.empty()) return false; if (s.rfind("hour", 0) == 0 || s.rfind("honest", 0) == 0 || s.rfind("honor", 0) == 0 || s.rfind("heir", 0) == 0 || s.rfind("herb", 0) == 0) { return true; } if (s.rfind("uni", 0) == 0 || s.rfind("use", 0) == 0 || s.rfind("user", 0) == 0 || s.rfind("one", 0) == 0 || s.rfind("once", 0) == 0 || s.rfind("euro", 0) == 0) { return false; } char c = s[0]; return c == 'a' || c == 'e' || c == 'i' || c == 'o' || c == 'u'; } static double english_rule_bonus(const std::string &context_tok, const std::string &cand){ const std::string ctx_key = normalize_dictionary_key(context_tok); const std::string cand_key = normalize_dictionary_key(cand); const auto &ctx_tags = dictionary_pos_for_token(context_tok); const auto &cand_tags = dictionary_pos_for_token(cand); const bool sentence_start = context_tok.empty() || is_sentence_boundary_token(context_tok) || is_open_punct_token(context_tok); const bool cand_nounish = has_pos_class(cand_tags, PosClass::Noun) || has_pos_class(cand_tags, PosClass::Adj) || has_pos_class(cand_tags, PosClass::Pron) || has_pos_class(cand_tags, PosClass::Num); const bool cand_verbish = has_pos_class(cand_tags, PosClass::Verb); const bool cand_advish = has_pos_class(cand_tags, PosClass::Adv); const bool cand_prepish = has_pos_class(cand_tags, PosClass::Prep); const bool cand_detish = has_pos_class(cand_tags, PosClass::Det); double bonus = 0.0; if (!cand_key.empty()){ if (sentence_start){ bonus += first_alpha_is_upper(cand) ? 0.22 : -0.08; } else if (first_alpha_is_upper(cand)){ bonus -= 0.03; } } if (ctx_key == "a" || ctx_key == "an"){ const bool vowel = begins_with_vowel_sound(cand_key.empty() ? cand : cand_key); bonus += ((ctx_key == "an") == vowel) ? 0.28 : -0.18; } const bool ctx_det = has_pos_class(ctx_tags, PosClass::Det) || is_common_determiner(ctx_key); const bool ctx_prep = has_pos_class(ctx_tags, PosClass::Prep) || is_common_preposition(ctx_key); const bool ctx_aux = is_common_aux_or_modal(ctx_key); if (ctx_det){ if (cand_nounish) bonus += 0.20; if (cand_verbish || cand_advish || cand_prepish) bonus -= 0.08; } if (ctx_prep){ if (cand_nounish) bonus += 0.16; if (cand_verbish) bonus -= 0.06; } if (ctx_aux){ if (cand_verbish) bonus += 0.18; if (cand_detish) bonus -= 0.04; } if (has_pos_class(ctx_tags, PosClass::Pron) || has_pos_class(ctx_tags, PosClass::Noun)){ if (cand_verbish) bonus += 0.05; } if (!context_tok.empty() && (context_tok.back() == ',' || context_tok.back() == ';' || context_tok.back() == ':')){ if (!cand.empty() && first_alpha_is_lower(cand)) bonus += 0.04; } if (is_punctuation_only_token(cand)){ if (sentence_start) bonus -= 0.05; else if (!context_tok.empty() && std::isalnum(static_cast(context_tok.back())) != 0) bonus += 0.03; } if (is_sentence_boundary_token(cand)) bonus += 0.06; return bonus; } // Tokenize by whitespace static std::vector tokenize_whitespace(const std::string &s){ std::istringstream iss(s); std::vector out; std::string t; while (iss >> t) out.push_back(t); return out; } // Tokenize by non-alphanumeric characters (for definitions) static std::vector tokenize_non_alnum(const std::string &s){ std::vector out; std::string cur; for (char ch : s){ if (std::isalnum(static_cast(ch)) || ch=='-' || ch=='\''){ cur.push_back(to_low(ch)); } else { if (!cur.empty()){ out.push_back(cur); cur.clear(); } } } if (!cur.empty()) out.push_back(cur); return out; } // --------------------------- String interning (short methods) -------------- struct StringInterner { std::unordered_set pool; std::mutex m; const std::string* intern(const std::string &s){ std::lock_guard lk(m); auto [it, inserted] = pool.emplace(s); return &*it; } }; // ---------- Global parsed dictionary (populated once in main) ---------- static void build_def_tokens_cache(){ global_def_tokens_cache.clear(); global_pos_cache.clear(); global_def_tokens_cache.reserve(global_dictionary_entries.size()); global_pos_cache.reserve(global_dictionary_entries.size()); for (const auto &entry : global_dictionary_entries){ const std::string key = normalize_dictionary_key(entry.word); if (key.empty()) continue; std::string pos = normalize_pos_tag(entry.pos); if (!pos.empty()) global_pos_cache[key].push_back(std::move(pos)); auto &defs = global_def_tokens_cache[key]; for (const auto &def : entry.definitions){ auto toks = tokenize_non_alnum(def); defs.insert(defs.end(), toks.begin(), toks.end()); } } for (auto &pr : global_def_tokens_cache){ auto &v = pr.second; std::sort(v.begin(), v.end()); v.erase(std::unique(v.begin(), v.end()), v.end()); } for (auto &pr : global_pos_cache){ auto &v = pr.second; std::sort(v.begin(), v.end()); v.erase(std::unique(v.begin(), v.end()), v.end()); } } // --------------------------- Knowledge base (short methods) -------------- using StrPtr = const std::string*; struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash()(p); } }; struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return a == b; } }; using NextSet = std::vector; struct KnowledgeBase { StringInterner interner; std::unordered_map next; std::unordered_map next_key_index; mutable std::mutex m; // def-index: for each interned word pointer -> list of interned tokens (definition expansion) std::unordered_map, PtrHash, PtrEq> def_index; mutable std::mutex def_m; int def_depth = 0; void add_pair_interned(StrPtr k, StrPtr v){ std::lock_guard lk(m); next_key_index.emplace(*k, k); auto &vec = next[k]; for (auto p : vec) if (p == v) return; vec.push_back(v); } // set def depth; if changed, drop previously computed def expansions void set_def_depth(int D){ std::lock_guard lk(def_m); if (D != def_depth){ def_index.clear(); def_depth = D; } } void ensure_def_for_interned(StrPtr wp){ if (wp == nullptr) return; if (def_depth <= 0) return; { std::lock_guard lk(def_m); if (def_index.find(wp) != def_index.end()) return; } std::unordered_set acc; std::vector frontier; const std::string start_key = normalize_dictionary_key(*wp); if (!start_key.empty()){ auto it_def = global_def_tokens_cache.find(start_key); if (it_def != global_def_tokens_cache.end()){ for (const auto &tok : it_def->second){ StrPtr tp = interner.intern(tok); if (acc.insert(tp).second) frontier.push_back(tp); } } } for (int depth = 1; depth < def_depth && !frontier.empty(); ++depth){ std::vector next_frontier; for (StrPtr w : frontier){ const std::string key = normalize_dictionary_key(*w); if (key.empty()) continue; auto it2 = global_def_tokens_cache.find(key); if (it2 == global_def_tokens_cache.end()) continue; for (const auto &tok : it2->second){ StrPtr tp = interner.intern(tok); if (acc.insert(tp).second) next_frontier.push_back(tp); } } frontier.swap(next_frontier); } std::vector out; out.reserve(acc.size()); for (StrPtr p : acc) out.push_back(p); { std::lock_guard lk(def_m); def_index.emplace(wp, std::move(out)); } } // existing public add_pair but now ensure def-expansion is built immediately void add_pair(const std::string &k, const std::string &v){ StrPtr kp = interner.intern(k); StrPtr vp = interner.intern(v); // ensure definition expansion for both words as soon as they are seen ensure_def_for_interned(kp); ensure_def_for_interned(vp); add_pair_interned(kp, vp); } std::optional lookup_by_string(const std::string &k) const { std::lock_guard lk(m); auto kit = next_key_index.find(k); if (kit == next_key_index.end()) return std::nullopt; auto it = next.find(kit->second); if (it == next.end()) return std::nullopt; return it->second; } std::optional lookup_by_ptr(StrPtr k) const { std::lock_guard lk(m); auto it = next.find(k); if (it == next.end()) return std::nullopt; return it->second; } }; static std::vector intern_tokens(KnowledgeBase &kb, const std::vector &tokens) { std::vector out; out.reserve(tokens.size()); for (const auto &t : tokens) out.push_back(kb.interner.intern(t)); return out; } static std::unordered_set aggregate_sets(const std::vector &tokens, const std::unordered_map, PtrHash, PtrEq> &def_index) { std::unordered_set agg; for (StrPtr t : tokens){ const std::string tk = normalize_dictionary_key(*t); if (!tk.empty()) agg.insert(tk); auto it = def_index.find(t); if (it != def_index.end()){ for (StrPtr d : it->second){ const std::string dk = normalize_dictionary_key(*d); if (!dk.empty()) agg.insert(dk); } } } return agg; } // --------------------------- Small JSON parse helpers ---------------------- static inline bool json_valid_index(size_t i, size_t n){ return i < n; } static std::string parse_quoted_string(const std::string &text, size_t &i){ std::string out; if (!json_valid_index(i, text.size()) || text[i] != '"') throw std::runtime_error("expected '\"'"); ++i; while (json_valid_index(i, text.size())){ char c = text[i++]; if (c == '"') break; if (c == '\\'){ if (!json_valid_index(i, text.size())) break; char e = text[i++]; if (e=='n') out.push_back('\n'); else if (e=='t') out.push_back('\t'); else out.push_back(e); } else out.push_back(c); } return out; } static void skip_spaces(const std::string &s, size_t &i){ while (json_valid_index(i, s.size()) && is_space(s[i])) ++i; } // Very small JSON-like parser tailored to dictionary_json structure static void skip_json_value(const std::string &s, size_t &i); static std::vector parse_json_string_array(const std::string &text, size_t &i){ std::vector out; if (!json_valid_index(i, text.size()) || text[i] != '[') return out; ++i; while (true){ skip_spaces(text, i); if (!json_valid_index(i, text.size())) break; if (text[i] == ']'){ ++i; break; } if (text[i] == '"') out.push_back(parse_quoted_string(text, i)); else skip_json_value(text, i); skip_spaces(text, i); if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; } if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; } } return out; } static void skip_json_value(const std::string &s, size_t &i){ skip_spaces(s, i); if (!json_valid_index(i, s.size())) return; if (s[i] == '"'){ (void)parse_quoted_string(s, i); return; } if (s[i] == '['){ ++i; while (json_valid_index(i, s.size())){ skip_spaces(s, i); if (!json_valid_index(i, s.size())) break; if (s[i] == ']'){ ++i; break; } skip_json_value(s, i); skip_spaces(s, i); if (json_valid_index(i, s.size()) && s[i] == ','){ ++i; continue; } if (json_valid_index(i, s.size()) && s[i] == ']'){ ++i; break; } } return; } if (s[i] == '{'){ ++i; while (json_valid_index(i, s.size())){ skip_spaces(s, i); if (!json_valid_index(i, s.size())) break; if (s[i] == '}'){ ++i; break; } if (s[i] == '"'){ (void)parse_quoted_string(s, i); skip_spaces(s, i); if (json_valid_index(i, s.size()) && s[i] == ':') ++i; skip_json_value(s, i); skip_spaces(s, i); if (json_valid_index(i, s.size()) && s[i] == ','){ ++i; continue; } if (json_valid_index(i, s.size()) && s[i] == '}'){ ++i; break; } } else { ++i; } } return; } while (json_valid_index(i, s.size())){ char c = s[i]; if (c == ',' || c == ']' || c == '}' || is_space(c)) break; ++i; } } static std::vector parse_dictionary_json(){ std::vector dict; if (dictionary_json_len == 0) return dict; std::string text; text.reserve(dictionary_json_len); for (unsigned int b = 0; b < dictionary_json_len; ++b){ text.push_back(static_cast(dictionary_json[b])); } size_t i = 0; skip_spaces(text, i); if (!json_valid_index(i, text.size()) || text[i] != '[') return dict; ++i; while (true){ skip_spaces(text, i); if (!json_valid_index(i, text.size())) break; if (text[i] == ']'){ ++i; break; } if (text[i] != '{'){ skip_json_value(text, i); skip_spaces(text, i); if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; } if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; } continue; } ++i; DictionaryEntry entry; while (true){ skip_spaces(text, i); if (!json_valid_index(i, text.size())) break; if (text[i] == '}'){ ++i; break; } std::string field = parse_quoted_string(text, i); skip_spaces(text, i); if (!json_valid_index(i, text.size()) || text[i] != ':') break; ++i; skip_spaces(text, i); if (field == "word"){ entry.word = parse_quoted_string(text, i); } else if (field == "pos"){ entry.pos = parse_quoted_string(text, i); } else if (field == "definitions"){ entry.definitions = parse_json_string_array(text, i); } else { skip_json_value(text, i); } skip_spaces(text, i); if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; } if (json_valid_index(i, text.size()) && text[i] == '}'){ ++i; break; } } if (!entry.word.empty()) dict.push_back(std::move(entry)); skip_spaces(text, i); if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; } if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; } } return dict; } static std::string best_candidate_by_similarity( const NextSet &cands, const std::vector &prompt_ptrs, const std::vector &resp_ptrs, const std::unordered_map, PtrHash, PtrEq> &def_index, const std::unordered_map &recent_counts, double repeat_penalty, const std::string &context_tok) { if (cands.empty()) return std::string(); if (cands.size() == 1) return *cands[0]; auto agg = aggregate_sets(prompt_ptrs, def_index); for (StrPtr r : resp_ptrs){ auto it = def_index.find(r); if (it != def_index.end()){ for (StrPtr d : it->second){ const std::string dk = normalize_dictionary_key(*d); if (!dk.empty()) agg.insert(dk); } } } double best = -1e9; std::string best_tok; size_t M = cands.size(); std::vector scores(M, 0.0); #pragma omp parallel for schedule(static) for (ptrdiff_t i = 0; i < static_cast(M); ++i){ const StrPtr cand = cands[(size_t)i]; const std::string cand_key = normalize_dictionary_key(*cand); size_t inter = (!cand_key.empty() && agg.count(cand_key)) ? 1 : 0; size_t cand_size = 1; auto it = def_index.find(cand); if (it != def_index.end()){ cand_size += it->second.size(); for (StrPtr d : it->second){ const std::string dk = normalize_dictionary_key(*d); if (!dk.empty() && agg.count(dk)) ++inter; } if (std::find(it->second.begin(), it->second.end(), cand) != it->second.end()){ --cand_size; } } size_t uni = agg.size() + cand_size - inter; double s = uni ? static_cast(inter) / static_cast(uni) : 0.0; scores[(size_t)i] = s; } for (size_t i = 0; i < M; ++i){ const std::string &tok = *cands[i]; const std::string tok_key = normalize_dictionary_key(tok); const std::string count_key = tok_key.empty() ? tok : tok; double s = scores[i]; auto rc_it = recent_counts.find(count_key); int cnt = (rc_it == recent_counts.end() ? 0 : rc_it->second); double adjusted = s + english_rule_bonus(context_tok, tok) - repeat_penalty * static_cast(cnt); if (adjusted > best || (adjusted == best && tok < best_tok)){ best = adjusted; best_tok = tok; } } return best_tok; } static std::vector construct_response(KnowledgeBase &kb, const std::vector &prompt_toks, size_t maxlen, double repeat_penalty) { std::vector resp; if (prompt_toks.empty() || maxlen == 0) return resp; auto prompt_ptrs = intern_tokens(kb, prompt_toks); std::vector resp_ptrs; std::unordered_map recent_counts; auto would_create_2_cycle = [&](const std::string &cand) -> bool { if (resp.size() < 3) return false; return normalize_dictionary_key(cand) == normalize_dictionary_key(resp[resp.size() - 2]) && normalize_dictionary_key(resp.back()) == normalize_dictionary_key(resp[resp.size() - 3]); }; std::string last_printed; for (size_t step = 0; step < maxlen; ++step){ NextSet candidates; bool found = false; std::string context_tok; if (step == 0){ for (ssize_t p = static_cast(prompt_toks.size()) - 1; p >= 0; --p){ auto opt = kb.lookup_by_string(prompt_toks[(size_t)p]); if (opt){ candidates = *opt; found = true; context_tok = prompt_toks[(size_t)p]; break; } } } else { auto opt = kb.lookup_by_string(last_printed); if (opt){ candidates = *opt; found = true; context_tok = last_printed; } else { for (ssize_t p = static_cast(prompt_toks.size()) - 1; p >= 0; --p){ auto opt2 = kb.lookup_by_string(prompt_toks[(size_t)p]); if (opt2){ candidates = *opt2; found = true; context_tok = prompt_toks[(size_t)p]; break; } } } } if (!found || candidates.empty()) break; if (candidates.size() == 1){ std::string only = *candidates[0]; std::string only_key = normalize_dictionary_key(only); if (recent_counts[only_key.empty() ? only : only_key] > 0) break; resp.push_back(only); resp_ptrs.push_back(kb.interner.intern(only)); recent_counts[only_key.empty() ? only : only_key] += 1; last_printed = only; std::cout << only << ' ' << std::flush; continue; } std::string chosen = best_candidate_by_similarity( candidates, prompt_ptrs, resp_ptrs, kb.def_index, recent_counts, repeat_penalty, context_tok ); if (chosen.empty()) break; if (would_create_2_cycle(chosen)) break; resp.push_back(chosen); resp_ptrs.push_back(kb.interner.intern(chosen)); std::string chosen_key = normalize_dictionary_key(chosen); recent_counts[chosen_key.empty() ? chosen : chosen_key] += 1; last_printed = chosen; std::cout << chosen << ' ' << std::flush; } return resp; } // --------------------------- Learning from files (short) ------------------- static void learn_from_file(KnowledgeBase &kb, const std::string &fname){ std::ifstream ifs(fname); if (!ifs) return; std::string tok; std::string prev; bool have_prev = false; while (ifs >> tok){ if (have_prev) kb.add_pair(prev, tok); prev = tok; have_prev = true; } } static void learn_files_parallel(KnowledgeBase &kb, const std::vector &files){ #pragma omp parallel for schedule(dynamic) for (ptrdiff_t i=0;i(files.size());++i) learn_from_file(kb, files[(size_t)i]); } // --------------------------- Serialization (binary, versioned) -------------- static constexpr std::uint64_t KB_MAGIC = 0x434850434B535641ULL; // "CHPCKSVA" static constexpr std::uint64_t KB_VERSION = 1ULL; static void write_u64(std::ostream &os, std::uint64_t v){ os.write(reinterpret_cast(&v), sizeof(v)); if(!os) throw std::runtime_error("write_u64 failed"); } static std::uint64_t read_u64(std::istream &is){ std::uint64_t v = 0; is.read(reinterpret_cast(&v), sizeof(v)); if(!is) throw std::runtime_error("read_u64 failed"); return v; } static void write_string(std::ostream &os, const std::string &s){ write_u64(os, static_cast(s.size())); if (!s.empty()){ os.write(s.data(), static_cast(s.size())); if(!os) throw std::runtime_error("write_string failed"); } } static std::string read_string(std::istream &is){ std::uint64_t n = read_u64(is); if (n > (1ULL << 30)) throw std::runtime_error("corrupt save file: string too large"); std::string s; s.resize(static_cast(n)); if (n != 0){ is.read(&s[0], static_cast(n)); if(!is) throw std::runtime_error("read_string failed"); } return s; } static void save_kb_binary(const KnowledgeBase &kb, const std::string &fname){ const std::string temp = fname + ".tmp"; { std::ofstream ofs(temp.c_str(), std::ios::binary | std::ios::trunc); if (!ofs) throw std::runtime_error("cannot open temp save file"); std::vector pool; pool.reserve(kb.interner.pool.size()); for (const auto &s : kb.interner.pool) pool.push_back(s); std::sort(pool.begin(), pool.end()); std::unordered_map id; id.reserve(pool.size()); for (std::uint64_t i = 0; i < static_cast(pool.size()); ++i){ id.emplace(pool[(size_t)i], i); } write_u64(ofs, KB_MAGIC); write_u64(ofs, KB_VERSION); write_u64(ofs, static_cast(kb.def_depth)); write_u64(ofs, static_cast(pool.size())); for (const auto &s : pool) write_string(ofs, s); write_u64(ofs, static_cast(kb.next.size())); for (const auto &pr : kb.next){ write_u64(ofs, id.at(*pr.first)); write_u64(ofs, static_cast(pr.second.size())); for (StrPtr nxt : pr.second){ write_u64(ofs, id.at(*nxt)); } } write_u64(ofs, static_cast(kb.def_index.size())); for (const auto &pr : kb.def_index){ write_u64(ofs, id.at(*pr.first)); write_u64(ofs, static_cast(pr.second.size())); for (StrPtr tok : pr.second){ write_u64(ofs, id.at(*tok)); } } ofs.flush(); if (!ofs) throw std::runtime_error("failed while writing temp save file"); } std::remove(fname.c_str()); if (std::rename(temp.c_str(), fname.c_str()) != 0){ std::remove(temp.c_str()); throw std::runtime_error("failed to commit save file"); } } static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_dict_depth){ std::ifstream ifs(fname, std::ios::binary); if (!ifs) throw std::runtime_error("cannot open load file"); const std::uint64_t magic = read_u64(ifs); if (magic != KB_MAGIC) throw std::runtime_error("bad save file magic"); const std::uint64_t version = read_u64(ifs); if (version != KB_VERSION) throw std::runtime_error("unsupported save file version"); const std::uint64_t file_def_depth = read_u64(ifs); const std::uint64_t N = read_u64(ifs); if (N > (1ULL << 26)) throw std::runtime_error("corrupt save file: pool too large"); std::vector strings; strings.reserve(static_cast(N)); for (std::uint64_t i = 0; i < N; ++i){ strings.push_back(read_string(ifs)); } kb.interner.pool.clear(); kb.interner.pool.reserve(static_cast(N)); std::vector ptrs; ptrs.reserve(static_cast(N)); for (const auto &s : strings){ ptrs.push_back(kb.interner.intern(s)); } // Rebuild next const std::uint64_t E = read_u64(ifs); if (E > (1ULL << 26)) throw std::runtime_error("corrupt save file: graph too large"); { std::lock_guard lk(kb.m); kb.next.clear(); kb.next_key_index.clear(); kb.next.reserve(static_cast(E)); kb.next_key_index.reserve(static_cast(E)); } for (std::uint64_t i = 0; i < E; ++i){ const std::uint64_t key_idx = read_u64(ifs); const std::uint64_t M = read_u64(ifs); if (key_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad graph key"); if (M > (1ULL << 26)) throw std::runtime_error("corrupt save file: graph degree too large"); StrPtr key_ptr = ptrs[(size_t)key_idx]; NextSet vec; vec.reserve(static_cast(M)); for (std::uint64_t j = 0; j < M; ++j){ const std::uint64_t v_idx = read_u64(ifs); if (v_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad graph value"); vec.push_back(ptrs[(size_t)v_idx]); } { std::lock_guard lk(kb.m); kb.next.emplace(key_ptr, std::move(vec)); kb.next_key_index.emplace(*key_ptr, key_ptr); } } // Rebuild def_index from file const std::uint64_t K = read_u64(ifs); if (K > (1ULL << 26)) throw std::runtime_error("corrupt save file: def_index too large"); { std::lock_guard lk(kb.def_m); kb.def_index.clear(); kb.def_index.reserve(static_cast(K)); kb.def_depth = static_cast(file_def_depth); } for (std::uint64_t i = 0; i < K; ++i){ const std::uint64_t key_idx = read_u64(ifs); const std::uint64_t M = read_u64(ifs); if (key_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad def key"); if (M > (1ULL << 26)) throw std::runtime_error("corrupt save file: def list too large"); std::vector toks; toks.reserve(static_cast(M)); for (std::uint64_t j = 0; j < M; ++j){ const std::uint64_t v_idx = read_u64(ifs); if (v_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad def value"); toks.push_back(ptrs[(size_t)v_idx]); } { std::lock_guard lk(kb.def_m); kb.def_index.emplace(ptrs[(size_t)key_idx], std::move(toks)); } } // If the caller asks for a different dict depth, recompute with the current embedded dictionary. if (cli_dict_depth != static_cast(file_def_depth)){ kb.set_def_depth(cli_dict_depth); std::vector targets; targets.reserve(ptrs.size() + kb.next.size() * 2); std::unordered_set seen; seen.reserve(ptrs.size() + kb.next.size() * 2); for (StrPtr p : ptrs){ if (seen.insert(p).second) targets.push_back(p); } { std::lock_guard lk(kb.m); for (const auto &pr : kb.next){ if (seen.insert(pr.first).second) targets.push_back(pr.first); for (StrPtr v : pr.second){ if (seen.insert(v).second) targets.push_back(v); } } } #pragma omp parallel for schedule(dynamic) for (ptrdiff_t i = 0; i < static_cast(targets.size()); ++i){ kb.ensure_def_for_interned(targets[(size_t)i]); } } } // --------------------------- CLI + Interactive loop (shorters) ----------- static void print_usage(const char *p){ std::cout << "Usage: " << p << " [--maxlen N] [--save FILE] [--load-kb FILE] [--dict-depth D] [--learn f1 f2 ...] [--repeat-penalty P] [--help]\n"; std::cout << " --maxlen N Maximum number of tokens constructed in a response.\n"; std::cout << " --save FILE Save the knowledge-base and dictionary expansions to a binary file.\n"; std::cout << " --load-kb FILE Load a previously saved knowledge-base (and dictionary expansions) from a binary file.\n"; std::cout << " --dict-depth D Depth of dictionary-definition expansion used during learning.\n"; std::cout << " --learn f1 f2 ... Learn from one or more text files to update the knowledge base.\n"; std::cout << " --repeat-penalty P Penalize repeated tokens during response generation (higher values discourage repetition).\n"; std::cout << " --help Show command-line interface options for ChatIPC usage.\n"; } int main(int argc, char **argv){ size_t maxlen = 100; std::string savefile; std::string load_txt; std::string load_kb; int dict_depth = 2; double repeat_penalty = 0.7; // default λ std::vector learn_files; for (int i=1;i " , std::getline(std::cin, line)){ if (line.empty()){ std::cout << "\n"; continue; } auto prompt_toks = tokenize_whitespace(line); for (size_t i=1;i combined = prompt_toks; combined.insert(combined.end(), resp.begin(), resp.end()); for (size_t i=1;i