Source code for imgreg.util.solver

"""
Solvermodel primitives.

Author: Fabian A. Preiss
"""
from __future__ import annotations

from enum import Enum
from typing import (
    Callable,
    Dict,
    Hashable,
    ItemsView,
    Iterable,
    Optional,
    Sequence,
    Set,
    Union,
    cast,
)

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from graphviz import Digraph

from imgreg.util.graph import DAGraph
from imgreg.util.params import ImageParameter, Parameter


[docs]class SolverError(Exception): pass
[docs]def dependency_graph(parameters: Set[Parameter], invert=False) -> DAGraph: """Construct a `DAGraph` dependency graph given a `Parameter` set.""" dep_graph = {} for parameter in parameters: if not invert: dep_graph[parameter] = {parent for parent in parameter.parents.values()} else: dep_graph[parameter] = {child for child in parameter.children.values()} return DAGraph(cast(Dict[Hashable, Set[Hashable]], dep_graph))
[docs]def dot_shape_func(parameter: Parameter) -> Dict[str, str]: """ Generate the shape argument for a dot graph depending on the given Parameter Parameters ---------- parameter : Parameter The parameter of a node """ if isinstance(parameter, ImageParameter): shape = "box" else: shape = "oval" return {"shape": shape}
[docs]class Solver: r""" Interface for a Solvermodel constructed from a set of parameters. Constructs a dependency graph for the `Parameter`\ s of the model and allows for lazy evaluation of the `Parameter`\ s. Notes ----- The state of this class is cached, when a parameter is changed on which later states depend, the properties ascendant to said parameter are removed from the cache. `Solver._get_dep_graph()` allows access to the internal dependency graph representation. """ def __init__(self, parameters: Optional[Set[Parameter]] = None): self.__params: Dict[Enum, Parameter] = dict() self.__dependency_graph: Optional[DAGraph] = None self.__idependency_graph: Optional[DAGraph] = None if parameters is not None: self._register_params(parameters) self._generate_dep_graphs() for param in parameters: setattr(self, param.enum_id.name, self[param.enum_id]) # TODO in concrete solver unittest if all Parameter objects are registered def __getitem__(self, item: Enum) -> Parameter: return self.__params[item] def _register_params(self, params: Iterable[Parameter]) -> None: for param in params: self._register_param(param) def _register_param(self, param: Parameter) -> None: if param.enum_id in self.__params: raise SolverError( f"Cannot register parameter '{param.enum_id.value}' two times." ) self.__params[param.enum_id] = param for parent in param.parents.values(): parent.add_child(param) # TODO insert descendants into each parameter here, then replace add_child methods etc. def _generate_dep_graphs(self) -> None: self.__dependency_graph = dependency_graph( set(self.__params.values()), invert=False ) self.__idependency_graph = dependency_graph( set(self.__params.values()), invert=True ) for (parameter, ascendants,) in cast( ItemsView[Parameter, Set[Parameter]], self.__idependency_graph.vertex_ascendants_dict.items(), ): for ascendant in ascendants: parameter.add_descendant(ascendant) def _get_dep_graph(self, invert=False) -> DAGraph: if self.__dependency_graph is None: self._generate_dep_graphs() result = self.__idependency_graph if invert else self.__dependency_graph return cast(DAGraph, result)
[docs] def display( self, param_list: Sequence[Union[Enum, ImageParameter]], title: Optional[str] = None, ) -> None: """ Fancy plot functionality for registered ImageParameters. Parameters ---------- plotlist : sequence sequence, to access the ImageParameters registered in solver title : str str, contains title of overall plot """ if len(param_list) == 0: return for param in param_list: if not isinstance(param, ImageParameter): try: if not isinstance(self[param], ImageParameter): raise SolverError( f"{param} does not reference an ImageParameter." ) except KeyError as err: raise SolverError from err n_rows = (1 + len(param_list)) // 2 fig, _ = plt.subplots(n_rows, 2, figsize=(8, 4 * n_rows)) gs = gridspec.GridSpec(n_rows, 4, fig) plt.subplots_adjust(wspace=0.4, hspace=0.3) for i in range(n_rows - 1 if len(param_list) % 2 else n_rows): ax_left = plt.subplot(gs[i : i + 1, :2]) ax_right = plt.subplot(gs[i : i + 1, 2:]) param = cast( ImageParameter, param_list[2 * i] if isinstance(param_list[2 * i], ImageParameter) else self[cast(Enum, param_list[2 * i])], ) param.display(ax_left) param = cast( ImageParameter, param_list[2 * i + 1] if isinstance(param_list[2 * i + 1], ImageParameter) else self[cast(Enum, param_list[2 * i + 1])], ) param.display(ax_right) if len(param_list) % 2: ax_center = plt.subplot(gs[-1, 1:3]) param = cast( ImageParameter, param_list[-1] if isinstance(param_list[-1], ImageParameter) else self[param_list[-1]], ) cast(ImageParameter, param).display(ax_center) fig.suptitle(title) plt.show()
[docs] def dot_graph( self, node_args_func: Callable[[Parameter], Dict[str, str]] = dot_shape_func, ) -> Digraph: """Return a dot graph representation of the solver model.""" vertex_parent_dict = self._get_dep_graph().vertex_parent_dict return vertex_parent_dict_to_dot( cast(Dict[Parameter, Set[Parameter]], vertex_parent_dict), node_args_func )
[docs]def vertex_parent_dict_to_dot( vertex_parent_dict: Dict[Parameter, Set[Parameter]], node_args_func: Optional[Callable[[Parameter], Dict[str, str]]] = None, invert=False, ) -> Digraph: """ Convert a directed graph to a dot graph. Parameters ---------- vertex_parent_dict : dict[Hashable, set[Hashable]] A dictionary representation of the directed graph node_args_func : Callable[[Hashable], dict[str, str]] A function handle to generate keyword arguments for the node of the dot graph depending on the current vertex """ node_args_func = (lambda x: {}) if node_args_func is None else node_args_func dot = Digraph(comment="dependencies") for parameter in vertex_parent_dict: dot.node(parameter.enum_id.name, **node_args_func(parameter)) for node, deps in vertex_parent_dict.items(): for dep in deps: if invert: dot.edge(node.enum_id.name, dep.enum_id.name) else: dot.edge(dep.enum_id.name, node.enum_id.name) return dot