diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index 81d9f1e..3e0d946 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -59,7 +59,7 @@ use node::{ErasedNode, NodeValue}; use petgraph::visit::{IntoEdgeReferences, NodeIndexable}; use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef}; use rule::{AsyncRule, Input, InputVisitor, Rule}; -use std::cell::{Cell, Ref, RefCell}; +use std::cell::{Cell, RefCell}; use std::collections::HashMap; use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; diff --git a/crates/compute_graph/src/node.rs b/crates/compute_graph/src/node.rs index 140692f..20401a5 100644 --- a/crates/compute_graph/src/node.rs +++ b/crates/compute_graph/src/node.rs @@ -10,7 +10,7 @@ pub(crate) struct ErasedNode { is_valid: Box) -> bool>, invalidate: Box) -> ()>, visit_inputs: Box, &mut dyn FnMut(NodeId) -> ()) -> ()>, - update: Synch::UpdateFn, + update: Box Fn(&'a mut Box) -> Synch::UpdateResult<'a>>, } impl ErasedNode { @@ -32,7 +32,10 @@ impl ErasedNode { let x = any.downcast_ref::>>().unwrap(); x.visit_inputs(visitor); }), - update: S::make_update_fn::(), + update: Box::new(|any| { + let x = any.downcast_mut::>>().unwrap(); + x.update() + }), } } @@ -164,7 +167,7 @@ impl Node for InvalidatableConstNode self.valid = true; // This node is only invalidate when node_value_eq between the old/new value is false, // so it is always the case that the update method has changed the value. - S::make_update_result(true) + S::make_update_result(true, crate::synchronicity::private::Token) } fn value_rc(&self) -> &Rc>> { @@ -222,7 +225,7 @@ impl Node for RuleNode *value = Some(new_value); } - S::make_update_result(value_changed) + S::make_update_result(value_changed, crate::synchronicity::private::Token) } fn value_rc(&self) -> &Rc>> { diff --git a/crates/compute_graph/src/synchronicity.rs b/crates/compute_graph/src/synchronicity.rs index 4cb6317..90ab33c 100644 --- a/crates/compute_graph/src/synchronicity.rs +++ b/crates/compute_graph/src/synchronicity.rs @@ -4,41 +4,31 @@ //! The [`Synchronicity`] trait is sealed to outside implementors, and you generally do not need to refer //! directly to the [`Synchronous`] or [`Asynchronous`] types. -use crate::node::{Node, NodeValue}; -use std::any::Any; use std::future::Future; use std::pin::Pin; -mod private { +pub(crate) mod private { pub trait Sealed {} impl Sealed for super::Synchronous {} impl Sealed for super::Asynchronous {} + impl Sealed for bool {} + impl<'a> Sealed for ::UpdateResult<'a> {} + pub struct Token; } pub trait Synchronicity: private::Sealed + 'static { - // TODO: do we need the associated type? can this just be a method update_node(Box>)? - type UpdateFn; - fn make_update_fn() -> Self::UpdateFn; - - type UpdateResult<'a>; - fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a>; + type UpdateResult<'a>: private::Sealed; + // Necessary for synchronous nodes that can be part of an async graph to return the + // appropriate result based on the type of graph they're in. + fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a>; } pub struct Synchronous; impl Synchronicity for Synchronous { - type UpdateFn = Box) -> bool>; - - fn make_update_fn() -> Self::UpdateFn { - Box::new(|any| { - let x = any.downcast_mut::>>().unwrap(); - x.update() - }) - } - type UpdateResult<'a> = bool; - fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a> { + fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> { result } } @@ -46,21 +36,9 @@ impl Synchronicity for Synchronous { pub struct Asynchronous; impl Synchronicity for Asynchronous { - type UpdateFn = - Box Fn(&'a mut Box) -> Pin + 'a>>>; - - fn make_update_fn() -> Self::UpdateFn { - Box::new(|any| { - Box::pin({ - let x = any.downcast_mut::>>().unwrap(); - x.update() - }) - }) - } - type UpdateResult<'a> = Pin + 'a>>; - fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a> { + fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> { Box::pin(std::future::ready(result)) } }