Modify graph

This commit is contained in:
Shadowfacts 2024-10-29 14:18:53 -04:00
parent 67ddf2f254
commit 29838e2113

View File

@ -1,12 +1,9 @@
use petgraph::{
graph::{DiGraph, NodeIndex},
visit::EdgeRef,
};
use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef};
use std::cell::{Cell, RefCell};
use std::rc::Rc;
use std::{any::Any, collections::VecDeque};
type NodeGraph = DiGraph<ErasedNode, (), u32>;
type NodeGraph = StableDiGraph<ErasedNode, (), u32>;
pub struct Graph<Output> {
// we treat this as a StableGraph, since nodes are never removed
@ -18,14 +15,13 @@ pub struct Graph<Output> {
impl<Output: Clone + 'static> Graph<Output> {
pub fn new() -> Self {
Self {
node_graph: Rc::new(RefCell::new(DiGraph::new())),
node_graph: Rc::new(RefCell::new(StableDiGraph::new())),
output: None,
output_type: std::marker::PhantomData,
}
}
pub fn set_output<R: Rule<Output> + 'static>(&mut self, rule: R) {
assert!(self.output.is_none(), "cannot replace graph output");
let input = self.add_rule(rule);
self.output = Some(input.node_idx);
}
@ -70,7 +66,7 @@ impl<Output: Clone + 'static> Graph<Output> {
}
let graph = self.node_graph.borrow();
let indices = graph.node_indices();
let indices = graph.node_indices().collect::<Vec<_>>();
drop(graph);
let mut edges = vec![];
for idx in indices {
@ -84,6 +80,8 @@ impl<Output: Clone + 'static> Graph<Output> {
self.node_graph.borrow_mut().add_edge(source, dest, ());
}
// TODO: remove nodes not connected to output
let sorted = petgraph::algo::toposort(&*self.node_graph.borrow(), None);
if let Err(_cycle) = sorted {
// TODO: actually build a vec describing the cycle path for debugging
@ -148,6 +146,30 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
let node = &graph[self.output].expect_type::<Output>();
node.value()
}
pub fn modify<F>(&mut self, mut f: F) -> Result<(), GraphFreezeError>
where
F: FnMut(&mut Graph<Output>) -> (),
{
// Clear the edges before modifying so that re-freezing results in a graph with up-to-date edges.
let mut graph = self.node_graph.borrow_mut();
graph.clear_edges();
drop(graph);
let mut graph = Graph {
node_graph: Rc::clone(&self.node_graph),
output: Some(self.output),
output_type: std::marker::PhantomData,
};
f(&mut graph);
match graph.freeze() {
Ok(g) => {
*self = g;
Ok(())
}
Err(e) => Err(e),
}
}
}
pub struct Input<T> {
@ -484,4 +506,18 @@ mod tests {
Ok(_) => assert!(false, "shouldn't have frozen graph"),
}
}
#[test]
fn modify_graph() {
let mut graph = Graph::new();
graph.set_output(ConstantRule(1));
let mut frozen = graph.freeze().unwrap();
assert_eq!(frozen.evaluate(), 1);
frozen
.modify(|g| {
g.set_output(ConstantRule(2));
})
.expect("modify");
assert_eq!(frozen.evaluate(), 2);
}
}