aboutsummaryrefslogtreecommitdiff
path: root/neb/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'neb/__init__.py')
-rw-r--r--neb/__init__.py170
1 files changed, 165 insertions, 5 deletions
diff --git a/neb/__init__.py b/neb/__init__.py
index f5afe60..583bef8 100644
--- a/neb/__init__.py
+++ b/neb/__init__.py
@@ -1,6 +1,166 @@
-from .structs import *
-from .lexer import *
-from .parser import *
-from .interpreter import *
+from .lexer import lex
+from .parser import parse
from .exceptions import *
-from .typeclass import *
+from .typeclass import TypeEnum, is_subtype_of
+from .structs import *
+
+def interpret(exprs, env, ns=None):
+ ret = None
+ for expr in exprs:
+ ret = evaluate(expr, env, ns)
+ return ret
+
+def evaluate(expr, env, ns=None):
+ if isinstance(expr, Literal) or isinstance(expr, Function) or isinstance(expr, Type):
+ return expr
+ elif isinstance(expr, Symbol):
+ if env.contains(expr.name):
+ return evaluate(env.get(expr.name), env, ns)
+ elif ns is not None and env.contains(f"{ns}/{expr.name}"):
+ return evaluate(env.get(f"{ns}/{expr.name}"), env, ns)
+ else:
+ raise NebPanic(f"no such symbol: {expr}")
+
+ # if it's an empty list, return it
+ elif len(expr.args) == 0:
+ return expr
+
+ if not isinstance(expr.args[0], Symbol):
+ raise NebPanic("can't evaluate without a symbol")
+ name = expr.args[0].name
+ if env.contains(name):
+ return env.get(name).call(expr, env, ns)
+ elif ns is not None and env.contains(f"{ns}/{name}"):
+ return env.get(f"{ns}/{name}").call(expr, env, ns)
+ else:
+ raise InterpretPanic(expr.args[0], "unable to evaluate")
+
+class Function:
+
+ def __init__(self, name, params, body, args=None, many=None):
+ self.name = name
+ self.params = params
+ self.body = body
+ self.args = args
+ self.many = many
+ self.type_ = TypeEnum.ANY # TODO no it's not
+
+ def describe(self, name=None):
+ if name is None:
+ name = self.name
+ out = [f"({name}"]
+ if self.args is not None:
+ for arg in self.args:
+ out.append(f"{arg}")
+ if self.many is not None:
+ out.append(f"{self.many}")
+ return " ".join(out) + ")"
+
+ def arity_check(self, symbol, params):
+ min_arity = len([a for a in self.args if not a.optional])
+ max_arity = -1 if self.many is not None else len(self.args)
+
+ if len(params) < min_arity or (max_arity >= 0 and len(params) > max_arity):
+ if max_arity < 0:
+ fmt = f"{min_arity}+"
+ elif min_arity != max_arity:
+ fmt = f"{min_arity}-{max_arity}"
+ else:
+ fmt = f"{min_arity}"
+ raise InterpretPanic(symbol, f"expected [{fmt}] arguments, received {len(params)}")
+ return True
+
+ def evaluate_args(self, symbol, params, env, ns):
+ self.arity_check(symbol, params)
+ ret = []
+
+ for idx, param in enumerate(params):
+ if idx < len(self.args):
+ arg = self.args[idx]
+ else:
+ arg = self.many
+ if arg.lazy:
+ ret.append(param)
+ continue
+ ev = evaluate(param, env, ns)
+ if not is_subtype_of(ev.type_, arg.type_):
+ exp = f"{arg.type_}"
+ rec = f"{ev.type_}"
+ raise InterpretPanic(symbol, f"received {rec}, expected {exp}", ev)
+ ret.append(ev)
+ return ret
+
+ def call(self, expr, env):
+ pass
+
+class Builtin(Function):
+
+ def __init__(self, callable_, args=None, many=None):
+ super().__init__("<builtin>", None, callable_, args, many)
+
+ def __str__(self):
+ return f"builtin function {self.name}"
+
+ def call(self, expr, env, ns):
+ self.arity_check(expr.args[0], expr.args[1:])
+ evaluated_args = self.evaluate_args(expr.args[0], expr.args[1:], env, ns)
+ return self.body(expr.args[0], evaluated_args, env, ns)
+
+
+class UserFunction(Function):
+
+ def __init__(self, name, params, body):
+ newparams, args, many = self.process_params(name, params)
+ super().__init__(name, newparams, body, args, many)
+
+ def __str__(self):
+ out = f"(func {self.name} ("
+ args_list = [f"{a.name} {a.type_}" for a in self.args]
+ if self.many:
+ args_list.append(f"{self.many.name} {self.many.type_}")
+ out = out + " ".join(args_list) + ") "
+ for expr in self.body:
+ out = out + f"{expr} "
+ return out.strip() + ")"
+
+
+ def process_params(self, name, params):
+ newparams = []
+ args = []
+ many = None
+ prev_type = False
+ first = True
+ for param in params:
+ if isinstance(param, Symbol):
+ if many is not None:
+ raise NebPanic("& must be last argument")
+ if param.name == "&":
+ many = Arg(param.name, TypeEnum.ANY)
+ else:
+ newparams.append(param)
+ args.append(Arg(param.name, TypeEnum.ANY))
+ prev_type = False
+ elif isinstance(param, Type) and not prev_type and not first:
+ typ = TypeEnum.__getattr__(param.name[1:].upper())
+ if many is None:
+ args[-1].type_ = typ
+ else:
+ many.type_ = typ
+ prev_type = True
+ else:
+ raise NebPanic("invalid :func signature", param)
+ first = False
+ return newparams, args, many
+
+ def call(self, expr, env, ns):
+ self.arity_check(expr.args[0], expr.args[1:])
+ evaluated_args = self.evaluate_args(expr.args[0], expr.args[1:], env, ns)
+ this_env = Environment(env)
+ for idx, param in enumerate(self.params):
+ this_env.register(param.name, evaluated_args[idx])
+
+ # if we got "many", wrap the rest in a list
+ if self.many:
+ this_env.register(self.many.name, List(evaluated_args[len(self.params):]))
+
+ return interpret(self.body, env=this_env, ns=ns)