diff options
| author | mryouse | 2022-05-19 23:44:53 +0000 |
|---|---|---|
| committer | mryouse | 2022-05-19 23:44:53 +0000 |
| commit | 0a7853b37e8e0ea86ce355338be285d298b90080 (patch) | |
| tree | 9b8b4ea69359e24ceb63af597af22c7a1c8353e7 /interpreter.py | |
| parent | 5d4c63b0664561f6fb696f552bea92f612118908 (diff) | |
refactor: this might be worse
Diffstat (limited to 'interpreter.py')
| -rw-r--r-- | interpreter.py | 126 |
1 files changed, 87 insertions, 39 deletions
diff --git a/interpreter.py b/interpreter.py index 9a375d0..f342caa 100644 --- a/interpreter.py +++ b/interpreter.py @@ -1,6 +1,49 @@ from parser import Expr -ENVIRONMENT = {} +class Function: + + def __init__(self, callable_, *arities): + self.callable = callable_ + if len(arities) == 0: + self.arities = None + else: + self.arities = arities + + def call(self, expr): + if self.arities is not None and len(expr.args) not in self.arities: + #if self.arity >= 0 and len(args) != self.arity: + fmt = f"[{self.arities[0]}" + for arity in self.arities[1:]: + fmt += f", {arity}" + fmt += "]" + raise Exception(f"expected {fmt} arguments, received {len(expr.args)}") + return self.callable(expr) + +class Environment: + + def __init__(self, parent=None): + self.parent = parent + self.environment = {} + + def register(self, key, value): + self.environment[key] = value + + def contains(self, key): + if key in self.environment: + return True + elif self.parent is not None: + return self.parent.contains(key) + else: + return False + + def get(self, key): + if not self.contains(key): + raise Exception(f"undefined symbol: '{key}") + if key in self.environment: + return self.environment[key] + else: + return self.parent.get(key) + def interpret(exprs): ret = None @@ -15,40 +58,19 @@ def evaluate(expr): # defined symbols will evaluate to their expression, # undefined will return their own name # TODO this is bad - if not expr.name in ENVIRONMENT: + if not GLOBALS.contains(expr.name): raise Exception(f"no such symbol: {expr}") - return interpretEnv(expr, ENVIRONMENT[expr.name]) - elif expr.symbol.name == "not": - return interpretNot(expr) - elif expr.symbol.name in ("*", "/"): - return interpretFactor(expr) - elif expr.symbol.name in ("-", "+"): - return interpretTerm(expr) - elif expr.symbol.name in (">", ">=", "<", "<="): - return interpretComparison(expr) - elif expr.symbol.name == "eq?": - return interpretEq(expr) - elif expr.symbol.name == "and": - return interpretAnd(expr) - elif expr.symbol.name == "or": - return interpretOr(expr) - - # flow control - elif expr.symbol.name == "if": - return interpretIf(expr) - - # io - elif expr.symbol.name == "print": - return interpretPrint(expr) - - # global variables - elif expr.symbol.name == "def": - return interpretDef(expr) - + return interpretEnv(expr, GLOBALS.get(expr.name)) + name = expr.symbol.name + if name == "def": + return interpretDef(expr, GLOBALS) + elif GLOBALS.contains(name): + return GLOBALS.get(name).call(expr) else: raise Exception(f"unable to evaluate: {expr}") +GLOBALS = Environment() def interpretOr(expr): # or returns true for the first expression that returns true @@ -62,6 +84,8 @@ def interpretOr(expr): return True return False +GLOBALS.register("or", Function(interpretOr)) + def interpretAnd(expr): # and returns false for the first expression that returns false if len(expr.args) < 2: @@ -74,6 +98,8 @@ def interpretAnd(expr): return False return True +GLOBALS.register("and", Function(interpretAnd)) + def interpretEq(expr): # equal if len(expr.args) != 2: @@ -81,7 +107,9 @@ def interpretEq(expr): first = evaluate(expr.args[0]) second = evaluate(expr.args[1]) return first == second - + +GLOBALS.register("eq?", Function(interpretEq, 2)) + def interpretComparison(expr): if len(expr.args) != 2: raise Exception("comparisons have two operands") @@ -101,20 +129,30 @@ def interpretComparison(expr): elif expr.symbol.name == "<=": return left <= right +GLOBALS.register(">", Function(interpretComparison, 2)) +GLOBALS.register(">=", Function(interpretComparison, 2)) +GLOBALS.register("<", Function(interpretComparison, 2)) +GLOBALS.register("<=", Function(interpretComparison, 2)) + def interpretTerm(expr): if len(expr.args) < 1: raise Exception("term has at least one operand") - res = 0 + res = None for arg in expr.args: ev = evaluate(arg) if type(ev) not in (int, float): raise Exception("term must be a number") - if expr.symbol.name == "+": + if res is None: + res = ev + elif expr.symbol.name == "+": res += ev elif expr.symbol.name == "-": res -= ev return res +GLOBALS.register("+", Function(interpretTerm)) +GLOBALS.register("-", Function(interpretTerm)) + def interpretFactor(expr): if expr.symbol.name == "/": if len(expr.args) != 2: @@ -140,6 +178,9 @@ def interpretFactor(expr): res = res * tmp return res +GLOBALS.register("*", Function(interpretFactor)) +GLOBALS.register("/", Function(interpretFactor, 2)) + def interpretNot(expr): if len(expr.args) != 1: raise Exception("'not' takes one operand") @@ -148,6 +189,8 @@ def interpretNot(expr): raise Exception("'not' only works on booleans") return not res +GLOBALS.register("not", Function(interpretNot, 1)) + def interpretIf(expr): # if cond t-branch [f-branch] if len(expr.args) not in (2, 3): @@ -161,20 +204,22 @@ def interpretIf(expr): return evaluate(expr.args[2]) return None # this shouldn't be reached +GLOBALS.register("if", Function(interpretIf, 2, 3)) + def interpretPrint(expr): - if len(expr.args) == 0: - print() - elif len(expr.args) == 1: + if len(expr.args) == 1: ev = evaluate(expr.args[0]) if not isinstance(ev, str): raise Exception("can only 'print' strings") print(ev) else: - raise Exception("'print' takes zero or one argument") + raise Exception("'print' takes one argument") return None # print returns nothing -def interpretDef(expr): +GLOBALS.register("print", Function(interpretPrint, 1)) + +def interpretDef(expr, env): if len(expr.args) != 2: raise Exception("'def' requires a name and an expression") if not isinstance(expr.args[0], Expr.Symbol): @@ -183,8 +228,11 @@ def interpretDef(expr): if not isinstance(name, str): raise Exception("'def' requires a string literal as a name") - ENVIRONMENT[name] = expr.args[1] + env.register(name, expr.args[1]) return None +GLOBALS.register("def", Function(interpretDef, 2)) + def interpretEnv(expr, env_expr): return evaluate(env_expr) # TODO more than this + |
