| | from enum import Enum |
| | from typing import List, Any, Optional, Union, Tuple, Dict |
| | import numpy as np |
| | from modules import scripts, processing, shared |
| | from scripts import global_state |
| | from scripts.processor import preprocessor_sliders_config, model_free_preprocessors |
| | from scripts.logging import logger |
| |
|
| | from modules.api import api |
| |
|
| |
|
| | def get_api_version() -> int: |
| | return 2 |
| |
|
| |
|
| | class ControlMode(Enum): |
| | """ |
| | The improved guess mode. |
| | """ |
| |
|
| | BALANCED = "Balanced" |
| | PROMPT = "My prompt is more important" |
| | CONTROL = "ControlNet is more important" |
| |
|
| |
|
| | class ResizeMode(Enum): |
| | """ |
| | Resize modes for ControlNet input images. |
| | """ |
| |
|
| | RESIZE = "Just Resize" |
| | INNER_FIT = "Crop and Resize" |
| | OUTER_FIT = "Resize and Fill" |
| |
|
| | def int_value(self): |
| | if self == ResizeMode.RESIZE: |
| | return 0 |
| | elif self == ResizeMode.INNER_FIT: |
| | return 1 |
| | elif self == ResizeMode.OUTER_FIT: |
| | return 2 |
| | assert False, "NOTREACHED" |
| |
|
| |
|
| | resize_mode_aliases = { |
| | 'Inner Fit (Scale to Fit)': 'Crop and Resize', |
| | 'Outer Fit (Shrink to Fit)': 'Resize and Fill', |
| | 'Scale to Fit (Inner Fit)': 'Crop and Resize', |
| | 'Envelope (Outer Fit)': 'Resize and Fill', |
| | } |
| |
|
| |
|
| | def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: |
| | if isinstance(value, str): |
| | return ResizeMode(resize_mode_aliases.get(value, value)) |
| | elif isinstance(value, int): |
| | assert value >= 0 |
| | if value == 3: |
| | return ResizeMode.RESIZE |
| | |
| | if value >= len(ResizeMode): |
| | logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.') |
| | return ResizeMode.RESIZE |
| |
|
| | return [e for e in ResizeMode][value] |
| | else: |
| | return value |
| |
|
| |
|
| | def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: |
| | if isinstance(value, str): |
| | return ControlMode(value) |
| | elif isinstance(value, int): |
| | return [e for e in ControlMode][value] |
| | else: |
| | return value |
| |
|
| |
|
| | def visualize_inpaint_mask(img): |
| | if img.ndim == 3 and img.shape[2] == 4: |
| | result = img.copy() |
| | mask = result[:, :, 3] |
| | mask = 255 - mask // 2 |
| | result[:, :, 3] = mask |
| | return np.ascontiguousarray(result.copy()) |
| | return img |
| |
|
| |
|
| | def pixel_perfect_resolution( |
| | image: np.ndarray, |
| | target_H: int, |
| | target_W: int, |
| | resize_mode: ResizeMode, |
| | ) -> int: |
| | """ |
| | Calculate the estimated resolution for resizing an image while preserving aspect ratio. |
| | |
| | The function first calculates scaling factors for height and width of the image based on the target |
| | height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger |
| | scaling factor to estimate the new resolution. |
| | |
| | If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image |
| | fits within the target dimensions, potentially leaving some empty space. |
| | |
| | If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target |
| | dimensions are fully filled, potentially cropping the image. |
| | |
| | After calculating the estimated resolution, the function prints some debugging information. |
| | |
| | Args: |
| | image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels]. |
| | target_H (int): The target height for the image. |
| | target_W (int): The target width for the image. |
| | resize_mode (ResizeMode): The mode for resizing. |
| | |
| | Returns: |
| | int: The estimated resolution after resizing. |
| | """ |
| | raw_H, raw_W, _ = image.shape |
| |
|
| | k0 = float(target_H) / float(raw_H) |
| | k1 = float(target_W) / float(raw_W) |
| |
|
| | if resize_mode == ResizeMode.OUTER_FIT: |
| | estimation = min(k0, k1) * float(min(raw_H, raw_W)) |
| | else: |
| | estimation = max(k0, k1) * float(min(raw_H, raw_W)) |
| | |
| | logger.debug(f"Pixel Perfect Computation:") |
| | logger.debug(f"resize_mode = {resize_mode}") |
| | logger.debug(f"raw_H = {raw_H}") |
| | logger.debug(f"raw_W = {raw_W}") |
| | logger.debug(f"target_H = {target_H}") |
| | logger.debug(f"target_W = {target_W}") |
| | logger.debug(f"estimation = {estimation}") |
| |
|
| | return int(np.round(estimation)) |
| |
|
| |
|
| | InputImage = Union[np.ndarray, str] |
| | InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] |
| |
|
| |
|
| | class ControlNetUnit: |
| | """ |
| | Represents an entire ControlNet processing unit. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | enabled: bool=True, |
| | module: Optional[str]=None, |
| | model: Optional[str]=None, |
| | weight: float=1.0, |
| | image: Optional[InputImage]=None, |
| | resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT, |
| | low_vram: bool=False, |
| | processor_res: int=-1, |
| | threshold_a: float=-1, |
| | threshold_b: float=-1, |
| | guidance_start: float=0.0, |
| | guidance_end: float=1.0, |
| | pixel_perfect: bool=False, |
| | control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED, |
| | **_kwargs, |
| | ): |
| | self.enabled = enabled |
| | self.module = module |
| | self.model = model |
| | self.weight = weight |
| | self.image = image |
| | self.resize_mode = resize_mode |
| | self.low_vram = low_vram |
| | self.processor_res = processor_res |
| | self.threshold_a = threshold_a |
| | self.threshold_b = threshold_b |
| | self.guidance_start = guidance_start |
| | self.guidance_end = guidance_end |
| | self.pixel_perfect = pixel_perfect |
| | self.control_mode = control_mode |
| |
|
| | def __eq__(self, other): |
| | if not isinstance(other, ControlNetUnit): |
| | return False |
| |
|
| | return vars(self) == vars(other) |
| |
|
| |
|
| | def to_base64_nparray(encoding: str): |
| | """ |
| | Convert a base64 image into the image type the extension uses |
| | """ |
| |
|
| | return np.array(api.decode_base64_to_image(encoding)).astype('uint8') |
| |
|
| |
|
| | def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]: |
| | """ |
| | Fetch ControlNet processing units from a StableDiffusionProcessing. |
| | """ |
| |
|
| | return get_all_units(p.scripts, p.script_args) |
| |
|
| |
|
| | def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]: |
| | """ |
| | Fetch ControlNet processing units from an existing script runner. |
| | Use this function to fetch units from the list of all scripts arguments. |
| | """ |
| |
|
| | cn_script = find_cn_script(script_runner) |
| | if cn_script: |
| | return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to]) |
| |
|
| | return [] |
| |
|
| |
|
| | def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]: |
| | """ |
| | Fetch ControlNet processing units from ControlNet script arguments. |
| | Use `external_code.get_all_units` to fetch units from the list of all scripts arguments. |
| | """ |
| | def is_stale_unit(script_arg: Any) -> bool: |
| | """ Returns whether the script_arg is potentially an stale version of |
| | ControlNetUnit created before module reload.""" |
| | return ( |
| | 'ControlNetUnit' in type(script_arg).__name__ and |
| | not isinstance(script_arg, ControlNetUnit) |
| | ) |
| | |
| | def is_controlnet_unit(script_arg: Any) -> bool: |
| | """ Returns whether the script_arg is ControlNetUnit or anything that |
| | can be treated like ControlNetUnit. """ |
| | return ( |
| | isinstance(script_arg, (ControlNetUnit, dict)) or |
| | ( |
| | hasattr(script_arg, '__dict__') and |
| | set(vars(ControlNetUnit()).keys()).issubset( |
| | set(vars(script_arg).keys())) |
| | ) |
| | ) |
| |
|
| | all_units = [ |
| | to_processing_unit(script_arg) |
| | for script_arg in script_args |
| | if is_controlnet_unit(script_arg) |
| | ] |
| | if not all_units: |
| | logger.warning("No ControlNetUnit detected in args. It is very likely that you are having an extension conflict." |
| | f"Here are args received by ControlNet: {script_args}.") |
| | if any(is_stale_unit(script_arg) for script_arg in script_args): |
| | logger.debug( |
| | "Stale version of ControlNetUnit detected. The ControlNetUnit received" |
| | "by ControlNet is created before the newest load of ControlNet extension." |
| | "They will still be used by ControlNet as long as they provide same fields" |
| | "defined in the newest version of ControlNetUnit." |
| | ) |
| |
|
| | return all_units |
| |
|
| |
|
| | def get_single_unit_from(script_args: List[Any], index: int=0) -> Optional[ControlNetUnit]: |
| | """ |
| | Fetch a single ControlNet processing unit from ControlNet script arguments. |
| | The list must not contain script positional arguments. It must only contain processing units. |
| | """ |
| |
|
| | i = 0 |
| | while i < len(script_args) and index >= 0: |
| | if index == 0 and script_args[i] is not None: |
| | return to_processing_unit(script_args[i]) |
| | i += 1 |
| |
|
| | index -= 1 |
| |
|
| | return None |
| |
|
| | def get_max_models_num(): |
| | """ |
| | Fetch the maximum number of allowed ControlNet models. |
| | """ |
| |
|
| | max_models_num = shared.opts.data.get("control_net_max_models_num", 1) |
| | return max_models_num |
| |
|
| | def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: |
| | """ |
| | Convert different types to processing unit. |
| | If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details. |
| | """ |
| |
|
| | ext_compat_keys = { |
| | 'guessmode': 'guess_mode', |
| | 'guidance': 'guidance_end', |
| | 'lowvram': 'low_vram', |
| | 'input_image': 'image' |
| | } |
| |
|
| | if isinstance(unit, dict): |
| | unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} |
| |
|
| | mask = None |
| | if 'mask' in unit: |
| | mask = unit['mask'] |
| | del unit['mask'] |
| |
|
| | if 'image' in unit and not isinstance(unit['image'], dict): |
| | unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit['image'] else None |
| |
|
| | if 'guess_mode' in unit: |
| | logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.') |
| |
|
| | unit = ControlNetUnit(**unit) |
| |
|
| | |
| | |
| | return unit |
| |
|
| |
|
| | def update_cn_script_in_processing( |
| | p: processing.StableDiffusionProcessing, |
| | cn_units: List[ControlNetUnit], |
| | **_kwargs, |
| | ): |
| | """ |
| | Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`. |
| | `cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. |
| | |
| | Does not update `p.script_args` if any of the folling is true: |
| | - ControlNet is not present in `p.scripts` |
| | - `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet |
| | """ |
| |
|
| | cn_units_type = type(cn_units) if type(cn_units) in (list, tuple) else list |
| | script_args = list(p.script_args) |
| | update_cn_script_in_place(p.scripts, script_args, cn_units) |
| | p.script_args = cn_units_type(script_args) |
| |
|
| |
|
| | def update_cn_script_in_place( |
| | script_runner: scripts.ScriptRunner, |
| | script_args: List[Any], |
| | cn_units: List[ControlNetUnit], |
| | **_kwargs, |
| | ): |
| | """ |
| | Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`. |
| | `cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. |
| | |
| | Does not update `script_args` if any of the folling is true: |
| | - ControlNet is not present in `script_runner` |
| | - `script_args` is not filled with script arguments for scripts that are processed before ControlNet |
| | """ |
| |
|
| | cn_script = find_cn_script(script_runner) |
| | if cn_script is None or len(script_args) < cn_script.args_from: |
| | return |
| |
|
| | |
| | max_models = shared.opts.data.get("control_net_max_models_num", 1) |
| | cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) |
| |
|
| | cn_script_args_diff = 0 |
| | for script in script_runner.alwayson_scripts: |
| | if script is cn_script: |
| | cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) |
| | script_args[script.args_from:script.args_to] = cn_units |
| | script.args_to = script.args_from + len(cn_units) |
| | else: |
| | script.args_from += cn_script_args_diff |
| | script.args_to += cn_script_args_diff |
| |
|
| |
|
| | def get_models(update: bool=False) -> List[str]: |
| | """ |
| | Fetch the list of available models. |
| | Each value is a valid candidate of `ControlNetUnit.model`. |
| | |
| | Keyword arguments: |
| | update -- Whether to refresh the list from disk. (default False) |
| | """ |
| |
|
| | if update: |
| | global_state.update_cn_models() |
| |
|
| | return list(global_state.cn_models_names.values()) |
| |
|
| |
|
| | def get_modules(alias_names: bool = False) -> List[str]: |
| | """ |
| | Fetch the list of available preprocessors. |
| | Each value is a valid candidate of `ControlNetUnit.module`. |
| | |
| | Keyword arguments: |
| | alias_names -- Whether to get the ui alias names instead of internal keys |
| | """ |
| |
|
| | modules = list(global_state.cn_preprocessor_modules.keys()) |
| |
|
| | if alias_names: |
| | modules = [global_state.preprocessor_aliases.get(module, module) for module in modules] |
| |
|
| | return modules |
| |
|
| |
|
| | def get_modules_detail(alias_names: bool = False) -> Dict[str, Any]: |
| | """ |
| | get the detail of all preprocessors including |
| | sliders: the slider config in Auto1111 webUI |
| | |
| | Keyword arguments: |
| | alias_names -- Whether to get the module detail with alias names instead of internal keys |
| | """ |
| |
|
| | _module_detail = {} |
| | _module_list = get_modules(False) |
| | _module_list_alias = get_modules(True) |
| | |
| | _output_list = _module_list if not alias_names else _module_list_alias |
| | for index, module in enumerate(_output_list): |
| | if _module_list[index] in preprocessor_sliders_config: |
| | _module_detail[module] = { |
| | "model_free": module in model_free_preprocessors, |
| | "sliders": preprocessor_sliders_config[_module_list[index]] |
| | } |
| | else: |
| | _module_detail[module] = { |
| | "model_free": False, |
| | "sliders": [] |
| | } |
| | |
| | return _module_detail |
| |
|
| |
|
| | def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]: |
| | """ |
| | Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script. |
| | """ |
| |
|
| | if script_runner is None: |
| | return None |
| |
|
| | for script in script_runner.alwayson_scripts: |
| | if is_cn_script(script): |
| | return script |
| |
|
| |
|
| | def is_cn_script(script: scripts.Script) -> bool: |
| | """ |
| | Determine whether `script` is a ControlNet script. |
| | """ |
| |
|
| | return script.title().lower() == 'controlnet' |
| |
|