diff options
Diffstat (limited to 'interpreter.py')
| -rw-r--r-- | interpreter.py | 126 | 
1 files changed, 87 insertions, 39 deletions
| diff --git a/interpreter.py b/interpreter.py index 9a375d0..f342caa 100644 --- a/interpreter.py +++ b/interpreter.py @@ -1,6 +1,49 @@  from parser import Expr -ENVIRONMENT = {} +class Function: +     +    def __init__(self, callable_, *arities): +        self.callable = callable_ +        if len(arities) == 0: +            self.arities = None +        else: +            self.arities = arities + +    def call(self, expr): +        if self.arities is not None and len(expr.args) not in self.arities: +        #if self.arity >= 0 and len(args) != self.arity: +            fmt = f"[{self.arities[0]}" +            for arity in self.arities[1:]: +                fmt += f", {arity}" +            fmt += "]" +            raise Exception(f"expected {fmt} arguments, received {len(expr.args)}") +        return self.callable(expr) + +class Environment: +     +    def __init__(self, parent=None): +        self.parent = parent +        self.environment = {} + +    def register(self, key, value): +        self.environment[key] = value + +    def contains(self, key): +        if key in self.environment: +            return True +        elif self.parent is not None: +            return self.parent.contains(key) +        else: +            return False + +    def get(self, key): +        if not self.contains(key): +            raise Exception(f"undefined symbol: '{key}") +        if key in self.environment: +            return self.environment[key] +        else: +            return self.parent.get(key) +  def interpret(exprs):      ret = None @@ -15,40 +58,19 @@ def evaluate(expr):          # defined symbols will evaluate to their expression,          # undefined will return their own name          # TODO this is bad -        if not expr.name in ENVIRONMENT: +        if not GLOBALS.contains(expr.name):              raise Exception(f"no such symbol: {expr}") -        return interpretEnv(expr, ENVIRONMENT[expr.name]) -    elif expr.symbol.name == "not": -        return interpretNot(expr) -    elif expr.symbol.name in ("*", "/"): -        return interpretFactor(expr) -    elif expr.symbol.name in ("-", "+"): -        return interpretTerm(expr) -    elif expr.symbol.name in (">", ">=", "<", "<="): -        return interpretComparison(expr) -    elif expr.symbol.name == "eq?": -        return interpretEq(expr) -    elif expr.symbol.name == "and": -        return interpretAnd(expr) -    elif expr.symbol.name == "or": -        return interpretOr(expr) - -    # flow control -    elif expr.symbol.name == "if": -        return interpretIf(expr) - -    # io -    elif expr.symbol.name == "print": -        return interpretPrint(expr) - -    # global variables -    elif expr.symbol.name == "def": -        return interpretDef(expr) - +        return interpretEnv(expr, GLOBALS.get(expr.name)) +    name = expr.symbol.name +    if name == "def": +        return interpretDef(expr, GLOBALS) +    elif GLOBALS.contains(name): +        return GLOBALS.get(name).call(expr)      else:          raise Exception(f"unable to evaluate: {expr}") +GLOBALS = Environment()  def interpretOr(expr):      # or returns true for the first expression that returns true @@ -62,6 +84,8 @@ def interpretOr(expr):              return True      return False +GLOBALS.register("or", Function(interpretOr)) +  def interpretAnd(expr):      # and returns false for the first expression that returns false      if len(expr.args) < 2: @@ -74,6 +98,8 @@ def interpretAnd(expr):              return False      return True +GLOBALS.register("and", Function(interpretAnd)) +  def interpretEq(expr):      # equal      if len(expr.args) != 2: @@ -81,7 +107,9 @@ def interpretEq(expr):      first = evaluate(expr.args[0])      second = evaluate(expr.args[1])      return first == second -             + +GLOBALS.register("eq?", Function(interpretEq, 2)) +  def interpretComparison(expr):      if len(expr.args) != 2:          raise Exception("comparisons have two operands") @@ -101,20 +129,30 @@ def interpretComparison(expr):      elif expr.symbol.name == "<=":          return left <= right +GLOBALS.register(">", Function(interpretComparison, 2)) +GLOBALS.register(">=", Function(interpretComparison, 2)) +GLOBALS.register("<", Function(interpretComparison, 2)) +GLOBALS.register("<=", Function(interpretComparison, 2)) +  def interpretTerm(expr):      if len(expr.args) < 1:          raise Exception("term has at least one operand") -    res = 0 +    res = None      for arg in expr.args:          ev = evaluate(arg)          if type(ev) not in (int, float):              raise Exception("term must be a number") -        if expr.symbol.name == "+": +        if res is None: +            res = ev +        elif expr.symbol.name == "+":              res += ev          elif expr.symbol.name == "-":              res -= ev      return res +GLOBALS.register("+", Function(interpretTerm)) +GLOBALS.register("-", Function(interpretTerm)) +  def interpretFactor(expr):      if expr.symbol.name == "/":          if len(expr.args) != 2: @@ -140,6 +178,9 @@ def interpretFactor(expr):              res = res * tmp          return res +GLOBALS.register("*", Function(interpretFactor)) +GLOBALS.register("/", Function(interpretFactor, 2)) +  def interpretNot(expr):      if len(expr.args) != 1:          raise Exception("'not' takes one operand") @@ -148,6 +189,8 @@ def interpretNot(expr):          raise Exception("'not' only works on booleans")      return not res +GLOBALS.register("not", Function(interpretNot, 1)) +  def interpretIf(expr):      # if cond t-branch [f-branch]      if len(expr.args) not in (2, 3): @@ -161,20 +204,22 @@ def interpretIf(expr):          return evaluate(expr.args[2])      return None  # this shouldn't be reached +GLOBALS.register("if", Function(interpretIf, 2, 3)) +  def interpretPrint(expr): -    if len(expr.args) == 0: -        print() -    elif len(expr.args) == 1: +    if len(expr.args) == 1:          ev = evaluate(expr.args[0])          if not isinstance(ev, str):              raise Exception("can only 'print' strings")          print(ev)      else: -        raise Exception("'print' takes zero or one argument") +        raise Exception("'print' takes one argument")      return None  # print returns nothing -def interpretDef(expr): +GLOBALS.register("print", Function(interpretPrint, 1)) + +def interpretDef(expr, env):      if len(expr.args) != 2:          raise Exception("'def' requires a name and an expression")      if not isinstance(expr.args[0], Expr.Symbol): @@ -183,8 +228,11 @@ def interpretDef(expr):      if not isinstance(name, str):          raise Exception("'def' requires a string literal as a name") -    ENVIRONMENT[name] = expr.args[1] +    env.register(name, expr.args[1])      return None +GLOBALS.register("def", Function(interpretDef, 2)) +  def interpretEnv(expr, env_expr):      return evaluate(env_expr)  # TODO more than this + | 
