Source code for turbograph.core.graphwrapper

"""Define the graph wrapper base class.

This module defines the abstract base class :py:class:`GraphWrapper`
and related exceptions for managing and manipulating graph structures.

The :py:class:`GraphWrapper` class provides a common interface for graph operations used
in the TurboGraph library. It allows defining various backends for different graph
libraries, implemented in the :py:mod:`turbograph.backend` package.
"""

from __future__ import annotations

import sys
from abc import ABC, abstractmethod
from logging import getLogger
from typing import TYPE_CHECKING, Any, Generic, Literal, NoReturn, TypeVar

if sys.version_info >= (3, 11):
    from typing import Self, Unpack
else:  # pragma: no cover
    from typing_extensions import Self, Unpack

from .constant import NA, NACLS, V, VertexFunc, VertexValue
from .exception import NotFoundError
from .funccall import CALL_MODES, CallMode, CallModeError

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable, Mapping

    from .attribute import VertexAttributeName, VertexAttributes

logger = getLogger(__name__)


G = TypeVar("G")
"""Type representing a graph (e.g., a networkx graph)."""


[docs] class VertexError(NotFoundError): """Exception raised for a missing vertex.""" def __init__(self, vertex: object, valid_vertices: Iterable[object]) -> None: """Initialize the exception with the vertex and valid vertices. Args: vertex: The missing vertex. valid_vertices: The valid vertices in the graph. """ super().__init__("vertex", vertex, valid_vertices, "vertices")
[docs] class EdgeError(NotFoundError): """Exception raised for a missing edge.""" def __init__(self, edge: object, valid_edges: Iterable[object]) -> None: """Initialize the exception with the edge and valid edges. Args: edge: The missing edge. valid_edges: The valid edges in the graph. """ super().__init__("edge", edge, valid_edges)
EdgeDirection = Literal["in", "out", "all"] """Represents the direction of an edge in the graph.""" EDGE_DIRECTIONS: tuple[EdgeDirection, ...] = ("in", "out", "all") """Tuple of all possible edge directions in the graph.""" DEFAULT_EDGE_DIRECTION: Literal["all"] = "all" """Default edge direction used in graph wrapper methods."""
[docs] class EdgeDirectionError(NotFoundError): """Exception raised for an invalid edge direction.""" def __init__( self, direction: object, valid_directions: Iterable[str] | None = None ) -> None: """Initialize the exception with the edge direction and a message. Args: direction: The invalid edge direction. valid_directions: The valid edge directions. Defaults to all directions. """ if valid_directions is None: valid_directions = EDGE_DIRECTIONS super().__init__("edge direction", direction, valid_directions)
[docs] class GraphWrapper(ABC, Generic[G, V]): """Abstract base class for graph operations in TurboGraph. This class defines the essential methods for interacting with graph data structures, used to construct and manipulate dependency graphs in TurboGraph. """ def __init__(self, graph: G | None = None) -> None: """Initialize the graph wrapper with an optional graph instance. Args: graph: An optional graph instance to initialize the wrapper with. If not provided, an empty graph is initialized. """ self.graph = graph if graph is not None else self.initialize_empty() """The actual graph instance.""" # Abstract methods ==============================================================
[docs] @classmethod @abstractmethod def initialize_empty(cls) -> G: """Initialize and return an empty graph.""" ...
[docs] @abstractmethod def get_graph_copy(self) -> G: """Return a copy of the internal graph.""" ...
# Construction ------------------------------------------------------------------
[docs] @abstractmethod def add_vertex(self, vertex: V, **attributes: Unpack[VertexAttributes[V]]) -> None: r"""Add a vertex with specified attributes to the graph. Args: vertex: The vertex to add. **attributes: Keyword arguments representing the vertex attributes. """ ...
[docs] @abstractmethod def add_edge(self, source: V, target: V) -> None: """Add an edge between the source and target vertices. Args: source: The source vertex. target: The target vertex. """ ...
# Destruction -------------------------------------------------------------------
[docs] @abstractmethod def delete_vertex(self, *vertices: V) -> None: r"""Delete specified vertices from the graph. Args: *vertices: The vertices to delete. Raises: VertexError: If a vertex is not found in the graph. """ ...
[docs] @abstractmethod def delete_edge(self, source: V, target: V) -> None: """Delete an edge between the source and target vertices. Args: source: The source vertex. target: The target vertex. Raises: EdgeError: If the edge is not found in the graph. """ ...
# Vertex attributes -------------------------------------------------------------
[docs] @abstractmethod def get_vertex_attribute(self, vertex: V, key: VertexAttributeName) -> object: """Get the value of a specific attribute of a vertex. Args: vertex: The vertex whose attribute is to be retrieved. key: The attribute key. Raises: VertexError: If the vertex is not found in the graph. """ ...
[docs] @abstractmethod def get_vertex_attributes(self, vertex: V) -> VertexAttributes[V]: """Get all attributes of a vertex. Args: vertex: The vertex whose attributes are to be retrieved. Raises: VertexError: If the vertex is not found in the graph. """ ...
[docs] @abstractmethod def set_vertex_attribute( self, vertex: V, key: VertexAttributeName, value: object ) -> None: """Set the value of a specific attribute of a vertex. Args: vertex: The vertex whose attribute is to be set. key: The attribute key. value: The value to set. Raises: VertexError: If the vertex is not found in the graph. """ ...
[docs] @abstractmethod def update_vertex_attributes( self, vertex: V, attributes: Mapping[VertexAttributeName, Any] ) -> None: """Update multiple attributes of a vertex. Args: vertex: The vertex whose attributes are to be updated. attributes: A mapping of attribute keys to values. Raises: VertexError: If the vertex is not found in the graph. """ ...
# Nodes and edges --------------------------------------------------------------- @property @abstractmethod def _vertices(self) -> Iterable[V]: """Get all vertices of the graph.""" ... @property @abstractmethod def _edges(self) -> Iterable[tuple[V, V]]: """Get all edges of the graph.""" ...
[docs] @abstractmethod def get_neighbors( self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION ) -> Iterable[V]: """Get the neighbors of a vertex in the specified direction. Args: vertex: The vertex whose neighbors are to be retrieved. direction: The direction of the edges to consider. Raises: VertexError: If the vertex is not found in the graph. """ ...
[docs] @abstractmethod def get_degree( self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION ) -> int: """Get the degree of a vertex in the specified direction. Args: vertex: The vertex whose degree is to be retrieved. direction: The direction of the edges to consider. Raises: VertexError: If the vertex is not found in the graph. """ ...
[docs] @abstractmethod def get_subcomponent( self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION ) -> Iterable[V]: """Get the subcomponent of a vertex in the specified direction. Args: vertex: The vertex whose subcomponent is to be retrieved. direction: The direction of the edges to consider Raises: VertexError: If the vertex is not found in the graph. """ ...
# Call mode --------------------------------------------------------------------- @property @abstractmethod def call_mode(self) -> CallMode | None: """Return the current call mode of the graph.""" ... @call_mode.setter @abstractmethod def call_mode(self, call_mode: CallMode | None) -> None: """Set the call mode of the graph.""" ... # Graph operations -------------------------------------------------------------- @abstractmethod def _subgraph(self, vertices: Iterable[V]) -> G: """Create a subgraph containing the specified vertices. Args: vertices: The vertices to include in the subgraph. Raises: KeyError: If a vertex is not found in the graph. """ ...
[docs] @abstractmethod def is_dag(self) -> bool: """Check if the graph is a directed acyclic graph (DAG).""" ...
[docs] @abstractmethod def get_sorted_vertices(self, direction: Literal["in", "out"]) -> Iterable[V]: """Get the vertices sorted topologically in the specified direction. Args: direction: The direction of the edges to consider. """ ...
# Private methods =============================================================== # Exception handling ------------------------------------------------------------ def _check_call_mode(self, call_mode: object) -> None: """Check if the call mode is valid. Args: call_mode: The call mode to check. Raises: ValueError: If the call mode is invalid """ if call_mode is not None and call_mode not in CALL_MODES: raise CallModeError(call_mode) def _check_edge_direction( self, direction: EdgeDirection, valid_directions: Iterable[EdgeDirection] | None = None, ) -> None: if valid_directions is None: valid_directions = EDGE_DIRECTIONS if direction not in valid_directions: raise EdgeDirectionError(direction, valid_directions) def _raise_vertex_not_found( self, vertex: V, origin: Exception | None = None ) -> NoReturn: """Raise an exception indicating that the vertex was not found. Args: vertex: The vertex that was not found. origin: The original exception that caused the error. """ raise VertexError(vertex, self._vertices) from origin def _raise_edge_not_found( self, source: V, target: V, origin: Exception | None = None ) -> NoReturn: """Raise an exception indicating that the edge was not found. Args: source: The source vertex of the edge. target: The target vertex of the edge. origin: The original exception that caused the error. """ raise EdgeError((source, target), self._edges) from origin def _raise_not_dag(self, origin: Exception | None = None) -> NoReturn: """Raise an exception indicating that the graph is not a DAG.""" msg = "The graph is not a directed acyclic graph (DAG)." raise ValueError(msg) from origin # Concrete methods ============================================================== # Vertices and edges ------------------------------------------------------------ @property def vertices(self) -> set[V]: """Return a set of all vertices in the graph.""" return set(self._vertices) @property def edges(self) -> set[tuple[V, V]]: """Return a set of all edges in the graph.""" return set(self._edges)
[docs] def has_vertex(self, vertex: V) -> bool: """Check if the vertex exists in the graph. Args: vertex: The vertex to check. """ return vertex in self.vertices
[docs] def has_edge(self, source: V, target: V) -> bool: """Check if the edge exists in the graph. Args: source: The source vertex of the edge. target: The target vertex of the edge. """ return (source, target) in self.edges
# Vertex attributes -------------------------------------------------------------
[docs] def get_vertices_with_known_value(self) -> set[V]: """Return vertices that have a known value (i.e., not :py:data:`NA`).""" return { vertex for vertex in self._vertices if self.get_vertex_attribute(vertex, "value") is not NA }
[docs] def set_attribute_to_vertices( self, key: VertexAttributeName, vertex_to_value: Mapping[V, Any] ) -> None: """Set a specific attribute for multiple vertices. Args: key: The attribute key. vertex_to_value: A mapping of vertices to values. Raises: KeyError: If a vertex is not found in the graph. """ for vertex, value in vertex_to_value.items(): logger.debug("Setting attribute %r to vertex %r", key, vertex) self.set_vertex_attribute(vertex, key, value)
[docs] def get_all_vertex_attributes(self) -> dict[V, VertexAttributes[V]]: """Return a dictionary of all vertex attributes.""" return {vertex: self.get_vertex_attributes(vertex) for vertex in self._vertices}
# Graph operations --------------------------------------------------------------
[docs] def copy(self) -> Self: """Return a copy of the graph wrapper, including the graph instance.""" return self.__class__(self.get_graph_copy())
[docs] def subgraph(self, vertices: Iterable[V]) -> Self: """Create and return a subgraph containing the specified vertices. Args: vertices: The vertices to include in the subgraph. """ return self.__class__(self._subgraph(vertices))
[docs] def reset(self) -> None: """Reset the graph while preserving its structure. This method clears the stored functions and values from all vertices while preserving the graph's structure and dependencies. """ # Clear vertex attributes. for vertex in self._vertices: self.set_vertex_attribute(vertex, "func", None) self.set_vertex_attribute(vertex, "value", NA) # Clear graph attributes. self.call_mode = None
def __eq__(self, other: object) -> bool: """Check if the graph wrapper is equal to another object.""" if not isinstance(other, self.__class__): return False return ( self.edges == other.edges and self.call_mode == other.call_mode and self.get_all_vertex_attributes() == other.get_all_vertex_attributes() ) # Run operations ----------------------------------------------------------------
[docs] def rebuild( self, vertices: Iterable[V] | None = None, values: Mapping[V, Any | NACLS] | None = None, funcs: Mapping[V, Callable[..., Any] | None] | None = None, *, reduce: bool = False, call_mode: CallMode | None = None, ) -> Self: """Alias for :py:func:`turbograph.rebuild_graph`. This method is a direct alias for :py:func:`turbograph.rebuild_graph`, forwarding all arguments to it. Refer to its documentation for usage details. """ from turbograph.run.graphupdating import rebuild_graph return rebuild_graph( self, vertices=vertices, values=values, funcs=funcs, reduce=reduce, call_mode=call_mode, )
[docs] def compute( self, vertices: Iterable[V] | None = None, values: Mapping[V, VertexValue] | None = None, funcs: Mapping[V, VertexFunc] | None = None, *, call_mode: CallMode | None = None, auto_prune: bool = True, ) -> dict[V, Any]: """Alias for :py:func:`turbograph.compute_from_graph`. This method is a direct alias for :py:func:`turbograph.compute_from_graph`, forwarding all arguments to it. Refer to its documentation for usage details. """ from turbograph.run.graphcomputing import compute_from_graph return compute_from_graph( self, vertices=vertices, values=values, funcs=funcs, call_mode=call_mode, auto_prune=auto_prune, )
GW = TypeVar("GW", bound=GraphWrapper[Any, Any]) """Type of the graph wrapper."""