From 29838e21135f5e08b6a0f84895a3e2c1c1815ac7 Mon Sep 17 00:00:00 2001 From: Shadowfacts Date: Tue, 29 Oct 2024 14:18:53 -0400 Subject: [PATCH] Modify graph --- src/graph/mod.rs | 52 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 6671c8a..8fa324a 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1,12 +1,9 @@ -use petgraph::{ - graph::{DiGraph, NodeIndex}, - visit::EdgeRef, -}; +use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef}; use std::cell::{Cell, RefCell}; use std::rc::Rc; use std::{any::Any, collections::VecDeque}; -type NodeGraph = DiGraph; +type NodeGraph = StableDiGraph; pub struct Graph { // we treat this as a StableGraph, since nodes are never removed @@ -18,14 +15,13 @@ pub struct Graph { impl Graph { pub fn new() -> Self { Self { - node_graph: Rc::new(RefCell::new(DiGraph::new())), + node_graph: Rc::new(RefCell::new(StableDiGraph::new())), output: None, output_type: std::marker::PhantomData, } } pub fn set_output + 'static>(&mut self, rule: R) { - assert!(self.output.is_none(), "cannot replace graph output"); let input = self.add_rule(rule); self.output = Some(input.node_idx); } @@ -70,7 +66,7 @@ impl Graph { } let graph = self.node_graph.borrow(); - let indices = graph.node_indices(); + let indices = graph.node_indices().collect::>(); drop(graph); let mut edges = vec![]; for idx in indices { @@ -84,6 +80,8 @@ impl Graph { self.node_graph.borrow_mut().add_edge(source, dest, ()); } + // TODO: remove nodes not connected to output + let sorted = petgraph::algo::toposort(&*self.node_graph.borrow(), None); if let Err(_cycle) = sorted { // TODO: actually build a vec describing the cycle path for debugging @@ -148,6 +146,30 @@ impl FrozenGraph { let node = &graph[self.output].expect_type::(); node.value() } + + pub fn modify(&mut self, mut f: F) -> Result<(), GraphFreezeError> + where + F: FnMut(&mut Graph) -> (), + { + // Clear the edges before modifying so that re-freezing results in a graph with up-to-date edges. + let mut graph = self.node_graph.borrow_mut(); + graph.clear_edges(); + drop(graph); + + let mut graph = Graph { + node_graph: Rc::clone(&self.node_graph), + output: Some(self.output), + output_type: std::marker::PhantomData, + }; + f(&mut graph); + match graph.freeze() { + Ok(g) => { + *self = g; + Ok(()) + } + Err(e) => Err(e), + } + } } pub struct Input { @@ -484,4 +506,18 @@ mod tests { Ok(_) => assert!(false, "shouldn't have frozen graph"), } } + + #[test] + fn modify_graph() { + let mut graph = Graph::new(); + graph.set_output(ConstantRule(1)); + let mut frozen = graph.freeze().unwrap(); + assert_eq!(frozen.evaluate(), 1); + frozen + .modify(|g| { + g.set_output(ConstantRule(2)); + }) + .expect("modify"); + assert_eq!(frozen.evaluate(), 2); + } }