aboutsummaryrefslogtreecommitdiff
path: root/interpreter.py
diff options
context:
space:
mode:
authormryouse2022-06-10 00:16:53 +0000
committermryouse2022-06-10 00:16:53 +0000
commit2821c14272c4296a64d94532fa8665ed53f5a0ef (patch)
tree81dd895a57ede879c2e5c0ae71a31e41b4b6f7ca /interpreter.py
parentd7466520fd61c153509710c257b358047cd01606 (diff)
refactor: more flexibility with builtin arities
Diffstat (limited to 'interpreter.py')
-rw-r--r--interpreter.py153
1 files changed, 69 insertions, 84 deletions
diff --git a/interpreter.py b/interpreter.py
index a5fab45..0136a04 100644
--- a/interpreter.py
+++ b/interpreter.py
@@ -11,18 +11,23 @@ import sys
class Function:
- def __init__(self, name, params, body, *arities):
+ def __init__(self, name, params, body, min_arity=0, max_arity=-1):
self.name = name
self.params = params
self.body = body
- if len(arities) == 0:
- self.arities = None
- else:
- self.arities = arities
+ self.min_arity = min_arity
+ self.max_arity = max_arity
def arity_check(self, symbol, args):
- if self.arities is not None and len(args) not in self.arities:
- fmt = ", ".join([f"{arity}" for arity in self.arities])
+ if len(args) < self.min_arity or (self.max_arity >= 0 and len(args) > self.max_arity):
+ if self.max_arity < 0:
+ fmt = f"{self.min_arity}+"
+ elif self.min_arity != self.max_arity:
+ fmt = f"{self.min_arity}-{self.max_arity}"
+ else:
+ fmt = f"{self.min_arity}"
+ #if self.arities is not None and len(args) not in self.arities:
+ #fmt = ", ".join([f"{arity}" for arity in self.arities])
raise InterpretPanic(symbol, f"expected [{fmt}] arguments, received {len(args)}")
return True
@@ -127,8 +132,6 @@ def evaluate(expr, env):
def interpretOr(symbol, args, env):
# or returns true for the first expression that returns true
- if len(args) < 2:
- raise InterpretPanic(symbol, "requires at least two arguments")
for arg in args:
ev = evaluate(arg, env)
if not isinstance(ev, Bool):
@@ -137,12 +140,10 @@ def interpretOr(symbol, args, env):
return ev
return Bool(False)
-GLOBALS.register("or", Builtin(interpretOr))
+GLOBALS.register("or", Builtin(interpretOr, 2))
def interpretAnd(symbol, args, env):
# and returns false for the first expression that returns false
- if len(args) < 2:
- raise InterpretPanic(symbol, "requires at least two arguments")
for arg in args:
ev = evaluate(arg, env)
if not isinstance(ev, Bool):
@@ -151,7 +152,7 @@ def interpretAnd(symbol, args, env):
return ev
return Bool(True)
-GLOBALS.register("and", Builtin(interpretAnd))
+GLOBALS.register("and", Builtin(interpretAnd, 2))
def interpretEq(symbol, args, env):
# equal
@@ -167,7 +168,7 @@ def interpretEq(symbol, args, env):
else:
return Bool(False)
-GLOBALS.register("eq?", Builtin(interpretEq, 2))
+GLOBALS.register("eq?", Builtin(interpretEq, 2, 2))
def interpretGreaterThan(symbol, args, env):
left = evaluate(args[0], env)
@@ -178,7 +179,7 @@ def interpretGreaterThan(symbol, args, env):
raise InterpretPanic(symbol, "second argument must be a :number", right)
return Bool(left.value > right.value)
-GLOBALS.register(">", Builtin(interpretGreaterThan, 2))
+GLOBALS.register(">", Builtin(interpretGreaterThan, 2, 2))
def interpretGreaterThanEqual(symbol, args, env):
left = evaluate(args[0], env)
@@ -189,7 +190,7 @@ def interpretGreaterThanEqual(symbol, args, env):
raise InterpretPanic(symbol, "second argument must be a :number", right)
return Bool(left.value >= right.value)
-GLOBALS.register(">=", Builtin(interpretGreaterThanEqual, 2))
+GLOBALS.register(">=", Builtin(interpretGreaterThanEqual, 2, 2))
def interpretLessThan(symbol, args, env):
left = evaluate(args[0], env)
@@ -200,7 +201,7 @@ def interpretLessThan(symbol, args, env):
raise InterpretPanic(symbol, "second argument must be a :number", right)
return Bool(left.value < right.value)
-GLOBALS.register("<", Builtin(interpretLessThan, 2))
+GLOBALS.register("<", Builtin(interpretLessThan, 2, 2))
def interpretLessThanEqual(symbol, args, env):
left = evaluate(args[0], env)
@@ -211,11 +212,9 @@ def interpretLessThanEqual(symbol, args, env):
raise InterpretPanic(symbol, "second argument must be a :number", right)
return Bool(left.value <= right.value)
-GLOBALS.register("<=", Builtin(interpretLessThanEqual, 2))
+GLOBALS.register("<=", Builtin(interpretLessThanEqual, 2, 2))
def interpretAddition(symbol, args, env):
- if len(args) < 1:
- raise InterpretPanic(symbol, "requires at least one argument")
res = 0
for arg in args:
ev = evaluate(arg, env)
@@ -227,11 +226,9 @@ def interpretAddition(symbol, args, env):
else:
return Int(res)
-GLOBALS.register("+", Builtin(interpretAddition))
+GLOBALS.register("+", Builtin(interpretAddition, 1))
def interpretSubtraction(symbol, args, env):
- if len(args) < 1:
- raise InterpretPanic(symbol, "requires at least one argument")
first = evaluate(args[0], env)
if not (isinstance(first, Int) or isinstance(first, Float)):
raise InterpretPanic(symbol, "argument must be a :number", first)
@@ -249,11 +246,9 @@ def interpretSubtraction(symbol, args, env):
else:
return Int(res)
-GLOBALS.register("-", Builtin(interpretSubtraction))
+GLOBALS.register("-", Builtin(interpretSubtraction, 1))
def interpretMultiplication(symbol, args, env):
- if len(args) < 2:
- raise InterpretPanic(symbol, "requires at least two arguments")
first = evaluate(args[0], env)
if not (isinstance(first, Int) or isinstance(first, Float)):
raise InterpretPanic(symbol, "argument must be a :number", first)
@@ -268,7 +263,7 @@ def interpretMultiplication(symbol, args, env):
else:
return Int(res)
-GLOBALS.register("*", Builtin(interpretMultiplication))
+GLOBALS.register("*", Builtin(interpretMultiplication, 2))
def interpretDivision(symbol, args, env):
num = evaluate(args[0], env)
@@ -283,7 +278,7 @@ def interpretDivision(symbol, args, env):
else:
return Float(ret)
-GLOBALS.register("/", Builtin(interpretDivision, 2))
+GLOBALS.register("/", Builtin(interpretDivision, 2, 2))
def interpretNot(symbol, args, env):
res = evaluate(args[0], env)
@@ -291,7 +286,7 @@ def interpretNot(symbol, args, env):
raise InterpretPanic(symbol, "requires a :bool", res)
return Bool(not res.value)
-GLOBALS.register("not", Builtin(interpretNot, 1))
+GLOBALS.register("not", Builtin(interpretNot, 1, 1))
def interpretIf(symbol, args, env):
# if cond t-branch [f-branch]
@@ -314,7 +309,7 @@ def interpretPrint(symbol, args, env):
return List([]) # print returns nothing
-GLOBALS.register("print", Builtin(interpretPrint, 1))
+GLOBALS.register("print", Builtin(interpretPrint, 1, 1))
def interpretDef(symbol, args, env):
if not isinstance(args[0], Symbol):
@@ -327,7 +322,7 @@ def interpretDef(symbol, args, env):
env.register(name, ev)
return List([])
-GLOBALS.register("def", Builtin(interpretDef, 2))
+GLOBALS.register("def", Builtin(interpretDef, 2, 2))
def interpretRedef(symbol, args, env):
if not isinstance(args[0], Symbol):
@@ -340,7 +335,7 @@ def interpretRedef(symbol, args, env):
env.reregister(name, ev)
return List([])
-GLOBALS.register("redef", Builtin(interpretRedef, 2))
+GLOBALS.register("redef", Builtin(interpretRedef, 2, 2))
def interpretLambda(symbol, args, env):
if len(args[0].args) != 0:
@@ -360,12 +355,10 @@ def interpretToString(symbol, args, env):
else:
return String(f"{ev}")
-GLOBALS.register("->string", Builtin(interpretToString, 1))
+GLOBALS.register("->string", Builtin(interpretToString, 1, 1))
def interpretConcat(symbol, args, env):
# concat str1 str2...strN
- if len(args) < 2:
- raise InterpretPanic(symbol, "requires at least two arguments")
out = ""
for arg in args:
tmp = evaluate(arg, env)
@@ -374,7 +367,7 @@ def interpretConcat(symbol, args, env):
out += tmp.value
return String(out)
-GLOBALS.register("concat", Builtin(interpretConcat))
+GLOBALS.register("concat", Builtin(interpretConcat, 2))
def interpretForCount(symbol, args, env):
# for-count int exprs
@@ -391,7 +384,7 @@ def interpretForCount(symbol, args, env):
return List([])
return ret
-GLOBALS.register("for-count", Builtin(interpretForCount))
+GLOBALS.register("for-count", Builtin(interpretForCount, 2))
def interpretForEach(symbol, args, env):
# for-each list exprs
@@ -408,11 +401,9 @@ def interpretForEach(symbol, args, env):
return List([])
return ret
-GLOBALS.register("for-each", Builtin(interpretForEach))
+GLOBALS.register("for-each", Builtin(interpretForEach, 2))
def interpretPipe(symbol, args, env):
- if len(args) < 2:
- raise InterpretPanic(symbol, "requires at least two expressions")
new_env = Environment(env)
pipe = None
for arg in args:
@@ -423,11 +414,9 @@ def interpretPipe(symbol, args, env):
return List([])
return pipe
-GLOBALS.register("|", Builtin(interpretPipe))
+GLOBALS.register("|", Builtin(interpretPipe, 2))
def interpretBranch(symbol, args, env):
- if len(args) == 0:
- raise InterpretPanic(symbol, "requires at least one pair of expressions")
for arg in args:
if len(arg.args) != 2:
raise InterpretPanic(symbol, "each branch requires two expressions")
@@ -438,12 +427,10 @@ def interpretBranch(symbol, args, env):
return evaluate(arg.args[1], env)
return List([])
-GLOBALS.register("branch", Builtin(interpretBranch))
+GLOBALS.register("branch", Builtin(interpretBranch, 1))
def interpretFunc(symbol, args, env):
# func <name> (args) (exprs)
- if len(args) < 3:
- raise InterpretPanic(symbol, "requires a name, argument list, and at least one expression")
if not isinstance(args[0], Symbol):
raise InterpretPanic(symbol, "requires a :string name")
name = args[0].name # NOTE: we are not evaluating the name!!
@@ -454,7 +441,7 @@ def interpretFunc(symbol, args, env):
env.register(name, func)
return List([])
-GLOBALS.register("func", Builtin(interpretFunc))
+GLOBALS.register("func", Builtin(interpretFunc, 3))
# THINGS NEEDED FOR AOC
# - read the contents of a file
@@ -468,7 +455,7 @@ def interpretReadLines(symbol, args, env):
out = List([String(d) for d in data], True) # all lines are strings
return out
-GLOBALS.register("read-lines", Builtin(interpretReadLines, 1))
+GLOBALS.register("read-lines", Builtin(interpretReadLines, 1, 1))
# - strip whitespace from string
def interpretStrip(symbol, args, env):
@@ -477,7 +464,7 @@ def interpretStrip(symbol, args, env):
raise InterpretPanic(symbol, "requires a :string", out)
return String(out.value.strip())
-GLOBALS.register("strip", Builtin(interpretStrip, 1))
+GLOBALS.register("strip", Builtin(interpretStrip, 1, 1))
# - string->int and string->float
def interpretStringToInt(symbol, args, env):
@@ -490,7 +477,7 @@ def interpretStringToInt(symbol, args, env):
except:
raise InterpretPanic(symbol, "can't convert to an :int", ev)
-GLOBALS.register("string->int", Builtin(interpretStringToInt, 1))
+GLOBALS.register("string->int", Builtin(interpretStringToInt, 1, 1))
# - split a string by a given field
def interpretSplit(symbol, args, env):
@@ -514,7 +501,7 @@ def interpretListLength(symbol, args, env):
raise InterpretPanic(symbol, "requires a :list", ev)
return Int(len(ev.args))
-GLOBALS.register("list-length", Builtin(interpretListLength, 1))
+GLOBALS.register("list-length", Builtin(interpretListLength, 1, 1))
# - first/rest of list
def interpretFirst(symbol, args, env):
@@ -525,7 +512,7 @@ def interpretFirst(symbol, args, env):
raise InterpretPanic(symbol, "list is empty")
return evaluate(ev.args[0], env)
-GLOBALS.register("first", Builtin(interpretFirst, 1))
+GLOBALS.register("first", Builtin(interpretFirst, 1, 1))
def interpretRest(symbol, args, env):
ev = evaluate(args[0], env)
@@ -534,7 +521,7 @@ def interpretRest(symbol, args, env):
# TODO do we know it's not evaluated?
return List(ev.args[1:], True) # we don't evaluate the remainder of the list
-GLOBALS.register("rest", Builtin(interpretRest, 1))
+GLOBALS.register("rest", Builtin(interpretRest, 1, 1))
# - iterate over list
# - map
@@ -552,7 +539,7 @@ def interpretMap(symbol, args, env):
out.append(ev)
return List(out, True)
-GLOBALS.register("map", Builtin(interpretMap, 2))
+GLOBALS.register("map", Builtin(interpretMap, 2, 2))
def interpretZip(symbol, args, env):
z1 = evaluate(args[0], env)
@@ -570,7 +557,7 @@ def interpretZip(symbol, args, env):
out.append(List([f, s], True))
return List(out, True)
-GLOBALS.register("zip", Builtin(interpretZip, 2))
+GLOBALS.register("zip", Builtin(interpretZip, 2, 2))
def interpretList(symbol, args, env):
out = []
@@ -578,7 +565,7 @@ def interpretList(symbol, args, env):
out.append(evaluate(arg, env))
return List(out, True)
-GLOBALS.register("list", Builtin(interpretList))
+GLOBALS.register("list", Builtin(interpretList, 0))
def interpretListReverse(symbol, args, env):
lst = evaluate(args[0], env)
@@ -588,7 +575,7 @@ def interpretListReverse(symbol, args, env):
new_args.reverse()
return List(new_args, True)
-GLOBALS.register("list-reverse", Builtin(interpretListReverse, 1))
+GLOBALS.register("list-reverse", Builtin(interpretListReverse, 1, 1))
def interpretApply(symbol, args, env):
func = args[0]
@@ -600,7 +587,7 @@ def interpretApply(symbol, args, env):
new_lst = List([func] + lst.args)
return evaluate(new_lst, env)
-GLOBALS.register("apply", Builtin(interpretApply, 2))
+GLOBALS.register("apply", Builtin(interpretApply, 2, 2))
def interpretGlob(symbol, args, env):
ev = evaluate(args[0], env)
@@ -609,7 +596,7 @@ def interpretGlob(symbol, args, env):
items = glob(ev.value)
return List([String(item) for item in items], True)
-GLOBALS.register("glob", Builtin(interpretGlob, 1))
+GLOBALS.register("glob", Builtin(interpretGlob, 1, 1))
def interpretShell(symbol, args, env):
ev = evaluate(args[0], env)
@@ -619,7 +606,7 @@ def interpretShell(symbol, args, env):
ret = subprocess.run(shlex.split(ev.value), capture_output=True)
return List([String(r) for r in ret.stdout.decode("utf-8").split("\n")], True)
-GLOBALS.register("$", Builtin(interpretShell, 1))
+GLOBALS.register("$", Builtin(interpretShell, 1, 1))
def interpretEmpty(symbol, args, env):
ev = evaluate(args[0], env)
@@ -627,7 +614,7 @@ def interpretEmpty(symbol, args, env):
raise InterpretPanic(symbol, "requires a :list", ev)
return Bool(len(ev.args) == 0)
-GLOBALS.register("empty?", Builtin(interpretEmpty, 1))
+GLOBALS.register("empty?", Builtin(interpretEmpty, 1, 1))
def interpretShuf(symbol, args, env):
ev = evaluate(args[0], env)
@@ -637,13 +624,13 @@ def interpretShuf(symbol, args, env):
random.shuffle(items)
return List(items, True)
-GLOBALS.register("shuf", Builtin(interpretShuf, 1))
+GLOBALS.register("shuf", Builtin(interpretShuf, 1, 1))
def interpretIsList(symbol, args, env):
ev = evaluate(args[0], env)
return Bool(isinstance(ev, List))
-GLOBALS.register("list?", Builtin(interpretIsList, 1))
+GLOBALS.register("list?", Builtin(interpretIsList, 1, 1))
def interpretBlock(symbol, args, env):
ret = List([])
@@ -651,11 +638,9 @@ def interpretBlock(symbol, args, env):
ret = evaluate(arg, env)
return ret
-GLOBALS.register("block", Builtin(interpretBlock))
+GLOBALS.register("block", Builtin(interpretBlock, 1))
def interpretExit(symbol, args, env):
- if len(args) > 1:
- raise InterpretPanic(symbol, "expects one (optional) argument")
status = 0 if len(args) == 0 else evaluate(args[0], env).value
if not isinstance(status, int):
raise InterpretPanic(symbol, "expects an :int", status)
@@ -674,7 +659,7 @@ def interpretUnlink(symbol, args, env):
target_path.unlink()
return List([])
-GLOBALS.register("unlink", Builtin(interpretUnlink, 1))
+GLOBALS.register("unlink", Builtin(interpretUnlink, 1, 1))
def interpretArgv(symbol, args, env):
out = []
@@ -682,7 +667,7 @@ def interpretArgv(symbol, args, env):
out.append(String(arg))
return List(out, True)
-GLOBALS.register("argv", Builtin(interpretArgv, 0))
+GLOBALS.register("argv", Builtin(interpretArgv, 0, 0))
def interpretIn(symbol, args, env):
target = evaluate(args[0], env)
@@ -697,7 +682,7 @@ def interpretIn(symbol, args, env):
return Bool(True)
return Bool(False)
-GLOBALS.register("in?", Builtin(interpretIn, 2))
+GLOBALS.register("in?", Builtin(interpretIn, 2, 2))
def interpretLast(symbol, args, env):
ev = evaluate(args[0], env)
@@ -707,7 +692,7 @@ def interpretLast(symbol, args, env):
raise InterpretPanic("List is empty")
return evaluate(ev.args[-1], env)
-GLOBALS.register("last", Builtin(interpretLast, 1))
+GLOBALS.register("last", Builtin(interpretLast, 1, 1))
def interpretJoin(symbol, args, env):
lst = evaluate(args[0], env)
@@ -718,7 +703,7 @@ def interpretJoin(symbol, args, env):
raise InterpretPanic(symbol, "expects a :string as its second argument", target)
return String(target.value.join([a.value for a in lst.args]))
-GLOBALS.register("join", Builtin(interpretJoin, 2))
+GLOBALS.register("join", Builtin(interpretJoin, 2, 2))
def interpretWithWrite(symbol, args, env):
if len(args) == 0:
@@ -735,7 +720,7 @@ def interpretWithWrite(symbol, args, env):
ret = evaluate(arg, new_env)
return ret
-GLOBALS.register("with-write", Builtin(interpretWithWrite))
+GLOBALS.register("with-write", Builtin(interpretWithWrite, 1))
def interpretWrite(symbol, args, env):
# write :string :filehandle
@@ -746,12 +731,12 @@ def interpretWrite(symbol, args, env):
handle.args[0].write(line.value) # TODO wrong! how do we evaluate a handle?
return Literal([])
-GLOBALS.register("write", Builtin(interpretWrite, 2))
+GLOBALS.register("write", Builtin(interpretWrite, 2, 2))
def interpretNewline(symbol, args, env):
return String("\n")
-GLOBALS.register("newline", Builtin(interpretNewline, 0))
+GLOBALS.register("newline", Builtin(interpretNewline, 0, 0))
def interpretExists(symbol, args, env):
file_or_dir = evaluate(args[0], env)
@@ -759,7 +744,7 @@ def interpretExists(symbol, args, env):
raise InterpretPanic(symbol, "expects a :string", file_or_dir)
return Bool(Path(file_or_dir.value).resolve().exists())
-GLOBALS.register("exists?", Builtin(interpretExists, 1))
+GLOBALS.register("exists?", Builtin(interpretExists, 1, 1))
def interpretFirstChar(symbol, args, env):
ev = evaluate(args[0], env)
@@ -769,7 +754,7 @@ def interpretFirstChar(symbol, args, env):
raise InterpretPanic(symbol, ":string is empty", ev)
return String(ev.value[0])
-GLOBALS.register("first-char", Builtin(interpretFirstChar, 1))
+GLOBALS.register("first-char", Builtin(interpretFirstChar, 1, 1))
def interpretRestChar(symbol, args, env):
ev = evaluate(args[0], env)
@@ -777,7 +762,7 @@ def interpretRestChar(symbol, args, env):
raise InterpretPanic(symbol, "expects a string", ev)
return String(ev.value[1:])
-GLOBALS.register("rest-char", Builtin(interpretRestChar, 1))
+GLOBALS.register("rest-char", Builtin(interpretRestChar, 1, 1))
def interpretSlice(symbol, args, env):
lst = evaluate(args[0], env)
@@ -800,7 +785,7 @@ def interpretClear(symbol, args, env):
subprocess.run(["clear"])
return List([])
-GLOBALS.register("clear", Builtin(interpretClear, 0))
+GLOBALS.register("clear", Builtin(interpretClear, 0, 0))
def interpretInput(symbol, args, env):
ev = evaluate(args[0], env)
@@ -809,7 +794,7 @@ def interpretInput(symbol, args, env):
ret = input(ev.value)
return String(ret)
-GLOBALS.register("input", Builtin(interpretInput, 1))
+GLOBALS.register("input", Builtin(interpretInput, 1, 1))
def interpretAppend(symbol, args, env):
lst = evaluate(args[0], env)
@@ -819,7 +804,7 @@ def interpretAppend(symbol, args, env):
items = lst.args[:]
return List(items + [val], True)
-GLOBALS.register("append", Builtin(interpretAppend, 2))
+GLOBALS.register("append", Builtin(interpretAppend, 2, 2))
def interpretRemove(symbol, args, env):
lst = evaluate(args[0], env)
@@ -832,7 +817,7 @@ def interpretRemove(symbol, args, env):
out.append(arg)
return List(out, True)
-GLOBALS.register("remove", Builtin(interpretRemove, 2))
+GLOBALS.register("remove", Builtin(interpretRemove, 2, 2))
def interpretWhile(symbol, args, env):
cond = args[0]
@@ -847,12 +832,12 @@ def interpretWhile(symbol, args, env):
ret = evaluate(arg, env)
return ret
-GLOBALS.register("while", Builtin(interpretWhile))
+GLOBALS.register("while", Builtin(interpretWhile, 2))
def interpretAnsiEscape(symbol, args, env):
return String(f"\033")
-GLOBALS.register("ansi-escape", Builtin(interpretAnsiEscape, 0))
+GLOBALS.register("ansi-escape", Builtin(interpretAnsiEscape, 0, 0))
def interpretUse(symbol, args, env):
target_file_name = evaluate(args[0], env).value
@@ -864,5 +849,5 @@ def interpretUse(symbol, args, env):
interpret(parse(lex(data)))
return List([])
-GLOBALS.register("use", Builtin(interpretUse, 1))
+GLOBALS.register("use", Builtin(interpretUse, 1, 1))