Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow calling non-class functions #535

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"compile-lsig": "bun ./src/bin/tealscript.ts tests/contracts/lsig.algo.ts tests/contracts/artifacts",
"compile-inheritance": "bun ./src/bin/tealscript.ts tests/contracts/inheritance.algo.ts tests/contracts/artifacts",
"compile-avm11": "bun ./src/bin/tealscript.ts --skip-algod tests/contracts/avm11.algo.ts tests/contracts/artifacts",
"compile-functions": "bun ./src/bin/tealscript.ts tests/contracts/functions.algo.ts tests/contracts/artifacts",
"compile-amm": "bun ./src/bin/tealscript.ts examples/amm/amm.algo.ts examples/amm/tealscript_artifacts",
"compile-arc75": "bun src/bin/tealscript.ts examples/arc75/arc75.algo.ts examples/arc75/artifacts && algokitgen generate -a examples/arc75/artifacts/ARC75.arc32.json -o examples/arc75/ARC75Client.ts",
"compile-auction": "bun ./src/bin/tealscript.ts examples/auction/auction.algo.ts examples/auction/tealscript_artifacts",
Expand Down
144 changes: 100 additions & 44 deletions src/lib/compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ export type TEALInfo = {

/** @internal */
export default class Compiler {
private pendingSubroutines: ts.FunctionDeclaration[] = [];

static diagsRan: string[] = [''];

private scratch: { [name: string]: { slot?: number; type: TypeInfo; initNode: ts.CallExpression } } = {};
Expand Down Expand Up @@ -1464,7 +1466,14 @@ export default class Compiler {

this.processNode(node.getArguments()[0]);

const indexInScratch: boolean = this.teal[this.currentProgram].length - preTealLength > 1;
// Get the opcodes that were needed to process the txn index
const opcodes = this.teal[this.currentProgram]
.slice(preTealLength)
.map((t) => t.teal)
.filter((t) => !t.startsWith('//'));

// If more than one opcode was needed, it will be more efficient to store the index in scratch
const indexInScratch: boolean = opcodes.length > 1;

if (indexInScratch) {
this.pushVoid(node, `store ${compilerScratch.verifyTxnIndex}`);
Expand Down Expand Up @@ -2622,19 +2631,19 @@ export default class Compiler {
*
* @param methods The methods to process
*/
preProcessMethods(methods: ts.MethodDeclaration[]) {
preProcessMethods(methods: (ts.MethodDeclaration | ts.FunctionDeclaration)[]) {
methods.forEach((node) => {
if (!node.getNameNode().isKind(ts.SyntaxKind.Identifier)) throw Error('Method name must be identifier');
if (!node.getNameNode()?.isKind(ts.SyntaxKind.Identifier)) throw Error('Method name must be identifier');
const name = node.getNameNode()!.getText();
const typeNode = node.getReturnType();
if (typeNode === undefined)
throw Error(`A return type annotation must be defined for ${node.getNameNode().getText()}`);
if (typeNode === undefined) throw Error(`A return type annotation must be defined for ${name}`);

const returnType = node.getReturnTypeNode()?.getType()
? this.getTypeInfo(node.getReturnTypeNode()!.getType())
: StackType.void;

const sub = {
name: node.getNameNode().getText(),
name,
allows: { call: [], create: [] },
nonAbi: { call: [], create: [] },
args: [],
Expand Down Expand Up @@ -2953,6 +2962,10 @@ export default class Compiler {
});
}

while (this.pendingSubroutines.length > 0) {
this.processSubroutine(this.pendingSubroutines.pop()!);
}

this.teal[this.currentProgram] = await this.postProcessTeal(this.teal[this.currentProgram]);
this.teal[this.currentProgram] = optimizeTeal(this.teal[this.currentProgram]);
this.teal[this.currentProgram] = this.prettyTeal(this.teal[this.currentProgram]);
Expand Down Expand Up @@ -3034,6 +3047,10 @@ export default class Compiler {
}
});

while (this.pendingSubroutines.length > 0) {
this.processSubroutine(this.pendingSubroutines.pop()!);
}

this.teal.clear = await this.postProcessTeal(this.teal.clear);
this.teal.clear = optimizeTeal(this.teal.clear);
this.teal.clear = this.prettyTeal(this.teal.clear);
Expand Down Expand Up @@ -3323,8 +3340,14 @@ export default class Compiler {
else if (node.isKind(ts.SyntaxKind.IfStatement)) this.processIfStatement(node);
else if (node.isKind(ts.SyntaxKind.PrefixUnaryExpression)) this.processUnaryExpression(node);
else if (node.isKind(ts.SyntaxKind.BinaryExpression)) this.processBinaryExpression(node);
else if (node.isKind(ts.SyntaxKind.CallExpression)) this.processExpressionChain(node);
else if (node.isKind(ts.SyntaxKind.ExpressionStatement)) this.processExpressionStatement(node);
else if (node.isKind(ts.SyntaxKind.CallExpression)) {
const expr = node.getExpression();
if (expr.isKind(ts.SyntaxKind.PropertyAccessExpression)) {
this.processExpressionChain(node);
} else {
this.processCallExpression(node);
}
} else if (node.isKind(ts.SyntaxKind.ExpressionStatement)) this.processExpressionStatement(node);
else if (node.isKind(ts.SyntaxKind.ReturnStatement)) this.processReturnStatement(node);
else if (node.isKind(ts.SyntaxKind.ParenthesizedExpression)) this.processNode(node.getExpression());
else if (node.isKind(ts.SyntaxKind.VariableStatement)) this.processNode(node.getDeclarationList());
Expand Down Expand Up @@ -5014,7 +5037,10 @@ export default class Compiler {
.map((a) => a.getKind())
.includes(ts.SyntaxKind.ClassDeclaration);

if (!inClass) {
// This is true when we are in a non-class function and the identifier is a function parameter
const isFunctionParam = defNode.getParent()?.isKind(ts.SyntaxKind.FunctionDeclaration);

if (!inClass && !isFunctionParam) {
if (!defNode.isKind(ts.SyntaxKind.VariableDeclaration)) throw Error();
this.processNode(defNode.getInitializerOrThrow());
return;
Expand Down Expand Up @@ -6168,33 +6194,74 @@ export default class Compiler {

if (chain[0].isKind(ts.SyntaxKind.PropertyAccessExpression) && chain[1].isKind(ts.SyntaxKind.CallExpression)) {
const methodName = chain[0].getNameNode().getText();
const preArgsType = this.lastType;
const subroutine = this.subroutines.find((s) => s.name === methodName);
if (!subroutine) throw new Error(`Unknown subroutine ${methodName}`);
this.processCallExpression(chain[1]);
chain.splice(0, 2);
}
}

new Array(...chain[1].getArguments()).reverse().forEach((a, i) => {
const prevTypeHint = this.typeHint;
this.typeHint = subroutine.args[i].type;
this.processNode(a);
this.typeHint = prevTypeHint;
if (this.lastType.kind === 'base' && this.lastType.type.startsWith('unsafe ')) {
this.checkEncoding(a, this.lastType);
if (isSmallNumber(this.lastType)) this.push(a, 'btoi', this.lastType);
}
typeComparison(this.lastType, subroutine.args[i].type);
});
private processCallExpression(node: ts.CallExpression) {
this.addSourceComment(node);

this.lastType = preArgsType;
const returnTypeStr = typeInfoToABIString(subroutine.returns.type);
const expr = node.getExpression();
let methodName: string;
if (expr?.isKind(ts.SyntaxKind.PropertyAccessExpression)) {
methodName = expr.getNameNode().getText();
} else if (expr?.isKind(ts.SyntaxKind.Identifier)) {
methodName = expr.getText();

// If this is a custom method
if (this.customMethods[methodName] && this.customMethods[methodName].check(node)) {
this.customMethods[methodName].fn(node);
return;
}

// If this is an opcode
if (langspec.Ops.map((o) => o.Name).includes(this.opcodeAliases[methodName] ?? methodName)) {
this.processOpcode(node);
return;
}

// If a txn method like sendMethodCall, sendPayment, etc.
if (TXN_METHODS.includes(methodName)) {
const { returnType, argTypes, name } = this.methodTypeArgsToTypes(node.getTypeArguments());

let returnType = subroutine.returns.type;
if (returnTypeStr.match(/\d+$/) && !returnTypeStr.match(/^(uint|ufixed)64/)) {
returnType = { kind: 'base', type: `unsafe ${returnTypeStr}` };
this.processTransaction(node, methodName, node.getArguments()[0], argTypes, returnType, name);
return;
}
this.push(chain[1], `callsub ${methodName}`, returnType);
if (this.nodeDepth === 1 && !equalTypes(subroutine.returns.type, StackType.void)) this.pushVoid(chain[1], 'pop');
chain.splice(0, 2);

if (this.subroutines.find((s) => s.name === methodName) === undefined) {
const definition = expr.getDefinitionNodes()[0];
if (!definition.isKind(ts.SyntaxKind.FunctionDeclaration)) throw Error();
this.preProcessMethods([definition]);
this.pendingSubroutines.push(definition);
}
} else throw new Error(`Invalid parent for call expression: ${expr?.getKindName()} ${expr?.getText()}`);

const preArgsType = this.lastType;
const subroutine = this.subroutines.find((s) => s.name === methodName);
if (!subroutine) throw new Error(`Unknown subroutine ${methodName}`);

new Array(...node.getArguments()).reverse().forEach((a, i) => {
const prevTypeHint = this.typeHint;
this.typeHint = subroutine.args[i].type;
this.processNode(a);
this.typeHint = prevTypeHint;
if (this.lastType.kind === 'base' && this.lastType.type.startsWith('unsafe ')) {
this.checkEncoding(a, this.lastType);
if (isSmallNumber(this.lastType)) this.push(a, 'btoi', this.lastType);
}
typeComparison(this.lastType, subroutine.args[i].type);
});

this.lastType = preArgsType;
const returnTypeStr = typeInfoToABIString(subroutine.returns.type);

let returnType = subroutine.returns.type;
if (returnTypeStr.match(/\d+$/) && !returnTypeStr.match(/^(uint|ufixed)64/)) {
returnType = { kind: 'base', type: `unsafe ${returnTypeStr}` };
}
this.push(node, `callsub ${methodName}`, returnType);
if (this.nodeDepth === 1 && !equalTypes(subroutine.returns.type, StackType.void)) this.pushVoid(node, 'pop');
}

private methodTypeArgsToTypes(typeArgs: ts.TypeNode[]) {
Expand Down Expand Up @@ -6384,18 +6451,6 @@ export default class Compiler {
return;
}

// If a txn method like sendMethodCall, sendPayment, etc.
if (TXN_METHODS.includes(base.getText())) {
if (!chain[0].isKind(ts.SyntaxKind.CallExpression))
throw Error(`Unsupported ${chain[0].getKindName()} ${chain[0].getText()}`);

const { returnType, argTypes, name } = this.methodTypeArgsToTypes(chain[0].getTypeArguments());

this.processTransaction(node, base.getText(), chain[0].getArguments()[0], argTypes, returnType, name);
chain.splice(0, 1);
return;
}

// If this is a global variable
if (base.getText() === 'globals') {
if (!chain[0].isKind(ts.SyntaxKind.PropertyAccessExpression))
Expand Down Expand Up @@ -6633,8 +6688,9 @@ export default class Compiler {
);
}

private processSubroutine(fn: ts.MethodDeclaration) {
private processSubroutine(fn: ts.MethodDeclaration | ts.FunctionDeclaration) {
const frameStart = this.teal[this.currentProgram].length;
this.currentSubroutine = this.subroutines.find((s) => s.name === fn.getNameNode()?.getText())!;

const sigParams = fn
.getSignature()
Expand Down
130 changes: 130 additions & 0 deletions tests/contracts/artifacts/FunctionsTest.approval.teal
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#pragma version 10

// This TEAL was generated by TEALScript v0.101.0
// https://github.com/algorandfoundation/TEALScript

// This contract is compliant with and/or implements the following ARCs: [ ARC4 ]

// The following ten lines of TEAL handle initial program flow
// This pattern is used to make it easy for anyone to parse the start of the program and determine if a specific action is allowed
// Here, action refers to the OnComplete in combination with whether the app is being created or called
// Every possible action for this contract is represented in the switch statement
// If the action is not implemented in the contract, its respective branch will be "*NOT_IMPLEMENTED" which just contains "err"
txn ApplicationID
!
int 6
*
txn OnCompletion
+
switch *call_NoOp *NOT_IMPLEMENTED *NOT_IMPLEMENTED *NOT_IMPLEMENTED *NOT_IMPLEMENTED *NOT_IMPLEMENTED *create_NoOp *NOT_IMPLEMENTED *NOT_IMPLEMENTED *NOT_IMPLEMENTED *NOT_IMPLEMENTED *NOT_IMPLEMENTED

*NOT_IMPLEMENTED:
// The requested action is not implemented in this contract. Are you using the correct OnComplete? Did you set your app ID?
err

// callNonClassFunction(uint64,uint64)uint64
*abi_route_callNonClassFunction:
// The ABI return prefix
byte 0x151f7c75

// b: uint64
txna ApplicationArgs 2
btoi

// a: uint64
txna ApplicationArgs 1
btoi

// execute callNonClassFunction(uint64,uint64)uint64
callsub callNonClassFunction
itob
concat
log
int 1
return

// callNonClassFunction(a: uint64, b: uint64): uint64
callNonClassFunction:
proto 2 1

// tests/contracts/functions.algo.ts:10
// return nonClassFunction(a, b);
frame_dig -2 // b: uint64
frame_dig -1 // a: uint64
callsub nonClassFunction
retsub

// callExternalFunction(uint64,uint64)uint64
*abi_route_callExternalFunction:
// The ABI return prefix
byte 0x151f7c75

// b: uint64
txna ApplicationArgs 2
btoi

// a: uint64
txna ApplicationArgs 1
btoi

// execute callExternalFunction(uint64,uint64)uint64
callsub callExternalFunction
itob
concat
log
int 1
return

// callExternalFunction(a: uint64, b: uint64): uint64
callExternalFunction:
proto 2 1

// tests/contracts/functions.algo.ts:14
// return externalFunction(a, b);
frame_dig -2 // b: uint64
frame_dig -1 // a: uint64
callsub externalFunction
retsub

*abi_route_createApplication:
int 1
return

*create_NoOp:
method "createApplication()void"
txna ApplicationArgs 0
match *abi_route_createApplication

// this contract does not implement the given ABI method for create NoOp
err

*call_NoOp:
method "callNonClassFunction(uint64,uint64)uint64"
method "callExternalFunction(uint64,uint64)uint64"
txna ApplicationArgs 0
match *abi_route_callNonClassFunction *abi_route_callExternalFunction

// this contract does not implement the given ABI method for call NoOp
err

// externalFunction(a: uint64, b: uint64): uint64
externalFunction:
proto 2 1

// tests/contracts/functions-external.algo.ts:2
// return a + b;
frame_dig -1 // a: uint64
frame_dig -2 // b: uint64
+
retsub

// nonClassFunction(a: uint64, b: uint64): uint64
nonClassFunction:
proto 2 1

// tests/contracts/functions.algo.ts:5
// return a + b;
frame_dig -1 // a: uint64
frame_dig -2 // b: uint64
+
retsub
Loading
Loading