"""Graph backend implementation using the py:mod:`igraph` library.
This module provides a concrete implementation of the :py:class:`GraphWrapper`
interface using :py:class:`igraph.Graph`.
"""
from __future__ import annotations
import sys
from logging import getLogger
from typing import TYPE_CHECKING, Any, Literal, cast
if sys.version_info >= (3, 11):
from typing import Unpack
else: # pragma: no cover
from typing_extensions import Unpack
try:
from igraph import Graph, Vertex # type: ignore[import-untyped]
except ImportError as e: # pragma: no cover
msg = (
"The igraph library is not available. "
"Please install it through 'pip install igraph', "
"or use the 'networkx' backend instead."
)
raise ImportError(msg) from e
from igraph._igraph import InternalError # type: ignore[import-untyped]
from ..core.constant import NA, V
from ..core.graphwrapper import DEFAULT_EDGE_DIRECTION, EdgeDirection, GraphWrapper
if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from ..core.attribute import VertexAttributeName, VertexAttributes
from ..core.funccall import CallMode
logger = getLogger(__name__)
[docs]
class IGraphWrapper(GraphWrapper[Graph, V]):
"""Graph backend for the igraph library."""
# Private methods ================================================================
def _get_v(self, vertex: V) -> Vertex:
"""Get the vertex object from its name.
Args:
vertex: the name of the vertex.
Returns:
the vertex object.
Raises:
VertexError: if the vertex is not found.
"""
try:
return self.graph.vs.find(name=vertex)
except ValueError as e:
self._raise_vertex_not_found(vertex, e)
def _get_edge_index(self, source: V, target: V) -> int:
"""Get the index of an edge.
Args:
source: the source vertex.
target: the target vertex.
Returns:
the index of the edge.
Raises:
VertexError: if the edge is not found.
"""
source_index = self._get_v(source).index
target_index = self._get_v(target).index
try:
return self.graph.get_eid(source_index, target_index)
except InternalError as e:
if "Cannot get edge ID, no such edge." in str(e):
self._raise_edge_not_found(source, target, e)
else:
raise # pragma: no cover
# Overridden methods =============================================================
# Initialization -----------------------------------------------------------------
[docs]
@classmethod
def initialize_empty(cls) -> Graph:
return Graph(directed=True)
[docs]
def get_graph_copy(self) -> Graph:
return self.graph.copy()
# Construction ------------------------------------------------------------------
[docs]
def add_vertex(self, vertex: V, **attributes: Unpack[VertexAttributes[V]]) -> None:
self.graph.add_vertex(name=vertex, **attributes)
[docs]
def add_edge(self, source: V, target: V) -> None:
v1 = self._get_v(source)
v2 = self._get_v(target)
self.graph.add_edge(v1.index, v2.index)
# Destruction -------------------------------------------------------------------
[docs]
def delete_vertex(self, *vertices: V) -> None:
vs = self.graph.vs.select(name_in=vertices)
if len(vs) != len(vertices):
missing_vertices = set(vertices) - set(vs["name"])
self._raise_vertex_not_found(missing_vertices.pop())
self.graph.delete_vertices(vs)
[docs]
def delete_edge(self, source: V, target: V) -> None:
edge_index = self._get_edge_index(source, target)
self.graph.delete_edges(edge_index)
# Vertex attributes -------------------------------------------------------------
[docs]
def get_vertex_attribute(self, vertex: V, key: VertexAttributeName) -> object:
return self._get_v(vertex)[key]
[docs]
def get_vertex_attributes(self, vertex: V) -> VertexAttributes[V]:
return cast(
"VertexAttributes[V]",
{
attr: value
for attr, value in self._get_v(vertex).attributes().items()
if attr != "name"
},
)
[docs]
def set_vertex_attribute(
self, vertex: V, key: VertexAttributeName, value: object
) -> None:
v = self._get_v(vertex)
v[key] = value
[docs]
def update_vertex_attributes(
self, vertex: V, attributes: Mapping[VertexAttributeName, Any]
) -> None:
v = self._get_v(vertex)
for attr, value in attributes.items():
v[attr] = value
# call_mode attribute -----------------------------------------------------------
@property
def call_mode(self) -> CallMode | None:
graph = self.graph
if "call_mode" not in graph.attributes():
graph["call_mode"] = None
return None
return cast("CallMode | None", graph["call_mode"])
@call_mode.setter
def call_mode(self, call_mode: CallMode | None) -> None:
self._check_call_mode(call_mode)
self.graph["call_mode"] = call_mode
# Nodes and edges ---------------------------------------------------------------
@property
def _vertices(self) -> Iterable[V]:
return self.graph.vs["name"]
@property
def _edges(self) -> Iterable[tuple[V, V]]:
return (
(self.graph.vs[e.source]["name"], self.graph.vs[e.target]["name"])
for e in self.graph.es
)
[docs]
def get_neighbors(
self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION
) -> Iterable[V]:
self._check_edge_direction(direction)
vertex_idx = self._get_v(vertex).index
return self.graph.vs[self.graph.neighbors(vertex_idx, mode=direction)]["name"]
[docs]
def get_degree(
self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION
) -> int:
self._check_edge_direction(direction)
vertex_index = self._get_v(vertex).index
return self.graph.degree(vertex_index, mode=direction)
[docs]
def get_subcomponent(
self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION
) -> Iterable[V]:
self._check_edge_direction(direction)
vertex_idx = self._get_v(vertex).index
return self.graph.vs[self.graph.subcomponent(vertex_idx, mode=direction)][
"name"
]
# Graph operations --------------------------------------------------------------
def _subgraph(self, vertices: Iterable[V]) -> Graph:
return self.graph.subgraph([self._get_v(v).index for v in vertices])
[docs]
def is_dag(self) -> bool:
return self.graph.is_dag()
[docs]
def reset(self) -> None:
graph = self.graph
n_vertices = len(graph.vs)
logger.info("Resetting graph with %d vertices", n_vertices)
graph.vs["func"] = [None] * n_vertices
graph.vs["value"] = [NA] * n_vertices
for attribute in graph.attributes():
del graph[attribute]
[docs]
def get_sorted_vertices(self, direction: Literal["in", "out"]) -> Iterable[V]:
self._check_edge_direction(direction, ("in", "out"))
if len(self.graph.vs) == 0:
return []
try:
return self.graph.vs[self.graph.topological_sorting(mode=direction)]["name"]
except InternalError as e:
if "The graph has cycles" in str(e):
self._raise_not_dag(e)
raise # pragma: no cover