diff options
Diffstat (limited to 'neb/__init__.py')
| -rw-r--r-- | neb/__init__.py | 170 |
1 files changed, 165 insertions, 5 deletions
diff --git a/neb/__init__.py b/neb/__init__.py index f5afe60..583bef8 100644 --- a/neb/__init__.py +++ b/neb/__init__.py @@ -1,6 +1,166 @@ -from .structs import * -from .lexer import * -from .parser import * -from .interpreter import * +from .lexer import lex +from .parser import parse from .exceptions import * -from .typeclass import * +from .typeclass import TypeEnum, is_subtype_of +from .structs import * + +def interpret(exprs, env, ns=None): + ret = None + for expr in exprs: + ret = evaluate(expr, env, ns) + return ret + +def evaluate(expr, env, ns=None): + if isinstance(expr, Literal) or isinstance(expr, Function) or isinstance(expr, Type): + return expr + elif isinstance(expr, Symbol): + if env.contains(expr.name): + return evaluate(env.get(expr.name), env, ns) + elif ns is not None and env.contains(f"{ns}/{expr.name}"): + return evaluate(env.get(f"{ns}/{expr.name}"), env, ns) + else: + raise NebPanic(f"no such symbol: {expr}") + + # if it's an empty list, return it + elif len(expr.args) == 0: + return expr + + if not isinstance(expr.args[0], Symbol): + raise NebPanic("can't evaluate without a symbol") + name = expr.args[0].name + if env.contains(name): + return env.get(name).call(expr, env, ns) + elif ns is not None and env.contains(f"{ns}/{name}"): + return env.get(f"{ns}/{name}").call(expr, env, ns) + else: + raise InterpretPanic(expr.args[0], "unable to evaluate") + +class Function: + + def __init__(self, name, params, body, args=None, many=None): + self.name = name + self.params = params + self.body = body + self.args = args + self.many = many + self.type_ = TypeEnum.ANY # TODO no it's not + + def describe(self, name=None): + if name is None: + name = self.name + out = [f"({name}"] + if self.args is not None: + for arg in self.args: + out.append(f"{arg}") + if self.many is not None: + out.append(f"{self.many}") + return " ".join(out) + ")" + + def arity_check(self, symbol, params): + min_arity = len([a for a in self.args if not a.optional]) + max_arity = -1 if self.many is not None else len(self.args) + + if len(params) < min_arity or (max_arity >= 0 and len(params) > max_arity): + if max_arity < 0: + fmt = f"{min_arity}+" + elif min_arity != max_arity: + fmt = f"{min_arity}-{max_arity}" + else: + fmt = f"{min_arity}" + raise InterpretPanic(symbol, f"expected [{fmt}] arguments, received {len(params)}") + return True + + def evaluate_args(self, symbol, params, env, ns): + self.arity_check(symbol, params) + ret = [] + + for idx, param in enumerate(params): + if idx < len(self.args): + arg = self.args[idx] + else: + arg = self.many + if arg.lazy: + ret.append(param) + continue + ev = evaluate(param, env, ns) + if not is_subtype_of(ev.type_, arg.type_): + exp = f"{arg.type_}" + rec = f"{ev.type_}" + raise InterpretPanic(symbol, f"received {rec}, expected {exp}", ev) + ret.append(ev) + return ret + + def call(self, expr, env): + pass + +class Builtin(Function): + + def __init__(self, callable_, args=None, many=None): + super().__init__("<builtin>", None, callable_, args, many) + + def __str__(self): + return f"builtin function {self.name}" + + def call(self, expr, env, ns): + self.arity_check(expr.args[0], expr.args[1:]) + evaluated_args = self.evaluate_args(expr.args[0], expr.args[1:], env, ns) + return self.body(expr.args[0], evaluated_args, env, ns) + + +class UserFunction(Function): + + def __init__(self, name, params, body): + newparams, args, many = self.process_params(name, params) + super().__init__(name, newparams, body, args, many) + + def __str__(self): + out = f"(func {self.name} (" + args_list = [f"{a.name} {a.type_}" for a in self.args] + if self.many: + args_list.append(f"{self.many.name} {self.many.type_}") + out = out + " ".join(args_list) + ") " + for expr in self.body: + out = out + f"{expr} " + return out.strip() + ")" + + + def process_params(self, name, params): + newparams = [] + args = [] + many = None + prev_type = False + first = True + for param in params: + if isinstance(param, Symbol): + if many is not None: + raise NebPanic("& must be last argument") + if param.name == "&": + many = Arg(param.name, TypeEnum.ANY) + else: + newparams.append(param) + args.append(Arg(param.name, TypeEnum.ANY)) + prev_type = False + elif isinstance(param, Type) and not prev_type and not first: + typ = TypeEnum.__getattr__(param.name[1:].upper()) + if many is None: + args[-1].type_ = typ + else: + many.type_ = typ + prev_type = True + else: + raise NebPanic("invalid :func signature", param) + first = False + return newparams, args, many + + def call(self, expr, env, ns): + self.arity_check(expr.args[0], expr.args[1:]) + evaluated_args = self.evaluate_args(expr.args[0], expr.args[1:], env, ns) + this_env = Environment(env) + for idx, param in enumerate(self.params): + this_env.register(param.name, evaluated_args[idx]) + + # if we got "many", wrap the rest in a list + if self.many: + this_env.register(self.many.name, List(evaluated_args[len(self.params):])) + + return interpret(self.body, env=this_env, ns=ns) |
