Fix derive macro with generics

This commit is contained in:
Shadowfacts 2024-11-03 14:57:21 -05:00
parent e69014d98d
commit b8ad929d0b
2 changed files with 33 additions and 5 deletions

View File

@ -2,7 +2,7 @@ use proc_macro::TokenStream;
use proc_macro2::Literal; use proc_macro2::Literal;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{ use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam,
PathArguments, Type, PathArguments, Type,
}; };
@ -25,14 +25,32 @@ pub fn derive_rule(input: TokenStream) -> TokenStream {
derive_rule_struct(&input, data) derive_rule_struct(&input, data)
} else { } else {
TokenStream::from( TokenStream::from(
syn::Error::new(input.ident.span(), "Only structs can derive `Rule`") syn::Error::new(
.to_compile_error(), input.ident.span(),
"Only structs can derive `InputVisitable`",
)
.to_compile_error(),
) )
} }
} }
fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
let name = &input.ident; let name = &input.ident;
let lt = &input.generics.lt_token;
let params = &input.generics.params;
let gt = &input.generics.gt_token;
let generics = quote!(#lt #params #gt);
let where_clause = &input.generics.where_clause;
let params_only_names = params.iter().map(|p| match p {
GenericParam::Lifetime(_) => {
panic!("Lifetime generics aren't supported when deriving `InputVisitable`")
}
GenericParam::Type(ty) => &ty.ident,
GenericParam::Const(_) => {
panic!("Const generics aren't supported when deriving `InputVisitable`")
}
});
let generics_only_names = quote!(#lt #(#params_only_names),* #gt);
let visit_inputs = match data.fields { let visit_inputs = match data.fields {
Fields::Named(ref named) => named Fields::Named(ref named) => named
@ -95,13 +113,13 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
TokenStream::from(quote!( TokenStream::from(quote!(
impl ::compute_graph::rule::InputVisitable for #name { impl #generics ::compute_graph::rule::InputVisitable for #name #generics_only_names #where_clause {
fn visit_inputs(&self, visitor: &mut impl ::compute_graph::rule::InputVisitor) { fn visit_inputs(&self, visitor: &mut impl ::compute_graph::rule::InputVisitor) {
#(#visit_inputs)* #(#visit_inputs)*
} }
} }
impl #name { impl #generics #name #generics_only_names #where_clause {
#(#input_value_methods)* #(#input_value_methods)*
} }

View File

@ -1,3 +1,4 @@
use compute_graph::node::NodeValue;
use compute_graph::rule::{Input, InputVisitable, Rule}; use compute_graph::rule::{Input, InputVisitable, Rule};
#[derive(InputVisitable)] #[derive(InputVisitable)]
@ -24,6 +25,15 @@ impl Rule for Add2 {
} }
} }
#[derive(InputVisitable)]
struct Passthrough<T: NodeValue + Clone>(Input<T>);
impl<T: NodeValue + Clone> Rule for Passthrough<T> {
type Output = T;
fn evaluate(&mut self) -> Self::Output {
self.input_0().clone()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use compute_graph::builder::GraphBuilder; use compute_graph::builder::GraphBuilder;