Modify graph
This commit is contained in:
parent
67ddf2f254
commit
29838e2113
@ -1,12 +1,9 @@
|
||||
use petgraph::{
|
||||
graph::{DiGraph, NodeIndex},
|
||||
visit::EdgeRef,
|
||||
};
|
||||
use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef};
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::rc::Rc;
|
||||
use std::{any::Any, collections::VecDeque};
|
||||
|
||||
type NodeGraph = DiGraph<ErasedNode, (), u32>;
|
||||
type NodeGraph = StableDiGraph<ErasedNode, (), u32>;
|
||||
|
||||
pub struct Graph<Output> {
|
||||
// we treat this as a StableGraph, since nodes are never removed
|
||||
@ -18,14 +15,13 @@ pub struct Graph<Output> {
|
||||
impl<Output: Clone + 'static> Graph<Output> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
node_graph: Rc::new(RefCell::new(DiGraph::new())),
|
||||
node_graph: Rc::new(RefCell::new(StableDiGraph::new())),
|
||||
output: None,
|
||||
output_type: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_output<R: Rule<Output> + 'static>(&mut self, rule: R) {
|
||||
assert!(self.output.is_none(), "cannot replace graph output");
|
||||
let input = self.add_rule(rule);
|
||||
self.output = Some(input.node_idx);
|
||||
}
|
||||
@ -70,7 +66,7 @@ impl<Output: Clone + 'static> Graph<Output> {
|
||||
}
|
||||
|
||||
let graph = self.node_graph.borrow();
|
||||
let indices = graph.node_indices();
|
||||
let indices = graph.node_indices().collect::<Vec<_>>();
|
||||
drop(graph);
|
||||
let mut edges = vec![];
|
||||
for idx in indices {
|
||||
@ -84,6 +80,8 @@ impl<Output: Clone + 'static> Graph<Output> {
|
||||
self.node_graph.borrow_mut().add_edge(source, dest, ());
|
||||
}
|
||||
|
||||
// TODO: remove nodes not connected to output
|
||||
|
||||
let sorted = petgraph::algo::toposort(&*self.node_graph.borrow(), None);
|
||||
if let Err(_cycle) = sorted {
|
||||
// TODO: actually build a vec describing the cycle path for debugging
|
||||
@ -148,6 +146,30 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
|
||||
let node = &graph[self.output].expect_type::<Output>();
|
||||
node.value()
|
||||
}
|
||||
|
||||
pub fn modify<F>(&mut self, mut f: F) -> Result<(), GraphFreezeError>
|
||||
where
|
||||
F: FnMut(&mut Graph<Output>) -> (),
|
||||
{
|
||||
// Clear the edges before modifying so that re-freezing results in a graph with up-to-date edges.
|
||||
let mut graph = self.node_graph.borrow_mut();
|
||||
graph.clear_edges();
|
||||
drop(graph);
|
||||
|
||||
let mut graph = Graph {
|
||||
node_graph: Rc::clone(&self.node_graph),
|
||||
output: Some(self.output),
|
||||
output_type: std::marker::PhantomData,
|
||||
};
|
||||
f(&mut graph);
|
||||
match graph.freeze() {
|
||||
Ok(g) => {
|
||||
*self = g;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Input<T> {
|
||||
@ -484,4 +506,18 @@ mod tests {
|
||||
Ok(_) => assert!(false, "shouldn't have frozen graph"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn modify_graph() {
|
||||
let mut graph = Graph::new();
|
||||
graph.set_output(ConstantRule(1));
|
||||
let mut frozen = graph.freeze().unwrap();
|
||||
assert_eq!(frozen.evaluate(), 1);
|
||||
frozen
|
||||
.modify(|g| {
|
||||
g.set_output(ConstantRule(2));
|
||||
})
|
||||
.expect("modify");
|
||||
assert_eq!(frozen.evaluate(), 2);
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user