aboutsummaryrefslogtreecommitdiff
path: root/neb/structs.py
diff options
context:
space:
mode:
Diffstat (limited to 'neb/structs.py')
-rw-r--r--neb/structs.py134
1 files changed, 134 insertions, 0 deletions
diff --git a/neb/structs.py b/neb/structs.py
index c8e7e8b..cba0c03 100644
--- a/neb/structs.py
+++ b/neb/structs.py
@@ -2,6 +2,7 @@ from dataclasses import dataclass
from enum import Enum, auto
from typing import Any
from .typeclass import TypeEnum
+#from . import Function
# tokens and types
# NOTE: this can probably be simplified
@@ -166,3 +167,136 @@ class Environment:
out += f"{k}: {v}, "
return out
+'''
+class Function:
+
+ def __init__(self, name, params, body, args=None, many=None):
+ self.name = name
+ self.params = params
+ self.body = body
+ self.args = args
+ self.many = many
+ self.type_ = TypeEnum.ANY # TODO no it's not
+
+ def describe(self, name=None):
+ if name is None:
+ name = self.name
+ out = [f"({name}"]
+ if self.args is not None:
+ for arg in self.args:
+ out.append(f"{arg}")
+ if self.many is not None:
+ out.append(f"{self.many}")
+ return " ".join(out) + ")"
+
+ def arity_check(self, symbol, params):
+ min_arity = len([a for a in self.args if not a.optional])
+ max_arity = -1 if self.many is not None else len(self.args)
+
+ if len(params) < min_arity or (max_arity >= 0 and len(params) > max_arity):
+ if max_arity < 0:
+ fmt = f"{min_arity}+"
+ elif min_arity != max_arity:
+ fmt = f"{min_arity}-{max_arity}"
+ else:
+ fmt = f"{min_arity}"
+ raise InterpretPanic(symbol, f"expected [{fmt}] arguments, received {len(params)}")
+ return True
+
+ def evaluate_args(self, symbol, params, env, ns):
+ self.arity_check(symbol, params)
+ ret = []
+
+ for idx, param in enumerate(params):
+ if idx < len(self.args):
+ arg = self.args[idx]
+ else:
+ arg = self.many
+ if arg.lazy:
+ ret.append(param)
+ continue
+ ev = evaluate(param, env, ns)
+ if not is_subtype_of(ev.type_, arg.type_):
+ exp = f"{arg.type_}"
+ rec = f"{ev.type_}"
+ raise InterpretPanic(symbol, f"received {rec}, expected {exp}", ev)
+ ret.append(ev)
+ return ret
+
+ def call(self, expr, env):
+ pass
+'''
+
+'''
+class Builtin(Function):
+
+ def __init__(self, callable_, args=None, many=None):
+ super().__init__("<builtin>", None, callable_, args, many)
+
+ def __str__(self):
+ return f"builtin function {self.name}"
+
+ def call(self, expr, env, ns):
+ self.arity_check(expr.args[0], expr.args[1:])
+ evaluated_args = self.evaluate_args(expr.args[0], expr.args[1:], env, ns)
+ return self.body(expr.args[0], evaluated_args, env, ns)
+
+
+class UserFunction(Function):
+
+ def __init__(self, name, params, body):
+ newparams, args, many = self.process_params(name, params)
+ super().__init__(name, newparams, body, args, many)
+
+ def __str__(self):
+ out = f"(func {self.name} ("
+ args_list = [f"{a.name} {a.type_}" for a in self.args]
+ if self.many:
+ args_list.append(f"{self.many.name} {self.many.type_}")
+ out = out + " ".join(args_list) + ") "
+ for expr in self.body:
+ out = out + f"{expr} "
+ return out.strip() + ")"
+
+
+ def process_params(self, name, params):
+ newparams = []
+ args = []
+ many = None
+ prev_type = False
+ first = True
+ for param in params:
+ if isinstance(param, Symbol):
+ if many is not None:
+ raise NebPanic("& must be last argument")
+ if param.name == "&":
+ many = Arg(param.name, TypeEnum.ANY)
+ else:
+ newparams.append(param)
+ args.append(Arg(param.name, TypeEnum.ANY))
+ prev_type = False
+ elif isinstance(param, Type) and not prev_type and not first:
+ typ = TypeEnum.__getattr__(param.name[1:].upper())
+ if many is None:
+ args[-1].type_ = typ
+ else:
+ many.type_ = typ
+ prev_type = True
+ else:
+ raise NebPanic("invalid :func signature", param)
+ first = False
+ return newparams, args, many
+
+ def call(self, expr, env, ns):
+ self.arity_check(expr.args[0], expr.args[1:])
+ evaluated_args = self.evaluate_args(expr.args[0], expr.args[1:], env, ns)
+ this_env = Environment(env)
+ for idx, param in enumerate(self.params):
+ this_env.register(param.name, evaluated_args[idx])
+
+ # if we got "many", wrap the rest in a list
+ if self.many:
+ this_env.register(self.many.name, List(evaluated_args[len(self.params):]))
+
+ return interpret(self.body, env=this_env, ns=ns)
+'''