From b8ad929d0b37e0c5fd44775cd0dc1bf8bf6ce9d5 Mon Sep 17 00:00:00 2001 From: Shadowfacts Date: Sun, 3 Nov 2024 14:57:21 -0500 Subject: [PATCH] Fix derive macro with generics --- crates/compute_graph_macros/src/lib.rs | 28 +++++++++++++++++++++----- crates/derive_test/src/lib.rs | 10 +++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/crates/compute_graph_macros/src/lib.rs b/crates/compute_graph_macros/src/lib.rs index cea01fe..75abec2 100644 --- a/crates/compute_graph_macros/src/lib.rs +++ b/crates/compute_graph_macros/src/lib.rs @@ -2,7 +2,7 @@ use proc_macro::TokenStream; use proc_macro2::Literal; use quote::{format_ident, quote}; use syn::{ - parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, + parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam, PathArguments, Type, }; @@ -25,14 +25,32 @@ pub fn derive_rule(input: TokenStream) -> TokenStream { derive_rule_struct(&input, data) } else { TokenStream::from( - syn::Error::new(input.ident.span(), "Only structs can derive `Rule`") - .to_compile_error(), + syn::Error::new( + input.ident.span(), + "Only structs can derive `InputVisitable`", + ) + .to_compile_error(), ) } } fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { let name = &input.ident; + let lt = &input.generics.lt_token; + let params = &input.generics.params; + let gt = &input.generics.gt_token; + let generics = quote!(#lt #params #gt); + let where_clause = &input.generics.where_clause; + let params_only_names = params.iter().map(|p| match p { + GenericParam::Lifetime(_) => { + panic!("Lifetime generics aren't supported when deriving `InputVisitable`") + } + GenericParam::Type(ty) => &ty.ident, + GenericParam::Const(_) => { + panic!("Const generics aren't supported when deriving `InputVisitable`") + } + }); + let generics_only_names = quote!(#lt #(#params_only_names),* #gt); let visit_inputs = match data.fields { Fields::Named(ref named) => named @@ -95,13 +113,13 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { TokenStream::from(quote!( - impl ::compute_graph::rule::InputVisitable for #name { + impl #generics ::compute_graph::rule::InputVisitable for #name #generics_only_names #where_clause { fn visit_inputs(&self, visitor: &mut impl ::compute_graph::rule::InputVisitor) { #(#visit_inputs)* } } - impl #name { + impl #generics #name #generics_only_names #where_clause { #(#input_value_methods)* } diff --git a/crates/derive_test/src/lib.rs b/crates/derive_test/src/lib.rs index 73be8a1..aa9d62b 100644 --- a/crates/derive_test/src/lib.rs +++ b/crates/derive_test/src/lib.rs @@ -1,3 +1,4 @@ +use compute_graph::node::NodeValue; use compute_graph::rule::{Input, InputVisitable, Rule}; #[derive(InputVisitable)] @@ -24,6 +25,15 @@ impl Rule for Add2 { } } +#[derive(InputVisitable)] +struct Passthrough(Input); +impl Rule for Passthrough { + type Output = T; + fn evaluate(&mut self) -> Self::Output { + self.input_0().clone() + } +} + #[cfg(test)] mod tests { use compute_graph::builder::GraphBuilder;