Dynamic rules

This commit is contained in:
Shadowfacts 2024-12-29 13:37:54 -05:00
parent 9cb6a8c6ce
commit d92ebf11b2
7 changed files with 803 additions and 103 deletions

View File

@ -1,8 +1,8 @@
use crate::node::{
AsyncConstNode, AsyncRuleNode, ConstNode, ErasedNode, InvalidatableConstNode, Node, NodeValue,
RuleNode,
AsyncConstNode, AsyncDynamicRuleNode, AsyncRuleNode, ConstNode, DynamicRuleNode, ErasedNode,
InvalidatableConstNode, Node, NodeValue, RuleNode,
};
use crate::rule::{AsyncRule, Input, Rule};
use crate::rule::{AsyncDynamicRule, AsyncRule, DynamicInput, DynamicRule, Input, Rule};
use crate::synchronicity::{Asynchronous, Synchronicity, Synchronous};
use crate::util;
use crate::{Graph, InvalidationSignal, NodeGraph, NodeId, ValueInvalidationSignal};
@ -73,7 +73,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
let erased = ErasedNode::new(node);
let idx = self.node_graph.borrow_mut().add_node(erased);
Input {
node_idx: idx,
node_idx: Rc::new(Cell::new(Some(idx))),
value,
}
}
@ -174,19 +174,42 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
}
fn make_invalidation_signal<V>(&self, input: &Input<V>) -> InvalidationSignal {
let node_idx = input.node_idx;
let node_idx = Rc::clone(&input.node_idx);
let graph = Rc::clone(&self.node_graph);
let graph_is_valid = Rc::clone(&self.is_valid);
InvalidationSignal {
do_invalidate: Rc::new(Box::new(move || {
graph_is_valid.set(false);
let mut graph = graph.borrow_mut();
let node = &mut graph[node_idx];
let node = &mut graph[node_idx.get().unwrap()];
node.invalidate();
})),
}
}
/// Adds a node to the graph whose output is additional nodes produced by the given rule.
pub fn add_dynamic_rule<R>(&mut self, rule: R) -> DynamicInput<R::ChildOutput>
where
R: DynamicRule,
{
let input = self.add_node(DynamicRuleNode::<R, R::ChildOutput, S>::new(rule));
DynamicInput { input }
}
/// Adds an externally-invalidatable node to the graph whose output is additional
/// nodes produced by the given rule.
pub fn add_invalidatable_dynamic_rule<R>(
&mut self,
rule: R,
) -> (DynamicInput<R::ChildOutput>, InvalidationSignal)
where
R: DynamicRule,
{
let input = self.add_dynamic_rule(rule);
let signal = self.make_invalidation_signal(&input.input);
(input, signal)
}
/// Creates a graph from this builder, consuming the builder.
///
/// To successfully build a graph, there must be an output node set (using either
@ -217,7 +240,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
graph.add_edge(source, dest, ());
}
util::remove_nodes_not_connected_to(&mut *graph, output.node_idx);
util::remove_nodes_not_connected_to(&mut *graph, output.node_idx.get().unwrap());
drop(graph);
@ -319,6 +342,29 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
let signal = self.make_invalidation_signal(&input);
(input, signal)
}
/// Adds a node to the graph whose output is additional nodes produced asynchronously by the given rule.
pub fn add_async_dynamic_rule<R>(&mut self, rule: R) -> DynamicInput<R::ChildOutput>
where
R: AsyncDynamicRule,
{
let input = self.add_node(AsyncDynamicRuleNode::<R, R::ChildOutput>::new(rule));
DynamicInput { input }
}
/// Adds an externally-invalidatable node to the graph whose output is additional nodes produced
/// asynchronously by the given rule.
pub fn add_invalidatable_async_dynamic_rule<R>(
&mut self,
rule: R,
) -> (DynamicInput<R::ChildOutput>, InvalidationSignal)
where
R: AsyncDynamicRule,
{
let input = self.add_async_dynamic_rule(rule);
let signal = self.make_invalidation_signal(&input.input);
(input, signal)
}
}
/// A reason why a [`GraphBuilder`] can fail to build a graph.
@ -383,8 +429,18 @@ mod tests {
builder.set_output(Double::new(b.clone()));
match builder.build() {
Err(super::BuildGraphError::Cycle(cycle)) => {
let a_start = cycle == vec![a.node_idx, b.node_idx, a.node_idx];
let b_start = cycle == vec![b.node_idx, a.node_idx, b.node_idx];
let a_start = cycle
== vec![
a.node_idx.get().unwrap(),
b.node_idx.get().unwrap(),
a.node_idx.get().unwrap(),
];
let b_start = cycle
== vec![
b.node_idx.get().unwrap(),
a.node_idx.get().unwrap(),
b.node_idx.get().unwrap(),
];
// either is a permisisble way of describing the cycle
assert!(a_start || b_start);
}

View File

@ -49,10 +49,10 @@ pub mod synchronicity;
mod util;
use builder::{BuildGraphError, GraphBuilder};
use node::{ErasedNode, NodeValue};
use node::{ErasedNode, NodeUpdateContext, NodeValue};
use petgraph::visit::{IntoEdgeReferences, IntoNodeReferences, NodeIndexable, NodeRef};
use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef};
use rule::{AsyncRule, Input, InputVisitor, Rule};
use rule::{Input, InputVisitor};
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::collections::VecDeque;
@ -127,7 +127,15 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
/// Because building a graph can fail and this method mutates the underlying graph, it takes
/// ownership of the current graph to prevent the graph being left in an invalid state.
/// It returns either the new, modified graph or an error.
pub fn modify<F>(mut self, mut f: F) -> Result<Self, BuildGraphError>
pub fn modify<F>(mut self, f: F) -> Result<Self, BuildGraphError>
where
F: FnMut(&mut GraphBuilder<O, S>) -> (),
{
self._modify(f)?;
Ok(self)
}
fn _modify<F>(&mut self, mut f: F) -> Result<(), BuildGraphError>
where
F: FnMut(&mut GraphBuilder<O, S>) -> (),
{
@ -142,12 +150,12 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
}
drop(graph);
let old_output = self.output.node_idx;
let old_output = self.output.node_idx.get();
// Modify
let mut builder = self.into_builder();
let mut builder = self.to_builder();
f(&mut builder);
self = builder.build()?;
*self = builder.build()?;
// Any new inboud edges invalidate their target nodes.
let mut graph = self.node_graph.borrow_mut();
@ -164,7 +172,7 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
}
// Edge case: if the only node in the graph is the output node, and it's replaced in the modify block,
// there are no edges but we still need to invalidate.
if !to_invalidate.is_empty() || self.output.node_idx != old_output {
if !to_invalidate.is_empty() || self.output.node_idx.get() != old_output {
self.is_valid.set(false);
for idx in to_invalidate {
let node = &mut graph[idx];
@ -173,13 +181,17 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
}
drop(graph);
Ok(self)
Ok(())
}
/// Convert this graph back into a builder for further modifications.
///
/// Returns a builder with the same output and synchronicity types.
pub fn into_builder(self) -> GraphBuilder<O, S> {
self.to_builder()
}
fn to_builder(&self) -> GraphBuilder<O, S> {
// Clear the edges before modifying so that rebuilding results in a graph with up-to-date edges.
let mut graph = self.node_graph.borrow_mut();
graph.clear_edges();
@ -250,13 +262,51 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
impl<O: 'static> Graph<O, Synchronous> {
fn update_invalid_nodes(&mut self) {
let mut graph = self.node_graph.borrow_mut();
for &idx in self.sorted_nodes.iter() {
let mut i = 0;
while i < self.sorted_nodes.len() {
let idx = self.sorted_nodes[i];
let node = &mut graph[idx];
if !node.is_valid() {
// Update this node
let value_changed = node.update();
let mut ctx = NodeUpdateContext::new();
node.update(&mut ctx);
if value_changed {
let mut nodes_changed = false;
for idx_to_remove in ctx.removed_nodes {
assert!(
idx_to_remove != idx,
"cannot remove node curently being evaluated"
);
let (index_to_remove_in_sorted, _) = self
.sorted_nodes
.iter()
.enumerate()
.find(|(_, idx)| **idx == idx_to_remove)
.expect("removed node must have been already added");
assert!(
index_to_remove_in_sorted > i,
"cannot remove already evaluated node"
);
graph.remove_node(idx_to_remove);
self.sorted_nodes.remove(index_to_remove_in_sorted);
nodes_changed = true;
}
for (added_node, id_cell) in ctx.added_nodes {
let id = graph.add_node(added_node);
id_cell.set(Some(id));
nodes_changed = true;
}
if nodes_changed {
// Update the graph before invalidating downstream nodes.
drop(graph);
self._modify(|_| {})
.expect("modifying graph during evaluation must produce valid graph");
graph = self.node_graph.borrow_mut();
}
if ctx.invalidate_dependent_nodes {
// Invalidate any downstream nodes (which we know we haven't visited yet, because
// we're iterating over a topological sort of the graph).
let dependents = graph
@ -270,14 +320,25 @@ impl<O: 'static> Graph<O, Synchronous> {
dependent.invalidate();
}
}
if nodes_changed {
// If we added/removed nodes, the sorted order has changed, so start evaluating
// from the beginning, in case of changes before i.
i = 0;
continue;
}
}
i += 1;
}
// Consistency check: after updating in the topological sort order, we should be left with
// no invalid nodes
// no invalid nodes.
debug_assert!(self
.sorted_nodes
.iter()
.all(|&idx| { (&graph[idx]).is_valid() }));
self.is_valid.set(true);
}
@ -300,13 +361,51 @@ impl<O: 'static> Graph<O, Asynchronous> {
async fn update_invalid_nodes(&mut self) {
// TODO: consider whether this can be done in parallel to any degree.
let mut graph = self.node_graph.borrow_mut();
for &idx in self.sorted_nodes.iter() {
let mut i = 0;
while i < self.sorted_nodes.len() {
let idx = self.sorted_nodes[i];
let node = &mut graph[idx];
if !node.is_valid() {
// Update this node
let value_changed = node.update().await;
let mut ctx = NodeUpdateContext::new();
node.update(&mut ctx).await;
if value_changed {
let mut nodes_changed = false;
for idx_to_remove in ctx.removed_nodes {
assert!(
idx_to_remove != idx,
"cannot remove node curently being evaluated"
);
let (index_to_remove_in_sorted, _) = self
.sorted_nodes
.iter()
.enumerate()
.find(|(_, idx)| **idx == idx_to_remove)
.expect("removed node must have been already added");
assert!(
index_to_remove_in_sorted > i,
"cannot remove already evaluated node"
);
graph.remove_node(idx_to_remove);
self.sorted_nodes.remove(index_to_remove_in_sorted);
nodes_changed = true;
}
for (added_node, id_cell) in ctx.added_nodes {
let id = graph.add_node(added_node);
id_cell.set(Some(id));
nodes_changed = true;
}
if nodes_changed {
// Update the graph before invalidating downstream nodes.
drop(graph);
self._modify(|_| {})
.expect("modifying graph during evaluation must produce valid graph");
graph = self.node_graph.borrow_mut();
}
if ctx.invalidate_dependent_nodes {
// Invalidate any downstream nodes (which we know we haven't visited yet, because
// we're iterating over a topological sort of the graph).
let dependents = graph
@ -320,14 +419,25 @@ impl<O: 'static> Graph<O, Asynchronous> {
dependent.invalidate();
}
}
if nodes_changed {
// If we added/removed nodes, the sorted order has changed, so start evaluating
// from the beginning, in case of changes before i.
i = 0;
continue;
}
}
i += 1;
}
// Consistency check: after updating in the topological sort order, we should be left with
// no invalid nodes
debug_assert!(self
.sorted_nodes
.iter()
.all(|&idx| { (&graph[idx]).is_valid() }));
self.is_valid.set(true);
}
@ -420,7 +530,9 @@ impl<V> Clone for ValueInvalidationSignal<V> {
#[cfg(test)]
mod tests {
use super::*;
use crate::rule::{ConstantRule, InputVisitable};
use crate::rule::{
AsyncDynamicRule, AsyncRule, ConstantRule, DynamicInput, DynamicRule, InputVisitable, Rule,
};
#[test]
fn rule_output_with_no_inputs() {
@ -713,11 +825,106 @@ mod tests {
r#"digraph {
0 [label="ConstNode<i32> (id=0)"]
1 [label="ConstNode<i32> (id=1)"]
2 [label ="RuleNode<compute_graph::tests::graphviz::AddWithLabel>(test) (id=2)"]
2 [label="RuleNode<AddWithLabel>(test) (id=2)"]
0 -> 2 []
1 -> 2 []
}
"#
)
}
#[test]
fn dynamic_rule() {
let mut builder = GraphBuilder::new();
let (count, set_count) = builder.add_invalidatable_value(1);
struct CountUpTo(Input<i32>, Vec<Input<i32>>);
impl InputVisitable for CountUpTo {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
}
impl DynamicRule for CountUpTo {
type ChildOutput = i32;
fn evaluate(
&mut self,
ctx: &mut impl rule::DynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
let count = *self.0.value();
assert!(count >= self.1.len() as i32);
while (self.1.len() as i32) < count {
self.1
.push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1)));
}
self.1.clone()
}
}
let all_inputs = builder.add_dynamic_rule(CountUpTo(count, vec![]));
struct Sum(DynamicInput<i32>);
impl InputVisitable for Sum {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit_dynamic(&self.0);
}
}
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.0.value().inputs.iter().map(|i| *i.value()).sum()
}
}
builder.set_output(Sum(all_inputs));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 1);
set_count.set_value(2);
assert_eq!(*graph.evaluate(), 3);
set_count.set_value(4);
assert_eq!(*graph.evaluate(), 10);
println!("{}", graph.as_dot_string());
}
#[tokio::test]
async fn async_dynamic_rule() {
let mut builder = GraphBuilder::new_async();
let (count, set_count) = builder.add_invalidatable_value(1);
struct CountUpTo(Input<i32>, Vec<Input<i32>>);
impl InputVisitable for CountUpTo {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
}
impl AsyncDynamicRule for CountUpTo {
type ChildOutput = i32;
async fn evaluate<'a>(
&'a mut self,
ctx: &'a mut impl rule::AsyncDynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
let count = *self.0.value();
assert!(count >= self.1.len() as i32);
while (self.1.len() as i32) < count {
self.1
.push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1)));
}
self.1.clone()
}
}
let all_inputs = builder.add_async_dynamic_rule(CountUpTo(count, vec![]));
struct Sum(DynamicInput<i32>);
impl InputVisitable for Sum {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit_dynamic(&self.0);
}
}
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.0.value().inputs.iter().map(|i| *i.value()).sum()
}
}
builder.set_output(Sum(all_inputs));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate_async().await, 1);
set_count.set_value(2);
assert_eq!(*graph.evaluate_async().await, 3);
set_count.set_value(4);
assert_eq!(*graph.evaluate_async().await, 10);
}
}

