From 45a043f126ea926689cc4b16dbf0bebffa2512a2 Mon Sep 17 00:00:00 2001 From: mryouse Date: Thu, 9 Jun 2022 20:58:55 +0000 Subject: add arity check to user defined functions --- interpreter.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'interpreter.py') diff --git a/interpreter.py b/interpreter.py index 3266830..a4bba6d 100644 --- a/interpreter.py +++ b/interpreter.py @@ -20,6 +20,12 @@ class Function: else: self.arities = arities + def arity_check(self, symbol, args): + if self.arities is not None and len(args) not in self.arities: + fmt = ", ".join([f"{arity}" for arity in self.arities]) + raise InterpretPanic(symbol, f"expected [{fmt}] arguments, received {len(args)}") + return True + def call(self, expr, env): pass @@ -29,12 +35,7 @@ class Builtin(Function): super().__init__("", None, callable_, *arities) def call(self, expr, env): - if self.arities is not None and len(expr.args[1:]) not in self.arities: - fmt = f"[{self.arities[0]}" - for arity in self.arities[1:]: - fmt += f", {arity}" - fmt += "]" - raise InterpretPanic(expr.args[0], f"expected {fmt} arguments, received {len(expr.args) - 1}") + self.arity_check(expr.args[0], expr.args[1:]) return self.body(expr.args[0], expr.args[1:], env) class UserFunction(Function): @@ -43,6 +44,7 @@ class UserFunction(Function): super().__init__(name, params, body, len(params)) def call(self, expr, env): + self.arity_check(expr.args[0], expr.args[1:]) this_env = Environment(env) for idx, param in enumerate(self.params): # TODO this is wrong!!! this won't always be a literal -- cgit v1.2.3