aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--interpreter.py146
-rw-r--r--structs.py86
2 files changed, 124 insertions, 108 deletions
diff --git a/interpreter.py b/interpreter.py
index c7a2d40..2f6dc3d 100644
--- a/interpreter.py
+++ b/interpreter.py
@@ -2,7 +2,6 @@ from structs import *
from exceptions import *
from lexer import lex
from parser import parse
-from structs import T
from pathlib import Path
from glob import glob
from collections import namedtuple
@@ -16,10 +15,15 @@ import sys
@dataclass
class Arg:
name: str
- type_: T
+ type_: TypeEnum
optional: bool
lazy: bool
+ def __str__(self):
+ opt = "?" if self.optional else ""
+ lazy = "~" if self.lazy else ""
+ return f"{lazy}{opt}{self.name} {self.type_}"
+
class Function:
@@ -30,6 +34,17 @@ class Function:
self.args = args
self.many = many
+ def describe(self, name=None):
+ if name is None:
+ name = self.name
+ out = [f"({name}"]
+ if self.args is not None:
+ for arg in self.args:
+ out.append(f"{arg}")
+ if self.many is not None:
+ out.append(f"*{self.many}")
+ return " ".join(out) + ")"
+
def arity_check(self, symbol, params):
min_arity = len([a for a in self.args if not a.optional])
max_arity = -1 if self.many is not None else len(self.args)
@@ -57,9 +72,9 @@ class Function:
ret.append(param)
continue
ev = evaluate(param, env)
- if not isinstance(ev, arg.type_):
- exp = f":{arg.type_.__name__.lower()}"
- rec = f":{ev.type_.__name__.lower()}"
+ if not is_subtype_of(ev.type_, arg.type_):
+ exp = f"{arg.type_}"
+ rec = f"{ev.type_}"
raise InterpretPanic(symbol, f"received {rec}, expected {exp}", ev)
ret.append(ev)
return ret
@@ -82,7 +97,7 @@ class UserFunction(Function):
def __init__(self, name, params, body):
# TODO this doesn't do type checking, or optional, or lazy
- args = [Arg("arg", T.Any, False, False)] * len(params)
+ args = [Arg("arg", TypeEnum.ANY, False, False)] * len(params)
super().__init__(name, params, body, args)
def call(self, expr, env):
@@ -176,7 +191,7 @@ def interpretOr(symbol, args, env):
return Bool(False)
#GLOBALS.register("or", Builtin(interpretOr, 2))
-or_arg = Arg("arg", T.Bool, False, True)
+or_arg = Arg("arg", TypeEnum.BOOL, False, True)
GLOBALS.register("or", Builtin(interpretOr, [or_arg, or_arg], or_arg))
def interpretAnd(symbol, args, env):
@@ -199,13 +214,13 @@ def interpretEq(symbol, args, env):
else:
return Bool(False)
-eq_arg = Arg("value", T.Literal, False, False)
+eq_arg = Arg("value", TypeEnum.LITERAL, False, False)
GLOBALS.register("eq?", Builtin(interpretEq, [eq_arg, eq_arg]))
def interpretGreaterThan(symbol, args, env):
return Bool(args[0].value > args[1].value)
-compare_arg = Arg("num", T.Number, False, False)
+compare_arg = Arg("num", TypeEnum.NUMBER, False, False)
GLOBALS.register(">", Builtin(interpretGreaterThan, [compare_arg, compare_arg]))
def interpretGreaterThanEqual(symbol, args, env):
@@ -232,7 +247,7 @@ def interpretAddition(symbol, args, env):
else:
return Int(res)
-term_arg = Arg("term", T.Number, False, False)
+term_arg = Arg("term", TypeEnum.NUMBER, False, False)
GLOBALS.register("+", Builtin(interpretAddition, [term_arg], term_arg))
def interpretSubtraction(symbol, args, env):
@@ -258,7 +273,7 @@ def interpretMultiplication(symbol, args, env):
else:
return Int(res)
-factor_arg = Arg("factor", T.Number, False, False)
+factor_arg = Arg("factor", TypeEnum.NUMBER, False, False)
GLOBALS.register("*", Builtin(interpretMultiplication, [factor_arg, factor_arg], factor_arg))
def interpretDivision(symbol, args, env):
@@ -273,7 +288,7 @@ GLOBALS.register("/", Builtin(interpretDivision, [factor_arg, factor_arg]))
def interpretNot(symbol, args, env):
return Bool(not args[0].value)
-not_arg = Arg("not", T.Bool, False, False)
+not_arg = Arg("not", TypeEnum.BOOL, False, False)
GLOBALS.register("not", Builtin(interpretNot, [not_arg]))
def interpretIf(symbol, args, env):
@@ -284,16 +299,16 @@ def interpretIf(symbol, args, env):
return evaluate(args[2], env)
return List([])
-cond = Arg("cond", T.Bool, False, False)
-t_branch = Arg("t-branch", T.Any, False, True)
-f_branch = Arg("f-branch", T.Any, True, True)
+cond = Arg("cond", TypeEnum.BOOL, False, False)
+t_branch = Arg("t-branch", TypeEnum.ANY, False, True)
+f_branch = Arg("f-branch", TypeEnum.ANY, True, True)
GLOBALS.register("if", Builtin(interpretIf, [cond, t_branch, f_branch]))
def interpretPrint(symbol, args, env):
print(args[0].value)
return List([]) # print returns nothing
-GLOBALS.register("print", Builtin(interpretPrint, [Arg("arg", T.String, False, False)]))
+GLOBALS.register("print", Builtin(interpretPrint, [Arg("arg", TypeEnum.STRING, False, False)]))
def interpretDef(symbol, args, env):
if not isinstance(args[0], Symbol):
@@ -305,8 +320,8 @@ def interpretDef(symbol, args, env):
env.register(name, args[1]) # TODO since this isn't lazily evaluated, side effects are allowed (bad!)
return List([])
-def_name_arg = Arg("name", T.Any, False, True)
-def_val_arg = Arg("value", T.Any, False, False)
+def_name_arg = Arg("name", TypeEnum.ANY, False, True)
+def_val_arg = Arg("value", TypeEnum.ANY, False, False)
GLOBALS.register("def", Builtin(interpretDef, [def_name_arg, def_val_arg]))
def interpretRedef(symbol, args, env):
@@ -328,8 +343,8 @@ def interpretLambda(symbol, args, env):
func = UserFunction("<lambda>", [], args[1:])
return func
-lambda_args_arg = Arg("args", T.Any, False, True)
-lambda_body_arg = Arg("body", T.Any, False, True)
+lambda_args_arg = Arg("args", TypeEnum.ANY, False, True)
+lambda_body_arg = Arg("body", TypeEnum.ANY, False, True)
GLOBALS.register("lambda", Builtin(interpretLambda, [lambda_args_arg, lambda_body_arg], lambda_body_arg))
def interpretToString(symbol, args, env):
@@ -341,7 +356,7 @@ def interpretToString(symbol, args, env):
else:
return String(f"{item}")
-GLOBALS.register("->string", Builtin(interpretToString, [Arg("arg", T.Any, False, False)]))
+GLOBALS.register("->string", Builtin(interpretToString, [Arg("arg", TypeEnum.ANY, False, False)]))
def interpretConcat(symbol, args, env):
# concat str1 str2...strN
@@ -350,7 +365,7 @@ def interpretConcat(symbol, args, env):
out += arg.value
return String(out)
-string_arg = Arg("arg", T.String, False, False)
+string_arg = Arg("arg", TypeEnum.STRING, False, False)
GLOBALS.register("concat", Builtin(interpretConcat, [string_arg, string_arg], string_arg))
def interpretForCount(symbol, args, env):
@@ -365,8 +380,8 @@ def interpretForCount(symbol, args, env):
return List([])
return ret
-for_count_arg = Arg("count", T.Int, False, False)
-for_body_arg = Arg("body", T.Any, False, True)
+for_count_arg = Arg("count", TypeEnum.INT, False, False)
+for_body_arg = Arg("body", TypeEnum.ANY, False, True)
GLOBALS.register("for-count", Builtin(interpretForCount, [for_count_arg, for_body_arg], for_body_arg))
def interpretForEach(symbol, args, env):
@@ -381,7 +396,7 @@ def interpretForEach(symbol, args, env):
return List([])
return ret
-for_each_arg = Arg("list", T.List, False, False)
+for_each_arg = Arg("list", TypeEnum.LIST, False, False)
GLOBALS.register("for-each", Builtin(interpretForEach, [for_each_arg, for_body_arg], for_body_arg))
def interpretPipe(symbol, args, env):
@@ -444,13 +459,13 @@ def interpretReadLines(symbol, args, env):
out = List([String(d) for d in data], True) # all lines are strings
return out
-GLOBALS.register("read-lines", Builtin(interpretReadLines, [Arg("filename", T.String, False, False)]))
+GLOBALS.register("read-lines", Builtin(interpretReadLines, [Arg("filename", TypeEnum.STRING, False, False)]))
# - strip whitespace from string
def interpretStrip(symbol, args, env):
return String(args[0].value.strip())
-GLOBALS.register("strip", Builtin(interpretStrip, [Arg("filename", T.String, False, False)]))
+GLOBALS.register("strip", Builtin(interpretStrip, [Arg("filename", TypeEnum.STRING, False, False)]))
# - string->int and string->float
def interpretStringToInt(symbol, args, env):
@@ -460,7 +475,7 @@ def interpretStringToInt(symbol, args, env):
except:
raise InterpretPanic(symbol, "can't convert to an :int", args[0])
-GLOBALS.register("string->int", Builtin(interpretStringToInt, [Arg("arg", T.String, False, False)]))
+GLOBALS.register("string->int", Builtin(interpretStringToInt, [Arg("arg", TypeEnum.STRING, False, False)]))
# - split a string by a given field
def interpretSplit(symbol, args, env):
@@ -471,13 +486,13 @@ def interpretSplit(symbol, args, env):
ret = target.value.split(splitter.value)
return List([String(r) for r in ret], True)
-GLOBALS.register("split", Builtin(interpretSplit, [Arg("target", T.String, False, False)], Arg("splitter", T.String, True, False)))
+GLOBALS.register("split", Builtin(interpretSplit, [Arg("target", TypeEnum.STRING, False, False)], Arg("splitter", TypeEnum.STRING, True, False)))
# - get the length of a list
def interpretListLength(symbol, args, env):
return Int(len(args[0].args))
-GLOBALS.register("list-length", Builtin(interpretListLength, [Arg("arg", T.List, False, False)]))
+GLOBALS.register("list-length", Builtin(interpretListLength, [Arg("arg", TypeEnum.LIST, False, False)]))
# - first/rest of list
def interpretFirst(symbol, args, env):
@@ -485,13 +500,13 @@ def interpretFirst(symbol, args, env):
raise InterpretPanic(symbol, "list is empty")
return evaluate(args[0].args[0], env)
-GLOBALS.register("first", Builtin(interpretFirst, [Arg("arg", T.List, False, False)]))
+GLOBALS.register("first", Builtin(interpretFirst, [Arg("arg", TypeEnum.LIST, False, False)]))
def interpretRest(symbol, args, env):
# TODO do we know it's not evaluated?
return List(args[0].args[1:], True) # we don't evaluate the remainder of the list
-GLOBALS.register("rest", Builtin(interpretRest, [Arg("arg", T.List, False, False)]))
+GLOBALS.register("rest", Builtin(interpretRest, [Arg("arg", TypeEnum.LIST, False, False)]))
# - iterate over list
# - map
@@ -509,8 +524,7 @@ def interpretMap(symbol, args, env):
out.append(ev)
return List(out, True)
-GLOBALS.register("map", Builtin(interpretMap, [Arg("func", T.Any, False, True), Arg("list", T.List, False, False)]))
-#GLOBALS.register("map", Builtin(interpretMap, [Arg("func", T.Any, False, False), Arg("list", T.List, False, False)]))
+GLOBALS.register("map", Builtin(interpretMap, [Arg("func", TypeEnum.ANY, False, True), Arg("list", TypeEnum.LIST, False, False)]))
def interpretZip(symbol, args, env):
z1 = args[0]
@@ -524,20 +538,20 @@ def interpretZip(symbol, args, env):
out.append(List([f, s], True))
return List(out, True)
-zip_arg = Arg("list", T.List, False, False)
+zip_arg = Arg("list", TypeEnum.LIST, False, False)
GLOBALS.register("zip", Builtin(interpretZip, [zip_arg, zip_arg]))
def interpretList(symbol, args, env):
return List(args, True)
-GLOBALS.register("list", Builtin(interpretList, [], Arg("item", T.Any, False, False)))
+GLOBALS.register("list", Builtin(interpretList, [], Arg("item", TypeEnum.ANY, False, False)))
def interpretListReverse(symbol, args, env):
new_args = args[0].args[:] # make a copy of the args
new_args.reverse()
return List(new_args, True)
-GLOBALS.register("list-reverse", Builtin(interpretListReverse, [Arg("list", T.List, False, False)]))
+GLOBALS.register("list-reverse", Builtin(interpretListReverse, [Arg("list", TypeEnum.LIST, False, False)]))
def interpretApply(symbol, args, env):
# TODO: to support lambdas, we can't assume the func is defined
@@ -547,37 +561,37 @@ def interpretApply(symbol, args, env):
new_lst = List([func] + args[1].args)
return evaluate(new_lst, env)
-GLOBALS.register("apply", Builtin(interpretApply, [Arg("func", T.Any, False, True), Arg("list", T.List, False, False)]))
+GLOBALS.register("apply", Builtin(interpretApply, [Arg("func", TypeEnum.ANY, False, True), Arg("list", TypeEnum.LIST, False, False)]))
def interpretGlob(symbol, args, env):
items = glob(args[0].value)
return List([String(item) for item in items], True)
-GLOBALS.register("glob", Builtin(interpretGlob, [Arg("regex", T.String, False, False)]))
+GLOBALS.register("glob", Builtin(interpretGlob, [Arg("regex", TypeEnum.STRING, False, False)]))
def interpretShell(symbol, args, env):
# TODO either fail or throw exception (?) on error
ret = subprocess.run(shlex.split(args[0].value), capture_output=True)
return List([String(r) for r in ret.stdout.decode("utf-8").split("\n")], True)
-GLOBALS.register("$", Builtin(interpretShell, [Arg("command", T.String, False, False)]))
+GLOBALS.register("$", Builtin(interpretShell, [Arg("command", TypeEnum.STRING, False, False)]))
def interpretEmpty(symbol, args, env):
return Bool(len(args[0].args) == 0)
-GLOBALS.register("empty?", Builtin(interpretEmpty, [Arg("list", T.List, False, False)]))
+GLOBALS.register("empty?", Builtin(interpretEmpty, [Arg("list", TypeEnum.LIST, False, False)]))
def interpretShuf(symbol, args, env):
items = args[0].args[:]
random.shuffle(items)
return List(items, True)
-GLOBALS.register("shuf", Builtin(interpretShuf, [Arg("list", T.List, False, False)]))
+GLOBALS.register("shuf", Builtin(interpretShuf, [Arg("list", TypeEnum.LIST, False, False)]))
def interpretIsList(symbol, args, env):
return Bool(isinstance(args[0], List))
-GLOBALS.register("list?", Builtin(interpretIsList, [Arg("arg", T.Any, False, False)]))
+GLOBALS.register("list?", Builtin(interpretIsList, [Arg("arg", TypeEnum.ANY, False, False)]))
def interpretBlock(symbol, args, env):
ret = List([])
@@ -585,7 +599,7 @@ def interpretBlock(symbol, args, env):
ret = evaluate(arg, env)
return ret
-block_arg = Arg("expr", T.Any, False, True)
+block_arg = Arg("expr", TypeEnum.ANY, False, True)
GLOBALS.register("block", Builtin(interpretBlock, [block_arg], block_arg))
def interpretExit(symbol, args, env):
@@ -593,7 +607,7 @@ def interpretExit(symbol, args, env):
sys.exit(status)
return List([])
-exit_arg = Arg("status", T.Int, True, False)
+exit_arg = Arg("status", TypeEnum.INT, True, False)
GLOBALS.register("exit", Builtin(interpretExit, [exit_arg]))
def interpretUnlink(symbol, args, env):
@@ -603,7 +617,7 @@ def interpretUnlink(symbol, args, env):
target_path.unlink()
return List([])
-GLOBALS.register("unlink", Builtin(interpretUnlink, [Arg("filename", T.String, False, False)]))
+GLOBALS.register("unlink", Builtin(interpretUnlink, [Arg("filename", TypeEnum.STRING, False, False)]))
def interpretArgv(symbol, args, env):
out = []
@@ -621,8 +635,8 @@ def interpretIn(symbol, args, env):
return Bool(True)
return Bool(False)
-in_target_arg = Arg("target", T.Literal, False, False)
-in_list_arg = Arg("list", T.List, False, False)
+in_target_arg = Arg("target", TypeEnum.LITERAL, False, False)
+in_list_arg = Arg("list", TypeEnum.LIST, False, False)
GLOBALS.register("in?", Builtin(interpretIn, [in_target_arg, in_list_arg]))
def interpretLast(symbol, args, env):
@@ -630,15 +644,15 @@ def interpretLast(symbol, args, env):
raise InterpretPanic("List is empty")
return evaluate(args[0].args[-1], env)
-GLOBALS.register("last", Builtin(interpretLast, [Arg("list", T.List, False, False)]))
+GLOBALS.register("last", Builtin(interpretLast, [Arg("list", TypeEnum.LIST, False, False)]))
def interpretJoin(symbol, args, env):
lst = args[0]
target = args[1]
return String(target.value.join([a.value for a in lst.args]))
-join_list_arg = Arg("list", T.List, False, False)
-join_string_arg = Arg("joiner", T.String, False, False)
+join_list_arg = Arg("list", TypeEnum.LIST, False, False)
+join_string_arg = Arg("joiner", TypeEnum.STRING, False, False)
GLOBALS.register("join", Builtin(interpretJoin, [join_list_arg, join_string_arg]))
def interpretWithWrite(symbol, args, env):
@@ -652,7 +666,7 @@ def interpretWithWrite(symbol, args, env):
ret = evaluate(arg, new_env)
return ret
-GLOBALS.register("with-write", Builtin(interpretWithWrite, [Arg("filename", T.String, False, False)], Arg("exprs", T.Any, False, True)))
+GLOBALS.register("with-write", Builtin(interpretWithWrite, [Arg("filename", TypeEnum.STRING, False, False)], Arg("exprs", TypeEnum.ANY, False, True)))
def interpretWrite(symbol, args, env):
# write :string :filehandle
@@ -661,7 +675,7 @@ def interpretWrite(symbol, args, env):
handle.args[0].write(line.value) # TODO wrong! how do we evaluate a handle?
return Literal([])
-GLOBALS.register("write", Builtin(interpretWrite, [Arg("string", T.String, False, False), Arg("filename", T.List, False, False)]))
+GLOBALS.register("write", Builtin(interpretWrite, [Arg("string", TypeEnum.STRING, False, False), Arg("filename", TypeEnum.LIST, False, False)]))
def interpretNewline(symbol, args, env):
return String("\n")
@@ -671,19 +685,19 @@ GLOBALS.register("newline", Builtin(interpretNewline, []))
def interpretExists(symbol, args, env):
return Bool(Path(args[0].value).resolve().exists())
-GLOBALS.register("exists?", Builtin(interpretExists, [Arg("filename", T.String, False, False)]))
+GLOBALS.register("exists?", Builtin(interpretExists, [Arg("filename", TypeEnum.STRING, False, False)]))
def interpretFirstChar(symbol, args, env):
if len(args[0].value) == 0:
raise InterpretPanic(symbol, ":string is empty", ev)
return String(args[0].value[0])
-GLOBALS.register("first-char", Builtin(interpretFirstChar, [Arg("string", T.String, False, False)]))
+GLOBALS.register("first-char", Builtin(interpretFirstChar, [Arg("string", TypeEnum.STRING, False, False)]))
def interpretRestChar(symbol, args, env):
return String(args[0].value[1:])
-GLOBALS.register("rest-char", Builtin(interpretRestChar, [Arg("string", T.String, False, False)]))
+GLOBALS.register("rest-char", Builtin(interpretRestChar, [Arg("string", TypeEnum.STRING, False, False)]))
def interpretSlice(symbol, args, env):
lst = args[0]
@@ -694,9 +708,9 @@ def interpretSlice(symbol, args, env):
diff = idx.value - 1 + length.value
return List(lst.args[idx.value - 1:diff])
-slice_list_arg = Arg("list", T.List, False, False)
-slice_idx_arg = Arg("idx", T.Int, False, False)
-slice_length_arg = Arg("length", T.Int, True, False)
+slice_list_arg = Arg("list", TypeEnum.LIST, False, False)
+slice_idx_arg = Arg("idx", TypeEnum.INT, False, False)
+slice_length_arg = Arg("length", TypeEnum.INT, True, False)
GLOBALS.register("slice", Builtin(interpretSlice, [slice_list_arg, slice_idx_arg, slice_length_arg]))
def interpretClear(symbol, args, env):
@@ -709,7 +723,7 @@ def interpretReadLine(symbol, args, env):
ret = input(args[0].value)
return String(ret)
-GLOBALS.register("read-line", Builtin(interpretReadLine, [Arg("prompt", T.String, False, False)]))
+GLOBALS.register("read-line", Builtin(interpretReadLine, [Arg("prompt", TypeEnum.STRING, False, False)]))
def interpretReadChar(symbol, args, env):
import termios, tty
@@ -732,7 +746,7 @@ def interpretAppend(symbol, args, env):
items = lst.args[:]
return List(items + [val], True)
-GLOBALS.register("append", Builtin(interpretAppend, [Arg("list", T.List, False, False), Arg("item", T.Any, False, False)]))
+GLOBALS.register("append", Builtin(interpretAppend, [Arg("list", TypeEnum.LIST, False, False), Arg("item", TypeEnum.ANY, False, False)]))
# TODO: this is actually for records/structs/whatever they're called
def interpretRemove(symbol, args, env):
@@ -744,7 +758,7 @@ def interpretRemove(symbol, args, env):
out.append(arg)
return List(out, True)
-GLOBALS.register("remove", Builtin(interpretRemove, [Arg("list", T.List, False, False), Arg("key", T.Any, False, False)]))
+GLOBALS.register("remove", Builtin(interpretRemove, [Arg("list", TypeEnum.LIST, False, False), Arg("key", TypeEnum.ANY, False, False)]))
def interpretWhile(symbol, args, env):
cond = args[0]
@@ -759,7 +773,7 @@ def interpretWhile(symbol, args, env):
ret = evaluate(arg, env)
return ret
-GLOBALS.register("while", Builtin(interpretWhile, [Arg("cond", T.Bool, False, True)], Arg("expr", T.Any, False, True)))
+GLOBALS.register("while", Builtin(interpretWhile, [Arg("cond", TypeEnum.BOOL, False, True)], Arg("expr", TypeEnum.ANY, False, True)))
def interpretUse(symbol, args, env):
target_file_name = args[0].value
@@ -771,11 +785,11 @@ def interpretUse(symbol, args, env):
interpret(parse(lex(data)))
return List([])
-GLOBALS.register("use", Builtin(interpretUse, [Arg("filename", T.String, False, False)]))
+GLOBALS.register("use", Builtin(interpretUse, [Arg("filename", TypeEnum.STRING, False, False)]))
-def interpretAssert(symbol, args, env, ns):
+def interpretAssert(symbol, args, env):
if args[0].value != True:
raise InterpretPanic(symbol, "assertion failed")
return List([])
-GLOBALS.register("assert", Builtin(interpretAssert, [Arg("cond", T.Bool, False, False)]))
+GLOBALS.register("assert", Builtin(interpretAssert, [Arg("cond", TypeEnum.BOOL, False, False)]))
diff --git a/structs.py b/structs.py
index c3e344c..72fabf7 100644
--- a/structs.py
+++ b/structs.py
@@ -59,6 +59,38 @@ class TokenType(Enum):
LIST_TYPE = auto()
+class TypeEnum(Enum):
+ ANY = auto()
+ STRING = auto()
+ INT = auto()
+ FLOAT = auto()
+ NUMBER = auto()
+ LIST = auto()
+ LITERAL = auto()
+ BOOL = auto()
+
+ def __str__(self):
+ return f":{self.name.lower()}"
+
+TYPE_HIERARCHY = { TypeEnum.ANY: None,
+ TypeEnum.LITERAL: TypeEnum.ANY,
+ TypeEnum.LIST: TypeEnum.ANY,
+ TypeEnum.STRING: TypeEnum.LITERAL,
+ TypeEnum.BOOL: TypeEnum.LITERAL,
+ TypeEnum.NUMBER: TypeEnum.LITERAL,
+ TypeEnum.INT: TypeEnum.NUMBER,
+ TypeEnum.FLOAT: TypeEnum.NUMBER }
+
+def is_subtype_of(candidate, expected):
+ if candidate == expected:
+ return True
+ parent = TYPE_HIERARCHY[candidate]
+ while parent is not None:
+ if parent == expected:
+ return True
+ parent = TYPE_HIERARCHY[parent]
+ return False
+
@dataclass
class Token:
type_: TokenType
@@ -69,67 +101,37 @@ class Token:
def __str__(self):
return f"{self.type_.name} {self.text} {self.line}"
-class T:
- def __repr__(self):
- return "T"
- class Any:
- def __repr__(self):
- return ":any"
- class List(Any):
- pass
- class Literal(Any):
- def __repr__(self):
- return ":literal"
- class String(Literal):
- def __repr__(self):
- return ":string"
- pass
- class Bool(Literal):
- def __repr__(self):
- return ":bool"
- pass
- class Number(Literal):
- def __repr__(self):
- return ":number"
- class Int(Number):
- def __repr__(self):
- return ":int"
- class Float(Number):
- def __repr__(self):
- return ":float"
-
-# Literals
-class Literal(T.Any):
+class Literal:
def __init__(self, value, type_=None):
self.value = value
if type_ is None:
- self.type_ = T.Any
+ self.type_ = TypeEnum.ANY
else:
self.type_ = type_
def __str__(self):
return f"{self.value}:literal"
-class Int(Literal, T.Int):
+class Int(Literal):
def __init__(self, value):
- super().__init__(value, T.Int)
+ super().__init__(value, TypeEnum.INT)
def __str__(self):
return f"{self.value}"
-class Float(Literal, T.Float):
+class Float(Literal):
def __init__(self, value):
- super().__init__(value, T.Float)
+ super().__init__(value, TypeEnum.FLOAT)
def __str__(self):
return f"{self.value}"
-class Bool(Literal, T.Bool):
+class Bool(Literal):
def __init__(self, value):
- super().__init__(value, T.Bool)
+ super().__init__(value, TypeEnum.BOOL)
def __str__(self):
return f"#{str(self.value).lower()}"
-class String(Literal, T.String):
+class String(Literal):
def __init__(self, value):
- super().__init__(value, T.String)
+ super().__init__(value, TypeEnum.STRING)
def __str__(self):
return f'"{repr(self.value)[1:-1]}"'
@@ -146,11 +148,11 @@ class Symbol:
def __str__(self):
return f"'{self.name}"
-class List(T.List):
+class List:
def __init__(self, args, data=False):
self.args = args
self.data = data
- self.type_ = T.List
+ self.type_ = TypeEnum.LIST
def __str__(self):
return "(" + " ".join(f"{arg}" for arg in self.args) + ")"