diff options
Diffstat (limited to 'neb/__init__.py')
| -rw-r--r-- | neb/__init__.py | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/neb/__init__.py b/neb/__init__.py index a492df8..c505a04 100644 --- a/neb/__init__.py +++ b/neb/__init__.py @@ -11,15 +11,18 @@ def interpret(exprs, env, ns=None): return ret def evaluate(expr, env, ns=None): - if isinstance(expr, Literal) or isinstance(expr, Function) or isinstance(expr, Type): + if isinstance(expr, Literal) or isinstance(expr, Function) or isinstance(expr, TypeWrap): return expr - elif isinstance(expr, Symbol): + elif isinstance(expr, Symbol) or isinstance(expr, Type): if env.contains(expr.name): return evaluate(env.get(expr.name), env, ns) elif ns is not None and env.contains(f"{ns}/{expr.name}"): return evaluate(env.get(f"{ns}/{expr.name}"), env, ns) else: - raise NebPanic(f"no such symbol: {expr}") + if isinstance(expr, Symbol): + raise NebPanic(f"no such symbol: {expr}") + else: + raise NebPanic(f"no such type {expr}") # if it's an empty list, return it elif len(expr.args) == 0: @@ -79,7 +82,10 @@ class Function: ret.append(param) continue ev = evaluate(param, env, ns) - if not is_subtype_of(ev.type_, arg.type_): + expected_name = f"{arg.type_}" + expected_type = env.get(expected_name) + valid = expected_type.validate_type(ev, env, ns) + if not valid.value: exp = f"{arg.type_}" rec = f"{ev.type_}" raise InterpretPanic(symbol, f"received {rec}, expected {exp}", ev) @@ -160,3 +166,31 @@ class UserFunction(Function): this_env.register(self.many.name, List(evaluated_args[len(self.params):])) return interpret(self.body, env=this_env, ns=ns) + + +class TypeWrap: + + def __init__(self, name, parent, is_func): + self.name = name + self.parent = parent + self.is_func = is_func + + def validate_type(self, target, env, ns): + valid = self.is_func(None, [target], env, ns) + if valid.value == True: + return valid + parent_type = env.get(f"{target.type_}") + while valid.value != True and parent_type.parent is not None: + parent_type = env.get(f"{parent_type.parent}") + valid = Bool(self.name == parent_type.name) + return valid + + def __str__(self): + return f"{self.name}" + + +class NebType(TypeWrap): + pass + +class UserType(TypeWrap): + pass |
