"""Graph backend implementation using the :py:mod:`networkx` library.
This module provides a concrete implementation of the :py:class:`GraphWrapper`
interface using :py:class:`networkx.DiGraph`.
"""
from __future__ import annotations
from itertools import chain
from logging import getLogger
from typing import TYPE_CHECKING, Any, Generic, Literal, cast
try:
import networkx as nx
except ImportError as e: # pragma: no cover
msg = (
"The networkx library is not available. "
"Please install it through 'pip install networkx', "
"or use the 'igraph' backend instead."
)
raise ImportError(msg) from e
from networkx import NetworkXError
from ..core.constant import V
from ..core.graphwrapper import (
DEFAULT_EDGE_DIRECTION,
EdgeDirection,
EdgeDirectionError,
GraphWrapper,
)
logger = getLogger(__name__)
# Handling Generic Type Support in NetworkX (with the types-networkx library)
if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from ..core.attribute import VertexAttributeName, VertexAttributes
from ..core.funccall import CallMode
try:
class _DiGraph(nx.DiGraph[V], Generic[V]): # type: ignore[no-redef]
pass
except TypeError:
class _DiGraph(nx.DiGraph, Generic[V]): # type: ignore[no-redef]
pass
else:
class _DiGraph(nx.DiGraph, Generic[V]):
pass
[docs]
class NetworkXWrapper(GraphWrapper[_DiGraph[V], V]):
"""Graph backend for the networkx library using a directed graph (DiGraph)."""
# Overridden methods =============================================================
[docs]
@classmethod
def initialize_empty(cls) -> _DiGraph[V]:
return _DiGraph()
[docs]
def get_graph_copy(self) -> _DiGraph:
return cast("_DiGraph", self.graph.copy())
# Construction ------------------------------------------------------------------
[docs]
def add_vertex(self, vertex: V, **attributes: Any) -> None:
self.graph.add_node(vertex, **attributes)
[docs]
def add_edge(self, source: V, target: V) -> None:
# Check the existence of the vertices
# (so that the behaviour is consistent with the igraph backend).
if not self.graph.has_node(source):
self._raise_vertex_not_found(source)
if not self.graph.has_node(target):
self._raise_vertex_not_found(target)
# Add the edge.
self.graph.add_edge(source, target)
# Destruction -------------------------------------------------------------------
[docs]
def delete_vertex(self, *vertices: V) -> None:
for vertex in vertices:
try:
self.graph.remove_node(vertex)
except NetworkXError as e: # noqa: PERF203
self._raise_vertex_not_found(vertex, e)
[docs]
def delete_edge(self, source: V, target: V) -> None:
try:
self.graph.remove_edge(source, target)
except NetworkXError as e:
self._raise_edge_not_found(source, target, e)
# Vertex attributes -------------------------------------------------------------
[docs]
def get_vertex_attribute(self, vertex: V, key: VertexAttributeName) -> object:
vertex_attributes = self.get_vertex_attributes(vertex)
return vertex_attributes[key]
[docs]
def get_vertex_attributes(self, vertex: V) -> VertexAttributes[V]:
try:
vertex_attributes = self.graph.nodes[vertex]
except KeyError as e:
self._raise_vertex_not_found(vertex, e)
return cast("VertexAttributes[V]", vertex_attributes)
[docs]
def set_vertex_attribute(
self, vertex: V, key: VertexAttributeName, value: object
) -> None:
vertex_attributes = self.get_vertex_attributes(vertex)
vertex_attributes[key] = value # type: ignore[assignment]
[docs]
def update_vertex_attributes(
self, vertex: V, attributes: Mapping[VertexAttributeName, Any]
) -> None:
graph_attributes = self.get_vertex_attributes(vertex)
for key, value in attributes.items():
graph_attributes[key] = value
# call_mode attribute -----------------------------------------------------------
@property
def call_mode(self) -> CallMode | None:
return cast("CallMode | None", self.graph.graph.setdefault("call_mode", None))
@call_mode.setter
def call_mode(self, call_mode: CallMode | None) -> None:
self._check_call_mode(call_mode)
self.graph.graph["call_mode"] = call_mode
# Nodes and edges ---------------------------------------------------------------
@property
def _vertices(self) -> Iterable[V]:
return self.graph.nodes()
@property
def _edges(self) -> Iterable[tuple[V, V]]:
return self.graph.edges()
[docs]
def get_neighbors(
self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION
) -> Iterable[V]:
try:
if direction == "out":
return self.graph.successors(vertex)
elif direction == "in":
return self.graph.predecessors(vertex)
elif direction == "all":
return chain(
self.graph.successors(vertex), self.graph.predecessors(vertex)
)
else:
raise EdgeDirectionError(direction)
except NetworkXError as e:
self._raise_vertex_not_found(vertex, e)
[docs]
def get_degree(
self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION
) -> int:
if not self.graph.has_node(vertex):
self._raise_vertex_not_found(vertex)
if direction == "out":
degree = self.graph.out_degree(vertex)
elif direction == "in":
degree = self.graph.in_degree(vertex)
elif direction == "all":
degree = self.graph.degree(vertex)
else:
raise EdgeDirectionError(direction)
assert isinstance(degree, int)
return degree
[docs]
def get_subcomponent(
self, vertex: V, direction: EdgeDirection = DEFAULT_EDGE_DIRECTION
) -> Iterable[V]:
if not self.graph.has_node(vertex):
self._raise_vertex_not_found(vertex)
elif direction == "out":
return {vertex} | nx.descendants(self.graph, vertex)
elif direction == "in":
return {vertex} | nx.ancestors(self.graph, vertex)
elif direction == "all":
undirected = self.graph.to_undirected()
return nx.node_connected_component(undirected, vertex)
else:
raise EdgeDirectionError(direction)
# Graph operations --------------------------------------------------------------
def _subgraph(self, vertices: Iterable[V]) -> _DiGraph[V]:
for vertex in vertices:
if not self.graph.has_node(vertex):
self._raise_vertex_not_found(vertex)
return cast("_DiGraph[V]", self.graph.subgraph(list(vertices)).copy())
[docs]
def is_dag(self) -> bool:
return nx.is_directed_acyclic_graph(self.graph)
[docs]
def get_sorted_vertices(self, direction: Literal["in", "out"]) -> Iterable[V]:
try:
sorted_vertices = list(nx.topological_sort(self.graph))
except nx.NetworkXUnfeasible as e:
self._raise_not_dag(e)
if direction == "in":
return reversed(sorted_vertices)
elif direction == "out":
return sorted_vertices
else:
raise EdgeDirectionError(direction, ("in", "out"))