Add GraphBuilder::add_async_value

This commit is contained in:
Shadowfacts 2024-11-02 18:49:29 -04:00
parent 88dfef75fd
commit a556b14188
3 changed files with 90 additions and 20 deletions

View File

@ -1,4 +1,6 @@
use crate::node::{AsyncRuleNode, ConstNode, InvalidatableConstNode, Node, NodeValue, RuleNode};
use crate::node::{
AsyncConstNode, AsyncRuleNode, ConstNode, InvalidatableConstNode, Node, NodeValue, RuleNode,
};
use crate::rule::{AsyncRule, Rule};
use crate::util;
use crate::{
@ -6,6 +8,7 @@ use crate::{
Synchronous, ValueInvalidationSignal,
};
use std::cell::{Cell, RefCell};
use std::future::Future;
use std::rc::Rc;
/// Builds a [`Graph`].
@ -228,7 +231,19 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
self.output = Some(input);
}
// TODO: add_async_value?
/// Adds a constant node whose value is computed by the given function to the graph.
///
/// The function is not called until the node is evaluated by the graph.
///
/// Returns an [`Input`] representing the newly-added node, which can be used to construct rules.
pub fn add_async_value<V, P, F>(&mut self, value_provider: P) -> Input<V>
where
V: NodeValue,
P: FnOnce() -> F + 'static,
F: Future<Output = V> + 'static,
{
self.add_node(AsyncConstNode::new(value_provider))
}
/// Adds a node whose value is produced using the given rule to the graph.
///

View File

@ -627,4 +627,14 @@ mod tests {
assert!(!graph.is_output_valid());
assert_eq!(*graph.evaluate(), 43);
}
#[tokio::test]
async fn async_value() {
let mut builder = GraphBuilder::new_async();
let a = builder.add_async_value(|| async { 42 });
let b = builder.add_value(1);
builder.set_output(Add(a, b));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate_async().await, 43);
}
}

View File

@ -2,6 +2,7 @@ use crate::synchronicity::{Asynchronous, Synchronicity};
use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous};
use std::any::Any;
use std::cell::RefCell;
use std::future::Future;
use std::rc::Rc;
pub(crate) struct ErasedNode<Synch: Synchronicity> {
@ -229,6 +230,52 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
}
}
pub(crate) struct AsyncConstNode<V, P: FnOnce() -> F, F: Future<Output = V>> {
provider: Option<P>,
value: Rc<RefCell<Option<V>>>,
valid: bool,
}
impl<V, P: FnOnce() -> F, F: Future<Output = V>> AsyncConstNode<V, P, F> {
pub(crate) fn new(provider: P) -> Self {
Self {
provider: Some(provider),
value: Rc::new(RefCell::new(None)),
valid: false,
}
}
async fn do_update(&mut self) -> bool {
self.valid = true;
let mut provider = None;
std::mem::swap(&mut self.provider, &mut provider);
*self.value.borrow_mut() = Some(provider.unwrap()().await);
true
}
}
impl<V: NodeValue, P: FnOnce() -> F, F: Future<Output = V>> Node<V, Asynchronous>
for AsyncConstNode<V, P, F>
{
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
unreachable!()
}
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update(&mut self) -> <Asynchronous as Synchronicity>::UpdateResult<'_> {
Box::pin(self.do_update())
}
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
&self.value
}
}
pub(crate) struct AsyncRuleNode<R, V> {
rule: R,
value: Rc<RefCell<Option<V>>>,
@ -243,6 +290,22 @@ impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
valid: false,
}
}
async fn do_update(&mut self) -> bool {
self.valid = true;
let new_value = self.rule.evaluate().await;
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
}
value_changed
}
}
impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output> {
@ -272,21 +335,3 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
&self.value
}
}
impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
async fn do_update(&mut self) -> bool {
self.valid = true;
let new_value = self.rule.evaluate().await;
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
}
value_changed
}
}