Source code for turbograph.backend.igraph_backend

"""Graph backend implementation using the py:mod:`igraph` library.

This module provides a concrete implementation of the :py:class:`GraphWrapper`
interface using :py:class:`igraph.Graph`.
"""

from __future__ import annotations

import sys
from logging import getLogger
from typing import TYPE_CHECKING, Any, Literal, cast

if sys.version_info >= (3, 11):
    from typing import Unpack
else:  # pragma: no cover
    from typing_extensions import Unpack

try:
    from igraph import Graph, Vertex  # type: ignore[import-untyped]
except ImportError as e:  # pragma: no cover
    msg = (
        "The igraph library is not available. "
        "Please install it through 'pip install igraph', "
        "or use the 'networkx' backend instead."
    )
    raise ImportError(msg) from e

from igraph._igraph import InternalError  # type: ignore[import-untyped]

from ..core.constant import NA, V
from ..core.graphwrapper import DEFAULT_EDGE_DIRECTION, EdgeDirection, GraphWrapper

if TYPE_CHECKING:
    from collections.abc import Iterable, Mapping

    from ..core.attribute import VertexAttributeName, VertexAttributes
    from ..core.funccall import CallMode

logger = getLogger(__name__)


[docs] class IGraphWrapper(GraphWrapper[Graph, V]): """Graph backend for the igraph library.""" # Private methods ================================================================ def _get_v(self, vertex: V) -> Vertex: """Get the vertex object from its name. Args: vertex: the name of the vertex. Returns: the vertex object. Raises: VertexError: if the vertex is not found. """ try: return self.graph.vs.find(name=vertex) except ValueError as e: self._raise_vertex_not_found(vertex, e) def _get_edge_index(self, source: V, target: V) -> int: """Get the index of an edge. Args: source: the source vertex. target: the target vertex. Returns: the index of the edge. Raises: VertexError: if the edge is not found. """ source_index = self._get_v(source).index target_index = self._get_v(target).index try: return self.graph.get_eid(source_index, target_index) except InternalError as e: if "Cannot get edge ID, no such edge." in str(e): self._raise_edge_not_found(source, target, e) else: raise # pragma: no cover # Overridden methods ============================================================= # Initialization -----------------------------------------------------------------
[docs] @classmethod def initialize_empty(cls) -> Graph: return Graph(directed=True)
[docs] def get_graph_copy(self) -> Graph: return self.graph.copy()
# Construction ------------------------------------------------------------------
[docs] def add_vertex(self, vertex: V, **attributes: Unpack[VertexAttributes[V]]) -> None: self.graph.add_vertex(name=vertex, **attributes)
[docs] def add_edge(self, source: V, target: V) -> None: v1 = self._get_v(source) v2 = self._get_v(target) self.graph.add_edge(v1.index, v2.index)
# Destruction -------------------------------------------------------------------
[docs] def delete_vertex(self, *vertices: V) -> None: vs = self.graph.vs.select(name_in=vertices) if len(vs) != len(vertices): missing_vertices = set(vertices) - set(vs["name"]) self._raise_vertex_not_found(missing_vertices.pop()) self.graph.delete_vertices(vs)
[docs] def delete_edge(self, source: V, target: V) -> None: edge_index = self._get_edge_index(source, target) self.graph.delete_edges(edge_index)
# Vertex attributes -------------------------------------------------------------
[docs] def get_vertex_attribute(self, vertex: V, key: VertexAttributeName) -> object: return self._get_v(vertex)[key]
[docs] def get_vertex_attributes(self, vertex: V) -> VertexAttributes[V]: return cast( "VertexAttributes[V]", { attr: value for attr, value in self._get_v(vertex).attributes().items() if attr != "name" }, )
[docs] def set_vertex_attribute( self, vertex: V, key: VertexAttributeName, value: object ) -> None: v = self._get_v(vertex) v[key] = value
[docs] def update_vertex_attributes( self, vertex: V, attributes: Mapping[VertexAttributeName, Any] ) -> None: v = self._get_v(vertex) for attr, value in attributes.items(): v[attr] = value
# call_mode attribute ----------------------------------------------------------- @property def call_mode(self) -> CallMode | None: graph = self.graph if "call_mode" not in graph.attributes(): graph["call_mode"] = None return None return cast("CallMode | None", graph["call_mode"]) @call_mode.setter def call_mode(self, call_mode: CallMode | None) -> None: self._check_call_mode(call_mode) self.graph["call_mode"] = call_mode # Nodes and edges --------------------------------------------------------------- @property def _vertices(self) -> Iterable[V]: return self.graph.vs["name"] @property def _edges(self) -> Iterable[tuple[V, V]]: return ( (self.graph.vs[e.source]["name"], self.graph.vs[e.target]["name"]) for e in self.graph.es )
[docs] def get_neighbors( self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION ) -> Iterable[V]: self._check_edge_direction(direction) vertex_idx = self._get_v(vertex).index return self.graph.vs[self.graph.neighbors(vertex_idx, mode=direction)]["name"]
[docs] def get_degree( self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION ) -> int: self._check_edge_direction(direction) vertex_index = self._get_v(vertex).index return self.graph.degree(vertex_index, mode=direction)
[docs] def get_subcomponent( self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION ) -> Iterable[V]: self._check_edge_direction(direction) vertex_idx = self._get_v(vertex).index return self.graph.vs[self.graph.subcomponent(vertex_idx, mode=direction)][ "name" ]
# Graph operations -------------------------------------------------------------- def _subgraph(self, vertices: Iterable[V]) -> Graph: return self.graph.subgraph([self._get_v(v).index for v in vertices])
[docs] def is_dag(self) -> bool: return self.graph.is_dag()
[docs] def reset(self) -> None: graph = self.graph n_vertices = len(graph.vs) logger.info("Resetting graph with %d vertices", n_vertices) graph.vs["func"] = [None] * n_vertices graph.vs["value"] = [NA] * n_vertices for attribute in graph.attributes(): del graph[attribute]
[docs] def get_sorted_vertices(self, direction: Literal["in", "out"]) -> Iterable[V]: self._check_edge_direction(direction, ("in", "out")) if len(self.graph.vs) == 0: return [] try: return self.graph.vs[self.graph.topological_sorting(mode=direction)]["name"] except InternalError as e: if "The graph has cycles" in str(e): self._raise_not_dag(e) raise # pragma: no cover