Use associated types for rule outputs

This commit is contained in:
Shadowfacts 2024-10-31 11:19:16 -04:00
parent d8f2a393ba
commit ca0b77349a

View File

@ -107,7 +107,7 @@ impl<Output: Clone + 'static> Graph<Output, Asynchronous> {
}
impl<O: Clone + 'static, S: Synchronicity> Graph<O, S> {
pub fn set_output<R: Rule<O> + 'static>(&mut self, rule: R) {
pub fn set_output<R: Rule<Output = O>>(&mut self, rule: R) {
let input = self.add_rule(rule);
self.output = Some(input.node_idx);
}
@ -126,14 +126,18 @@ impl<O: Clone + 'static, S: Synchronicity> Graph<O, S> {
return self.add_node(ConstNode::new(value.clone()));
}
pub fn add_rule<R: Rule<V> + 'static, V: Clone + 'static>(&mut self, rule: R) -> Input<V> {
pub fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule,
R::Output: Clone,
{
return self.add_node(RuleNode::new(rule));
}
pub fn add_invalidatable_rule<R, V, F>(&mut self, mut f: F) -> Input<V>
pub fn add_invalidatable_rule<R, F>(&mut self, mut f: F) -> Input<R::Output>
where
R: Rule<V> + 'static,
V: Clone + 'static,
R: Rule,
R::Output: Clone,
F: FnMut(InvalidationSignal<S>) -> R,
{
let node_idx = Rc::new(Cell::new(None));
@ -189,22 +193,23 @@ impl<O: Clone + 'static, S: Synchronicity> Graph<O, S> {
}
impl<O: Clone + 'static> Graph<O, Asynchronous> {
pub fn set_async_output<R: AsyncRule<O> + 'static>(&mut self, rule: R) {
pub fn set_async_output<R: AsyncRule<Output = O>>(&mut self, rule: R) {
let input = self.add_async_rule(rule);
self.output = Some(input.node_idx);
}
pub fn add_async_rule<R: AsyncRule<V> + 'static, V: Clone + 'static>(
&mut self,
rule: R,
) -> Input<V> {
pub fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: AsyncRule,
R::Output: Clone,
{
self.add_node(AsyncRuleNode::new(rule))
}
pub fn add_invalidatable_async_rule<R, V, F>(&mut self, mut f: F) -> Input<V>
pub fn add_invalidatable_async_rule<R, F>(&mut self, mut f: F) -> Input<R::Output>
where
R: AsyncRule<V> + 'static,
V: Clone + 'static,
R: AsyncRule,
R::Output: Clone,
F: FnMut(InvalidationSignal<Asynchronous>) -> R,
{
let node_idx = Rc::new(Cell::new(None));
@ -494,7 +499,7 @@ struct RuleNode<R, V, S> {
synchronicity: std::marker::PhantomData<S>,
}
impl<R: Rule<V>, V, S> RuleNode<R, V, S> {
impl<R: Rule, S> RuleNode<R, R::Output, S> {
fn new(rule: R) -> Self {
Self {
rule,
@ -505,7 +510,7 @@ impl<R: Rule<V>, V, S> RuleNode<R, V, S> {
}
}
impl<R: Rule<V> + 'static, V: Clone + 'static, S: Synchronicity> Node<V, S> for RuleNode<R, V, S> {
impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S> {
fn is_valid(&self) -> bool {
self.valid
}
@ -531,7 +536,7 @@ impl<R: Rule<V> + 'static, V: Clone + 'static, S: Synchronicity> Node<V, S> for
S::make_update_result()
}
fn value_rc(&self) -> Rc<RefCell<Option<V>>> {
fn value_rc(&self) -> Rc<RefCell<Option<R::Output>>> {
Rc::clone(&self.value)
}
}
@ -542,7 +547,7 @@ struct AsyncRuleNode<R, V> {
valid: bool,
}
impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> {
impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
fn new(rule: R) -> Self {
Self {
rule,
@ -552,7 +557,10 @@ impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> {
}
}
impl<R: AsyncRule<V> + 'static, V: Clone + 'static> Node<V, Asynchronous> for AsyncRuleNode<R, V> {
impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
where
R::Output: Clone,
{
fn is_valid(&self) -> bool {
self.valid
}
@ -575,12 +583,12 @@ impl<R: AsyncRule<V> + 'static, V: Clone + 'static> Node<V, Asynchronous> for As
Box::pin(self.do_update())
}
fn value_rc(&self) -> Rc<RefCell<Option<V>>> {
fn value_rc(&self) -> Rc<RefCell<Option<R::Output>>> {
Rc::clone(&self.value)
}
}
impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> {
impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
async fn do_update(&mut self) {
let new_value = self.rule.evaluate().await;
self.valid = true;
@ -588,16 +596,20 @@ impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> {
}
}
pub trait Rule<Output> {
pub trait Rule: 'static {
type Output;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor);
fn evaluate(&mut self) -> Output;
fn evaluate(&mut self) -> Self::Output;
}
pub trait AsyncRule<Output> {
pub trait AsyncRule: 'static {
type Output: 'static;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor);
async fn evaluate(&mut self) -> Output;
async fn evaluate(&mut self) -> Self::Output;
}
pub trait InputVisitor {
@ -616,7 +628,8 @@ mod tests {
}
struct ConstantRule(i32);
impl Rule<i32> for ConstantRule {
impl Rule for ConstantRule {
type Output = i32;
fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> i32 {
self.0
@ -641,7 +654,8 @@ mod tests {
}
struct Double(Input<i32>);
impl Rule<i32> for Double {
impl Rule for Double {
type Output = i32;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
visitor.visit(&mut self.0);
}
@ -667,7 +681,8 @@ mod tests {
assert_eq!(graph.freeze().unwrap().evaluate(), 168);
}
struct Inc(i32);
impl Rule<i32> for Inc {
impl Rule for Inc {
type Output = i32;
fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> i32 {
self.0 += 1;
@ -694,7 +709,8 @@ mod tests {
}
struct Add(Input<i32>, Input<i32>);
impl Rule<i32> for Add {
impl Rule for Add {
type Output = i32;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
visitor.visit(&mut self.0);
visitor.visit(&mut self.1);
@ -741,7 +757,8 @@ mod tests {
}
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
impl Rule<i32> for DeferredInput {
impl Rule for DeferredInput {
type Output = i32;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
let mut borrowed = self.0.borrow_mut();
let input = borrowed.as_mut().unwrap();
@ -811,7 +828,8 @@ mod tests {
#[tokio::test]
async fn async_rule() {
struct AsyncConst(i32);
impl AsyncRule<i32> for AsyncConst {
impl AsyncRule for AsyncConst {
type Output = i32;
fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {}
async fn evaluate(&mut self) -> i32 {
self.0