From 1253999961c32fd625120c03b9c9e3846f060892 Mon Sep 17 00:00:00 2001 From: Shadowfacts Date: Wed, 1 Jan 2025 18:00:49 -0500 Subject: [PATCH] More compute_graph changes --- crates/compute_graph/src/lib.rs | 4 +-- crates/compute_graph/src/node.rs | 14 ++++++++ crates/compute_graph/src/rule.rs | 47 +++++++++++++++++++------- crates/compute_graph_macros/src/lib.rs | 16 +++++++-- crates/derive_test/src/lib.rs | 26 +++++++++++++- 5 files changed, 89 insertions(+), 18 deletions(-) diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index e9dffdf..7937721 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -863,7 +863,7 @@ mod tests { let count = *self.count.value(); for i in 1..=count { self.node_factory - .add_rule(ctx, i, |ctx| ctx.add_rule(ConstantRule::new(i))); + .add_node(ctx, i, |ctx| ctx.add_rule(ConstantRule::new(i))); } self.node_factory.all_nodes(ctx) } @@ -942,7 +942,7 @@ mod tests { ) -> Vec> { let count = *self.count.value(); for i in 1..=count { - self.node_factory.add_rule(ctx, i, |ctx| { + self.node_factory.add_node(ctx, i, |ctx| { let constant = ctx.add_rule(ConstantRule::new(i)); let (input, signal) = ctx.add_invalidatable_rule(IncAdd(constant, 0)); self.signals.borrow_mut().push(signal); diff --git a/crates/compute_graph/src/node.rs b/crates/compute_graph/src/node.rs index 3898586..a2fce26 100644 --- a/crates/compute_graph/src/node.rs +++ b/crates/compute_graph/src/node.rs @@ -567,6 +567,13 @@ impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S self.add_node(RuleNode::new(rule)) } + fn add_dynamic_rule(&mut self, rule: R) -> Input> + where + R: DynamicRule, + { + self.add_node(DynamicRuleNode::new(rule)) + } + fn add_invalidatable_rule(&mut self, rule: R) -> (Input, InvalidationSignal) where R: Rule, @@ -676,6 +683,13 @@ impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> { DynamicRuleUpdateContext(self.0).add_rule(rule) } + fn add_dynamic_rule(&mut self, rule: R) -> Input> + where + R: DynamicRule, + { + DynamicRuleUpdateContext(self.0).add_dynamic_rule(rule) + } + fn add_invalidatable_rule(&mut self, rule: R) -> (Input, InvalidationSignal) where R: Rule, diff --git a/crates/compute_graph/src/rule.rs b/crates/compute_graph/src/rule.rs index 0dc5f68..19dd06d 100644 --- a/crates/compute_graph/src/rule.rs +++ b/crates/compute_graph/src/rule.rs @@ -117,6 +117,11 @@ pub trait DynamicRuleContext { where R: Rule; + /// Adds a node whose output is additional nodes produced by the given dynamic rule. + fn add_dynamic_rule(&mut self, rule: R) -> Input> + where + R: DynamicRule; + /// Adds an externally-invalidatable node whose value is produced using the given rule. fn add_invalidatable_rule(&mut self, rule: R) -> (Input, InvalidationSignal) where @@ -148,38 +153,48 @@ impl DynamicNodeFactory { /// /// This method must be called for every node that is part of the output. The `build` function /// will only be called for nodes that have not previously been built. - pub fn add_rule(&mut self, ctx: &mut C, id: ID, build: F) + pub fn add_node(&mut self, ctx: &mut C, id: ID, build: F) -> Input where C: DynamicRuleContext, F: FnOnce(&mut C) -> Input, { - if !self.existing_nodes.contains_key(&id) { + let input = if let Some(input) = self.existing_nodes.get(&id) { + input.clone() + } else { let input = build(ctx); - self.existing_nodes.insert(id.clone(), input); - } + self.existing_nodes.insert(id.clone(), input.clone()); + input + }; self.ids_added_this_evaluation.insert(id); + input } /// Registers a node that is part of the output. /// - /// See [`DynamicNodeFactory::add_rule`]. - pub fn add_async_rule(&mut self, ctx: &mut C, id: ID, build: F) + /// See [`DynamicNodeFactory::add_node`]. + pub fn add_async_node(&mut self, ctx: &mut C, id: ID, build: F) -> Input where C: AsyncDynamicRuleContext, F: FnOnce(&mut C) -> Input, { - if !self.existing_nodes.contains_key(&id) { + let input = if let Some(input) = self.existing_nodes.get(&id) { + input.clone() + } else { let input = build(ctx); - self.existing_nodes.insert(id.clone(), input); - } + self.existing_nodes.insert(id.clone(), input.clone()); + input + }; self.ids_added_this_evaluation.insert(id); + input } - /// Builds the final list of all nodes currently present in the output. - /// /// Removes any nodes that were previously output but which have not had [`DynamicNodeFactory::add_rule`] /// called during this evaluation. - pub fn all_nodes(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec> { + /// + /// Either this method or [`DynamicNodeFactory::all_nodes`] should only be called once per evaluation. + /// + /// This method is useful when adding nodes that are not directly part of a dynamic node's output. + pub fn finalize_nodes(&mut self, ctx: &mut impl DynamicRuleContext) { // collect everything up front so we can mutably borrow existing_nodes let to_remove = self .existing_nodes @@ -192,6 +207,14 @@ impl DynamicNodeFactory { ctx.remove_node(input.node_id()); } self.ids_added_this_evaluation.clear(); + } + + /// Builds the final list of all nodes currently present in the output. + /// + /// This method calls [`DynamicNodeFactory::finalize_nodes`], and this method or that one should only + /// be called once per evaluation. + pub fn all_nodes(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec> { + self.finalize_nodes(ctx); self.existing_nodes.values().cloned().collect() } } diff --git a/crates/compute_graph_macros/src/lib.rs b/crates/compute_graph_macros/src/lib.rs index dc453be..e79c655 100644 --- a/crates/compute_graph_macros/src/lib.rs +++ b/crates/compute_graph_macros/src/lib.rs @@ -2,8 +2,8 @@ use proc_macro::TokenStream; use proc_macro2::Literal; use quote::{format_ident, quote, ToTokens}; use syn::{ - parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam, - PathArguments, Type, + parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, + GenericParam, PathArguments, Type, }; extern crate proc_macro; @@ -18,7 +18,7 @@ extern crate proc_macro; /// For named fields, the generated method name matches the field name. In both cases, the method /// returns a reference to the input value. As with the `Input::value` method, calling the helper methods /// before the referenced node has been evaluated is forbidden. -#[proc_macro_derive(InputVisitable)] +#[proc_macro_derive(InputVisitable, attributes(ignore_input))] pub fn derive_rule(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); if let Data::Struct(ref data) = input.data { @@ -155,6 +155,9 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { } fn input_value_type(field: &Field) -> Option<(&Type, bool)> { + if field.attrs.iter().any(|attr| is_ignore_attr(attr)) { + return None; + } if let Type::Path(ref path) = field.ty { let last_segment = path.path.segments.last().unwrap(); if last_segment.ident == "Input" || last_segment.ident == "DynamicInput" { @@ -179,3 +182,10 @@ fn input_value_type(field: &Field) -> Option<(&Type, bool)> { None } } + +fn is_ignore_attr(attr: &Attribute) -> bool { + match attr.meta.require_path_only() { + Ok(path) => path.is_ident("ignore_input"), + Err(_) => false, + } +} diff --git a/crates/derive_test/src/lib.rs b/crates/derive_test/src/lib.rs index becb59c..b7e81e2 100644 --- a/crates/derive_test/src/lib.rs +++ b/crates/derive_test/src/lib.rs @@ -51,7 +51,8 @@ impl Rule for Sum { mod tests { use compute_graph::{ builder::GraphBuilder, - rule::{ConstantRule, DynamicRule}, + rule::{ConstantRule, DynamicRule, InputVisitor}, + synchronicity::Synchronous, }; use super::*; @@ -98,4 +99,27 @@ mod tests { let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate(), 3); } + + #[test] + fn test_ignore() { + #[derive(InputVisitable)] + struct Ignore { + #[ignore_input] + input: Input, + } + let mut builder = GraphBuilder::::new(); + struct Visitor; + impl InputVisitor for Visitor { + fn visit(&mut self, _input: &Input) { + assert!(false); + } + fn visit_dynamic(&mut self, _input: &DynamicInput) { + unreachable!(); + } + } + Ignore { + input: builder.add_value(0), + } + .visit_inputs(&mut Visitor); + } }