diff options
Diffstat (limited to 'interpreter.py')
| -rw-r--r-- | interpreter.py | 162 | 
1 files changed, 101 insertions, 61 deletions
| diff --git a/interpreter.py b/interpreter.py index e8c0271..d5b2b4c 100644 --- a/interpreter.py +++ b/interpreter.py @@ -1,15 +1,25 @@  from parser import Expr  class Function: -     -    def __init__(self, callable_, *arities): -        self.callable = callable_ + +    def __init__(self, name, params, body, *arities): +        self.name = name +        self.params = params +        self.body = body          if len(arities) == 0:              self.arities = None          else:              self.arities = arities -    def call(self, expr): +    def call(self, expr, env): +        pass + +class Builtin(Function): +     +    def __init__(self, callable_, *arities): +        super().__init__("<builtin>", None, callable_, *arities) + +    def call(self, expr, env):          if self.arities is not None and len(expr.args) not in self.arities:          #if self.arity >= 0 and len(args) != self.arity:              fmt = f"[{self.arities[0]}" @@ -17,7 +27,18 @@ class Function:                  fmt += f", {arity}"              fmt += "]"              raise Exception(f"expected {fmt} arguments, received {len(expr.args)}") -        return self.callable(expr) +        return self.body(expr, env) + +class UserFunction(Function): +     +    def __init__(self, name, params, body): +        super().__init__(name, params, body, len(params)) + +    def call(self, expr, env): +        this_env = Environment(env) +        for idx, param in enumerate(self.params): +            this_env.register(param.name, expr.args[idx]) +        return interpret(self.body, this_env)  class Environment: @@ -44,75 +65,77 @@ class Environment:          else:              return self.parent.get(key) +    def __str__(self): +        out = "" +        for k, v in self.environment.items(): +            out += f"{k}: {v}, " +        return out + +GLOBALS = Environment() -def interpret(exprs): +def interpret(exprs, env=GLOBALS):      ret = None      for expr in exprs: -        ret = evaluate(expr) +        ret = evaluate(expr, env)      return ret -def evaluate(expr): +def evaluate(expr, env):      if isinstance(expr, Expr.Literal):          return expr.value      elif isinstance(expr, Expr.Symbol): -        # defined symbols will evaluate to their expression, -        # undefined will return their own name -        # TODO this is bad -        if not GLOBALS.contains(expr.name): +        if not env.contains(expr.name):              raise Exception(f"no such symbol: {expr}") -        return interpretEnv(expr, GLOBALS.get(expr.name)) +        return interpretEnv(expr, env.get(expr.name), env)      name = expr.symbol.name      if name == "def": -        return interpretDef(expr, GLOBALS) -    elif GLOBALS.contains(name): -        return GLOBALS.get(name).call(expr) +        return interpretDef(expr, env) +    elif env.contains(name): +        return env.get(name).call(expr, env)      else:          raise Exception(f"unable to evaluate: {expr}") -GLOBALS = Environment() - -def interpretOr(expr): +def interpretOr(expr, env):      # or returns true for the first expression that returns true      if len(expr.args) < 2:          raise Exception("'or' has at least two operands")      for arg in expr.args: -        ev = evaluate(arg) +        ev = evaluate(arg, env)          if ev not in (True, False):              raise Exception("'or' needs boolean arguments")          if ev == True:              return True      return False -GLOBALS.register("or", Function(interpretOr)) +GLOBALS.register("or", Builtin(interpretOr)) -def interpretAnd(expr): +def interpretAnd(expr, env):      # and returns false for the first expression that returns false      if len(expr.args) < 2:          raise Exception("'and' has at least two operands")      for arg in expr.args: -        ev = evaluate(arg) +        ev = evaluate(arg, env)          if ev not in (True, False):              raise Exception("'and' needs boolean arguments")          if ev == False:              return False      return True -GLOBALS.register("and", Function(interpretAnd)) +GLOBALS.register("and", Builtin(interpretAnd)) -def interpretEq(expr): +def interpretEq(expr, env):      # equal -    first = evaluate(expr.args[0]) -    second = evaluate(expr.args[1]) +    first = evaluate(expr.args[0], env) +    second = evaluate(expr.args[1], env)      return first == second -GLOBALS.register("eq?", Function(interpretEq, 2)) +GLOBALS.register("eq?", Builtin(interpretEq, 2)) -def interpretComparison(expr): -    left = evaluate(expr.args[0]) +def interpretComparison(expr, env): +    left = evaluate(expr.args[0], env)      if type(left) not in (int, float):          raise Exception("'left' must be a number") -    right = evaluate(expr.args[1]) +    right = evaluate(expr.args[1], env)      if type(right) not in (int, float):          raise Exception("'right' must be a number") @@ -125,17 +148,17 @@ def interpretComparison(expr):      elif expr.symbol.name == "<=":          return left <= right -GLOBALS.register(">", Function(interpretComparison, 2)) -GLOBALS.register(">=", Function(interpretComparison, 2)) -GLOBALS.register("<", Function(interpretComparison, 2)) -GLOBALS.register("<=", Function(interpretComparison, 2)) +GLOBALS.register(">", Builtin(interpretComparison, 2)) +GLOBALS.register(">=", Builtin(interpretComparison, 2)) +GLOBALS.register("<", Builtin(interpretComparison, 2)) +GLOBALS.register("<=", Builtin(interpretComparison, 2)) -def interpretTerm(expr): +def interpretTerm(expr, env):      if len(expr.args) < 1:          raise Exception("term has at least one operand")      res = None      for arg in expr.args: -        ev = evaluate(arg) +        ev = evaluate(arg, env)          if type(ev) not in (int, float):              raise Exception("term must be a number")          if res is None: @@ -146,65 +169,65 @@ def interpretTerm(expr):              res -= ev      return res -GLOBALS.register("+", Function(interpretTerm)) -GLOBALS.register("-", Function(interpretTerm)) +GLOBALS.register("+", Builtin(interpretTerm)) +GLOBALS.register("-", Builtin(interpretTerm)) -def interpretFactor(expr): +def interpretFactor(expr, env):      if expr.symbol.name == "/": -        num = evaluate(expr.args[0]) +        num = evaluate(expr.args[0], env)          if type(num) not in (int, float):              raise Exception("numerator must be a number") -        denom = evaluate(expr.args[1]) +        denom = evaluate(expr.args[1], env)          if type(denom) not in (int, float):              raise Exception("denominator must be a number")          return num / denom  # TODO floats and ints      else:          if len(expr.args) < 2:              raise Exception("'*' requires at least two operands") -        first = evaluate(expr.args[0]) +        first = evaluate(expr.args[0], env)          if type(first) not in (int, float):              raise Exception("'*' operand must be a number")          res = first          for arg in expr.args[1:]: -            tmp = evaluate(arg) +            tmp = evaluate(arg, env)              if type(tmp) not in (int, float):                  raise Exception("'*' operand must be a number")              res = res * tmp          return res -GLOBALS.register("*", Function(interpretFactor)) -GLOBALS.register("/", Function(interpretFactor, 2)) +GLOBALS.register("*", Builtin(interpretFactor)) +GLOBALS.register("/", Builtin(interpretFactor, 2)) -def interpretNot(expr): -    res = evaluate(expr.args[0]) +def interpretNot(expr, env): +    res = evaluate(expr.args[0], env)      if res not in (True, False):          raise Exception("'not' only works on booleans")      return not res -GLOBALS.register("not", Function(interpretNot, 1)) +GLOBALS.register("not", Builtin(interpretNot, 1)) -def interpretIf(expr): +def interpretIf(expr, env):      # if cond t-branch [f-branch] -    cond = evaluate(expr.args[0]) +    cond = evaluate(expr.args[0], env)      if cond not in (True, False):          raise Exception("'if' condition must be boolean")      if cond: -        return evaluate(expr.args[1]) +        return evaluate(expr.args[1], env)      elif len(expr.args) == 3: -        return evaluate(expr.args[2]) +        return evaluate(expr.args[2], env)      return None  # this shouldn't be reached -GLOBALS.register("if", Function(interpretIf, 2, 3)) +GLOBALS.register("if", Builtin(interpretIf, 2, 3)) -def interpretPrint(expr): -    ev = evaluate(expr.args[0]) +def interpretPrint(expr, env): +    ev = evaluate(expr.args[0], env)      if not isinstance(ev, str):          raise Exception("can only 'print' strings")      print(ev)      return None  # print returns nothing -GLOBALS.register("print", Function(interpretPrint, 1)) +GLOBALS.register("print", Builtin(interpretPrint, 1))  def interpretDef(expr, env):      if not isinstance(expr.args[0], Expr.Symbol): @@ -213,11 +236,28 @@ def interpretDef(expr, env):      if not isinstance(name, str):          raise Exception("'def' requires a string literal as a name") -    env.register(name, expr.args[1]) +    ev = evaluate(expr.args[1], env) +    if isinstance(ev, UserFunction): +        env.register(name, ev) +    else: +        env.register(name, expr.args[1])      return None -GLOBALS.register("def", Function(interpretDef, 2)) +GLOBALS.register("def", Builtin(interpretDef, 2)) + +def interpretLambda(expr, env): +    if expr.args[0].symbol != None: +        args = expr.args[0].args +        args = [expr.args[0].symbol] + args +        func = UserFunction("<lambda>", args, expr.args[1:]) +    else: +        func = UserFunction("<lambda>", [], expr.args[1:]) +    #GLOBALS.register(name, func) +    return func + +GLOBALS.register("lambda", Builtin(interpretLambda)) -def interpretEnv(expr, env_expr): -    return evaluate(env_expr)  # TODO more than this +def interpretEnv(expr, env_expr, env): +    ev = evaluate(env_expr, env) +    return ev  # TODO more than this? | 