View File

@ -1,8 +1,12 @@
use crate::rule::{
AsyncDynamicRule, AsyncDynamicRuleContext, AsyncRule, DynamicInput, DynamicRule,
DynamicRuleContext, InputVisitable, Rule,
};
use crate::synchronicity::{Asynchronous, Synchronicity};
use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous};
use crate::{Input, InputVisitor, NodeId, Synchronous};
use quote::ToTokens;
use std::any::Any;
use std::cell::RefCell;
use std::cell::{Cell, RefCell};
use std::future::Future;
use std::rc::Rc;
@ -11,10 +15,35 @@ pub(crate) struct ErasedNode<Synch: Synchronicity> {
is_valid: Box<dyn Fn(&Box<dyn Any>) -> bool>,
invalidate: Box<dyn Fn(&mut Box<dyn Any>) -> ()>,
visit_inputs: Box<dyn Fn(&Box<dyn Any>, &mut dyn FnMut(NodeId) -> ()) -> ()>,
update: Box<dyn for<'a> Fn(&'a mut Box<dyn Any>) -> Synch::UpdateResult<'a>>,
update: Box<
dyn for<'a> Fn(
&'a mut Box<dyn Any>,
&'a mut NodeUpdateContext<Synch>,
) -> Synch::UpdateResult<'a>,
>,
debug_fmt: Box<dyn Fn(&Box<dyn Any>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result>,
}
pub(crate) struct NodeUpdateContext<Synch: Synchronicity> {
pub(crate) invalidate_dependent_nodes: bool,
pub(crate) removed_nodes: Vec<NodeId>,
pub(crate) added_nodes: Vec<(ErasedNode<Synch>, Rc<Cell<Option<NodeId>>>)>,
}
impl<S: Synchronicity> NodeUpdateContext<S> {
pub(crate) fn new() -> Self {
Self {
invalidate_dependent_nodes: false,
removed_nodes: vec![],
added_nodes: vec![],
}
}
fn invalidate_dependent_nodes(&mut self) {
self.invalidate_dependent_nodes = true;
}
}
impl<S: Synchronicity> ErasedNode<S> {
pub(crate) fn new<N: Node<V, S> + 'static, V: NodeValue>(base: N) -> Self {
// i don't love the double boxing, but i'm not sure how else to do this
@ -34,9 +63,9 @@ impl<S: Synchronicity> ErasedNode<S> {
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
x.visit_inputs(visitor);
}),
update: Box::new(|any| {
update: Box::new(|any, ctx| {
let x = any.downcast_mut::<Box<dyn Node<V, S>>>().unwrap();
x.update()
x.update(ctx)
}),
debug_fmt: Box::new(|any, f| {
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
@ -57,14 +86,14 @@ impl<S: Synchronicity> ErasedNode<S> {
}
impl ErasedNode<Synchronous> {
pub(crate) fn update(&mut self) -> bool {
(self.update)(&mut self.any)
pub(crate) fn update(&mut self, ctx: &mut NodeUpdateContext<Synchronous>) {
(self.update)(&mut self.any, ctx)
}
}
impl ErasedNode<Asynchronous> {
pub(crate) async fn update(&mut self) -> bool {
(self.update)(&mut self.any).await
pub(crate) async fn update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
(self.update)(&mut self.any, ctx).await
}
}
@ -78,7 +107,7 @@ pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity>: std::fmt::Debug {
fn is_valid(&self) -> bool;
fn invalidate(&mut self);
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ());
fn update(&mut self) -> Synch::UpdateResult<'_>;
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<Synch>) -> Synch::UpdateResult<'a>;
fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>;
}
@ -139,7 +168,7 @@ impl<V: NodeValue, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update(&mut self) -> S::UpdateResult<'_> {
fn update<'a>(&'a mut self, _ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
unreachable!()
}
@ -181,11 +210,12 @@ impl<V: NodeValue, S: Synchronicity> Node<V, S> for InvalidatableConstNode<V, S>
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update(&mut self) -> S::UpdateResult<'_> {
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true;
// This node is only invalidate when node_value_eq between the old/new value is false,
// so it is always the case that the update method has changed the value.
S::make_update_result(true, crate::synchronicity::private::Token)
ctx.invalidate_dependent_nodes();
S::make_update_result(crate::synchronicity::private::Token)
}
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
@ -217,6 +247,32 @@ impl<R: Rule, S> RuleNode<R, R::Output, S> {
}
}
fn visit_inputs<V: InputVisitable>(visitable: &V, visitor: &mut dyn FnMut(NodeId) -> ()) {
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ());
impl<'a> InputVisitor for InputIndexVisitor<'a> {
fn visit<T>(&mut self, input: &Input<T>) {
self.0(input.node_idx.get().unwrap());
}
fn visit_dynamic<T>(&mut self, input: &DynamicInput<T>) {
// Visit the dynamic node itself
self.visit(&input.input);
// And visit all the nodes it produces
let maybe_dynamic_output = input.input.value.borrow();
if let Some(dynamic_output) = maybe_dynamic_output.as_ref() {
for input in dynamic_output.inputs.iter() {
self.visit(input);
}
} else {
// Haven't evaluated the dynamic node for the first time yet.
// Upon doing so, if the nodes it produces change, we'll modify the graph
// and end up back here in the other branch.
}
}
}
visitable.visit_inputs(&mut InputIndexVisitor(visitor));
}
impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S> {
fn is_valid(&self) -> bool {
self.valid
@ -227,16 +283,10 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ());
impl<'a> InputVisitor for InputIndexVisitor<'a> {
fn visit<T>(&mut self, input: &Input<T>) {
self.0(input.node_idx);
}
}
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
visit_inputs(&self.rule, visitor);
}
fn update(&mut self) -> S::UpdateResult<'_> {
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true;
let new_value = self.rule.evaluate();
@ -247,9 +297,10 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
S::make_update_result(value_changed, crate::synchronicity::private::Token)
S::make_update_result(crate::synchronicity::private::Token)
}
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
@ -290,12 +341,12 @@ impl<V, P: FnOnce() -> F, F: Future<Output = V>> AsyncConstNode<V, P, F> {
}
}
async fn do_update(&mut self) -> bool {
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true;
let mut provider = None;
std::mem::swap(&mut self.provider, &mut provider);
*self.value.borrow_mut() = Some(provider.unwrap()().await);
true
ctx.invalidate_dependent_nodes();
}
}
@ -312,8 +363,11 @@ impl<V: NodeValue, P: FnOnce() -> F, F: Future<Output = V>> Node<V, Asynchronous
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update(&mut self) -> <Asynchronous as Synchronicity>::UpdateResult<'_> {
Box::pin(self.do_update())
fn update<'a>(
&'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
}
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
@ -342,7 +396,7 @@ impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
}
}
async fn do_update(&mut self) -> bool {
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true;
let new_value = self.rule.evaluate().await;
@ -353,9 +407,8 @@ impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
value_changed
}
}
@ -369,17 +422,14 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ());
impl<'a> InputVisitor for InputIndexVisitor<'a> {
fn visit<T>(&mut self, input: &Input<T>) {
self.0(input.node_idx);
}
}
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
visit_inputs(&self.rule, visitor);
}
fn update(&mut self) -> <Asynchronous as Synchronicity>::UpdateResult<'_> {
Box::pin(self.do_update())
fn update<'a>(
&'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
}
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
@ -405,6 +455,236 @@ impl<R: AsyncRule, V> std::fmt::Debug for AsyncRuleNode<R, V> {
}
}
// todo: better name for this
pub struct DynamicRuleOutput<O> {
pub inputs: Vec<Input<O>>,
}
impl<O: 'static> NodeValue for DynamicRuleOutput<O> {
fn node_value_eq(&self, other: &Self) -> bool {
if self.inputs.len() != other.inputs.len() {
return false;
}
self.inputs
.iter()
.zip(other.inputs.iter())
.all(|(s, o)| s.node_idx == o.node_idx)
}
}
impl<O> std::fmt::Debug for DynamicRuleOutput<O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(std::any::type_name::<Self>())
.field("inputs", &self.inputs)
.finish()
}
}
pub(crate) struct DynamicRuleNode<R, O, S> {
rule: R,
valid: bool,
value: Rc<RefCell<Option<DynamicRuleOutput<O>>>>,
synchronicity: std::marker::PhantomData<S>,
}
impl<R, O, S> DynamicRuleNode<R, O, S> {
pub(crate) fn new(rule: R) -> Self {
Self {
rule,
valid: false,
value: Rc::new(RefCell::new(None)),
synchronicity: std::marker::PhantomData,
}
}
}
impl<R: DynamicRule, S: Synchronicity> Node<DynamicRuleOutput<R::ChildOutput>, S>
for DynamicRuleNode<R, R::ChildOutput, S>
{
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
visit_inputs(&self.rule, visitor);
}
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true;
let new_value = DynamicRuleOutput {
inputs: self.rule.evaluate(&mut DynamicRuleUpdateContext(ctx)),
};
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
S::make_update_result(crate::synchronicity::private::Token)
}
fn value_rc(&self) -> &Rc<RefCell<Option<DynamicRuleOutput<R::ChildOutput>>>> {
&self.value
}
}
struct DynamicRuleUpdateContext<'a, Synch: Synchronicity>(&'a mut NodeUpdateContext<Synch>);
impl<'a, S: Synchronicity> DynamicRuleUpdateContext<'a, S> {
fn add_node<V: NodeValue>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> {
let node_idx = Rc::new(Cell::new(None));
let value = Rc::clone(node.value_rc());
let erased = ErasedNode::new(node);
self.0.added_nodes.push((erased, Rc::clone(&node_idx)));
Input { node_idx, value }
}
}
impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S> {
fn remove_node(&mut self, id: NodeId) {
self.0.removed_nodes.push(id);
}
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule,
{
self.add_node(RuleNode::new(rule))
}
}
struct DynamicRuleLabel<'a, R: DynamicRule>(&'a R);
impl<'a, R: DynamicRule> std::fmt::Display for DynamicRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.node_label(f)
}
}
impl<R: DynamicRule, O, V> std::fmt::Debug for DynamicRuleNode<R, O, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"DynamicRuleNode<{}>({})",
pretty_type_name::<R>(),
DynamicRuleLabel(&self.rule)
)
}
}
pub(crate) struct AsyncDynamicRuleNode<R, O> {
rule: R,
valid: bool,
value: Rc<RefCell<Option<DynamicRuleOutput<O>>>>,
}
impl<R: AsyncDynamicRule> AsyncDynamicRuleNode<R, R::ChildOutput> {
pub(crate) fn new(rule: R) -> Self {
Self {
rule,
valid: false,
value: Rc::new(RefCell::new(None)),
}
}
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true;
let new_value = DynamicRuleOutput {
inputs: self
.rule
.evaluate(&mut AsyncDynamicRuleUpdateContext(ctx))
.await,
};
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
}
}
impl<R: AsyncDynamicRule> Node<DynamicRuleOutput<R::ChildOutput>, Asynchronous>
for AsyncDynamicRuleNode<R, R::ChildOutput>
{
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
visit_inputs(&self.rule, visitor);
}
fn update<'a>(
&'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
}
fn value_rc(&self) -> &Rc<RefCell<Option<DynamicRuleOutput<R::ChildOutput>>>> {
&self.value
}
}
struct AsyncDynamicRuleUpdateContext<'a>(&'a mut NodeUpdateContext<Asynchronous>);
impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
fn remove_node(&mut self, id: NodeId) {
DynamicRuleUpdateContext(self.0).remove_node(id);
}
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule,
{
DynamicRuleUpdateContext(self.0).add_rule(rule)
}
}
impl<'a> AsyncDynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: AsyncRule,
{
DynamicRuleUpdateContext(self.0).add_node(AsyncRuleNode::new(rule))
}
}
struct AsyncDynamicRuleLabel<'a, R: AsyncDynamicRule>(&'a R);
impl<'a, R: AsyncDynamicRule> std::fmt::Display for AsyncDynamicRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.node_label(f)
}
}
impl<R: AsyncDynamicRule> std::fmt::Debug for AsyncDynamicRuleNode<R, R::ChildOutput> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AsyncDynamicRuleNode<{}>({})",
pretty_type_name::<R>(),
AsyncDynamicRuleLabel(&self.rule)
)
}
}
fn pretty_type_name<T>() -> String {
let s = std::any::type_name::<T>();
let ty = syn::parse_str::<syn::Type>(s).unwrap();

View File

@ -1,7 +1,7 @@
use crate::node::NodeValue;
use crate::node::{DynamicRuleOutput, NodeValue};
use crate::NodeId;
pub use compute_graph_macros::InputVisitable;
use std::cell::{Ref, RefCell};
use std::cell::{Cell, Ref, RefCell};
use std::future::Future;
use std::ops::Deref;
use std::rc::Rc;
@ -76,6 +76,75 @@ pub trait AsyncRule: InputVisitable + 'static {
}
}
/// A rule whose output is further nodes in the graph.
///
/// Types implementing this rule should track which nodes they previously output and not
/// add additional equivalent nodes (for whatever domain-specific definition of equivalence)
/// on susbequent evaluations.
pub trait DynamicRule: InputVisitable + 'static {
/// The type of the output value of each of the child nodes that this rule produces.
type ChildOutput: NodeValue;
/// Evaluates this rule, producing additional nodes.
///
/// Use the methods on [`DynamicRuleContext`] to add or remove nodes from the graph.
fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<Self::ChildOutput>>;
#[allow(unused_variables)]
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Ok(())
}
}
/// Facilities for adding/removing nodes in the graph during the update of a [`DynamicRule`].
// todo: better abstracion for this
// something that handles diffing and does the add/remove automatically
pub trait DynamicRuleContext {
/// Removes the node with the given ID from the graph.
///
/// Be careful when removing nodes. Removing a node that is still depended-upon by another node
/// (i.e., is an input in some other node's [`InputVisitable::visit_inputs`]) is an error.
fn remove_node(&mut self, id: NodeId);
/// Adds a node whose value is produced using the given rule to the graph.
///
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules.
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule;
}
/// An asynchronous rule whose output is further nodes in the graph.
///
/// See [`DynamicRule`].
pub trait AsyncDynamicRule: InputVisitable + 'static {
/// The type of the output value of each of the child nodes that this rule produces.
type ChildOutput: NodeValue;
/// Evaluates this rule asynchronously, producing additional nodes.
///
/// Use the methods on [`AsyncDynamicRuleContext`] to add or remove nodes from the graph.
fn evaluate<'a>(
&'a mut self,
ctx: &'a mut impl AsyncDynamicRuleContext,
) -> impl Future<Output = Vec<Input<Self::ChildOutput>>> + 'a;
#[allow(unused_variables)]
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Ok(())
}
}
/// Facilities for adding/removing nodes in the graph during the update of an [`AsyncDynamicRule`].
pub trait AsyncDynamicRuleContext: DynamicRuleContext {
/// Adds a node whose value is produced using the given rule to the graph.
///
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules.
fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: AsyncRule;
}
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
///
/// The implementation of this trait can generally be derived using [`derive@InputVisitable`].
@ -93,13 +162,13 @@ pub trait InputVisitable {
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
}
/// A placeholder for the output of one node to be used as an input for another.
/// A placeholder for the output of one node, to be used as an input for another.
///
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
///
/// Note that this type implements `Clone`, so can be cloned and used as an input for multiple nodes.
pub struct Input<T> {
pub(crate) node_idx: NodeId,
pub(crate) node_idx: Rc<Cell<Option<NodeId>>>,
pub(crate) value: Rc<RefCell<Option<T>>>,
}
@ -119,7 +188,7 @@ impl<T> Input<T> {
impl<T> Clone for Input<T> {
fn clone(&self) -> Self {
Self {
node_idx: self.node_idx,
node_idx: Rc::clone(&self.node_idx),
value: Rc::clone(&self.value),
}
}
@ -136,6 +205,25 @@ impl<T> std::fmt::Debug for Input<T> {
}
}
/// A placeholder for the output of a dynamic rule node, to be used as an input for another.
///
/// See [`GraphBuilder::add_dynamic_rule`](`crate::builder::GraphBuilder::add_dynamic_rule`).
///
/// A dependency on a dynamic input represents both a dependency on the dynamic node itself,
/// as well as dependencies on each of the nodes that are the output of the dynamic node.
#[derive(Clone)]
pub struct DynamicInput<T> {
pub(crate) input: Input<DynamicRuleOutput<T>>,
}
impl<T> DynamicInput<T> {
/// Retrieves a reference to the current value of the dynamic node (i.e., the set of inputs
/// representing the nodes that are the outputs of the dynamic node).
pub fn value(&self) -> impl Deref<Target = DynamicRuleOutput<T>> + '_ {
self.input.value()
}
}
// TODO: i really want Input to be able to implement Deref somehow
/// A type that can visit arbitrary [`Input`]s.
@ -145,6 +233,9 @@ impl<T> std::fmt::Debug for Input<T> {
pub trait InputVisitor {
/// Visit an input whose value is of type `T`.
fn visit<T>(&mut self, input: &Input<T>);
/// Visit a dynamic input whose child value is of type `T`.
fn visit_dynamic<T>(&mut self, input: &DynamicInput<T>);
}
/// A simple rule that provides a constant value.

View File

@ -11,7 +11,7 @@ pub(crate) mod private {
pub trait Sealed {}
impl Sealed for super::Synchronous {}
impl Sealed for super::Asynchronous {}
impl Sealed for bool {}
impl Sealed for () {}
impl<'a> Sealed for <super::Asynchronous as super::Synchronicity>::UpdateResult<'a> {}
pub struct Token;
}
@ -20,25 +20,23 @@ pub trait Synchronicity: private::Sealed + 'static {
type UpdateResult<'a>: private::Sealed;
// Necessary for synchronous nodes that can be part of an async graph to return the
// appropriate result based on the type of graph they're in.
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a>;
fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a>;
}
pub struct Synchronous;
impl Synchronicity for Synchronous {
type UpdateResult<'a> = bool;
type UpdateResult<'a> = ();
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> {
result
}
fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a> {}
}
pub struct Asynchronous;
impl Synchronicity for Asynchronous {
type UpdateResult<'a> = Pin<Box<dyn Future<Output = bool> + 'a>>;
type UpdateResult<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> {
Box::pin(std::future::ready(result))
fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a> {
Box::pin(std::future::ready(()))
}
}

View File

@ -1,6 +1,6 @@
use proc_macro::TokenStream;
use proc_macro2::Literal;
use quote::{format_ident, quote};
use quote::{format_ident, quote, ToTokens};
use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam,
PathArguments, Type,
@ -10,8 +10,8 @@ extern crate proc_macro;
/// Derive an implementation of the `InputVisitable` trait and helper methods.
///
/// This macro generates an implementation of the `InputVisitable` trait and the `visit_input` method that
/// calls `visit` on each field of the struct that is of type `Input<T>` for any T.
/// This macro generates an implementation of the `InputVisitable` trait and the `visit_inputs` method that
/// calls `visit` on each field of the struct that is of type `Input<T>` or `DynamicInput<T>` for any `T`.
///
/// The macro also generates helper methods for accessing the value of each input less verbosely.
/// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc.
@ -56,20 +56,34 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
Fields::Named(ref named) => named
.named
.iter()
.filter(|field| input_value_type(field).is_some())
.map(|field| {
.flat_map(|field| {
if let Some((_ty, is_dynamic)) = input_value_type(field) {
let ident = field.ident.as_ref().unwrap();
quote!(visitor.visit(&self.#ident);)
if is_dynamic {
Some(quote!(visitor.visit_dynamic(&self.#ident);))
} else {
Some(quote!(visitor.visit(&self.#ident);))
}
} else {
None
}
})
.collect::<Vec<_>>(),
Fields::Unnamed(ref unnamed) => unnamed
.unnamed
.iter()
.enumerate()
.filter(|(_, field)| input_value_type(field).is_some())
.map(|(i, _)| {
.flat_map(|(i, field)| {
if let Some((_ty, is_dynamic)) = input_value_type(field) {
let idx_lit = Literal::usize_unsuffixed(i);
quote!(visitor.visit(&self.#idx_lit);)
if is_dynamic {
Some(quote!(visitor.visit_dynamic(&self.#idx_lit);))
} else {
Some(quote!(visitor.visit(&self.#idx_lit);))
}
} else {
None
}
})
.collect::<Vec<_>>(),
Fields::Unit => vec![],
@ -79,12 +93,19 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
Fields::Named(ref named) => named
.named
.iter()
.filter_map(|field| input_value_type(field).map(|ty| (field, ty)))
.map(|(field, ty)| {
.filter_map(|field| {
input_value_type(field).map(|(ty, is_dynamic)| (field, ty, is_dynamic))
})
.map(|(field, ty, is_dynamic)| {
let ident = field.ident.as_ref().unwrap();
let target = if is_dynamic {
quote!(::compute_graph::node::DynamicRuleOutput<#ty>)
} else {
ty.to_token_stream()
};
quote!(
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
fn #ident(&self) -> impl ::std::ops::Deref<Target = #target> + '_ {
self.#ident.value()
}
@ -95,13 +116,20 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
.unnamed
.iter()
.enumerate()
.filter_map(|(i, field)| input_value_type(field).map(|ty| (i, ty)))
.map(|(i, ty)| {
.filter_map(|(i, field)| {
input_value_type(field).map(|(ty, is_dynamic)| (i, ty, is_dynamic))
})
.map(|(i, ty, is_dynamic)| {
let idx_lit = Literal::usize_unsuffixed(i);
let ident = format_ident!("input_{i}");
let target = if is_dynamic {
quote!(::compute_graph::node::DynamicRuleOutput<#ty>)
} else {
ty.to_token_stream()
};
quote!(
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
fn #ident(&self) -> impl ::std::ops::Deref<Target = #target> + '_ {
self.#idx_lit.value()
}
@ -126,14 +154,15 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
))
}
fn input_value_type(field: &Field) -> Option<&Type> {
fn input_value_type(field: &Field) -> Option<(&Type, bool)> {
if let Type::Path(ref path) = field.ty {
let last_segment = path.path.segments.last().unwrap();
if last_segment.ident == "Input" {
if last_segment.ident == "Input" || last_segment.ident == "DynamicInput" {
let is_dynamic = last_segment.ident == "DynamicInput";
if let PathArguments::AngleBracketed(ref args) = last_segment.arguments {
if args.args.len() == 1 {
if let GenericArgument::Type(ref ty) = args.args.first().unwrap() {
Some(ty)
Some((ty, is_dynamic))
} else {
None
}

View File

@ -1,5 +1,5 @@
use compute_graph::node::NodeValue;
use compute_graph::rule::{Input, InputVisitable, Rule};
use compute_graph::rule::{DynamicInput, Input, InputVisitable, Rule};
#[derive(InputVisitable)]
struct Add(Input<i32>, Input<i32>, i32);
@ -34,9 +34,25 @@ impl<T: NodeValue + Clone> Rule for Passthrough<T> {
}
}
#[derive(InputVisitable)]
struct Sum(DynamicInput<i32>);
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.input_0()
.inputs
.iter()
.map(|input| *input.value())
.sum()
}
}
#[cfg(test)]
mod tests {
use compute_graph::builder::GraphBuilder;
use compute_graph::{
builder::GraphBuilder,
rule::{ConstantRule, DynamicRule},
};
use super::*;
@ -59,4 +75,27 @@ mod tests {
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 6);
}
#[test]
fn test_sum() {
#[derive(InputVisitable)]
struct Dynamic;
impl DynamicRule for Dynamic {
type ChildOutput = i32;
fn evaluate(
&mut self,
ctx: &mut impl compute_graph::rule::DynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
vec![
ctx.add_rule(ConstantRule::new(1)),
ctx.add_rule(ConstantRule::new(2)),
]
}
}
let mut builder = GraphBuilder::new();
let dynamic_input = builder.add_dynamic_rule(Dynamic);
builder.set_output(Sum(dynamic_input));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 3);
}
}