diff --git a/crates/compute_graph/src/builder.rs b/crates/compute_graph/src/builder.rs index 1a24245..3c14bb4 100644 --- a/crates/compute_graph/src/builder.rs +++ b/crates/compute_graph/src/builder.rs @@ -1,4 +1,6 @@ -use crate::node::{AsyncRuleNode, ConstNode, InvalidatableConstNode, Node, NodeValue, RuleNode}; +use crate::node::{ + AsyncConstNode, AsyncRuleNode, ConstNode, InvalidatableConstNode, Node, NodeValue, RuleNode, +}; use crate::rule::{AsyncRule, Rule}; use crate::util; use crate::{ @@ -6,6 +8,7 @@ use crate::{ Synchronous, ValueInvalidationSignal, }; use std::cell::{Cell, RefCell}; +use std::future::Future; use std::rc::Rc; /// Builds a [`Graph`]. @@ -228,7 +231,19 @@ impl GraphBuilder { self.output = Some(input); } - // TODO: add_async_value? + /// Adds a constant node whose value is computed by the given function to the graph. + /// + /// The function is not called until the node is evaluated by the graph. + /// + /// Returns an [`Input`] representing the newly-added node, which can be used to construct rules. + pub fn add_async_value(&mut self, value_provider: P) -> Input + where + V: NodeValue, + P: FnOnce() -> F + 'static, + F: Future + 'static, + { + self.add_node(AsyncConstNode::new(value_provider)) + } /// Adds a node whose value is produced using the given rule to the graph. /// diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index 1616c98..0fbeb63 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -627,4 +627,14 @@ mod tests { assert!(!graph.is_output_valid()); assert_eq!(*graph.evaluate(), 43); } + + #[tokio::test] + async fn async_value() { + let mut builder = GraphBuilder::new_async(); + let a = builder.add_async_value(|| async { 42 }); + let b = builder.add_value(1); + builder.set_output(Add(a, b)); + let mut graph = builder.build().unwrap(); + assert_eq!(*graph.evaluate_async().await, 43); + } } diff --git a/crates/compute_graph/src/node.rs b/crates/compute_graph/src/node.rs index 1e6157b..140692f 100644 --- a/crates/compute_graph/src/node.rs +++ b/crates/compute_graph/src/node.rs @@ -2,6 +2,7 @@ use crate::synchronicity::{Asynchronous, Synchronicity}; use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous}; use std::any::Any; use std::cell::RefCell; +use std::future::Future; use std::rc::Rc; pub(crate) struct ErasedNode { @@ -229,6 +230,52 @@ impl Node for RuleNode } } +pub(crate) struct AsyncConstNode F, F: Future> { + provider: Option

, + value: Rc>>, + valid: bool, +} + +impl F, F: Future> AsyncConstNode { + pub(crate) fn new(provider: P) -> Self { + Self { + provider: Some(provider), + value: Rc::new(RefCell::new(None)), + valid: false, + } + } + + async fn do_update(&mut self) -> bool { + self.valid = true; + let mut provider = None; + std::mem::swap(&mut self.provider, &mut provider); + *self.value.borrow_mut() = Some(provider.unwrap()().await); + true + } +} + +impl F, F: Future> Node + for AsyncConstNode +{ + fn is_valid(&self) -> bool { + self.valid + } + + fn invalidate(&mut self) { + unreachable!() + } + + fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {} + + fn update(&mut self) -> ::UpdateResult<'_> { + Box::pin(self.do_update()) + } + + fn value_rc(&self) -> &Rc>> { + &self.value + } +} + pub(crate) struct AsyncRuleNode { rule: R, value: Rc>>, @@ -243,6 +290,22 @@ impl AsyncRuleNode { valid: false, } } + + async fn do_update(&mut self) -> bool { + self.valid = true; + + let new_value = self.rule.evaluate().await; + let mut value = self.value.borrow_mut(); + let value_changed = value + .as_ref() + .map_or(true, |v| !v.node_value_eq(&new_value)); + + if value_changed { + *value = Some(new_value); + } + + value_changed + } } impl Node for AsyncRuleNode { @@ -272,21 +335,3 @@ impl Node for AsyncRuleNode &self.value } } - -impl AsyncRuleNode { - async fn do_update(&mut self) -> bool { - self.valid = true; - - let new_value = self.rule.evaluate().await; - let mut value = self.value.borrow_mut(); - let value_changed = value - .as_ref() - .map_or(true, |v| !v.node_value_eq(&new_value)); - - if value_changed { - *value = Some(new_value); - } - - value_changed - } -}