diff --git a/compiler/coqCodegen.ts b/compiler/coqCodegen.ts index 2a22661..6c92058 100644 --- a/compiler/coqCodegen.ts +++ b/compiler/coqCodegen.ts @@ -320,7 +320,81 @@ Proof. simpl. repeat destruct (decide _). all: solve_decision. Defined. type: 'statement', } } - case '': { + case 'less': { + const { expression: leftExpression } = dfs(value.left) + const { expression: rightExpression } = dfs(value.right) + return { + expression: `(bind ${leftExpression} (fun a => bind ${rightExpression} (fun b => bool_decide (a < b))))`, + type: 'bool', + } + } + case 'sLess': { + const { expression: leftExpression, type: leftType } = dfs( + value.left + ) + const { expression: rightExpression, type: rightType } = dfs( + value.right + ) + assert(leftType === rightType) + assert(isNumeric(leftType)) + const bitWidth = getBitWidth(leftType) + const toSigned = 'toSigned' + bitWidth + return { + expression: `(bind ${leftExpression} (fun a => bind ${rightExpression} (fun b => bool_decide (${toSigned} a < ${toSigned} b))))`, + type: 'bool', + } + } + case 'unaryOp': { + const { expression, type } = dfs(value.value) + switch (value.operator) { + case 'bitwise not': { + assert(isNumeric(type)) + const bitWidth = getBitWidth(type) + return { + expression: `(notBits ${bitWidth} ${expression})`, + type, + } + } + case 'boolean not': { + return { + expression: `(bind ${expression} (fun x => ~x))`, + type, + } + } + case 'plus': { + return { expression, type } + } + case 'minus': { + return { + expression: `(bind ${expression} (fun x => -x))`, + type, + } + } + } + } + case "literal": { + switch (value.valueType) { + case "boolean": { + return { expression: `(Done _ _ _ ${value.raw})`, type: "bool" } + } + case "number": { + const number = BigInt(value.raw) + const converted = number < 0n ? (2n ** 64n + number) : number + return { expression: `(Done _ _ _ ${converted}%Z)`, type: "int64" } + } + } + } + case "subscript": { + const { expression, type } = dfs(value.value) + assert(Array.isArray(type)) + const length = type.length + const index = Number(value.index.raw) + // because of validation, this is nonnegative and less than length + const reverseIndex = length - index - 1 + let finalExpression = expression + for (let i = 0; i < reverseIndex; i++) finalExpression = "fst (" + finalExpression + ")" + finalExpression = `(snd (${finalExpression}))` + return { expression: finalExpression, type: type[index] } } } }