Remove parts of the graph unused by the output node

This commit is contained in:
Shadowfacts 2024-10-29 14:54:11 -04:00
parent bd2cdba5bc
commit 140c6a67fd
2 changed files with 64 additions and 6 deletions

View File

@ -1,3 +1,5 @@
mod util;
use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef}; use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef};
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::rc::Rc; use std::rc::Rc;
@ -61,9 +63,10 @@ impl<Output: Clone + 'static> Graph<Output> {
} }
pub fn freeze(self) -> Result<FrozenGraph<Output>, GraphFreezeError> { pub fn freeze(self) -> Result<FrozenGraph<Output>, GraphFreezeError> {
if self.output.is_none() { let output: NodeIndex<u32> = match self.output {
return Err(GraphFreezeError::NoOutput); None => return Err(GraphFreezeError::NoOutput),
} Some(idx) => idx,
};
let graph = self.node_graph.borrow(); let graph = self.node_graph.borrow();
let indices = graph.node_indices().collect::<Vec<_>>(); let indices = graph.node_indices().collect::<Vec<_>>();
@ -80,7 +83,9 @@ impl<Output: Clone + 'static> Graph<Output> {
self.node_graph.borrow_mut().add_edge(source, dest, ()); self.node_graph.borrow_mut().add_edge(source, dest, ());
} }
// TODO: remove nodes not connected to output let mut graph = self.node_graph.borrow_mut();
util::remove_nodes_not_connected_to(&mut *graph, output);
drop(graph);
let sorted = petgraph::algo::toposort(&*self.node_graph.borrow(), None); let sorted = petgraph::algo::toposort(&*self.node_graph.borrow(), None);
if let Err(_cycle) = sorted { if let Err(_cycle) = sorted {
@ -147,6 +152,10 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
node.value() node.value()
} }
pub fn node_count(&self) -> usize {
self.node_graph.borrow().node_count()
}
pub fn modify<F>(&mut self, mut f: F) -> Result<(), GraphFreezeError> pub fn modify<F>(&mut self, mut f: F) -> Result<(), GraphFreezeError>
where where
F: FnMut(&mut Graph<Output>) -> (), F: FnMut(&mut Graph<Output>) -> (),
@ -172,6 +181,7 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
} }
} }
#[derive(Clone)]
pub struct Input<T> { pub struct Input<T> {
node_idx: NodeIndex<u32>, node_idx: NodeIndex<u32>,
value: Rc<RefCell<Option<T>>>, value: Rc<RefCell<Option<T>>>,
@ -498,8 +508,8 @@ mod tests {
let a = graph.add_rule(DeferredInput(Rc::clone(&a_input))); let a = graph.add_rule(DeferredInput(Rc::clone(&a_input)));
let b_input = Rc::new(RefCell::new(Some(a))); let b_input = Rc::new(RefCell::new(Some(a)));
let b = graph.add_rule(DeferredInput(b_input)); let b = graph.add_rule(DeferredInput(b_input));
*a_input.borrow_mut() = Some(b); *a_input.borrow_mut() = Some(b.clone());
graph.set_output(Inc(0)); graph.set_output(Double(b));
match graph.freeze() { match graph.freeze() {
Err(GraphFreezeError::Cyclic(_)) => (), Err(GraphFreezeError::Cyclic(_)) => (),
Err(e) => assert!(false, "unexpected error {:?}", e), Err(e) => assert!(false, "unexpected error {:?}", e),
@ -519,5 +529,6 @@ mod tests {
}) })
.expect("modify"); .expect("modify");
assert_eq!(frozen.evaluate(), 2); assert_eq!(frozen.evaluate(), 2);
assert_eq!(frozen.node_count(), 1);
} }
} }

47
crates/graph/src/util.rs Normal file
View File

@ -0,0 +1,47 @@
use petgraph::{
stable_graph::{IndexType, NodeIndex, StableGraph},
unionfind::UnionFind,
visit::{EdgeRef, IntoEdgeReferences, NodeIndexable},
EdgeType,
};
pub fn remove_nodes_not_connected_to<N, E, Ty: EdgeType, Ix: IndexType>(
g: &mut StableGraph<N, E, Ty, Ix>,
node: NodeIndex<Ix>,
) {
// based on petgraph's connected_components
let mut vertex_sets = UnionFind::new(g.node_bound());
for edge in g.edge_references() {
vertex_sets.union(g.to_index(edge.source()), g.to_index(edge.target()));
}
let to_remove = g
.node_indices()
.filter(|other_idx| !vertex_sets.equiv(g.to_index(node), g.to_index(*other_idx)))
.collect::<Vec<_>>();
for idx in to_remove {
g.remove_node(idx);
}
}
#[cfg(test)]
mod tests {
use petgraph::stable_graph::StableGraph;
#[test]
fn remove_nodes_not_connected_to() {
let mut graph: StableGraph<(), ()> = Default::default();
let a = graph.add_node(());
let b = graph.add_node(());
let c = graph.add_node(());
let d = graph.add_node(());
let e = graph.add_node(());
graph.extend_with_edges(&[(a, b), (a, c), (d, e)]);
super::remove_nodes_not_connected_to(&mut graph, a);
assert_eq!(graph.node_count(), 3);
assert!(graph.contains_node(a));
assert!(graph.contains_node(b));
assert!(graph.contains_node(c));
assert!(!graph.contains_node(d));
assert!(!graph.contains_node(e));
}
}