Source code for turbograph.run.graphcomputing

"""Compute values from a dependency graph.

A module that provides functions to compute values in a directed acyclic graph (DAG)
based on defined dependencies and vertex-specific functions and values.
"""

from __future__ import annotations

from collections.abc import Callable, Iterable, Mapping
from logging import getLogger
from typing import TYPE_CHECKING, Any, TypeVar

from ..core.constant import NA, V, VertexFunc, VertexValue
from ..core.funccall import DEFAULT_CALL_MODE, CallMode, call_func
from .graphupdating import rebuild_graph

if TYPE_CHECKING:
    from ..core.attribute import VertexAttributes
    from ..core.graphwrapper import GraphWrapper

logger = getLogger(__name__)

NAType = TypeVar("NAType")
"""Type representing a missing value in vertex attributes."""


def _get_vertex_func(
    vertex: V, vertex_attrs: VertexAttributes[V]
) -> Callable[..., Any]:
    """Retrieve the computation function associated with a vertex, with error handling.

    Args:
        vertex: The vertex whose function is being retrieved.
        vertex_attrs: The attributes of the vertex, which may include a function.

    Returns:
        The function used to compute the vertex value.

    Raises:
        ValueError: If no function is associated with the vertex.

    """
    if (func := vertex_attrs.get("func")) is None:
        msg = (
            f"Function to compute vertex {vertex!r} is not provided. "
            "Please provide a function to compute the vertex."
        )
        raise ValueError(msg)
    return func


def _compute_vertex(
    graph: GraphWrapper[Any, V],
    vertex: V,
    func: Callable[..., Any],
    *,
    call_mode: CallMode,
) -> object:
    r"""Compute the value of a vertex based on its dependencies.

    This function retrieves the values of a vertex's predecessors and uses them
    as inputs to compute the vertex's value.

    Args:
        graph: The dependency graph.
        vertex: The vertex to compute.
        func: The function that defines how to compute the vertex value.
        call_mode: The mode used to pass inputs to the function.

            - ``"args"``: Pass values as positional arguments.
            - ``"kwargs"``: Pass values as keyword arguments.
            - ``"arg"``: Pass values as a single dictionary argument.

    Returns:
        The computed value of the vertex.

    Raises:
        TypeError: If there is an issue with the function signature.

    """
    # Get the values of the predecessors, to be passed to the function
    ordered_precedecessors = graph.get_vertex_attribute(vertex, "predecessors")
    assert isinstance(ordered_precedecessors, Iterable), ordered_precedecessors
    arguments = {
        predecessor: graph.get_vertex_attribute(predecessor, "value")
        for predecessor in ordered_precedecessors
    }

    # Compute the value of the vertex
    logger.debug("Computing vertex %r with arguments %r", vertex, arguments)
    try:
        return call_func(func, arguments, call_mode)
    except Exception as e:
        msg = (
            f"Error when computing vertex {vertex!r} with arguments {arguments!r}. "
            "Please look at the stack trace for more information. "
            f"Exception {e.__class__.__name__}: {e}"
        )
        raise TypeError(msg) from e


[docs] def compute_from_graph( graph: GraphWrapper[Any, V], vertices: Iterable[V] | None = None, values: Mapping[V, VertexValue] | None = None, funcs: Mapping[V, VertexFunc] | None = None, *, call_mode: CallMode | None = None, auto_prune: bool = True, ) -> dict[V, Any]: r"""Compute values in a dependency graph. This function evaluates the specified vertices in a directed acyclic graph (DAG), ensuring that computations follow the correct dependency order. It allows overriding values, updating functions dynamically, and pruning unnecessary intermediate values. Args: graph: The dependency graph. vertices: The set of vertices to compute. - If ``None``, all vertices in the graph are computed. - If specified, only the listed vertices and their dependencies are computed. values: A mapping from vertex names to precomputed values. If provided, these values override the existing ones in the graph. funcs: A mapping from vertex names to updated functions. If provided, these functions override the existing ones in the graph. call_mode: The mode used to pass dependencies to functions. If ``None``, the mode is inferred from the graph. - ``"args"``: Dependencies are passed as positional arguments. - ``"kwargs"``: Dependencies are passed as keyword arguments. - ``"arg"``: Dependencies are passed as a single dictionary. auto_prune: If ``True``, automatically removes intermediate vertices that are no longer needed after computation to optimize memory usage. Returns: A dictionary mapping vertex names to their computed values. Raises: ValueError: If cycles are detected in the graph. Example: Compute values from a dependency graph. >>> from turbograph import compute_from_graph, build_graph >>> deps = {"a": [], "b": [], "c": ["a", "b"], "d": ["c"]} >>> funcs = { ... "a": lambda: 5, ... "b": lambda: 10, ... "c": lambda a, b: a + b, ... "d": lambda c: c * 2, ... } >>> graph = build_graph(deps, funcs=funcs) >>> compute_from_graph(graph) {'a': 5, 'b': 10, 'c': 15, 'd': 30} Override a vertex value before computation: >>> compute_from_graph(graph, values={"a": 7}) {'a': 7, 'b': 10, 'c': 17, 'd': 34} Modify a function before computation: >>> compute_from_graph(graph, funcs={"c": lambda a, b: a * b}) {'a': 5, 'b': 10, 'c': 50, 'd': 100} Compute only a specific vertex: >>> compute_from_graph(graph, vertices=["c"]) {'c': 50} Notes: If both a function update and a precomputed value are provided for a vertex, the precomputed value takes precedence. """ vertices = set(vertices) if vertices is not None else graph.vertices if not vertices: logger.warning("No vertex to compute.") return {} requested_values: dict[V, Any] = {} """A mapping from vertex names to their computed values. Vertices that are not requested are not stored in this dictionary. """ graph = rebuild_graph( graph, vertices, values, funcs, call_mode=call_mode, reduce=True ) call_mode = ( call_mode if (call_mode := graph.call_mode) is not None else DEFAULT_CALL_MODE ) assert call_mode is not None # Compute the quantities ordered_vertices = graph.get_sorted_vertices(direction="out") logger.info( "Compute vertices %s in call mode %r. " "Vertices are computed in the following order: %r", vertices, call_mode, ordered_vertices, ) for vertex in ordered_vertices: vertex_attrs = graph.get_vertex_attributes(vertex) logger.debug("Handling vertex %r", vertex_attrs) if (found_value := vertex_attrs.get("value", NA)) is not NA: logger.debug( "Found already computed value %r for vertex %r", found_value, vertex ) value = found_value else: func = _get_vertex_func(vertex, vertex_attrs) value = _compute_vertex(graph, vertex, func, call_mode=call_mode) # Store the value of the computed vertex in the graph and in the output graph.set_vertex_attribute(vertex, "value", value) if vertex in vertices: logger.debug("Storing value %r of vertex %r", value, vertex) requested_values[vertex] = value # Remove the predecessors of the computed vertex # if they are not needed anymore # i.e., remove the values stored in the graph if auto_prune: for predecessor in list(graph.get_neighbors(vertex, direction="in")): # Remove the vertex if it is isolated if graph.get_degree(predecessor) == 1: logger.debug( "Remove predecessor %r of vertex %r", predecessor, vertex ) graph.delete_vertex(predecessor) else: # Remove the edge predecessor -> vertex logger.debug("Remove edge %r -> %r", predecessor, vertex) graph.delete_edge(predecessor, vertex) return requested_values