Skip to content

Commit

Permalink
Handle string literals in range()
Browse files Browse the repository at this point in the history
  • Loading branch information
huynhtrankhanh committed Dec 7, 2023
1 parent baa3717 commit ed451a1
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 55 deletions.
1 change: 1 addition & 0 deletions _CoqProject
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ Codeforces/contests/1154/A/restore_three_numbers.v
theories/BubbleSort.v
theories/BubbleSortProperties.v
theories/Imperative.v
theories/ExampleProgram.v
2 changes: 1 addition & 1 deletion compiler/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function transform(
case 'undefined binder':
case 'not representable int64':
case 'bad number literal':
case 'range end must be int64':
case 'range end must be int64 or string':
case 'instruction expects int8':
case 'instruction expects int64':
case 'instruction expects tuple':
Expand Down
40 changes: 32 additions & 8 deletions compiler/coqCodegen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ Proof. simpl. repeat destruct (decide _). all: solve_decision. Defined.

// every element of body is an Action returning absolutely anything
const statements = body.map((statement) => {
const localBinderMap = new Map<string, number>()
type BinderInfo = { number: number; type: 'int8' | 'int64' }
const localBinderMap = new Map<string, BinderInfo>()
let binderCounter = 0
const dfs = (
value: ValueType
Expand Down Expand Up @@ -439,6 +440,8 @@ Proof. simpl. repeat destruct (decide _). all: solve_decision. Defined.
type: 'int64',
}
}
case 'string':
assert(false, 'only makes sense within range()')
}
}
case 'subscript': {
Expand Down Expand Up @@ -497,27 +500,48 @@ Proof. simpl. repeat destruct (decide _). all: solve_decision. Defined.
const { name } = value
const binder = localBinderMap.get(name)
assert(binder !== undefined)
return { expression: 'binder_' + binder, type: 'int64' }
return {
expression: 'binder_' + binder.number,
type: binder.type,
}
}
case 'range': {
const { loopVariable, loopBody, end } = value
const previousBinderValue = localBinderMap.get(loopVariable)

localBinderMap.set(loopVariable, binderCounter++)
if (end.type === 'literal' && end.valueType === 'string')
localBinderMap.set(loopVariable, {
type: 'int8',
number: binderCounter++,
})
else
localBinderMap.set(loopVariable, {
type: 'int64',
number: binderCounter++,
})

const bodyExpression = joinStatements(
loopBody.map(dfs).map((x) => x.expression)
)

const { expression: endExpression } = dfs(end)

if (previousBinderValue === undefined) localBinderMap.delete(name)
else localBinderMap.set(loopVariable, previousBinderValue)
binderCounter--

return {
expression: `(bind ${endExpression} (fun x => loop (Z.to_nat x) (fun binder_${binderCounter}_intermediate => let binder_${binderCounter} := Done _ _ _ (Z.sub (Z.sub x (Z.of_nat binder_${binderCounter}_intermediate)) 1%Z) in bind (${bodyExpression}) (fun ignored => Done _ _ _ KeepGoing))))`,
type: 'statement',
if (end.type === 'literal' && end.valueType === 'string') {
return {
expression: `(loopString (${getCoqString(
end.raw
)}) (fun binder_${binderCounter}_intermediate => let binder_${binderCounter} := Done _ _ _ binder_${binderCounter}_intermediate in bind (${bodyExpression}) (fun ignored => Done _ _ _ KeepGoing)))`,
type: 'statement',
}
} else {
return {
expression: `(bind ${
dfs(end).expression
} (fun x => loop (Z.to_nat x) (fun binder_${binderCounter}_intermediate => let binder_${binderCounter} := Done _ _ _ (Z.sub (Z.sub x (Z.of_nat binder_${binderCounter}_intermediate)) 1%Z) in bind (${bodyExpression}) (fun ignored => Done _ _ _ KeepGoing))))`,
type: 'statement',
}
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/cppCodegen.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ int main() {
continue;
flushSTDOUT();
}
for (uint8_t binder_0 : { 104, 101, 108, 108, 111, 32, 108, 105, 102, 101 }) {
writeChar(binder_0);
}
(uint64_t(2) / uint64_t(3));
(toSigned(uint64_t(1)) / toSigned(uint64_t(2)));
binaryOp([&]() { return readChar(); }, [&]() { return uint8_t(uint64_t(3)); }, [&](auto a, auto b) { return a / b; });
Expand Down
68 changes: 46 additions & 22 deletions compiler/cppCodegen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -347,28 +347,52 @@ export const cppCodegen = ({ environment, procedures }: CoqCPAST): string => {

const baseIndent = indent.repeat(indentationLevel)

const constructed =
baseIndent +
'for (uint64_t binder_' +
index +
' = 0; binder_' +
index +
' < ' +
print(end) +
'; binder_' +
index +
'++) {\n' +
loopBody
.map((x) =>
print(x, {
type: 'inside block',
indentationLevel: indentationLevel + 1,
})
)
.join('') +
baseIndent +
'}\n'

let constructed = ''
if (end.type === 'literal' && end.valueType === 'string') {
constructed =
baseIndent +
'for (uint8_t binder_' +
index +
' : { ' +
(() => {
const encoder = new TextEncoder()
const encoded = encoder.encode(end.raw)
return encoded.join(", ")
})() +
' }) {\n' +
loopBody
.map((x) =>
print(x, {
type: 'inside block',
indentationLevel: indentationLevel + 1,
})
)
.join('') +
baseIndent +
'}\n'
} else {
constructed =
baseIndent +
'for (uint64_t binder_' +
index +
' = 0; binder_' +
index +
' < ' +
print(end) +
'; binder_' +
index +
'++) {\n' +
loopBody
.map((x) =>
print(x, {
type: 'inside block',
indentationLevel: indentationLevel + 1,
})
)
.join('') +
baseIndent +
'}\n'
}
if (previousIndex === undefined) {
localBinderMap.delete(loopVariable)
} else {
Expand Down
2 changes: 2 additions & 0 deletions compiler/exampleCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ procedure("fibonacci", { n: int32, a: int32, b: int32, i: int32 }, () => {
"flush"
})
range("hello life", x => { writeChar(x) })
divide(2, 3)
sDivide(1, 2)
Expand Down
4 changes: 4 additions & 0 deletions compiler/printCoqCode.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import { transformer } from './exampleCode'
import { coqCodegen } from './coqCodegen'

process.stdout.write(coqCodegen(transformer.transform()))
13 changes: 13 additions & 0 deletions compiler/validateAST.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ procedure("hello", { a: bool }, () => {
`procedure('hello', {}, () => {
range(3, x => {})
})`,
`procedure('hello', {}, () => {
range("hello", a => {});
})`,
`procedure('hello', { a: int8 }, () => {
range("hello", a => {
writeChar(get("a") + a)
})
})`
]
for (const program of programs) {
if (!noErrors(program)) {
Expand Down Expand Up @@ -175,6 +183,11 @@ procedure("hello", { a: bool }, () => {
`procedure('hello', {}, () => {
range(74829387492847492947392874928473974929748293737, x => {})
})`,
`procedure('hello', { a: int64 }, () => {
range("hello", a => {
writeChar(get("a") + a)
})
})`
]
for (const program of programs) {
if (!hasValidationErrorsOnly(program)) {
Expand Down
66 changes: 43 additions & 23 deletions compiler/validateAST.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,36 @@ import {
export type ValidationError = (
| {
type: 'binary expression expects numeric' | 'instruction expects numeric'
actualType1: PrimitiveType | PrimitiveType[]
actualType2: PrimitiveType | PrimitiveType[]
actualType1: PrimitiveType | PrimitiveType[] | 'string'
actualType2: PrimitiveType | PrimitiveType[] | 'string'
}
| {
type: 'binary expression expects boolean'
actualType1: PrimitiveType | PrimitiveType[]
actualType2: PrimitiveType | PrimitiveType[]
actualType1: PrimitiveType | PrimitiveType[] | 'string'
actualType2: PrimitiveType | PrimitiveType[] | 'string'
}
| {
type: 'binary expression type mismatch' | 'instruction type mismatch'
actualType1: PrimitiveType | PrimitiveType[]
actualType2: PrimitiveType | PrimitiveType[]
actualType1: PrimitiveType | PrimitiveType[] | 'string'
actualType2: PrimitiveType | PrimitiveType[] | 'string'
}
| { type: 'expression no statement' }
| { type: 'procedure not found'; name: string }
| { type: 'variable not present'; variables: string[] }
| {
type: 'variable type mismatch'
expectedType: PrimitiveType | PrimitiveType[]
actualType: PrimitiveType | PrimitiveType[]
expectedType: PrimitiveType | PrimitiveType[] | 'string'
actualType: PrimitiveType | PrimitiveType[] | 'string'
}
| {
type: 'condition must be boolean'
actualType: PrimitiveType | PrimitiveType[]
actualType: PrimitiveType | PrimitiveType[] | 'string'
}
| { type: 'no surrounding range command' }
| { type: 'undefined variable' | 'undefined binder' }
| { type: 'not representable int64' }
| { type: 'bad number literal' }
| { type: 'range end must be int64' }
| { type: 'range end must be int64 or string' }
| {
type:
| 'instruction expects int8'
Expand All @@ -52,6 +52,7 @@ export type ValidationError = (
type:
| 'unary operator expects numeric'
| "unary operator can't operate on tuples"
| "unary operator can't operate on strings"
| 'unary operator expects boolean'
}
| { type: "array length can't be negative" }
Expand Down Expand Up @@ -100,9 +101,14 @@ export const validateAST = ({
}
const procedureMap = new Map<string, Procedure>()
for (const procedure of procedures) {
type Type = PrimitiveType | 'statement' | 'illegal' | PrimitiveType[]
type Type =
| PrimitiveType
| 'string'
| 'statement'
| 'illegal'
| PrimitiveType[]
let hasSurroundingRangeCommand = false
const presentBinders = new Set<string>()
const presentBinderType = new Map<string, 'int64' | 'int8'>()
const dfs = (instruction: ValueType): Type => {
switch (instruction.type) {
case 'binaryOp': {
Expand Down Expand Up @@ -313,10 +319,10 @@ export const validateAST = ({
return instruction.type === 'coerceInt16'
? 'int16'
: instruction.type === 'coerceInt32'
? 'int32'
: instruction.type === 'coerceInt64'
? 'int64'
: 'int8'
? 'int32'
: instruction.type === 'coerceInt64'
? 'int64'
: 'int8'
}
case 'condition': {
const { alternate, body, condition, location } = instruction
Expand Down Expand Up @@ -414,6 +420,7 @@ export const validateAST = ({
}
case 'literal': {
if (instruction.valueType === 'boolean') return 'bool'
else if (instruction.valueType === 'string') return 'string'
else if (instruction.valueType === 'number') {
if (
instruction.raw !== '0' &&
Expand All @@ -439,14 +446,15 @@ export const validateAST = ({
}
case 'local binder': {
const { name } = instruction
if (!presentBinders.has(name)) {
const binderType = presentBinderType.get(name)
if (binderType === undefined) {
errors.push({
type: 'undefined binder',
location: instruction.location,
})
return 'illegal'
}
return 'int64'
return binderType
}
case 'range': {
const { end, loopBody, loopVariable } = instruction
Expand All @@ -459,22 +467,27 @@ export const validateAST = ({
})
return 'illegal'
}
if (endType !== 'int64') {
if (endType !== 'int64' && endType !== 'string') {
errors.push({
type: 'range end must be int64',
type: 'range end must be int64 or string',
location: instruction.location,
})
return 'illegal'
}
const previousHasSurroundingRangeCommand = hasSurroundingRangeCommand
hasSurroundingRangeCommand = true
const binderPresentBefore = presentBinders.has(loopVariable)
presentBinders.add(loopVariable)
const binderTypeBefore = presentBinderType.get(loopVariable)
presentBinderType.set(
loopVariable,
endType === 'string' ? 'int8' : 'int64'
)
const result = loopBody.map(dfs).includes('illegal')
? 'illegal'
: 'statement'
hasSurroundingRangeCommand = previousHasSurroundingRangeCommand
if (!binderPresentBefore) presentBinders.delete(loopVariable)
if (binderTypeBefore === undefined)
presentBinderType.delete(loopVariable)
else presentBinderType.set(loopVariable, binderTypeBefore)
return result
}
case 'readChar':
Expand Down Expand Up @@ -654,6 +667,13 @@ export const validateAST = ({
})
return 'illegal'
}
if (valueType === 'string') {
errors.push({
type: "unary operator can't operate on strings",
location: instruction.location,
})
return 'illegal'
}
if (valueType === 'bool') {
switch (operator) {
case 'bitwise not':
Expand Down
Loading

0 comments on commit ed451a1

Please sign in to comment.