From 2821c14272c4296a64d94532fa8665ed53f5a0ef Mon Sep 17 00:00:00 2001 From: mryouse Date: Fri, 10 Jun 2022 00:16:53 +0000 Subject: refactor: more flexibility with builtin arities --- interpreter.py | 153 ++++++++++++++++++++++++++------------------------------- 1 file changed, 69 insertions(+), 84 deletions(-) (limited to 'interpreter.py') diff --git a/interpreter.py b/interpreter.py index a5fab45..0136a04 100644 --- a/interpreter.py +++ b/interpreter.py @@ -11,18 +11,23 @@ import sys class Function: - def __init__(self, name, params, body, *arities): + def __init__(self, name, params, body, min_arity=0, max_arity=-1): self.name = name self.params = params self.body = body - if len(arities) == 0: - self.arities = None - else: - self.arities = arities + self.min_arity = min_arity + self.max_arity = max_arity def arity_check(self, symbol, args): - if self.arities is not None and len(args) not in self.arities: - fmt = ", ".join([f"{arity}" for arity in self.arities]) + if len(args) < self.min_arity or (self.max_arity >= 0 and len(args) > self.max_arity): + if self.max_arity < 0: + fmt = f"{self.min_arity}+" + elif self.min_arity != self.max_arity: + fmt = f"{self.min_arity}-{self.max_arity}" + else: + fmt = f"{self.min_arity}" + #if self.arities is not None and len(args) not in self.arities: + #fmt = ", ".join([f"{arity}" for arity in self.arities]) raise InterpretPanic(symbol, f"expected [{fmt}] arguments, received {len(args)}") return True @@ -127,8 +132,6 @@ def evaluate(expr, env): def interpretOr(symbol, args, env): # or returns true for the first expression that returns true - if len(args) < 2: - raise InterpretPanic(symbol, "requires at least two arguments") for arg in args: ev = evaluate(arg, env) if not isinstance(ev, Bool): @@ -137,12 +140,10 @@ def interpretOr(symbol, args, env): return ev return Bool(False) -GLOBALS.register("or", Builtin(interpretOr)) +GLOBALS.register("or", Builtin(interpretOr, 2)) def interpretAnd(symbol, args, env): # and returns false for the first expression that returns false - if len(args) < 2: - raise InterpretPanic(symbol, "requires at least two arguments") for arg in args: ev = evaluate(arg, env) if not isinstance(ev, Bool): @@ -151,7 +152,7 @@ def interpretAnd(symbol, args, env): return ev return Bool(True) -GLOBALS.register("and", Builtin(interpretAnd)) +GLOBALS.register("and", Builtin(interpretAnd, 2)) def interpretEq(symbol, args, env): # equal @@ -167,7 +168,7 @@ def interpretEq(symbol, args, env): else: return Bool(False) -GLOBALS.register("eq?", Builtin(interpretEq, 2)) +GLOBALS.register("eq?", Builtin(interpretEq, 2, 2)) def interpretGreaterThan(symbol, args, env): left = evaluate(args[0], env) @@ -178,7 +179,7 @@ def interpretGreaterThan(symbol, args, env): raise InterpretPanic(symbol, "second argument must be a :number", right) return Bool(left.value > right.value) -GLOBALS.register(">", Builtin(interpretGreaterThan, 2)) +GLOBALS.register(">", Builtin(interpretGreaterThan, 2, 2)) def interpretGreaterThanEqual(symbol, args, env): left = evaluate(args[0], env) @@ -189,7 +190,7 @@ def interpretGreaterThanEqual(symbol, args, env): raise InterpretPanic(symbol, "second argument must be a :number", right) return Bool(left.value >= right.value) -GLOBALS.register(">=", Builtin(interpretGreaterThanEqual, 2)) +GLOBALS.register(">=", Builtin(interpretGreaterThanEqual, 2, 2)) def interpretLessThan(symbol, args, env): left = evaluate(args[0], env) @@ -200,7 +201,7 @@ def interpretLessThan(symbol, args, env): raise InterpretPanic(symbol, "second argument must be a :number", right) return Bool(left.value < right.value) -GLOBALS.register("<", Builtin(interpretLessThan, 2)) +GLOBALS.register("<", Builtin(interpretLessThan, 2, 2)) def interpretLessThanEqual(symbol, args, env): left = evaluate(args[0], env) @@ -211,11 +212,9 @@ def interpretLessThanEqual(symbol, args, env): raise InterpretPanic(symbol, "second argument must be a :number", right) return Bool(left.value <= right.value) -GLOBALS.register("<=", Builtin(interpretLessThanEqual, 2)) +GLOBALS.register("<=", Builtin(interpretLessThanEqual, 2, 2)) def interpretAddition(symbol, args, env): - if len(args) < 1: - raise InterpretPanic(symbol, "requires at least one argument") res = 0 for arg in args: ev = evaluate(arg, env) @@ -227,11 +226,9 @@ def interpretAddition(symbol, args, env): else: return Int(res) -GLOBALS.register("+", Builtin(interpretAddition)) +GLOBALS.register("+", Builtin(interpretAddition, 1)) def interpretSubtraction(symbol, args, env): - if len(args) < 1: - raise InterpretPanic(symbol, "requires at least one argument") first = evaluate(args[0], env) if not (isinstance(first, Int) or isinstance(first, Float)): raise InterpretPanic(symbol, "argument must be a :number", first) @@ -249,11 +246,9 @@ def interpretSubtraction(symbol, args, env): else: return Int(res) -GLOBALS.register("-", Builtin(interpretSubtraction)) +GLOBALS.register("-", Builtin(interpretSubtraction, 1)) def interpretMultiplication(symbol, args, env): - if len(args) < 2: - raise InterpretPanic(symbol, "requires at least two arguments") first = evaluate(args[0], env) if not (isinstance(first, Int) or isinstance(first, Float)): raise InterpretPanic(symbol, "argument must be a :number", first) @@ -268,7 +263,7 @@ def interpretMultiplication(symbol, args, env): else: return Int(res) -GLOBALS.register("*", Builtin(interpretMultiplication)) +GLOBALS.register("*", Builtin(interpretMultiplication, 2)) def interpretDivision(symbol, args, env): num = evaluate(args[0], env) @@ -283,7 +278,7 @@ def interpretDivision(symbol, args, env): else: return Float(ret) -GLOBALS.register("/", Builtin(interpretDivision, 2)) +GLOBALS.register("/", Builtin(interpretDivision, 2, 2)) def interpretNot(symbol, args, env): res = evaluate(args[0], env) @@ -291,7 +286,7 @@ def interpretNot(symbol, args, env): raise InterpretPanic(symbol, "requires a :bool", res) return Bool(not res.value) -GLOBALS.register("not", Builtin(interpretNot, 1)) +GLOBALS.register("not", Builtin(interpretNot, 1, 1)) def interpretIf(symbol, args, env): # if cond t-branch [f-branch] @@ -314,7 +309,7 @@ def interpretPrint(symbol, args, env): return List([]) # print returns nothing -GLOBALS.register("print", Builtin(interpretPrint, 1)) +GLOBALS.register("print", Builtin(interpretPrint, 1, 1)) def interpretDef(symbol, args, env): if not isinstance(args[0], Symbol): @@ -327,7 +322,7 @@ def interpretDef(symbol, args, env): env.register(name, ev) return List([]) -GLOBALS.register("def", Builtin(interpretDef, 2)) +GLOBALS.register("def", Builtin(interpretDef, 2, 2)) def interpretRedef(symbol, args, env): if not isinstance(args[0], Symbol): @@ -340,7 +335,7 @@ def interpretRedef(symbol, args, env): env.reregister(name, ev) return List([]) -GLOBALS.register("redef", Builtin(interpretRedef, 2)) +GLOBALS.register("redef", Builtin(interpretRedef, 2, 2)) def interpretLambda(symbol, args, env): if len(args[0].args) != 0: @@ -360,12 +355,10 @@ def interpretToString(symbol, args, env): else: return String(f"{ev}") -GLOBALS.register("->string", Builtin(interpretToString, 1)) +GLOBALS.register("->string", Builtin(interpretToString, 1, 1)) def interpretConcat(symbol, args, env): # concat str1 str2...strN - if len(args) < 2: - raise InterpretPanic(symbol, "requires at least two arguments") out = "" for arg in args: tmp = evaluate(arg, env) @@ -374,7 +367,7 @@ def interpretConcat(symbol, args, env): out += tmp.value return String(out) -GLOBALS.register("concat", Builtin(interpretConcat)) +GLOBALS.register("concat", Builtin(interpretConcat, 2)) def interpretForCount(symbol, args, env): # for-count int exprs @@ -391,7 +384,7 @@ def interpretForCount(symbol, args, env): return List([]) return ret -GLOBALS.register("for-count", Builtin(interpretForCount)) +GLOBALS.register("for-count", Builtin(interpretForCount, 2)) def interpretForEach(symbol, args, env): # for-each list exprs @@ -408,11 +401,9 @@ def interpretForEach(symbol, args, env): return List([]) return ret -GLOBALS.register("for-each", Builtin(interpretForEach)) +GLOBALS.register("for-each", Builtin(interpretForEach, 2)) def interpretPipe(symbol, args, env): - if len(args) < 2: - raise InterpretPanic(symbol, "requires at least two expressions") new_env = Environment(env) pipe = None for arg in args: @@ -423,11 +414,9 @@ def interpretPipe(symbol, args, env): return List([]) return pipe -GLOBALS.register("|", Builtin(interpretPipe)) +GLOBALS.register("|", Builtin(interpretPipe, 2)) def interpretBranch(symbol, args, env): - if len(args) == 0: - raise InterpretPanic(symbol, "requires at least one pair of expressions") for arg in args: if len(arg.args) != 2: raise InterpretPanic(symbol, "each branch requires two expressions") @@ -438,12 +427,10 @@ def interpretBranch(symbol, args, env): return evaluate(arg.args[1], env) return List([]) -GLOBALS.register("branch", Builtin(interpretBranch)) +GLOBALS.register("branch", Builtin(interpretBranch, 1)) def interpretFunc(symbol, args, env): # func (args) (exprs) - if len(args) < 3: - raise InterpretPanic(symbol, "requires a name, argument list, and at least one expression") if not isinstance(args[0], Symbol): raise InterpretPanic(symbol, "requires a :string name") name = args[0].name # NOTE: we are not evaluating the name!! @@ -454,7 +441,7 @@ def interpretFunc(symbol, args, env): env.register(name, func) return List([]) -GLOBALS.register("func", Builtin(interpretFunc)) +GLOBALS.register("func", Builtin(interpretFunc, 3)) # THINGS NEEDED FOR AOC # - read the contents of a file @@ -468,7 +455,7 @@ def interpretReadLines(symbol, args, env): out = List([String(d) for d in data], True) # all lines are strings return out -GLOBALS.register("read-lines", Builtin(interpretReadLines, 1)) +GLOBALS.register("read-lines", Builtin(interpretReadLines, 1, 1)) # - strip whitespace from string def interpretStrip(symbol, args, env): @@ -477,7 +464,7 @@ def interpretStrip(symbol, args, env): raise InterpretPanic(symbol, "requires a :string", out) return String(out.value.strip()) -GLOBALS.register("strip", Builtin(interpretStrip, 1)) +GLOBALS.register("strip", Builtin(interpretStrip, 1, 1)) # - string->int and string->float def interpretStringToInt(symbol, args, env): @@ -490,7 +477,7 @@ def interpretStringToInt(symbol, args, env): except: raise InterpretPanic(symbol, "can't convert to an :int", ev) -GLOBALS.register("string->int", Builtin(interpretStringToInt, 1)) +GLOBALS.register("string->int", Builtin(interpretStringToInt, 1, 1)) # - split a string by a given field def interpretSplit(symbol, args, env): @@ -514,7 +501,7 @@ def interpretListLength(symbol, args, env): raise InterpretPanic(symbol, "requires a :list", ev) return Int(len(ev.args)) -GLOBALS.register("list-length", Builtin(interpretListLength, 1)) +GLOBALS.register("list-length", Builtin(interpretListLength, 1, 1)) # - first/rest of list def interpretFirst(symbol, args, env): @@ -525,7 +512,7 @@ def interpretFirst(symbol, args, env): raise InterpretPanic(symbol, "list is empty") return evaluate(ev.args[0], env) -GLOBALS.register("first", Builtin(interpretFirst, 1)) +GLOBALS.register("first", Builtin(interpretFirst, 1, 1)) def interpretRest(symbol, args, env): ev = evaluate(args[0], env) @@ -534,7 +521,7 @@ def interpretRest(symbol, args, env): # TODO do we know it's not evaluated? return List(ev.args[1:], True) # we don't evaluate the remainder of the list -GLOBALS.register("rest", Builtin(interpretRest, 1)) +GLOBALS.register("rest", Builtin(interpretRest, 1, 1)) # - iterate over list # - map @@ -552,7 +539,7 @@ def interpretMap(symbol, args, env): out.append(ev) return List(out, True) -GLOBALS.register("map", Builtin(interpretMap, 2)) +GLOBALS.register("map", Builtin(interpretMap, 2, 2)) def interpretZip(symbol, args, env): z1 = evaluate(args[0], env) @@ -570,7 +557,7 @@ def interpretZip(symbol, args, env): out.append(List([f, s], True)) return List(out, True) -GLOBALS.register("zip", Builtin(interpretZip, 2)) +GLOBALS.register("zip", Builtin(interpretZip, 2, 2)) def interpretList(symbol, args, env): out = [] @@ -578,7 +565,7 @@ def interpretList(symbol, args, env): out.append(evaluate(arg, env)) return List(out, True) -GLOBALS.register("list", Builtin(interpretList)) +GLOBALS.register("list", Builtin(interpretList, 0)) def interpretListReverse(symbol, args, env): lst = evaluate(args[0], env) @@ -588,7 +575,7 @@ def interpretListReverse(symbol, args, env): new_args.reverse() return List(new_args, True) -GLOBALS.register("list-reverse", Builtin(interpretListReverse, 1)) +GLOBALS.register("list-reverse", Builtin(interpretListReverse, 1, 1)) def interpretApply(symbol, args, env): func = args[0] @@ -600,7 +587,7 @@ def interpretApply(symbol, args, env): new_lst = List([func] + lst.args) return evaluate(new_lst, env) -GLOBALS.register("apply", Builtin(interpretApply, 2)) +GLOBALS.register("apply", Builtin(interpretApply, 2, 2)) def interpretGlob(symbol, args, env): ev = evaluate(args[0], env) @@ -609,7 +596,7 @@ def interpretGlob(symbol, args, env): items = glob(ev.value) return List([String(item) for item in items], True) -GLOBALS.register("glob", Builtin(interpretGlob, 1)) +GLOBALS.register("glob", Builtin(interpretGlob, 1, 1)) def interpretShell(symbol, args, env): ev = evaluate(args[0], env) @@ -619,7 +606,7 @@ def interpretShell(symbol, args, env): ret = subprocess.run(shlex.split(ev.value), capture_output=True) return List([String(r) for r in ret.stdout.decode("utf-8").split("\n")], True) -GLOBALS.register("$", Builtin(interpretShell, 1)) +GLOBALS.register("$", Builtin(interpretShell, 1, 1)) def interpretEmpty(symbol, args, env): ev = evaluate(args[0], env) @@ -627,7 +614,7 @@ def interpretEmpty(symbol, args, env): raise InterpretPanic(symbol, "requires a :list", ev) return Bool(len(ev.args) == 0) -GLOBALS.register("empty?", Builtin(interpretEmpty, 1)) +GLOBALS.register("empty?", Builtin(interpretEmpty, 1, 1)) def interpretShuf(symbol, args, env): ev = evaluate(args[0], env) @@ -637,13 +624,13 @@ def interpretShuf(symbol, args, env): random.shuffle(items) return List(items, True) -GLOBALS.register("shuf", Builtin(interpretShuf, 1)) +GLOBALS.register("shuf", Builtin(interpretShuf, 1, 1)) def interpretIsList(symbol, args, env): ev = evaluate(args[0], env) return Bool(isinstance(ev, List)) -GLOBALS.register("list?", Builtin(interpretIsList, 1)) +GLOBALS.register("list?", Builtin(interpretIsList, 1, 1)) def interpretBlock(symbol, args, env): ret = List([]) @@ -651,11 +638,9 @@ def interpretBlock(symbol, args, env): ret = evaluate(arg, env) return ret -GLOBALS.register("block", Builtin(interpretBlock)) +GLOBALS.register("block", Builtin(interpretBlock, 1)) def interpretExit(symbol, args, env): - if len(args) > 1: - raise InterpretPanic(symbol, "expects one (optional) argument") status = 0 if len(args) == 0 else evaluate(args[0], env).value if not isinstance(status, int): raise InterpretPanic(symbol, "expects an :int", status) @@ -674,7 +659,7 @@ def interpretUnlink(symbol, args, env): target_path.unlink() return List([]) -GLOBALS.register("unlink", Builtin(interpretUnlink, 1)) +GLOBALS.register("unlink", Builtin(interpretUnlink, 1, 1)) def interpretArgv(symbol, args, env): out = [] @@ -682,7 +667,7 @@ def interpretArgv(symbol, args, env): out.append(String(arg)) return List(out, True) -GLOBALS.register("argv", Builtin(interpretArgv, 0)) +GLOBALS.register("argv", Builtin(interpretArgv, 0, 0)) def interpretIn(symbol, args, env): target = evaluate(args[0], env) @@ -697,7 +682,7 @@ def interpretIn(symbol, args, env): return Bool(True) return Bool(False) -GLOBALS.register("in?", Builtin(interpretIn, 2)) +GLOBALS.register("in?", Builtin(interpretIn, 2, 2)) def interpretLast(symbol, args, env): ev = evaluate(args[0], env) @@ -707,7 +692,7 @@ def interpretLast(symbol, args, env): raise InterpretPanic("List is empty") return evaluate(ev.args[-1], env) -GLOBALS.register("last", Builtin(interpretLast, 1)) +GLOBALS.register("last", Builtin(interpretLast, 1, 1)) def interpretJoin(symbol, args, env): lst = evaluate(args[0], env) @@ -718,7 +703,7 @@ def interpretJoin(symbol, args, env): raise InterpretPanic(symbol, "expects a :string as its second argument", target) return String(target.value.join([a.value for a in lst.args])) -GLOBALS.register("join", Builtin(interpretJoin, 2)) +GLOBALS.register("join", Builtin(interpretJoin, 2, 2)) def interpretWithWrite(symbol, args, env): if len(args) == 0: @@ -735,7 +720,7 @@ def interpretWithWrite(symbol, args, env): ret = evaluate(arg, new_env) return ret -GLOBALS.register("with-write", Builtin(interpretWithWrite)) +GLOBALS.register("with-write", Builtin(interpretWithWrite, 1)) def interpretWrite(symbol, args, env): # write :string :filehandle @@ -746,12 +731,12 @@ def interpretWrite(symbol, args, env): handle.args[0].write(line.value) # TODO wrong! how do we evaluate a handle? return Literal([]) -GLOBALS.register("write", Builtin(interpretWrite, 2)) +GLOBALS.register("write", Builtin(interpretWrite, 2, 2)) def interpretNewline(symbol, args, env): return String("\n") -GLOBALS.register("newline", Builtin(interpretNewline, 0)) +GLOBALS.register("newline", Builtin(interpretNewline, 0, 0)) def interpretExists(symbol, args, env): file_or_dir = evaluate(args[0], env) @@ -759,7 +744,7 @@ def interpretExists(symbol, args, env): raise InterpretPanic(symbol, "expects a :string", file_or_dir) return Bool(Path(file_or_dir.value).resolve().exists()) -GLOBALS.register("exists?", Builtin(interpretExists, 1)) +GLOBALS.register("exists?", Builtin(interpretExists, 1, 1)) def interpretFirstChar(symbol, args, env): ev = evaluate(args[0], env) @@ -769,7 +754,7 @@ def interpretFirstChar(symbol, args, env): raise InterpretPanic(symbol, ":string is empty", ev) return String(ev.value[0]) -GLOBALS.register("first-char", Builtin(interpretFirstChar, 1)) +GLOBALS.register("first-char", Builtin(interpretFirstChar, 1, 1)) def interpretRestChar(symbol, args, env): ev = evaluate(args[0], env) @@ -777,7 +762,7 @@ def interpretRestChar(symbol, args, env): raise InterpretPanic(symbol, "expects a string", ev) return String(ev.value[1:]) -GLOBALS.register("rest-char", Builtin(interpretRestChar, 1)) +GLOBALS.register("rest-char", Builtin(interpretRestChar, 1, 1)) def interpretSlice(symbol, args, env): lst = evaluate(args[0], env) @@ -800,7 +785,7 @@ def interpretClear(symbol, args, env): subprocess.run(["clear"]) return List([]) -GLOBALS.register("clear", Builtin(interpretClear, 0)) +GLOBALS.register("clear", Builtin(interpretClear, 0, 0)) def interpretInput(symbol, args, env): ev = evaluate(args[0], env) @@ -809,7 +794,7 @@ def interpretInput(symbol, args, env): ret = input(ev.value) return String(ret) -GLOBALS.register("input", Builtin(interpretInput, 1)) +GLOBALS.register("input", Builtin(interpretInput, 1, 1)) def interpretAppend(symbol, args, env): lst = evaluate(args[0], env) @@ -819,7 +804,7 @@ def interpretAppend(symbol, args, env): items = lst.args[:] return List(items + [val], True) -GLOBALS.register("append", Builtin(interpretAppend, 2)) +GLOBALS.register("append", Builtin(interpretAppend, 2, 2)) def interpretRemove(symbol, args, env): lst = evaluate(args[0], env) @@ -832,7 +817,7 @@ def interpretRemove(symbol, args, env): out.append(arg) return List(out, True) -GLOBALS.register("remove", Builtin(interpretRemove, 2)) +GLOBALS.register("remove", Builtin(interpretRemove, 2, 2)) def interpretWhile(symbol, args, env): cond = args[0] @@ -847,12 +832,12 @@ def interpretWhile(symbol, args, env): ret = evaluate(arg, env) return ret -GLOBALS.register("while", Builtin(interpretWhile)) +GLOBALS.register("while", Builtin(interpretWhile, 2)) def interpretAnsiEscape(symbol, args, env): return String(f"\033") -GLOBALS.register("ansi-escape", Builtin(interpretAnsiEscape, 0)) +GLOBALS.register("ansi-escape", Builtin(interpretAnsiEscape, 0, 0)) def interpretUse(symbol, args, env): target_file_name = evaluate(args[0], env).value @@ -864,5 +849,5 @@ def interpretUse(symbol, args, env): interpret(parse(lex(data))) return List([]) -GLOBALS.register("use", Builtin(interpretUse, 1)) +GLOBALS.register("use", Builtin(interpretUse, 1, 1)) -- cgit v1.2.3