2024-12-31 14:29:11 -05:00

735 lines
22 KiB
Rust

use crate::rule::{
AsyncDynamicRule, AsyncDynamicRuleContext, AsyncRule, DynamicInput, DynamicRule,
DynamicRuleContext, InputVisitable, Rule,
};
use crate::synchronicity::{Asynchronous, Synchronicity};
use crate::{Input, InputVisitor, NodeId, Synchronous};
use quote::ToTokens;
use std::any::Any;
use std::cell::{Cell, RefCell};
use std::future::Future;
use std::rc::Rc;
pub(crate) struct ErasedNode<Synch: Synchronicity> {
any: Box<dyn Any>,
is_valid: Box<dyn Fn(&Box<dyn Any>) -> bool>,
invalidate: Box<dyn Fn(&mut Box<dyn Any>) -> ()>,
visit_inputs: Box<dyn Fn(&Box<dyn Any>, &mut dyn FnMut(NodeId) -> ()) -> ()>,
update: Box<
dyn for<'a> Fn(
&'a mut Box<dyn Any>,
&'a mut NodeUpdateContext<Synch>,
) -> Synch::UpdateResult<'a>,
>,
debug_fmt: Box<dyn Fn(&Box<dyn Any>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result>,
}
pub(crate) struct NodeUpdateContext<Synch: Synchronicity> {
pub(crate) invalidate_dependent_nodes: bool,
pub(crate) removed_nodes: Vec<NodeId>,
pub(crate) added_nodes: Vec<(ErasedNode<Synch>, Rc<Cell<Option<NodeId>>>)>,
}
impl<S: Synchronicity> NodeUpdateContext<S> {
pub(crate) fn new() -> Self {
Self {
invalidate_dependent_nodes: false,
removed_nodes: vec![],
added_nodes: vec![],
}
}
fn invalidate_dependent_nodes(&mut self) {
self.invalidate_dependent_nodes = true;
}
}
impl<S: Synchronicity> ErasedNode<S> {
pub(crate) fn new<N: Node<V, S> + 'static, V: NodeValue>(base: N) -> Self {
// i don't love the double boxing, but i'm not sure how else to do this
let thing: Box<dyn Node<V, S>> = Box::new(base);
let any: Box<dyn Any> = Box::new(thing);
Self {
any,
is_valid: Box::new(|any| {
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
x.is_valid()
}),
invalidate: Box::new(|any| {
let x = any.downcast_mut::<Box<dyn Node<V, S>>>().unwrap();
x.invalidate();
}),
visit_inputs: Box::new(|any, visitor| {
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
x.visit_inputs(visitor);
}),
update: Box::new(|any, ctx| {
let x = any.downcast_mut::<Box<dyn Node<V, S>>>().unwrap();
x.update(ctx)
}),
debug_fmt: Box::new(|any, f| {
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
x.fmt(f)
}),
}
}
pub(crate) fn is_valid(&self) -> bool {
(self.is_valid)(&self.any)
}
pub(crate) fn invalidate(&mut self) {
(self.invalidate)(&mut self.any);
}
pub(crate) fn visit_inputs(&self, f: &mut dyn FnMut(NodeId) -> ()) {
(self.visit_inputs)(&self.any, f);
}
}
impl ErasedNode<Synchronous> {
pub(crate) fn update(&mut self, ctx: &mut NodeUpdateContext<Synchronous>) {
(self.update)(&mut self.any, ctx)
}
}
impl ErasedNode<Asynchronous> {
pub(crate) async fn update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
(self.update)(&mut self.any, ctx).await
}
}
impl<S: Synchronicity> std::fmt::Debug for ErasedNode<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
(self.debug_fmt)(&self.any, f)
}
}
pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity>: std::fmt::Debug {
fn is_valid(&self) -> bool;
fn invalidate(&mut self);
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ());
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<Synch>) -> Synch::UpdateResult<'a>;
fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>;
}
/// A value that can be used as the value of a node in the graph.
///
/// This trait is used to determine, when a node is invalidated, whether its value has truly changed
/// and thus whether downstream nodes need to be invalidated too.
///
/// A blanket implementation of this trait for all types implementing `PartialEq` is provided.
pub trait NodeValue: 'static {
/// Whether self is equal, for the purposes of graph invalidation, from other.
///
/// This method should be conservative. That is, if the equality of the two values cannot be affirmatively
/// determined, this method should return `false`.
///
/// The default implementation of this method always returns `false`, so any non-`PartialEq` type can
/// implement this trait simply:
///
/// ```rust
/// # use compute_graph::node::NodeValue;
/// struct MyType;
/// impl NodeValue for MyType {}
/// ```
///
/// Note that always returning `false` may result in more node invalidations than strictly necessary.
#[allow(unused_variables)]
fn node_value_eq(&self, other: &Self) -> bool {
false
}
}
impl<T: PartialEq + 'static> NodeValue for T {
fn node_value_eq(&self, other: &Self) -> bool {
self == other
}
}
pub(crate) struct ConstNode<V, S> {
value: Rc<RefCell<Option<V>>>,
synchronicity: std::marker::PhantomData<S>,
}
impl<V, S> ConstNode<V, S> {
pub(crate) fn new(value: V) -> Self {
Self {
value: Rc::new(RefCell::new(Some(value))),
synchronicity: std::marker::PhantomData,
}
}
}
impl<V: NodeValue, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
fn is_valid(&self) -> bool {
true
}
fn invalidate(&mut self) {}
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update<'a>(&'a mut self, _ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
unreachable!()
}
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
&self.value
}
}
impl<V, S> std::fmt::Debug for ConstNode<V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ConstNode<{}>", pretty_type_name::<V>())
}
}
pub(crate) struct InvalidatableConstNode<V, S> {
value: Rc<RefCell<Option<V>>>,
valid: bool,
synchronicity: std::marker::PhantomData<S>,
}
impl<V, S> InvalidatableConstNode<V, S> {
pub(crate) fn new(value: V) -> Self {
Self {
value: Rc::new(RefCell::new(Some(value))),
valid: true,
synchronicity: std::marker::PhantomData,
}
}
}
impl<V: NodeValue, S: Synchronicity> Node<V, S> for InvalidatableConstNode<V, S> {
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true;
// This node is only invalidate when node_value_eq between the old/new value is false,
// so it is always the case that the update method has changed the value.
ctx.invalidate_dependent_nodes();
S::make_update_result(crate::synchronicity::private::Token)
}
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
&self.value
}
}
impl<V, S> std::fmt::Debug for InvalidatableConstNode<V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "InvalidatableConstNode<{}>", pretty_type_name::<V>())
}
}
pub(crate) struct RuleNode<R, V, S> {
rule: R,
value: Rc<RefCell<Option<V>>>,
valid: bool,
synchronicity: std::marker::PhantomData<S>,
}
impl<R: Rule, S> RuleNode<R, R::Output, S> {
pub(crate) fn new(rule: R) -> Self {
Self {
rule,
value: Rc::new(RefCell::new(None)),
valid: false,
synchronicity: std::marker::PhantomData,
}
}
}
fn visit_inputs<V: InputVisitable>(visitable: &V, visitor: &mut dyn FnMut(NodeId) -> ()) {
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ());
impl<'a> InputVisitor for InputIndexVisitor<'a> {
fn visit<T>(&mut self, input: &Input<T>) {
self.0(input.node_idx.get().unwrap());
}
fn visit_dynamic<T>(&mut self, input: &DynamicInput<T>) {
// Visit the dynamic node itself
self.visit(&input.input);
// And visit all the nodes it produces
let maybe_dynamic_output = input.input.value.borrow();
if let Some(dynamic_output) = maybe_dynamic_output.as_ref() {
// This might be slightly overzealous: it is possible for a node to only depend on the
// dynamic node itself, and not directly depend on any of the nodes the dynamic node produces.
for input in dynamic_output.inputs.iter() {
self.visit(input);
}
} else {
// Haven't evaluated the dynamic node for the first time yet.
// Upon doing so, if the nodes it produces change, we'll modify the graph
// and end up back here in the other branch.
}
}
}
visitable.visit_inputs(&mut InputIndexVisitor(visitor));
}
impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S> {
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
visit_inputs(&self.rule, visitor);
}
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true;
let new_value = self.rule.evaluate();
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
S::make_update_result(crate::synchronicity::private::Token)
}
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
&self.value
}
}
struct RuleLabel<'a, R: Rule>(&'a R);
impl<'a, R: Rule> std::fmt::Display for RuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.node_label(f)
}
}
impl<R: Rule, V, S> std::fmt::Debug for RuleNode<R, V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RuleNode<{}>({})",
pretty_type_name::<R>(),
RuleLabel(&self.rule)
)
}
}
pub(crate) struct AsyncConstNode<V, P: FnOnce() -> F, F: Future<Output = V>> {
provider: Option<P>,
value: Rc<RefCell<Option<V>>>,
valid: bool,
}
impl<V, P: FnOnce() -> F, F: Future<Output = V>> AsyncConstNode<V, P, F> {
pub(crate) fn new(provider: P) -> Self {
Self {
provider: Some(provider),
value: Rc::new(RefCell::new(None)),
valid: false,
}
}
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true;
let mut provider = None;
std::mem::swap(&mut self.provider, &mut provider);
*self.value.borrow_mut() = Some(provider.unwrap()().await);
ctx.invalidate_dependent_nodes();
}
}
impl<V: NodeValue, P: FnOnce() -> F, F: Future<Output = V>> Node<V, Asynchronous>
for AsyncConstNode<V, P, F>
{
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
unreachable!()
}
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update<'a>(
&'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
}
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
&self.value
}
}
impl<V, P: FnOnce() -> F, F: Future<Output = V>> std::fmt::Debug for AsyncConstNode<V, P, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "AsyncConstNode<{}>", pretty_type_name::<V>())
}
}
pub(crate) struct AsyncRuleNode<R, V> {
rule: R,
value: Rc<RefCell<Option<V>>>,
valid: bool,
}
impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
pub(crate) fn new(rule: R) -> Self {
Self {
rule,
value: Rc::new(RefCell::new(None)),
valid: false,
}
}
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true;
let new_value = self.rule.evaluate().await;
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
}
}
impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output> {
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
visit_inputs(&self.rule, visitor);
}
fn update<'a>(
&'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
}
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
&self.value
}
}
struct AsyncRuleLabel<'a, R: AsyncRule>(&'a R);
impl<'a, R: AsyncRule> std::fmt::Display for AsyncRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.node_label(f)
}
}
impl<R: AsyncRule, V> std::fmt::Debug for AsyncRuleNode<R, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AsyncRuleNode<{}>({})",
pretty_type_name::<R>(),
AsyncRuleLabel(&self.rule)
)
}
}
// todo: better name for this
pub struct DynamicRuleOutput<O> {
pub inputs: Vec<Input<O>>,
}
impl<O: 'static> NodeValue for DynamicRuleOutput<O> {
fn node_value_eq(&self, other: &Self) -> bool {
if self.inputs.len() != other.inputs.len() {
return false;
}
self.inputs
.iter()
.zip(other.inputs.iter())
.all(|(s, o)| s.node_idx == o.node_idx)
}
}
impl<O> std::fmt::Debug for DynamicRuleOutput<O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(std::any::type_name::<Self>())
.field("inputs", &self.inputs)
.finish()
}
}
pub(crate) struct DynamicRuleNode<R, O, S> {
rule: R,
valid: bool,
value: Rc<RefCell<Option<DynamicRuleOutput<O>>>>,
synchronicity: std::marker::PhantomData<S>,
}
impl<R, O, S> DynamicRuleNode<R, O, S> {
pub(crate) fn new(rule: R) -> Self {
Self {
rule,
valid: false,
value: Rc::new(RefCell::new(None)),
synchronicity: std::marker::PhantomData,
}
}
}
impl<R: DynamicRule, S: Synchronicity> Node<DynamicRuleOutput<R::ChildOutput>, S>
for DynamicRuleNode<R, R::ChildOutput, S>
{
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
visit_inputs(&self.rule, visitor);
}
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true;
let new_value = DynamicRuleOutput {
inputs: self.rule.evaluate(&mut DynamicRuleUpdateContext(ctx)),
};
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
S::make_update_result(crate::synchronicity::private::Token)
}
fn value_rc(&self) -> &Rc<RefCell<Option<DynamicRuleOutput<R::ChildOutput>>>> {
&self.value
}
}
struct DynamicRuleUpdateContext<'a, Synch: Synchronicity>(&'a mut NodeUpdateContext<Synch>);
impl<'a, S: Synchronicity> DynamicRuleUpdateContext<'a, S> {
fn add_node<V: NodeValue>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> {
let node_idx = Rc::new(Cell::new(None));
let value = Rc::clone(node.value_rc());
let erased = ErasedNode::new(node);
self.0.added_nodes.push((erased, Rc::clone(&node_idx)));
Input { node_idx, value }
}
}
impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S> {
fn remove_node(&mut self, id: NodeId) {
self.0.removed_nodes.push(id);
}
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule,
{
self.add_node(RuleNode::new(rule))
}
}
struct DynamicRuleLabel<'a, R: DynamicRule>(&'a R);
impl<'a, R: DynamicRule> std::fmt::Display for DynamicRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.node_label(f)
}
}
impl<R: DynamicRule, O, V> std::fmt::Debug for DynamicRuleNode<R, O, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"DynamicRuleNode<{}>({})",
pretty_type_name::<R>(),
DynamicRuleLabel(&self.rule)
)
}
}
pub(crate) struct AsyncDynamicRuleNode<R, O> {
rule: R,
valid: bool,
value: Rc<RefCell<Option<DynamicRuleOutput<O>>>>,
}
impl<R: AsyncDynamicRule> AsyncDynamicRuleNode<R, R::ChildOutput> {
pub(crate) fn new(rule: R) -> Self {
Self {
rule,
valid: false,
value: Rc::new(RefCell::new(None)),
}
}
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true;
let new_value = DynamicRuleOutput {
inputs: self
.rule
.evaluate(&mut AsyncDynamicRuleUpdateContext(ctx))
.await,
};
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
}
}
impl<R: AsyncDynamicRule> Node<DynamicRuleOutput<R::ChildOutput>, Asynchronous>
for AsyncDynamicRuleNode<R, R::ChildOutput>
{
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
visit_inputs(&self.rule, visitor);
}
fn update<'a>(
&'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
}
fn value_rc(&self) -> &Rc<RefCell<Option<DynamicRuleOutput<R::ChildOutput>>>> {
&self.value
}
}
struct AsyncDynamicRuleUpdateContext<'a>(&'a mut NodeUpdateContext<Asynchronous>);
impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
fn remove_node(&mut self, id: NodeId) {
DynamicRuleUpdateContext(self.0).remove_node(id);
}
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule,
{
DynamicRuleUpdateContext(self.0).add_rule(rule)
}
}
impl<'a> AsyncDynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: AsyncRule,
{
DynamicRuleUpdateContext(self.0).add_node(AsyncRuleNode::new(rule))
}
}
struct AsyncDynamicRuleLabel<'a, R: AsyncDynamicRule>(&'a R);
impl<'a, R: AsyncDynamicRule> std::fmt::Display for AsyncDynamicRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.node_label(f)
}
}
impl<R: AsyncDynamicRule> std::fmt::Debug for AsyncDynamicRuleNode<R, R::ChildOutput> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AsyncDynamicRuleNode<{}>({})",
pretty_type_name::<R>(),
AsyncDynamicRuleLabel(&self.rule)
)
}
}
fn pretty_type_name<T>() -> String {
// idk where the {{closure}} comes from in one of the tests, just do this to avoid panicking
let s = std::any::type_name::<T>().replace("{{closure}}", "__closure__");
let ty = syn::parse_str::<syn::Type>(&s).unwrap();
pretty_type_name_type(ty)
}
fn pretty_type_name_type(ty: syn::Type) -> String {
match ty {
syn::Type::Path(path) => pretty_type_name_path(path),
_ => format!("{}", ty.into_token_stream()),
}
}
fn pretty_type_name_path(path: syn::TypePath) -> String {
if path.qself.is_some() {
format!("{}", path.into_token_stream())
} else {
let last_segment = path.path.segments.last().unwrap();
match &last_segment.arguments {
syn::PathArguments::None => {
format!("{}", last_segment.ident.to_token_stream())
}
syn::PathArguments::AngleBracketed(args) => {
let mut str = format!("{}", last_segment.ident.to_token_stream());
str.push('<');
for arg in &args.args {
match arg {
syn::GenericArgument::Type(ty) => {
str.push_str(&pretty_type_name_type(ty.clone()))
}
_ => str.push_str(&format!("{}", arg.into_token_stream())),
}
str.push_str(", ")
}
str.remove(str.len() - 1);
str.replace_range((str.len() - 1).., ">");
str
}
syn::PathArguments::Parenthesized(_) => {
format!("{}", last_segment.into_token_stream())
}
}
}
}