Dynamic rules
This commit is contained in:
parent
9cb6a8c6ce
commit
d92ebf11b2
@ -1,8 +1,8 @@
|
|||||||
use crate::node::{
|
use crate::node::{
|
||||||
AsyncConstNode, AsyncRuleNode, ConstNode, ErasedNode, InvalidatableConstNode, Node, NodeValue,
|
AsyncConstNode, AsyncDynamicRuleNode, AsyncRuleNode, ConstNode, DynamicRuleNode, ErasedNode,
|
||||||
RuleNode,
|
InvalidatableConstNode, Node, NodeValue, RuleNode,
|
||||||
};
|
};
|
||||||
use crate::rule::{AsyncRule, Input, Rule};
|
use crate::rule::{AsyncDynamicRule, AsyncRule, DynamicInput, DynamicRule, Input, Rule};
|
||||||
use crate::synchronicity::{Asynchronous, Synchronicity, Synchronous};
|
use crate::synchronicity::{Asynchronous, Synchronicity, Synchronous};
|
||||||
use crate::util;
|
use crate::util;
|
||||||
use crate::{Graph, InvalidationSignal, NodeGraph, NodeId, ValueInvalidationSignal};
|
use crate::{Graph, InvalidationSignal, NodeGraph, NodeId, ValueInvalidationSignal};
|
||||||
@ -73,7 +73,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
|||||||
let erased = ErasedNode::new(node);
|
let erased = ErasedNode::new(node);
|
||||||
let idx = self.node_graph.borrow_mut().add_node(erased);
|
let idx = self.node_graph.borrow_mut().add_node(erased);
|
||||||
Input {
|
Input {
|
||||||
node_idx: idx,
|
node_idx: Rc::new(Cell::new(Some(idx))),
|
||||||
value,
|
value,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,19 +174,42 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn make_invalidation_signal<V>(&self, input: &Input<V>) -> InvalidationSignal {
|
fn make_invalidation_signal<V>(&self, input: &Input<V>) -> InvalidationSignal {
|
||||||
let node_idx = input.node_idx;
|
let node_idx = Rc::clone(&input.node_idx);
|
||||||
let graph = Rc::clone(&self.node_graph);
|
let graph = Rc::clone(&self.node_graph);
|
||||||
let graph_is_valid = Rc::clone(&self.is_valid);
|
let graph_is_valid = Rc::clone(&self.is_valid);
|
||||||
InvalidationSignal {
|
InvalidationSignal {
|
||||||
do_invalidate: Rc::new(Box::new(move || {
|
do_invalidate: Rc::new(Box::new(move || {
|
||||||
graph_is_valid.set(false);
|
graph_is_valid.set(false);
|
||||||
let mut graph = graph.borrow_mut();
|
let mut graph = graph.borrow_mut();
|
||||||
let node = &mut graph[node_idx];
|
let node = &mut graph[node_idx.get().unwrap()];
|
||||||
node.invalidate();
|
node.invalidate();
|
||||||
})),
|
})),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Adds a node to the graph whose output is additional nodes produced by the given rule.
|
||||||
|
pub fn add_dynamic_rule<R>(&mut self, rule: R) -> DynamicInput<R::ChildOutput>
|
||||||
|
where
|
||||||
|
R: DynamicRule,
|
||||||
|
{
|
||||||
|
let input = self.add_node(DynamicRuleNode::<R, R::ChildOutput, S>::new(rule));
|
||||||
|
DynamicInput { input }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Adds an externally-invalidatable node to the graph whose output is additional
|
||||||
|
/// nodes produced by the given rule.
|
||||||
|
pub fn add_invalidatable_dynamic_rule<R>(
|
||||||
|
&mut self,
|
||||||
|
rule: R,
|
||||||
|
) -> (DynamicInput<R::ChildOutput>, InvalidationSignal)
|
||||||
|
where
|
||||||
|
R: DynamicRule,
|
||||||
|
{
|
||||||
|
let input = self.add_dynamic_rule(rule);
|
||||||
|
let signal = self.make_invalidation_signal(&input.input);
|
||||||
|
(input, signal)
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a graph from this builder, consuming the builder.
|
/// Creates a graph from this builder, consuming the builder.
|
||||||
///
|
///
|
||||||
/// To successfully build a graph, there must be an output node set (using either
|
/// To successfully build a graph, there must be an output node set (using either
|
||||||
@ -217,7 +240,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
|||||||
graph.add_edge(source, dest, ());
|
graph.add_edge(source, dest, ());
|
||||||
}
|
}
|
||||||
|
|
||||||
util::remove_nodes_not_connected_to(&mut *graph, output.node_idx);
|
util::remove_nodes_not_connected_to(&mut *graph, output.node_idx.get().unwrap());
|
||||||
|
|
||||||
drop(graph);
|
drop(graph);
|
||||||
|
|
||||||
@ -319,6 +342,29 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
|
|||||||
let signal = self.make_invalidation_signal(&input);
|
let signal = self.make_invalidation_signal(&input);
|
||||||
(input, signal)
|
(input, signal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Adds a node to the graph whose output is additional nodes produced asynchronously by the given rule.
|
||||||
|
pub fn add_async_dynamic_rule<R>(&mut self, rule: R) -> DynamicInput<R::ChildOutput>
|
||||||
|
where
|
||||||
|
R: AsyncDynamicRule,
|
||||||
|
{
|
||||||
|
let input = self.add_node(AsyncDynamicRuleNode::<R, R::ChildOutput>::new(rule));
|
||||||
|
DynamicInput { input }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Adds an externally-invalidatable node to the graph whose output is additional nodes produced
|
||||||
|
/// asynchronously by the given rule.
|
||||||
|
pub fn add_invalidatable_async_dynamic_rule<R>(
|
||||||
|
&mut self,
|
||||||
|
rule: R,
|
||||||
|
) -> (DynamicInput<R::ChildOutput>, InvalidationSignal)
|
||||||
|
where
|
||||||
|
R: AsyncDynamicRule,
|
||||||
|
{
|
||||||
|
let input = self.add_async_dynamic_rule(rule);
|
||||||
|
let signal = self.make_invalidation_signal(&input.input);
|
||||||
|
(input, signal)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A reason why a [`GraphBuilder`] can fail to build a graph.
|
/// A reason why a [`GraphBuilder`] can fail to build a graph.
|
||||||
@ -383,8 +429,18 @@ mod tests {
|
|||||||
builder.set_output(Double::new(b.clone()));
|
builder.set_output(Double::new(b.clone()));
|
||||||
match builder.build() {
|
match builder.build() {
|
||||||
Err(super::BuildGraphError::Cycle(cycle)) => {
|
Err(super::BuildGraphError::Cycle(cycle)) => {
|
||||||
let a_start = cycle == vec![a.node_idx, b.node_idx, a.node_idx];
|
let a_start = cycle
|
||||||
let b_start = cycle == vec![b.node_idx, a.node_idx, b.node_idx];
|
== vec![
|
||||||
|
a.node_idx.get().unwrap(),
|
||||||
|
b.node_idx.get().unwrap(),
|
||||||
|
a.node_idx.get().unwrap(),
|
||||||
|
];
|
||||||
|
let b_start = cycle
|
||||||
|
== vec![
|
||||||
|
b.node_idx.get().unwrap(),
|
||||||
|
a.node_idx.get().unwrap(),
|
||||||
|
b.node_idx.get().unwrap(),
|
||||||
|
];
|
||||||
// either is a permisisble way of describing the cycle
|
// either is a permisisble way of describing the cycle
|
||||||
assert!(a_start || b_start);
|
assert!(a_start || b_start);
|
||||||
}
|
}
|
||||||
|
@ -49,10 +49,10 @@ pub mod synchronicity;
|
|||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
use builder::{BuildGraphError, GraphBuilder};
|
use builder::{BuildGraphError, GraphBuilder};
|
||||||
use node::{ErasedNode, NodeValue};
|
use node::{ErasedNode, NodeUpdateContext, NodeValue};
|
||||||
use petgraph::visit::{IntoEdgeReferences, IntoNodeReferences, NodeIndexable, NodeRef};
|
use petgraph::visit::{IntoEdgeReferences, IntoNodeReferences, NodeIndexable, NodeRef};
|
||||||
use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef};
|
use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef};
|
||||||
use rule::{AsyncRule, Input, InputVisitor, Rule};
|
use rule::{Input, InputVisitor};
|
||||||
use std::cell::{Cell, RefCell};
|
use std::cell::{Cell, RefCell};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
@ -127,7 +127,15 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
|||||||
/// Because building a graph can fail and this method mutates the underlying graph, it takes
|
/// Because building a graph can fail and this method mutates the underlying graph, it takes
|
||||||
/// ownership of the current graph to prevent the graph being left in an invalid state.
|
/// ownership of the current graph to prevent the graph being left in an invalid state.
|
||||||
/// It returns either the new, modified graph or an error.
|
/// It returns either the new, modified graph or an error.
|
||||||
pub fn modify<F>(mut self, mut f: F) -> Result<Self, BuildGraphError>
|
pub fn modify<F>(mut self, f: F) -> Result<Self, BuildGraphError>
|
||||||
|
where
|
||||||
|
F: FnMut(&mut GraphBuilder<O, S>) -> (),
|
||||||
|
{
|
||||||
|
self._modify(f)?;
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _modify<F>(&mut self, mut f: F) -> Result<(), BuildGraphError>
|
||||||
where
|
where
|
||||||
F: FnMut(&mut GraphBuilder<O, S>) -> (),
|
F: FnMut(&mut GraphBuilder<O, S>) -> (),
|
||||||
{
|
{
|
||||||
@ -142,12 +150,12 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
|||||||
}
|
}
|
||||||
drop(graph);
|
drop(graph);
|
||||||
|
|
||||||
let old_output = self.output.node_idx;
|
let old_output = self.output.node_idx.get();
|
||||||
|
|
||||||
// Modify
|
// Modify
|
||||||
let mut builder = self.into_builder();
|
let mut builder = self.to_builder();
|
||||||
f(&mut builder);
|
f(&mut builder);
|
||||||
self = builder.build()?;
|
*self = builder.build()?;
|
||||||
|
|
||||||
// Any new inboud edges invalidate their target nodes.
|
// Any new inboud edges invalidate their target nodes.
|
||||||
let mut graph = self.node_graph.borrow_mut();
|
let mut graph = self.node_graph.borrow_mut();
|
||||||
@ -164,7 +172,7 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
|||||||
}
|
}
|
||||||
// Edge case: if the only node in the graph is the output node, and it's replaced in the modify block,
|
// Edge case: if the only node in the graph is the output node, and it's replaced in the modify block,
|
||||||
// there are no edges but we still need to invalidate.
|
// there are no edges but we still need to invalidate.
|
||||||
if !to_invalidate.is_empty() || self.output.node_idx != old_output {
|
if !to_invalidate.is_empty() || self.output.node_idx.get() != old_output {
|
||||||
self.is_valid.set(false);
|
self.is_valid.set(false);
|
||||||
for idx in to_invalidate {
|
for idx in to_invalidate {
|
||||||
let node = &mut graph[idx];
|
let node = &mut graph[idx];
|
||||||
@ -173,13 +181,17 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
drop(graph);
|
drop(graph);
|
||||||
Ok(self)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert this graph back into a builder for further modifications.
|
/// Convert this graph back into a builder for further modifications.
|
||||||
///
|
///
|
||||||
/// Returns a builder with the same output and synchronicity types.
|
/// Returns a builder with the same output and synchronicity types.
|
||||||
pub fn into_builder(self) -> GraphBuilder<O, S> {
|
pub fn into_builder(self) -> GraphBuilder<O, S> {
|
||||||
|
self.to_builder()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_builder(&self) -> GraphBuilder<O, S> {
|
||||||
// Clear the edges before modifying so that rebuilding results in a graph with up-to-date edges.
|
// Clear the edges before modifying so that rebuilding results in a graph with up-to-date edges.
|
||||||
let mut graph = self.node_graph.borrow_mut();
|
let mut graph = self.node_graph.borrow_mut();
|
||||||
graph.clear_edges();
|
graph.clear_edges();
|
||||||
@ -232,7 +244,7 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
|||||||
for node in self.0.node_references() {
|
for node in self.0.node_references() {
|
||||||
let id = self.0.to_index(node.id());
|
let id = self.0.to_index(node.id());
|
||||||
let label = Escaped(node.weight());
|
let label = Escaped(node.weight());
|
||||||
writeln!(f, "\t{id} [label =\"{label:?} (id={id})\"]")?;
|
writeln!(f, "\t{id} [label=\"{label:?} (id={id})\"]")?;
|
||||||
}
|
}
|
||||||
for edge in self.0.edge_references() {
|
for edge in self.0.edge_references() {
|
||||||
let source = self.0.to_index(edge.source());
|
let source = self.0.to_index(edge.source());
|
||||||
@ -250,13 +262,51 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
|||||||
impl<O: 'static> Graph<O, Synchronous> {
|
impl<O: 'static> Graph<O, Synchronous> {
|
||||||
fn update_invalid_nodes(&mut self) {
|
fn update_invalid_nodes(&mut self) {
|
||||||
let mut graph = self.node_graph.borrow_mut();
|
let mut graph = self.node_graph.borrow_mut();
|
||||||
for &idx in self.sorted_nodes.iter() {
|
let mut i = 0;
|
||||||
|
while i < self.sorted_nodes.len() {
|
||||||
|
let idx = self.sorted_nodes[i];
|
||||||
let node = &mut graph[idx];
|
let node = &mut graph[idx];
|
||||||
if !node.is_valid() {
|
if !node.is_valid() {
|
||||||
// Update this node
|
// Update this node
|
||||||
let value_changed = node.update();
|
let mut ctx = NodeUpdateContext::new();
|
||||||
|
node.update(&mut ctx);
|
||||||
|
|
||||||
if value_changed {
|
let mut nodes_changed = false;
|
||||||
|
for idx_to_remove in ctx.removed_nodes {
|
||||||
|
assert!(
|
||||||
|
idx_to_remove != idx,
|
||||||
|
"cannot remove node curently being evaluated"
|
||||||
|
);
|
||||||
|
let (index_to_remove_in_sorted, _) = self
|
||||||
|
.sorted_nodes
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, idx)| **idx == idx_to_remove)
|
||||||
|
.expect("removed node must have been already added");
|
||||||
|
assert!(
|
||||||
|
index_to_remove_in_sorted > i,
|
||||||
|
"cannot remove already evaluated node"
|
||||||
|
);
|
||||||
|
graph.remove_node(idx_to_remove);
|
||||||
|
self.sorted_nodes.remove(index_to_remove_in_sorted);
|
||||||
|
nodes_changed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (added_node, id_cell) in ctx.added_nodes {
|
||||||
|
let id = graph.add_node(added_node);
|
||||||
|
id_cell.set(Some(id));
|
||||||
|
nodes_changed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if nodes_changed {
|
||||||
|
// Update the graph before invalidating downstream nodes.
|
||||||
|
drop(graph);
|
||||||
|
self._modify(|_| {})
|
||||||
|
.expect("modifying graph during evaluation must produce valid graph");
|
||||||
|
graph = self.node_graph.borrow_mut();
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.invalidate_dependent_nodes {
|
||||||
// Invalidate any downstream nodes (which we know we haven't visited yet, because
|
// Invalidate any downstream nodes (which we know we haven't visited yet, because
|
||||||
// we're iterating over a topological sort of the graph).
|
// we're iterating over a topological sort of the graph).
|
||||||
let dependents = graph
|
let dependents = graph
|
||||||
@ -270,14 +320,25 @@ impl<O: 'static> Graph<O, Synchronous> {
|
|||||||
dependent.invalidate();
|
dependent.invalidate();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nodes_changed {
|
||||||
|
// If we added/removed nodes, the sorted order has changed, so start evaluating
|
||||||
|
// from the beginning, in case of changes before i.
|
||||||
|
i = 0;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
i += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consistency check: after updating in the topological sort order, we should be left with
|
// Consistency check: after updating in the topological sort order, we should be left with
|
||||||
// no invalid nodes
|
// no invalid nodes.
|
||||||
debug_assert!(self
|
debug_assert!(self
|
||||||
.sorted_nodes
|
.sorted_nodes
|
||||||
.iter()
|
.iter()
|
||||||
.all(|&idx| { (&graph[idx]).is_valid() }));
|
.all(|&idx| { (&graph[idx]).is_valid() }));
|
||||||
|
|
||||||
self.is_valid.set(true);
|
self.is_valid.set(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,13 +361,51 @@ impl<O: 'static> Graph<O, Asynchronous> {
|
|||||||
async fn update_invalid_nodes(&mut self) {
|
async fn update_invalid_nodes(&mut self) {
|
||||||
// TODO: consider whether this can be done in parallel to any degree.
|
// TODO: consider whether this can be done in parallel to any degree.
|
||||||
let mut graph = self.node_graph.borrow_mut();
|
let mut graph = self.node_graph.borrow_mut();
|
||||||
for &idx in self.sorted_nodes.iter() {
|
let mut i = 0;
|
||||||
|
while i < self.sorted_nodes.len() {
|
||||||
|
let idx = self.sorted_nodes[i];
|
||||||
let node = &mut graph[idx];
|
let node = &mut graph[idx];
|
||||||
if !node.is_valid() {
|
if !node.is_valid() {
|
||||||
// Update this node
|
// Update this node
|
||||||
let value_changed = node.update().await;
|
let mut ctx = NodeUpdateContext::new();
|
||||||
|
node.update(&mut ctx).await;
|
||||||
|
|
||||||
if value_changed {
|
let mut nodes_changed = false;
|
||||||
|
for idx_to_remove in ctx.removed_nodes {
|
||||||
|
assert!(
|
||||||
|
idx_to_remove != idx,
|
||||||
|
"cannot remove node curently being evaluated"
|
||||||
|
);
|
||||||
|
let (index_to_remove_in_sorted, _) = self
|
||||||
|
.sorted_nodes
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, idx)| **idx == idx_to_remove)
|
||||||
|
.expect("removed node must have been already added");
|
||||||
|
assert!(
|
||||||
|
index_to_remove_in_sorted > i,
|
||||||
|
"cannot remove already evaluated node"
|
||||||
|
);
|
||||||
|
graph.remove_node(idx_to_remove);
|
||||||
|
self.sorted_nodes.remove(index_to_remove_in_sorted);
|
||||||
|
nodes_changed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (added_node, id_cell) in ctx.added_nodes {
|
||||||
|
let id = graph.add_node(added_node);
|
||||||
|
id_cell.set(Some(id));
|
||||||
|
nodes_changed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if nodes_changed {
|
||||||
|
// Update the graph before invalidating downstream nodes.
|
||||||
|
drop(graph);
|
||||||
|
self._modify(|_| {})
|
||||||
|
.expect("modifying graph during evaluation must produce valid graph");
|
||||||
|
graph = self.node_graph.borrow_mut();
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.invalidate_dependent_nodes {
|
||||||
// Invalidate any downstream nodes (which we know we haven't visited yet, because
|
// Invalidate any downstream nodes (which we know we haven't visited yet, because
|
||||||
// we're iterating over a topological sort of the graph).
|
// we're iterating over a topological sort of the graph).
|
||||||
let dependents = graph
|
let dependents = graph
|
||||||
@ -320,14 +419,25 @@ impl<O: 'static> Graph<O, Asynchronous> {
|
|||||||
dependent.invalidate();
|
dependent.invalidate();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nodes_changed {
|
||||||
|
// If we added/removed nodes, the sorted order has changed, so start evaluating
|
||||||
|
// from the beginning, in case of changes before i.
|
||||||
|
i = 0;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
i += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consistency check: after updating in the topological sort order, we should be left with
|
// Consistency check: after updating in the topological sort order, we should be left with
|
||||||
// no invalid nodes
|
// no invalid nodes
|
||||||
debug_assert!(self
|
debug_assert!(self
|
||||||
.sorted_nodes
|
.sorted_nodes
|
||||||
.iter()
|
.iter()
|
||||||
.all(|&idx| { (&graph[idx]).is_valid() }));
|
.all(|&idx| { (&graph[idx]).is_valid() }));
|
||||||
|
|
||||||
self.is_valid.set(true);
|
self.is_valid.set(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -420,7 +530,9 @@ impl<V> Clone for ValueInvalidationSignal<V> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::rule::{ConstantRule, InputVisitable};
|
use crate::rule::{
|
||||||
|
AsyncDynamicRule, AsyncRule, ConstantRule, DynamicInput, DynamicRule, InputVisitable, Rule,
|
||||||
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn rule_output_with_no_inputs() {
|
fn rule_output_with_no_inputs() {
|
||||||
@ -711,13 +823,108 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
graph.as_dot_string(),
|
graph.as_dot_string(),
|
||||||
r#"digraph {
|
r#"digraph {
|
||||||
0 [label ="ConstNode<i32> (id=0)"]
|
0 [label="ConstNode<i32> (id=0)"]
|
||||||
1 [label ="ConstNode<i32> (id=1)"]
|
1 [label="ConstNode<i32> (id=1)"]
|
||||||
2 [label ="RuleNode<compute_graph::tests::graphviz::AddWithLabel>(test) (id=2)"]
|
2 [label="RuleNode<AddWithLabel>(test) (id=2)"]
|
||||||
0 -> 2 []
|
0 -> 2 []
|
||||||
1 -> 2 []
|
1 -> 2 []
|
||||||
}
|
}
|
||||||
"#
|
"#
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dynamic_rule() {
|
||||||
|
let mut builder = GraphBuilder::new();
|
||||||
|
let (count, set_count) = builder.add_invalidatable_value(1);
|
||||||
|
struct CountUpTo(Input<i32>, Vec<Input<i32>>);
|
||||||
|
impl InputVisitable for CountUpTo {
|
||||||
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
|
visitor.visit(&self.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl DynamicRule for CountUpTo {
|
||||||
|
type ChildOutput = i32;
|
||||||
|
fn evaluate(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut impl rule::DynamicRuleContext,
|
||||||
|
) -> Vec<Input<Self::ChildOutput>> {
|
||||||
|
let count = *self.0.value();
|
||||||
|
assert!(count >= self.1.len() as i32);
|
||||||
|
while (self.1.len() as i32) < count {
|
||||||
|
self.1
|
||||||
|
.push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1)));
|
||||||
|
}
|
||||||
|
self.1.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let all_inputs = builder.add_dynamic_rule(CountUpTo(count, vec![]));
|
||||||
|
struct Sum(DynamicInput<i32>);
|
||||||
|
impl InputVisitable for Sum {
|
||||||
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
|
visitor.visit_dynamic(&self.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Rule for Sum {
|
||||||
|
type Output = i32;
|
||||||
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
|
self.0.value().inputs.iter().map(|i| *i.value()).sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.set_output(Sum(all_inputs));
|
||||||
|
let mut graph = builder.build().unwrap();
|
||||||
|
assert_eq!(*graph.evaluate(), 1);
|
||||||
|
set_count.set_value(2);
|
||||||
|
assert_eq!(*graph.evaluate(), 3);
|
||||||
|
set_count.set_value(4);
|
||||||
|
assert_eq!(*graph.evaluate(), 10);
|
||||||
|
println!("{}", graph.as_dot_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn async_dynamic_rule() {
|
||||||
|
let mut builder = GraphBuilder::new_async();
|
||||||
|
let (count, set_count) = builder.add_invalidatable_value(1);
|
||||||
|
struct CountUpTo(Input<i32>, Vec<Input<i32>>);
|
||||||
|
impl InputVisitable for CountUpTo {
|
||||||
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
|
visitor.visit(&self.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl AsyncDynamicRule for CountUpTo {
|
||||||
|
type ChildOutput = i32;
|
||||||
|
async fn evaluate<'a>(
|
||||||
|
&'a mut self,
|
||||||
|
ctx: &'a mut impl rule::AsyncDynamicRuleContext,
|
||||||
|
) -> Vec<Input<Self::ChildOutput>> {
|
||||||
|
let count = *self.0.value();
|
||||||
|
assert!(count >= self.1.len() as i32);
|
||||||
|
while (self.1.len() as i32) < count {
|
||||||
|
self.1
|
||||||
|
.push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1)));
|
||||||
|
}
|
||||||
|
self.1.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let all_inputs = builder.add_async_dynamic_rule(CountUpTo(count, vec![]));
|
||||||
|
struct Sum(DynamicInput<i32>);
|
||||||
|
impl InputVisitable for Sum {
|
||||||
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
|
visitor.visit_dynamic(&self.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Rule for Sum {
|
||||||
|
type Output = i32;
|
||||||
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
|
self.0.value().inputs.iter().map(|i| *i.value()).sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.set_output(Sum(all_inputs));
|
||||||
|
let mut graph = builder.build().unwrap();
|
||||||
|
assert_eq!(*graph.evaluate_async().await, 1);
|
||||||
|
set_count.set_value(2);
|
||||||
|
assert_eq!(*graph.evaluate_async().await, 3);
|
||||||
|
set_count.set_value(4);
|
||||||
|
assert_eq!(*graph.evaluate_async().await, 10);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
|
use crate::rule::{
|
||||||
|
AsyncDynamicRule, AsyncDynamicRuleContext, AsyncRule, DynamicInput, DynamicRule,
|
||||||
|
DynamicRuleContext, InputVisitable, Rule,
|
||||||
|
};
|
||||||
use crate::synchronicity::{Asynchronous, Synchronicity};
|
use crate::synchronicity::{Asynchronous, Synchronicity};
|
||||||
use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous};
|
use crate::{Input, InputVisitor, NodeId, Synchronous};
|
||||||
use quote::ToTokens;
|
use quote::ToTokens;
|
||||||
use std::any::Any;
|
use std::any::Any;
|
||||||
use std::cell::RefCell;
|
use std::cell::{Cell, RefCell};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
@ -11,10 +15,35 @@ pub(crate) struct ErasedNode<Synch: Synchronicity> {
|
|||||||
is_valid: Box<dyn Fn(&Box<dyn Any>) -> bool>,
|
is_valid: Box<dyn Fn(&Box<dyn Any>) -> bool>,
|
||||||
invalidate: Box<dyn Fn(&mut Box<dyn Any>) -> ()>,
|
invalidate: Box<dyn Fn(&mut Box<dyn Any>) -> ()>,
|
||||||
visit_inputs: Box<dyn Fn(&Box<dyn Any>, &mut dyn FnMut(NodeId) -> ()) -> ()>,
|
visit_inputs: Box<dyn Fn(&Box<dyn Any>, &mut dyn FnMut(NodeId) -> ()) -> ()>,
|
||||||
update: Box<dyn for<'a> Fn(&'a mut Box<dyn Any>) -> Synch::UpdateResult<'a>>,
|
update: Box<
|
||||||
|
dyn for<'a> Fn(
|
||||||
|
&'a mut Box<dyn Any>,
|
||||||
|
&'a mut NodeUpdateContext<Synch>,
|
||||||
|
) -> Synch::UpdateResult<'a>,
|
||||||
|
>,
|
||||||
debug_fmt: Box<dyn Fn(&Box<dyn Any>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result>,
|
debug_fmt: Box<dyn Fn(&Box<dyn Any>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) struct NodeUpdateContext<Synch: Synchronicity> {
|
||||||
|
pub(crate) invalidate_dependent_nodes: bool,
|
||||||
|
pub(crate) removed_nodes: Vec<NodeId>,
|
||||||
|
pub(crate) added_nodes: Vec<(ErasedNode<Synch>, Rc<Cell<Option<NodeId>>>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Synchronicity> NodeUpdateContext<S> {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
invalidate_dependent_nodes: false,
|
||||||
|
removed_nodes: vec![],
|
||||||
|
added_nodes: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn invalidate_dependent_nodes(&mut self) {
|
||||||
|
self.invalidate_dependent_nodes = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<S: Synchronicity> ErasedNode<S> {
|
impl<S: Synchronicity> ErasedNode<S> {
|
||||||
pub(crate) fn new<N: Node<V, S> + 'static, V: NodeValue>(base: N) -> Self {
|
pub(crate) fn new<N: Node<V, S> + 'static, V: NodeValue>(base: N) -> Self {
|
||||||
// i don't love the double boxing, but i'm not sure how else to do this
|
// i don't love the double boxing, but i'm not sure how else to do this
|
||||||
@ -34,9 +63,9 @@ impl<S: Synchronicity> ErasedNode<S> {
|
|||||||
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
|
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
|
||||||
x.visit_inputs(visitor);
|
x.visit_inputs(visitor);
|
||||||
}),
|
}),
|
||||||
update: Box::new(|any| {
|
update: Box::new(|any, ctx| {
|
||||||
let x = any.downcast_mut::<Box<dyn Node<V, S>>>().unwrap();
|
let x = any.downcast_mut::<Box<dyn Node<V, S>>>().unwrap();
|
||||||
x.update()
|
x.update(ctx)
|
||||||
}),
|
}),
|
||||||
debug_fmt: Box::new(|any, f| {
|
debug_fmt: Box::new(|any, f| {
|
||||||
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
|
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
|
||||||
@ -57,14 +86,14 @@ impl<S: Synchronicity> ErasedNode<S> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ErasedNode<Synchronous> {
|
impl ErasedNode<Synchronous> {
|
||||||
pub(crate) fn update(&mut self) -> bool {
|
pub(crate) fn update(&mut self, ctx: &mut NodeUpdateContext<Synchronous>) {
|
||||||
(self.update)(&mut self.any)
|
(self.update)(&mut self.any, ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ErasedNode<Asynchronous> {
|
impl ErasedNode<Asynchronous> {
|
||||||
pub(crate) async fn update(&mut self) -> bool {
|
pub(crate) async fn update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
|
||||||
(self.update)(&mut self.any).await
|
(self.update)(&mut self.any, ctx).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +107,7 @@ pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity>: std::fmt::Debug {
|
|||||||
fn is_valid(&self) -> bool;
|
fn is_valid(&self) -> bool;
|
||||||
fn invalidate(&mut self);
|
fn invalidate(&mut self);
|
||||||
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ());
|
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ());
|
||||||
fn update(&mut self) -> Synch::UpdateResult<'_>;
|
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<Synch>) -> Synch::UpdateResult<'a>;
|
||||||
fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>;
|
fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,7 +168,7 @@ impl<V: NodeValue, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
|
|||||||
|
|
||||||
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
|
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
|
||||||
|
|
||||||
fn update(&mut self) -> S::UpdateResult<'_> {
|
fn update<'a>(&'a mut self, _ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,11 +210,12 @@ impl<V: NodeValue, S: Synchronicity> Node<V, S> for InvalidatableConstNode<V, S>
|
|||||||
|
|
||||||
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
|
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
|
||||||
|
|
||||||
fn update(&mut self) -> S::UpdateResult<'_> {
|
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
|
||||||
self.valid = true;
|
self.valid = true;
|
||||||
// This node is only invalidate when node_value_eq between the old/new value is false,
|
// This node is only invalidate when node_value_eq between the old/new value is false,
|
||||||
// so it is always the case that the update method has changed the value.
|
// so it is always the case that the update method has changed the value.
|
||||||
S::make_update_result(true, crate::synchronicity::private::Token)
|
ctx.invalidate_dependent_nodes();
|
||||||
|
S::make_update_result(crate::synchronicity::private::Token)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
|
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
|
||||||
@ -217,6 +247,32 @@ impl<R: Rule, S> RuleNode<R, R::Output, S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn visit_inputs<V: InputVisitable>(visitable: &V, visitor: &mut dyn FnMut(NodeId) -> ()) {
|
||||||
|
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ());
|
||||||
|
impl<'a> InputVisitor for InputIndexVisitor<'a> {
|
||||||
|
fn visit<T>(&mut self, input: &Input<T>) {
|
||||||
|
self.0(input.node_idx.get().unwrap());
|
||||||
|
}
|
||||||
|
fn visit_dynamic<T>(&mut self, input: &DynamicInput<T>) {
|
||||||
|
// Visit the dynamic node itself
|
||||||
|
self.visit(&input.input);
|
||||||
|
|
||||||
|
// And visit all the nodes it produces
|
||||||
|
let maybe_dynamic_output = input.input.value.borrow();
|
||||||
|
if let Some(dynamic_output) = maybe_dynamic_output.as_ref() {
|
||||||
|
for input in dynamic_output.inputs.iter() {
|
||||||
|
self.visit(input);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Haven't evaluated the dynamic node for the first time yet.
|
||||||
|
// Upon doing so, if the nodes it produces change, we'll modify the graph
|
||||||
|
// and end up back here in the other branch.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
visitable.visit_inputs(&mut InputIndexVisitor(visitor));
|
||||||
|
}
|
||||||
|
|
||||||
impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S> {
|
impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S> {
|
||||||
fn is_valid(&self) -> bool {
|
fn is_valid(&self) -> bool {
|
||||||
self.valid
|
self.valid
|
||||||
@ -227,16 +283,10 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
|
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
|
||||||
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ());
|
visit_inputs(&self.rule, visitor);
|
||||||
impl<'a> InputVisitor for InputIndexVisitor<'a> {
|
|
||||||
fn visit<T>(&mut self, input: &Input<T>) {
|
|
||||||
self.0(input.node_idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update(&mut self) -> S::UpdateResult<'_> {
|
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
|
||||||
self.valid = true;
|
self.valid = true;
|
||||||
|
|
||||||
let new_value = self.rule.evaluate();
|
let new_value = self.rule.evaluate();
|
||||||
@ -247,9 +297,10 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
|
|||||||
|
|
||||||
if value_changed {
|
if value_changed {
|
||||||
*value = Some(new_value);
|
*value = Some(new_value);
|
||||||
|
ctx.invalidate_dependent_nodes();
|
||||||
}
|
}
|
||||||
|
|
||||||
S::make_update_result(value_changed, crate::synchronicity::private::Token)
|
S::make_update_result(crate::synchronicity::private::Token)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
|
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
|
||||||
@ -290,12 +341,12 @@ impl<V, P: FnOnce() -> F, F: Future<Output = V>> AsyncConstNode<V, P, F> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn do_update(&mut self) -> bool {
|
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
|
||||||
self.valid = true;
|
self.valid = true;
|
||||||
let mut provider = None;
|
let mut provider = None;
|
||||||
std::mem::swap(&mut self.provider, &mut provider);
|
std::mem::swap(&mut self.provider, &mut provider);
|
||||||
*self.value.borrow_mut() = Some(provider.unwrap()().await);
|
*self.value.borrow_mut() = Some(provider.unwrap()().await);
|
||||||
true
|
ctx.invalidate_dependent_nodes();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,8 +363,11 @@ impl<V: NodeValue, P: FnOnce() -> F, F: Future<Output = V>> Node<V, Asynchronous
|
|||||||
|
|
||||||
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
|
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
|
||||||
|
|
||||||
fn update(&mut self) -> <Asynchronous as Synchronicity>::UpdateResult<'_> {
|
fn update<'a>(
|
||||||
Box::pin(self.do_update())
|
&'a mut self,
|
||||||
|
ctx: &'a mut NodeUpdateContext<Asynchronous>,
|
||||||
|
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
|
||||||
|
Box::pin(self.do_update(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
|
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
|
||||||
@ -342,7 +396,7 @@ impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn do_update(&mut self) -> bool {
|
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
|
||||||
self.valid = true;
|
self.valid = true;
|
||||||
|
|
||||||
let new_value = self.rule.evaluate().await;
|
let new_value = self.rule.evaluate().await;
|
||||||
@ -353,9 +407,8 @@ impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
|
|||||||
|
|
||||||
if value_changed {
|
if value_changed {
|
||||||
*value = Some(new_value);
|
*value = Some(new_value);
|
||||||
|
ctx.invalidate_dependent_nodes();
|
||||||
}
|
}
|
||||||
|
|
||||||
value_changed
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,17 +422,14 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
|
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
|
||||||
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ());
|
visit_inputs(&self.rule, visitor);
|
||||||
impl<'a> InputVisitor for InputIndexVisitor<'a> {
|
|
||||||
fn visit<T>(&mut self, input: &Input<T>) {
|
|
||||||
self.0(input.node_idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update(&mut self) -> <Asynchronous as Synchronicity>::UpdateResult<'_> {
|
fn update<'a>(
|
||||||
Box::pin(self.do_update())
|
&'a mut self,
|
||||||
|
ctx: &'a mut NodeUpdateContext<Asynchronous>,
|
||||||
|
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
|
||||||
|
Box::pin(self.do_update(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
|
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
|
||||||
@ -405,6 +455,236 @@ impl<R: AsyncRule, V> std::fmt::Debug for AsyncRuleNode<R, V> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: better name for this
|
||||||
|
pub struct DynamicRuleOutput<O> {
|
||||||
|
pub inputs: Vec<Input<O>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<O: 'static> NodeValue for DynamicRuleOutput<O> {
|
||||||
|
fn node_value_eq(&self, other: &Self) -> bool {
|
||||||
|
if self.inputs.len() != other.inputs.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
self.inputs
|
||||||
|
.iter()
|
||||||
|
.zip(other.inputs.iter())
|
||||||
|
.all(|(s, o)| s.node_idx == o.node_idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<O> std::fmt::Debug for DynamicRuleOutput<O> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct(std::any::type_name::<Self>())
|
||||||
|
.field("inputs", &self.inputs)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct DynamicRuleNode<R, O, S> {
|
||||||
|
rule: R,
|
||||||
|
valid: bool,
|
||||||
|
value: Rc<RefCell<Option<DynamicRuleOutput<O>>>>,
|
||||||
|
synchronicity: std::marker::PhantomData<S>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, O, S> DynamicRuleNode<R, O, S> {
|
||||||
|
pub(crate) fn new(rule: R) -> Self {
|
||||||
|
Self {
|
||||||
|
rule,
|
||||||
|
valid: false,
|
||||||
|
value: Rc::new(RefCell::new(None)),
|
||||||
|
synchronicity: std::marker::PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: DynamicRule, S: Synchronicity> Node<DynamicRuleOutput<R::ChildOutput>, S>
|
||||||
|
for DynamicRuleNode<R, R::ChildOutput, S>
|
||||||
|
{
|
||||||
|
fn is_valid(&self) -> bool {
|
||||||
|
self.valid
|
||||||
|
}
|
||||||
|
|
||||||
|
fn invalidate(&mut self) {
|
||||||
|
self.valid = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
|
||||||
|
visit_inputs(&self.rule, visitor);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
|
||||||
|
self.valid = true;
|
||||||
|
|
||||||
|
let new_value = DynamicRuleOutput {
|
||||||
|
inputs: self.rule.evaluate(&mut DynamicRuleUpdateContext(ctx)),
|
||||||
|
};
|
||||||
|
let mut value = self.value.borrow_mut();
|
||||||
|
let value_changed = value
|
||||||
|
.as_ref()
|
||||||
|
.map_or(true, |v| !v.node_value_eq(&new_value));
|
||||||
|
|
||||||
|
if value_changed {
|
||||||
|
*value = Some(new_value);
|
||||||
|
ctx.invalidate_dependent_nodes();
|
||||||
|
}
|
||||||
|
|
||||||
|
S::make_update_result(crate::synchronicity::private::Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn value_rc(&self) -> &Rc<RefCell<Option<DynamicRuleOutput<R::ChildOutput>>>> {
|
||||||
|
&self.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DynamicRuleUpdateContext<'a, Synch: Synchronicity>(&'a mut NodeUpdateContext<Synch>);
|
||||||
|
|
||||||
|
impl<'a, S: Synchronicity> DynamicRuleUpdateContext<'a, S> {
|
||||||
|
fn add_node<V: NodeValue>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> {
|
||||||
|
let node_idx = Rc::new(Cell::new(None));
|
||||||
|
let value = Rc::clone(node.value_rc());
|
||||||
|
let erased = ErasedNode::new(node);
|
||||||
|
self.0.added_nodes.push((erased, Rc::clone(&node_idx)));
|
||||||
|
Input { node_idx, value }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S> {
|
||||||
|
fn remove_node(&mut self, id: NodeId) {
|
||||||
|
self.0.removed_nodes.push(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
|
||||||
|
where
|
||||||
|
R: Rule,
|
||||||
|
{
|
||||||
|
self.add_node(RuleNode::new(rule))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DynamicRuleLabel<'a, R: DynamicRule>(&'a R);
|
||||||
|
impl<'a, R: DynamicRule> std::fmt::Display for DynamicRuleLabel<'a, R> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
self.0.node_label(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: DynamicRule, O, V> std::fmt::Debug for DynamicRuleNode<R, O, V> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"DynamicRuleNode<{}>({})",
|
||||||
|
pretty_type_name::<R>(),
|
||||||
|
DynamicRuleLabel(&self.rule)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct AsyncDynamicRuleNode<R, O> {
|
||||||
|
rule: R,
|
||||||
|
valid: bool,
|
||||||
|
value: Rc<RefCell<Option<DynamicRuleOutput<O>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncDynamicRule> AsyncDynamicRuleNode<R, R::ChildOutput> {
|
||||||
|
pub(crate) fn new(rule: R) -> Self {
|
||||||
|
Self {
|
||||||
|
rule,
|
||||||
|
valid: false,
|
||||||
|
value: Rc::new(RefCell::new(None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
|
||||||
|
self.valid = true;
|
||||||
|
|
||||||
|
let new_value = DynamicRuleOutput {
|
||||||
|
inputs: self
|
||||||
|
.rule
|
||||||
|
.evaluate(&mut AsyncDynamicRuleUpdateContext(ctx))
|
||||||
|
.await,
|
||||||
|
};
|
||||||
|
let mut value = self.value.borrow_mut();
|
||||||
|
let value_changed = value
|
||||||
|
.as_ref()
|
||||||
|
.map_or(true, |v| !v.node_value_eq(&new_value));
|
||||||
|
|
||||||
|
if value_changed {
|
||||||
|
*value = Some(new_value);
|
||||||
|
ctx.invalidate_dependent_nodes();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncDynamicRule> Node<DynamicRuleOutput<R::ChildOutput>, Asynchronous>
|
||||||
|
for AsyncDynamicRuleNode<R, R::ChildOutput>
|
||||||
|
{
|
||||||
|
fn is_valid(&self) -> bool {
|
||||||
|
self.valid
|
||||||
|
}
|
||||||
|
|
||||||
|
fn invalidate(&mut self) {
|
||||||
|
self.valid = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
|
||||||
|
visit_inputs(&self.rule, visitor);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update<'a>(
|
||||||
|
&'a mut self,
|
||||||
|
ctx: &'a mut NodeUpdateContext<Asynchronous>,
|
||||||
|
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
|
||||||
|
Box::pin(self.do_update(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn value_rc(&self) -> &Rc<RefCell<Option<DynamicRuleOutput<R::ChildOutput>>>> {
|
||||||
|
&self.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AsyncDynamicRuleUpdateContext<'a>(&'a mut NodeUpdateContext<Asynchronous>);
|
||||||
|
|
||||||
|
impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
|
||||||
|
fn remove_node(&mut self, id: NodeId) {
|
||||||
|
DynamicRuleUpdateContext(self.0).remove_node(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
|
||||||
|
where
|
||||||
|
R: Rule,
|
||||||
|
{
|
||||||
|
DynamicRuleUpdateContext(self.0).add_rule(rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> AsyncDynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
|
||||||
|
fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
|
||||||
|
where
|
||||||
|
R: AsyncRule,
|
||||||
|
{
|
||||||
|
DynamicRuleUpdateContext(self.0).add_node(AsyncRuleNode::new(rule))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AsyncDynamicRuleLabel<'a, R: AsyncDynamicRule>(&'a R);
|
||||||
|
impl<'a, R: AsyncDynamicRule> std::fmt::Display for AsyncDynamicRuleLabel<'a, R> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
self.0.node_label(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncDynamicRule> std::fmt::Debug for AsyncDynamicRuleNode<R, R::ChildOutput> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"AsyncDynamicRuleNode<{}>({})",
|
||||||
|
pretty_type_name::<R>(),
|
||||||
|
AsyncDynamicRuleLabel(&self.rule)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn pretty_type_name<T>() -> String {
|
fn pretty_type_name<T>() -> String {
|
||||||
let s = std::any::type_name::<T>();
|
let s = std::any::type_name::<T>();
|
||||||
let ty = syn::parse_str::<syn::Type>(s).unwrap();
|
let ty = syn::parse_str::<syn::Type>(s).unwrap();
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::node::NodeValue;
|
use crate::node::{DynamicRuleOutput, NodeValue};
|
||||||
use crate::NodeId;
|
use crate::NodeId;
|
||||||
pub use compute_graph_macros::InputVisitable;
|
pub use compute_graph_macros::InputVisitable;
|
||||||
use std::cell::{Ref, RefCell};
|
use std::cell::{Cell, Ref, RefCell};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
@ -76,6 +76,75 @@ pub trait AsyncRule: InputVisitable + 'static {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A rule whose output is further nodes in the graph.
|
||||||
|
///
|
||||||
|
/// Types implementing this rule should track which nodes they previously output and not
|
||||||
|
/// add additional equivalent nodes (for whatever domain-specific definition of equivalence)
|
||||||
|
/// on susbequent evaluations.
|
||||||
|
pub trait DynamicRule: InputVisitable + 'static {
|
||||||
|
/// The type of the output value of each of the child nodes that this rule produces.
|
||||||
|
type ChildOutput: NodeValue;
|
||||||
|
|
||||||
|
/// Evaluates this rule, producing additional nodes.
|
||||||
|
///
|
||||||
|
/// Use the methods on [`DynamicRuleContext`] to add or remove nodes from the graph.
|
||||||
|
fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<Self::ChildOutput>>;
|
||||||
|
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Facilities for adding/removing nodes in the graph during the update of a [`DynamicRule`].
|
||||||
|
// todo: better abstracion for this
|
||||||
|
// something that handles diffing and does the add/remove automatically
|
||||||
|
pub trait DynamicRuleContext {
|
||||||
|
/// Removes the node with the given ID from the graph.
|
||||||
|
///
|
||||||
|
/// Be careful when removing nodes. Removing a node that is still depended-upon by another node
|
||||||
|
/// (i.e., is an input in some other node's [`InputVisitable::visit_inputs`]) is an error.
|
||||||
|
fn remove_node(&mut self, id: NodeId);
|
||||||
|
|
||||||
|
/// Adds a node whose value is produced using the given rule to the graph.
|
||||||
|
///
|
||||||
|
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules.
|
||||||
|
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
|
||||||
|
where
|
||||||
|
R: Rule;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An asynchronous rule whose output is further nodes in the graph.
|
||||||
|
///
|
||||||
|
/// See [`DynamicRule`].
|
||||||
|
pub trait AsyncDynamicRule: InputVisitable + 'static {
|
||||||
|
/// The type of the output value of each of the child nodes that this rule produces.
|
||||||
|
type ChildOutput: NodeValue;
|
||||||
|
|
||||||
|
/// Evaluates this rule asynchronously, producing additional nodes.
|
||||||
|
///
|
||||||
|
/// Use the methods on [`AsyncDynamicRuleContext`] to add or remove nodes from the graph.
|
||||||
|
fn evaluate<'a>(
|
||||||
|
&'a mut self,
|
||||||
|
ctx: &'a mut impl AsyncDynamicRuleContext,
|
||||||
|
) -> impl Future<Output = Vec<Input<Self::ChildOutput>>> + 'a;
|
||||||
|
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Facilities for adding/removing nodes in the graph during the update of an [`AsyncDynamicRule`].
|
||||||
|
pub trait AsyncDynamicRuleContext: DynamicRuleContext {
|
||||||
|
/// Adds a node whose value is produced using the given rule to the graph.
|
||||||
|
///
|
||||||
|
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules.
|
||||||
|
fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
|
||||||
|
where
|
||||||
|
R: AsyncRule;
|
||||||
|
}
|
||||||
|
|
||||||
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
|
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
|
||||||
///
|
///
|
||||||
/// The implementation of this trait can generally be derived using [`derive@InputVisitable`].
|
/// The implementation of this trait can generally be derived using [`derive@InputVisitable`].
|
||||||
@ -93,13 +162,13 @@ pub trait InputVisitable {
|
|||||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A placeholder for the output of one node to be used as an input for another.
|
/// A placeholder for the output of one node, to be used as an input for another.
|
||||||
///
|
///
|
||||||
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
|
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
|
||||||
///
|
///
|
||||||
/// Note that this type implements `Clone`, so can be cloned and used as an input for multiple nodes.
|
/// Note that this type implements `Clone`, so can be cloned and used as an input for multiple nodes.
|
||||||
pub struct Input<T> {
|
pub struct Input<T> {
|
||||||
pub(crate) node_idx: NodeId,
|
pub(crate) node_idx: Rc<Cell<Option<NodeId>>>,
|
||||||
pub(crate) value: Rc<RefCell<Option<T>>>,
|
pub(crate) value: Rc<RefCell<Option<T>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,7 +188,7 @@ impl<T> Input<T> {
|
|||||||
impl<T> Clone for Input<T> {
|
impl<T> Clone for Input<T> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
node_idx: self.node_idx,
|
node_idx: Rc::clone(&self.node_idx),
|
||||||
value: Rc::clone(&self.value),
|
value: Rc::clone(&self.value),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,6 +205,25 @@ impl<T> std::fmt::Debug for Input<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A placeholder for the output of a dynamic rule node, to be used as an input for another.
|
||||||
|
///
|
||||||
|
/// See [`GraphBuilder::add_dynamic_rule`](`crate::builder::GraphBuilder::add_dynamic_rule`).
|
||||||
|
///
|
||||||
|
/// A dependency on a dynamic input represents both a dependency on the dynamic node itself,
|
||||||
|
/// as well as dependencies on each of the nodes that are the output of the dynamic node.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct DynamicInput<T> {
|
||||||
|
pub(crate) input: Input<DynamicRuleOutput<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> DynamicInput<T> {
|
||||||
|
/// Retrieves a reference to the current value of the dynamic node (i.e., the set of inputs
|
||||||
|
/// representing the nodes that are the outputs of the dynamic node).
|
||||||
|
pub fn value(&self) -> impl Deref<Target = DynamicRuleOutput<T>> + '_ {
|
||||||
|
self.input.value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: i really want Input to be able to implement Deref somehow
|
// TODO: i really want Input to be able to implement Deref somehow
|
||||||
|
|
||||||
/// A type that can visit arbitrary [`Input`]s.
|
/// A type that can visit arbitrary [`Input`]s.
|
||||||
@ -145,6 +233,9 @@ impl<T> std::fmt::Debug for Input<T> {
|
|||||||
pub trait InputVisitor {
|
pub trait InputVisitor {
|
||||||
/// Visit an input whose value is of type `T`.
|
/// Visit an input whose value is of type `T`.
|
||||||
fn visit<T>(&mut self, input: &Input<T>);
|
fn visit<T>(&mut self, input: &Input<T>);
|
||||||
|
|
||||||
|
/// Visit a dynamic input whose child value is of type `T`.
|
||||||
|
fn visit_dynamic<T>(&mut self, input: &DynamicInput<T>);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A simple rule that provides a constant value.
|
/// A simple rule that provides a constant value.
|
||||||
|
@ -11,7 +11,7 @@ pub(crate) mod private {
|
|||||||
pub trait Sealed {}
|
pub trait Sealed {}
|
||||||
impl Sealed for super::Synchronous {}
|
impl Sealed for super::Synchronous {}
|
||||||
impl Sealed for super::Asynchronous {}
|
impl Sealed for super::Asynchronous {}
|
||||||
impl Sealed for bool {}
|
impl Sealed for () {}
|
||||||
impl<'a> Sealed for <super::Asynchronous as super::Synchronicity>::UpdateResult<'a> {}
|
impl<'a> Sealed for <super::Asynchronous as super::Synchronicity>::UpdateResult<'a> {}
|
||||||
pub struct Token;
|
pub struct Token;
|
||||||
}
|
}
|
||||||
@ -20,25 +20,23 @@ pub trait Synchronicity: private::Sealed + 'static {
|
|||||||
type UpdateResult<'a>: private::Sealed;
|
type UpdateResult<'a>: private::Sealed;
|
||||||
// Necessary for synchronous nodes that can be part of an async graph to return the
|
// Necessary for synchronous nodes that can be part of an async graph to return the
|
||||||
// appropriate result based on the type of graph they're in.
|
// appropriate result based on the type of graph they're in.
|
||||||
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a>;
|
fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Synchronous;
|
pub struct Synchronous;
|
||||||
|
|
||||||
impl Synchronicity for Synchronous {
|
impl Synchronicity for Synchronous {
|
||||||
type UpdateResult<'a> = bool;
|
type UpdateResult<'a> = ();
|
||||||
|
|
||||||
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> {
|
fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a> {}
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Asynchronous;
|
pub struct Asynchronous;
|
||||||
|
|
||||||
impl Synchronicity for Asynchronous {
|
impl Synchronicity for Asynchronous {
|
||||||
type UpdateResult<'a> = Pin<Box<dyn Future<Output = bool> + 'a>>;
|
type UpdateResult<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
|
||||||
|
|
||||||
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> {
|
fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a> {
|
||||||
Box::pin(std::future::ready(result))
|
Box::pin(std::future::ready(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use proc_macro2::Literal;
|
use proc_macro2::Literal;
|
||||||
use quote::{format_ident, quote};
|
use quote::{format_ident, quote, ToTokens};
|
||||||
use syn::{
|
use syn::{
|
||||||
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam,
|
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam,
|
||||||
PathArguments, Type,
|
PathArguments, Type,
|
||||||
@ -10,8 +10,8 @@ extern crate proc_macro;
|
|||||||
|
|
||||||
/// Derive an implementation of the `InputVisitable` trait and helper methods.
|
/// Derive an implementation of the `InputVisitable` trait and helper methods.
|
||||||
///
|
///
|
||||||
/// This macro generates an implementation of the `InputVisitable` trait and the `visit_input` method that
|
/// This macro generates an implementation of the `InputVisitable` trait and the `visit_inputs` method that
|
||||||
/// calls `visit` on each field of the struct that is of type `Input<T>` for any T.
|
/// calls `visit` on each field of the struct that is of type `Input<T>` or `DynamicInput<T>` for any `T`.
|
||||||
///
|
///
|
||||||
/// The macro also generates helper methods for accessing the value of each input less verbosely.
|
/// The macro also generates helper methods for accessing the value of each input less verbosely.
|
||||||
/// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc.
|
/// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc.
|
||||||
@ -56,20 +56,34 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
|
|||||||
Fields::Named(ref named) => named
|
Fields::Named(ref named) => named
|
||||||
.named
|
.named
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|field| input_value_type(field).is_some())
|
.flat_map(|field| {
|
||||||
.map(|field| {
|
if let Some((_ty, is_dynamic)) = input_value_type(field) {
|
||||||
let ident = field.ident.as_ref().unwrap();
|
let ident = field.ident.as_ref().unwrap();
|
||||||
quote!(visitor.visit(&self.#ident);)
|
if is_dynamic {
|
||||||
|
Some(quote!(visitor.visit_dynamic(&self.#ident);))
|
||||||
|
} else {
|
||||||
|
Some(quote!(visitor.visit(&self.#ident);))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
Fields::Unnamed(ref unnamed) => unnamed
|
Fields::Unnamed(ref unnamed) => unnamed
|
||||||
.unnamed
|
.unnamed
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter(|(_, field)| input_value_type(field).is_some())
|
.flat_map(|(i, field)| {
|
||||||
.map(|(i, _)| {
|
if let Some((_ty, is_dynamic)) = input_value_type(field) {
|
||||||
let idx_lit = Literal::usize_unsuffixed(i);
|
let idx_lit = Literal::usize_unsuffixed(i);
|
||||||
quote!(visitor.visit(&self.#idx_lit);)
|
if is_dynamic {
|
||||||
|
Some(quote!(visitor.visit_dynamic(&self.#idx_lit);))
|
||||||
|
} else {
|
||||||
|
Some(quote!(visitor.visit(&self.#idx_lit);))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
Fields::Unit => vec![],
|
Fields::Unit => vec![],
|
||||||
@ -79,12 +93,19 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
|
|||||||
Fields::Named(ref named) => named
|
Fields::Named(ref named) => named
|
||||||
.named
|
.named
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|field| input_value_type(field).map(|ty| (field, ty)))
|
.filter_map(|field| {
|
||||||
.map(|(field, ty)| {
|
input_value_type(field).map(|(ty, is_dynamic)| (field, ty, is_dynamic))
|
||||||
|
})
|
||||||
|
.map(|(field, ty, is_dynamic)| {
|
||||||
let ident = field.ident.as_ref().unwrap();
|
let ident = field.ident.as_ref().unwrap();
|
||||||
|
let target = if is_dynamic {
|
||||||
|
quote!(::compute_graph::node::DynamicRuleOutput<#ty>)
|
||||||
|
} else {
|
||||||
|
ty.to_token_stream()
|
||||||
|
};
|
||||||
quote!(
|
quote!(
|
||||||
|
|
||||||
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
|
fn #ident(&self) -> impl ::std::ops::Deref<Target = #target> + '_ {
|
||||||
self.#ident.value()
|
self.#ident.value()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,13 +116,20 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
|
|||||||
.unnamed
|
.unnamed
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter_map(|(i, field)| input_value_type(field).map(|ty| (i, ty)))
|
.filter_map(|(i, field)| {
|
||||||
.map(|(i, ty)| {
|
input_value_type(field).map(|(ty, is_dynamic)| (i, ty, is_dynamic))
|
||||||
|
})
|
||||||
|
.map(|(i, ty, is_dynamic)| {
|
||||||
let idx_lit = Literal::usize_unsuffixed(i);
|
let idx_lit = Literal::usize_unsuffixed(i);
|
||||||
let ident = format_ident!("input_{i}");
|
let ident = format_ident!("input_{i}");
|
||||||
|
let target = if is_dynamic {
|
||||||
|
quote!(::compute_graph::node::DynamicRuleOutput<#ty>)
|
||||||
|
} else {
|
||||||
|
ty.to_token_stream()
|
||||||
|
};
|
||||||
quote!(
|
quote!(
|
||||||
|
|
||||||
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
|
fn #ident(&self) -> impl ::std::ops::Deref<Target = #target> + '_ {
|
||||||
self.#idx_lit.value()
|
self.#idx_lit.value()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,14 +154,15 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn input_value_type(field: &Field) -> Option<&Type> {
|
fn input_value_type(field: &Field) -> Option<(&Type, bool)> {
|
||||||
if let Type::Path(ref path) = field.ty {
|
if let Type::Path(ref path) = field.ty {
|
||||||
let last_segment = path.path.segments.last().unwrap();
|
let last_segment = path.path.segments.last().unwrap();
|
||||||
if last_segment.ident == "Input" {
|
if last_segment.ident == "Input" || last_segment.ident == "DynamicInput" {
|
||||||
|
let is_dynamic = last_segment.ident == "DynamicInput";
|
||||||
if let PathArguments::AngleBracketed(ref args) = last_segment.arguments {
|
if let PathArguments::AngleBracketed(ref args) = last_segment.arguments {
|
||||||
if args.args.len() == 1 {
|
if args.args.len() == 1 {
|
||||||
if let GenericArgument::Type(ref ty) = args.args.first().unwrap() {
|
if let GenericArgument::Type(ref ty) = args.args.first().unwrap() {
|
||||||
Some(ty)
|
Some((ty, is_dynamic))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use compute_graph::node::NodeValue;
|
use compute_graph::node::NodeValue;
|
||||||
use compute_graph::rule::{Input, InputVisitable, Rule};
|
use compute_graph::rule::{DynamicInput, Input, InputVisitable, Rule};
|
||||||
|
|
||||||
#[derive(InputVisitable)]
|
#[derive(InputVisitable)]
|
||||||
struct Add(Input<i32>, Input<i32>, i32);
|
struct Add(Input<i32>, Input<i32>, i32);
|
||||||
@ -34,9 +34,25 @@ impl<T: NodeValue + Clone> Rule for Passthrough<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(InputVisitable)]
|
||||||
|
struct Sum(DynamicInput<i32>);
|
||||||
|
impl Rule for Sum {
|
||||||
|
type Output = i32;
|
||||||
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
|
self.input_0()
|
||||||
|
.inputs
|
||||||
|
.iter()
|
||||||
|
.map(|input| *input.value())
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use compute_graph::builder::GraphBuilder;
|
use compute_graph::{
|
||||||
|
builder::GraphBuilder,
|
||||||
|
rule::{ConstantRule, DynamicRule},
|
||||||
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
@ -59,4 +75,27 @@ mod tests {
|
|||||||
let mut graph = builder.build().unwrap();
|
let mut graph = builder.build().unwrap();
|
||||||
assert_eq!(*graph.evaluate(), 6);
|
assert_eq!(*graph.evaluate(), 6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sum() {
|
||||||
|
#[derive(InputVisitable)]
|
||||||
|
struct Dynamic;
|
||||||
|
impl DynamicRule for Dynamic {
|
||||||
|
type ChildOutput = i32;
|
||||||
|
fn evaluate(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut impl compute_graph::rule::DynamicRuleContext,
|
||||||
|
) -> Vec<Input<Self::ChildOutput>> {
|
||||||
|
vec![
|
||||||
|
ctx.add_rule(ConstantRule::new(1)),
|
||||||
|
ctx.add_rule(ConstantRule::new(2)),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut builder = GraphBuilder::new();
|
||||||
|
let dynamic_input = builder.add_dynamic_rule(Dynamic);
|
||||||
|
builder.set_output(Sum(dynamic_input));
|
||||||
|
let mut graph = builder.build().unwrap();
|
||||||
|
assert_eq!(*graph.evaluate(), 3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user