More compute_graph changes

This commit is contained in:
Shadowfacts 2025-01-01 18:00:49 -05:00
parent f467025569
commit 1253999961
5 changed files with 89 additions and 18 deletions

View File

@ -863,7 +863,7 @@ mod tests {
let count = *self.count.value(); let count = *self.count.value();
for i in 1..=count { for i in 1..=count {
self.node_factory self.node_factory
.add_rule(ctx, i, |ctx| ctx.add_rule(ConstantRule::new(i))); .add_node(ctx, i, |ctx| ctx.add_rule(ConstantRule::new(i)));
} }
self.node_factory.all_nodes(ctx) self.node_factory.all_nodes(ctx)
} }
@ -942,7 +942,7 @@ mod tests {
) -> Vec<Input<Self::ChildOutput>> { ) -> Vec<Input<Self::ChildOutput>> {
let count = *self.count.value(); let count = *self.count.value();
for i in 1..=count { for i in 1..=count {
self.node_factory.add_rule(ctx, i, |ctx| { self.node_factory.add_node(ctx, i, |ctx| {
let constant = ctx.add_rule(ConstantRule::new(i)); let constant = ctx.add_rule(ConstantRule::new(i));
let (input, signal) = ctx.add_invalidatable_rule(IncAdd(constant, 0)); let (input, signal) = ctx.add_invalidatable_rule(IncAdd(constant, 0));
self.signals.borrow_mut().push(signal); self.signals.borrow_mut().push(signal);

View File

@ -567,6 +567,13 @@ impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S
self.add_node(RuleNode::new(rule)) self.add_node(RuleNode::new(rule))
} }
fn add_dynamic_rule<R>(&mut self, rule: R) -> Input<DynamicRuleOutput<R::ChildOutput>>
where
R: DynamicRule,
{
self.add_node(DynamicRuleNode::new(rule))
}
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal) fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
where where
R: Rule, R: Rule,
@ -676,6 +683,13 @@ impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
DynamicRuleUpdateContext(self.0).add_rule(rule) DynamicRuleUpdateContext(self.0).add_rule(rule)
} }
fn add_dynamic_rule<R>(&mut self, rule: R) -> Input<DynamicRuleOutput<R::ChildOutput>>
where
R: DynamicRule,
{
DynamicRuleUpdateContext(self.0).add_dynamic_rule(rule)
}
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal) fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
where where
R: Rule, R: Rule,

View File

@ -117,6 +117,11 @@ pub trait DynamicRuleContext {
where where
R: Rule; R: Rule;
/// Adds a node whose output is additional nodes produced by the given dynamic rule.
fn add_dynamic_rule<R>(&mut self, rule: R) -> Input<DynamicRuleOutput<R::ChildOutput>>
where
R: DynamicRule;
/// Adds an externally-invalidatable node whose value is produced using the given rule. /// Adds an externally-invalidatable node whose value is produced using the given rule.
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal) fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
where where
@ -148,38 +153,48 @@ impl<ID: Hash + Eq + Clone, ChildOutput> DynamicNodeFactory<ID, ChildOutput> {
/// ///
/// This method must be called for every node that is part of the output. The `build` function /// This method must be called for every node that is part of the output. The `build` function
/// will only be called for nodes that have not previously been built. /// will only be called for nodes that have not previously been built.
pub fn add_rule<C, F>(&mut self, ctx: &mut C, id: ID, build: F) pub fn add_node<C, F>(&mut self, ctx: &mut C, id: ID, build: F) -> Input<ChildOutput>
where where
C: DynamicRuleContext, C: DynamicRuleContext,
F: FnOnce(&mut C) -> Input<ChildOutput>, F: FnOnce(&mut C) -> Input<ChildOutput>,
{ {
if !self.existing_nodes.contains_key(&id) { let input = if let Some(input) = self.existing_nodes.get(&id) {
input.clone()
} else {
let input = build(ctx); let input = build(ctx);
self.existing_nodes.insert(id.clone(), input); self.existing_nodes.insert(id.clone(), input.clone());
} input
};
self.ids_added_this_evaluation.insert(id); self.ids_added_this_evaluation.insert(id);
input
} }
/// Registers a node that is part of the output. /// Registers a node that is part of the output.
/// ///
/// See [`DynamicNodeFactory::add_rule`]. /// See [`DynamicNodeFactory::add_node`].
pub fn add_async_rule<C, F>(&mut self, ctx: &mut C, id: ID, build: F) pub fn add_async_node<C, F>(&mut self, ctx: &mut C, id: ID, build: F) -> Input<ChildOutput>
where where
C: AsyncDynamicRuleContext, C: AsyncDynamicRuleContext,
F: FnOnce(&mut C) -> Input<ChildOutput>, F: FnOnce(&mut C) -> Input<ChildOutput>,
{ {
if !self.existing_nodes.contains_key(&id) { let input = if let Some(input) = self.existing_nodes.get(&id) {
input.clone()
} else {
let input = build(ctx); let input = build(ctx);
self.existing_nodes.insert(id.clone(), input); self.existing_nodes.insert(id.clone(), input.clone());
} input
};
self.ids_added_this_evaluation.insert(id); self.ids_added_this_evaluation.insert(id);
input
} }
/// Builds the final list of all nodes currently present in the output.
///
/// Removes any nodes that were previously output but which have not had [`DynamicNodeFactory::add_rule`] /// Removes any nodes that were previously output but which have not had [`DynamicNodeFactory::add_rule`]
/// called during this evaluation. /// called during this evaluation.
pub fn all_nodes(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<ChildOutput>> { ///
/// Either this method or [`DynamicNodeFactory::all_nodes`] should only be called once per evaluation.
///
/// This method is useful when adding nodes that are not directly part of a dynamic node's output.
pub fn finalize_nodes(&mut self, ctx: &mut impl DynamicRuleContext) {
// collect everything up front so we can mutably borrow existing_nodes // collect everything up front so we can mutably borrow existing_nodes
let to_remove = self let to_remove = self
.existing_nodes .existing_nodes
@ -192,6 +207,14 @@ impl<ID: Hash + Eq + Clone, ChildOutput> DynamicNodeFactory<ID, ChildOutput> {
ctx.remove_node(input.node_id()); ctx.remove_node(input.node_id());
} }
self.ids_added_this_evaluation.clear(); self.ids_added_this_evaluation.clear();
}
/// Builds the final list of all nodes currently present in the output.
///
/// This method calls [`DynamicNodeFactory::finalize_nodes`], and this method or that one should only
/// be called once per evaluation.
pub fn all_nodes(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<ChildOutput>> {
self.finalize_nodes(ctx);
self.existing_nodes.values().cloned().collect() self.existing_nodes.values().cloned().collect()
} }
} }

View File

@ -2,8 +2,8 @@ use proc_macro::TokenStream;
use proc_macro2::Literal; use proc_macro2::Literal;
use quote::{format_ident, quote, ToTokens}; use quote::{format_ident, quote, ToTokens};
use syn::{ use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam, parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument,
PathArguments, Type, GenericParam, PathArguments, Type,
}; };
extern crate proc_macro; extern crate proc_macro;
@ -18,7 +18,7 @@ extern crate proc_macro;
/// For named fields, the generated method name matches the field name. In both cases, the method /// For named fields, the generated method name matches the field name. In both cases, the method
/// returns a reference to the input value. As with the `Input::value` method, calling the helper methods /// returns a reference to the input value. As with the `Input::value` method, calling the helper methods
/// before the referenced node has been evaluated is forbidden. /// before the referenced node has been evaluated is forbidden.
#[proc_macro_derive(InputVisitable)] #[proc_macro_derive(InputVisitable, attributes(ignore_input))]
pub fn derive_rule(input: TokenStream) -> TokenStream { pub fn derive_rule(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
if let Data::Struct(ref data) = input.data { if let Data::Struct(ref data) = input.data {
@ -155,6 +155,9 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
} }
fn input_value_type(field: &Field) -> Option<(&Type, bool)> { fn input_value_type(field: &Field) -> Option<(&Type, bool)> {
if field.attrs.iter().any(|attr| is_ignore_attr(attr)) {
return None;
}
if let Type::Path(ref path) = field.ty { if let Type::Path(ref path) = field.ty {
let last_segment = path.path.segments.last().unwrap(); let last_segment = path.path.segments.last().unwrap();
if last_segment.ident == "Input" || last_segment.ident == "DynamicInput" { if last_segment.ident == "Input" || last_segment.ident == "DynamicInput" {
@ -179,3 +182,10 @@ fn input_value_type(field: &Field) -> Option<(&Type, bool)> {
None None
} }
} }
fn is_ignore_attr(attr: &Attribute) -> bool {
match attr.meta.require_path_only() {
Ok(path) => path.is_ident("ignore_input"),
Err(_) => false,
}
}

View File

@ -51,7 +51,8 @@ impl Rule for Sum {
mod tests { mod tests {
use compute_graph::{ use compute_graph::{
builder::GraphBuilder, builder::GraphBuilder,
rule::{ConstantRule, DynamicRule}, rule::{ConstantRule, DynamicRule, InputVisitor},
synchronicity::Synchronous,
}; };
use super::*; use super::*;
@ -98,4 +99,27 @@ mod tests {
let mut graph = builder.build().unwrap(); let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 3); assert_eq!(*graph.evaluate(), 3);
} }
#[test]
fn test_ignore() {
#[derive(InputVisitable)]
struct Ignore {
#[ignore_input]
input: Input<i32>,
}
let mut builder = GraphBuilder::<i32, Synchronous>::new();
struct Visitor;
impl InputVisitor for Visitor {
fn visit<T>(&mut self, _input: &Input<T>) {
assert!(false);
}
fn visit_dynamic<T>(&mut self, _input: &DynamicInput<T>) {
unreachable!();
}
}
Ignore {
input: builder.add_value(0),
}
.visit_inputs(&mut Visitor);
}
} }