| | from dataclasses import dataclass |
| | from itertools import product |
| | import re |
| | from typing import Union, List, Tuple |
| | import numpy as np |
| | import open_clip |
| | from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase as CLIP |
| | from modules import prompt_parser, shared |
| | from scripts.cutofflib.utils import log |
| |
|
| | class ClipWrapper: |
| | def __init__(self, te: CLIP): |
| | self.te = te |
| | self.v1 = hasattr(te.wrapped, 'tokenizer') |
| | self.t = ( |
| | te.wrapped.tokenizer if self.v1 |
| | else open_clip.tokenizer._tokenizer |
| | ) |
| | |
| | def token_to_id(self, token: str) -> int: |
| | if self.v1: |
| | return self.t._convert_token_to_id(token) |
| | else: |
| | return self.t.encoder[token] |
| | |
| | def id_to_token(self, id: int) -> str: |
| | if self.v1: |
| | return self.t.convert_ids_to_tokens(id) |
| | else: |
| | return self.t.decoder[id] |
| | |
| | def ids_to_tokens(self, ids: List[int]) -> List[str]: |
| | if self.v1: |
| | return self.t.convert_ids_to_tokens(ids) |
| | else: |
| | return [self.t.decoder[id] for id in ids] |
| | |
| | def token(self, token: Union[int,str]): |
| | if isinstance(token, int): |
| | return Token(token, self.id_to_token(token)) |
| | else: |
| | return Token(self.token_to_id(token), token) |
| |
|
| |
|
| | @dataclass |
| | class Token: |
| | id: int |
| | token: str |
| |
|
| | class CutoffPrompt: |
| | |
| | @staticmethod |
| | def _cutoff(prompt: str, clip: CLIP, tokens: List[str], padding: str): |
| | def token_count(text: str): |
| | tt = token_to_block(clip, text) |
| | |
| | for index, (t, _) in enumerate(tt): |
| | if t.id == clip.id_end: |
| | return index - 1 |
| | return 0 |
| | |
| | re_targets = [ re.compile(r'\b' + re.escape(x) + r'\b') for x in tokens ] |
| | replacer = [ ' ' + ' '.join([padding] * token_count(x)) + ' ' for x in tokens ] |
| | |
| | rows: List[Tuple[str,str]] = [] |
| | for block in prompt.split(','): |
| | b0 = block |
| | for r, p in zip(re_targets, replacer): |
| | block = r.sub(p, block) |
| | b1 = block |
| | rows.append((b0, b1)) |
| | |
| | return rows |
| | |
| | def __init__(self, prompt: str, clip: CLIP, tokens: List[str], padding: str): |
| | self.prompt = prompt |
| | rows = CutoffPrompt._cutoff(prompt, clip, tokens, padding) |
| | self.base = np.array([x[0] for x in rows]) |
| | self.cut = np.array([x[1] for x in rows]) |
| | self.sw = np.array([False] * len(rows)) |
| | |
| | @property |
| | def block_count(self): |
| | return self.base.shape[0] |
| | |
| | def switch(self, block_index: int, to: Union[bool,None] = None): |
| | if to is None: |
| | to = not self.sw[block_index] |
| | self.sw[block_index] = to |
| | return to |
| | |
| | def text(self, sw=None): |
| | if sw is None: |
| | sw = self.sw |
| | blocks = np.where(sw, self.cut, self.base) |
| | return ','.join(blocks) |
| | |
| | def active_blocks(self) -> np.ndarray: |
| | indices, = (self.base != self.cut).nonzero() |
| | return indices |
| | |
| | def generate(self): |
| | indices = self.active_blocks() |
| | for diff_sw in product([False, True], repeat=indices.shape[0]): |
| | sw = np.full_like(self.sw, False) |
| | sw[indices] = diff_sw |
| | yield diff_sw, self.text(sw) |
| |
|
| |
|
| | def generate_prompts( |
| | clip: CLIP, |
| | prompt: str, |
| | targets: List[str], |
| | padding: Union[str,int,Token], |
| | ) -> CutoffPrompt: |
| | |
| | te = ClipWrapper(clip) |
| | |
| | if not isinstance(padding, Token): |
| | o_pad = padding |
| | padding = te.token(padding) |
| | if padding.id == clip.id_end: |
| | raise ValueError(f'`{o_pad}` is not a valid token.') |
| | |
| | result = CutoffPrompt(prompt, clip, targets, padding.token.replace('</w>', '')) |
| | |
| | log(f'[Cutoff] replace: {", ".join(targets)}') |
| | log(f'[Cutoff] to: {padding.token} ({padding.id})') |
| | log(f'[Cutoff] original: {prompt}') |
| | for i, (_, pp) in enumerate(result.generate()): |
| | log(f'[Cutoff] #{i}: {pp}') |
| | |
| | return result |
| |
|
| |
|
| | def token_to_block(clip: CLIP, prompt: str): |
| | te = ClipWrapper(clip) |
| | |
| | |
| | |
| | parsed = prompt_parser.parse_prompt_attention(prompt) |
| | tokenized: List[List[int]] = clip.tokenize([text for text, _ in parsed]) |
| | |
| | CHUNK_LENGTH = 75 |
| | id_start = te.token(clip.id_start) |
| | id_end = te.token(clip.id_end) |
| | comma = te.token(',</w>') |
| | |
| | last_comma = -1 |
| | current_block = 0 |
| | current_tokens: List[Tuple[Token,int]] = [] |
| | result: List[Tuple[Token,int]] = [] |
| | |
| | def next_chunk(): |
| | nonlocal current_tokens, last_comma |
| | |
| | to_add = CHUNK_LENGTH - len(current_tokens) |
| | if 0 < to_add: |
| | current_tokens += [(id_end, -1)] * to_add |
| | |
| | current_tokens = [(id_start, -1)] + current_tokens + [(id_end, -1)] |
| | |
| | last_comma = -1 |
| | result.extend(current_tokens) |
| | current_tokens = [] |
| | |
| | for tokens, (text, weight) in zip(tokenized, parsed): |
| | if text == 'BREAK' and weight == -1: |
| | next_chunk() |
| | continue |
| | |
| | p = 0 |
| | while p < len(tokens): |
| | token = tokens[p] |
| | |
| | if token == comma.id: |
| | last_comma = len(current_tokens) |
| | current_block += 1 |
| | |
| | elif ( |
| | shared.opts.comma_padding_backtrack != 0 |
| | and len(current_tokens) == CHUNK_LENGTH |
| | and last_comma != -1 |
| | and len(current_tokens) - last_comma <= shared.opts.comma_padding_backtrack |
| | ): |
| | break_location = last_comma + 1 |
| | reloc_tokens = current_tokens[break_location:] |
| | current_tokens = current_tokens[:break_location] |
| | next_chunk() |
| | current_tokens = reloc_tokens |
| | |
| | if len(current_tokens) == CHUNK_LENGTH: |
| | next_chunk() |
| | |
| | embedding, embedding_length_in_tokens = clip.hijack.embedding_db.find_embedding_at_position(tokens, p) |
| | if embedding is None: |
| | if token == comma.id: |
| | current_tokens.append((te.token(token), -1)) |
| | else: |
| | current_tokens.append((te.token(token), current_block)) |
| | p += 1 |
| | continue |
| |
|
| | emb_len = int(embedding.vec.shape[0]) |
| | if len(current_tokens) + emb_len > CHUNK_LENGTH: |
| | next_chunk() |
| |
|
| | current_tokens += [(te.token(0), current_block)] * emb_len |
| | p += embedding_length_in_tokens |
| | |
| | if len(current_tokens) > 0: |
| | next_chunk() |
| | |
| | return result |
| |
|