diff --git a/crates/graph/src/lib.rs b/crates/graph/src/lib.rs index 8fa324a..d33f238 100644 --- a/crates/graph/src/lib.rs +++ b/crates/graph/src/lib.rs @@ -1,3 +1,5 @@ +mod util; + use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef}; use std::cell::{Cell, RefCell}; use std::rc::Rc; @@ -61,9 +63,10 @@ impl Graph { } pub fn freeze(self) -> Result, GraphFreezeError> { - if self.output.is_none() { - return Err(GraphFreezeError::NoOutput); - } + let output: NodeIndex = match self.output { + None => return Err(GraphFreezeError::NoOutput), + Some(idx) => idx, + }; let graph = self.node_graph.borrow(); let indices = graph.node_indices().collect::>(); @@ -80,7 +83,9 @@ impl Graph { self.node_graph.borrow_mut().add_edge(source, dest, ()); } - // TODO: remove nodes not connected to output + let mut graph = self.node_graph.borrow_mut(); + util::remove_nodes_not_connected_to(&mut *graph, output); + drop(graph); let sorted = petgraph::algo::toposort(&*self.node_graph.borrow(), None); if let Err(_cycle) = sorted { @@ -147,6 +152,10 @@ impl FrozenGraph { node.value() } + pub fn node_count(&self) -> usize { + self.node_graph.borrow().node_count() + } + pub fn modify(&mut self, mut f: F) -> Result<(), GraphFreezeError> where F: FnMut(&mut Graph) -> (), @@ -172,6 +181,7 @@ impl FrozenGraph { } } +#[derive(Clone)] pub struct Input { node_idx: NodeIndex, value: Rc>>, @@ -498,8 +508,8 @@ mod tests { let a = graph.add_rule(DeferredInput(Rc::clone(&a_input))); let b_input = Rc::new(RefCell::new(Some(a))); let b = graph.add_rule(DeferredInput(b_input)); - *a_input.borrow_mut() = Some(b); - graph.set_output(Inc(0)); + *a_input.borrow_mut() = Some(b.clone()); + graph.set_output(Double(b)); match graph.freeze() { Err(GraphFreezeError::Cyclic(_)) => (), Err(e) => assert!(false, "unexpected error {:?}", e), @@ -519,5 +529,6 @@ mod tests { }) .expect("modify"); assert_eq!(frozen.evaluate(), 2); + assert_eq!(frozen.node_count(), 1); } } diff --git a/crates/graph/src/util.rs b/crates/graph/src/util.rs new file mode 100644 index 0000000..94ed6fc --- /dev/null +++ b/crates/graph/src/util.rs @@ -0,0 +1,47 @@ +use petgraph::{ + stable_graph::{IndexType, NodeIndex, StableGraph}, + unionfind::UnionFind, + visit::{EdgeRef, IntoEdgeReferences, NodeIndexable}, + EdgeType, +}; + +pub fn remove_nodes_not_connected_to( + g: &mut StableGraph, + node: NodeIndex, +) { + // based on petgraph's connected_components + let mut vertex_sets = UnionFind::new(g.node_bound()); + for edge in g.edge_references() { + vertex_sets.union(g.to_index(edge.source()), g.to_index(edge.target())); + } + let to_remove = g + .node_indices() + .filter(|other_idx| !vertex_sets.equiv(g.to_index(node), g.to_index(*other_idx))) + .collect::>(); + for idx in to_remove { + g.remove_node(idx); + } +} + +#[cfg(test)] +mod tests { + use petgraph::stable_graph::StableGraph; + + #[test] + fn remove_nodes_not_connected_to() { + let mut graph: StableGraph<(), ()> = Default::default(); + let a = graph.add_node(()); + let b = graph.add_node(()); + let c = graph.add_node(()); + let d = graph.add_node(()); + let e = graph.add_node(()); + graph.extend_with_edges(&[(a, b), (a, c), (d, e)]); + super::remove_nodes_not_connected_to(&mut graph, a); + assert_eq!(graph.node_count(), 3); + assert!(graph.contains_node(a)); + assert!(graph.contains_node(b)); + assert!(graph.contains_node(c)); + assert!(!graph.contains_node(d)); + assert!(!graph.contains_node(e)); + } +}