Skip to content

Commit

Permalink
some changes and gen and ssa improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Westsi committed Jan 9, 2025
1 parent 67f6833 commit af053f4
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 99 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ To see output, view `out/ssa`. An example program is shown below.

Before optimizations:
```
func main(int v1, int v2) int {
func test(int v1, int v2) int {
int v3 = 7 + v1
int v4 = 8 + v2
int v5 = v3 + v4
Expand All @@ -41,7 +41,7 @@ func main(int v1, int v2) int {

After optimizations:
```
func main(int v1, int v2) int {
func test(int v1, int v2) int {
v5 = 15 + v1 + v2
return v5
}
Expand Down
1 change: 1 addition & 0 deletions ci/test/metadata.tests
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ define:8
import:48
fncall:27
while:3
nested:1
11 changes: 11 additions & 0 deletions ci/test/nested.dor
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
int main() {
int x = 8
int y = 12
if (x > 6) {
if (y == 11) {
return 0
}
return 1
}
return 113
}
105 changes: 78 additions & 27 deletions codegen/aarch64_clang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type AARCH64Generator struct {
AST ast.Program
VirtualStack *util.Armstack[codegen.VTabVar]
VirtualRegisters map[StorageLoc]string
LabelCounter int
ConditionCounter int
Gdefs map[string]string
}

Expand Down Expand Up @@ -70,14 +70,14 @@ var FNCallRegs = []StorageLoc{X0, X1, X2, X3, X4, X5, X6, X7}

// https://johannst.github.io/notes/arch/arm64.html

func New(fpath string, ast *ast.Program, defs map[string]string, lc int) *AARCH64Generator {
func New(fpath string, ast *ast.Program, defs map[string]string, cc int) *AARCH64Generator {
generator := &AARCH64Generator{
fpath: fpath,
out: strings.Builder{},
AST: *ast,
VirtualStack: util.NewAStack[codegen.VTabVar](32),
VirtualRegisters: map[StorageLoc]string{},
LabelCounter: lc,
ConditionCounter: cc,
Gdefs: defs,
}
os.MkdirAll("out/aarch64", os.ModePerm)
Expand All @@ -93,7 +93,7 @@ func (g *AARCH64Generator) Generate() int {
g.GenerateFunction(stmt)
}
}
return g.LabelCounter
return g.ConditionCounter
}

func (g *AARCH64Generator) e(tok lex.LexedTok, err string) {
Expand Down Expand Up @@ -148,8 +148,8 @@ func (g *AARCH64Generator) GenerateExpression(node ast.Expression) StorageLoc {
return g.GenerateIntegerLiteral(node)
case *ast.IfExpression:
g.GenerateIf(node)
// case *ast.WhileExpression:
// g.GenerateWhileLoop(node)
case *ast.WhileExpression:
g.GenerateWhileLoop(node)
case *ast.CallExpression:
g.GenerateCall(node)
return X0
Expand Down Expand Up @@ -313,6 +313,7 @@ func (g *AARCH64Generator) GenerateInfix(node *ast.InfixExpression) StorageLoc {
case "|":
g.out.WriteString("orr ")
}
// fmt.Println(leftS, rightS, destLoc)
g.out.WriteString(StorageLocs[destLoc] + ", " + leftS + ", " + rightS + "\n")
return destLoc
}
Expand Down Expand Up @@ -359,6 +360,7 @@ func (g *AARCH64Generator) GetInfixOperands(node *ast.InfixExpression) (string,
}

func (g *AARCH64Generator) GenerateIntegerLiteral(il *ast.IntegerLiteral) StorageLoc {
defer tracer.Untrace(tracer.Trace("GenerateIntegerLiteral"))
var sloc StorageLoc
for _, v := range Sls {
_, ok := g.VirtualRegisters[v]
Expand All @@ -374,14 +376,6 @@ func (g *AARCH64Generator) GenerateIntegerLiteral(il *ast.IntegerLiteral) Storag
return sloc
}

func (g *AARCH64Generator) GenerateLabel() string {
tracer.Trace("GenerateLabel")
defer tracer.Untrace("GenerateLabel")
g.out.WriteString(fmt.Sprintf("LBB%d:\n", g.LabelCounter))
g.LabelCounter++
return fmt.Sprintf("LBB%d", g.LabelCounter-1)
}

func (g *AARCH64Generator) GenerateIf(i *ast.IfExpression) {
tracer.Trace("GenerateIf")
defer tracer.Untrace("GenerateIf")
Expand Down Expand Up @@ -410,10 +404,10 @@ func (g *AARCH64Generator) GenerateIf(i *ast.IfExpression) {
// b LBB3
// LBB3:
// ...

predictedTrueLabel := fmt.Sprintf("LBB%d", g.LabelCounter)
predictedFalseLabel := fmt.Sprintf("LBB%d", g.LabelCounter+1)
predictedEndLabel := fmt.Sprintf("LBB%d", g.LabelCounter+2)
trueLabel := fmt.Sprintf("LBBif%dtrue", g.ConditionCounter)
falseLabel := fmt.Sprintf("LBBif%dfalse", g.ConditionCounter)
endLabel := fmt.Sprintf("LBBif%dend", g.ConditionCounter)
g.ConditionCounter++

g.out.WriteString("cmp " + leftS + ", " + rightS + "\n")
g.out.WriteString("cset x8, ")
Expand All @@ -431,15 +425,72 @@ func (g *AARCH64Generator) GenerateIf(i *ast.IfExpression) {
case ">=":
g.out.WriteString("ge\n")
}
g.out.WriteString("tbnz x8, #0, " + predictedTrueLabel + "\n")
g.out.WriteString("b " + predictedFalseLabel + "\n")
g.GenerateLabel()
g.out.WriteString("tbnz x8, #0, " + trueLabel + "\n")
g.out.WriteString("b " + falseLabel + "\n")
g.out.WriteString(trueLabel + ":\n")
g.GenerateBlock(i.Consequence)
g.out.WriteString("b " + predictedEndLabel + "\n")
g.GenerateLabel()
g.GenerateBlock(i.Alternative)
g.out.WriteString("b " + predictedEndLabel + "\n")
g.GenerateLabel()
g.out.WriteString("b " + endLabel + "\n")
g.out.WriteString(falseLabel + ":\n")
if i.Alternative != nil {
g.GenerateBlock(i.Alternative)
}
g.out.WriteString("b " + endLabel + "\n")
g.out.WriteString(endLabel + ":\n")
}

func (g *AARCH64Generator) GenerateWhileLoop(w *ast.WhileExpression) {
defer tracer.Untrace(tracer.Trace("GenerateWhileLoop"))
// b LBB0_1 ; jump to LBB0_1
// COMPAR: LBB0_1: ; =>This Inner Loop Header: Depth=1
// ldr w8, [sp, #8] ; load a to w8
// subs w8, w8, #5 ; subtract 5 from w8 and set flags
// cset w8, ge ; set w8 to 1 if w8-5 is greater than or equal to 0, and 0 otherwise
// tbnz w8, #0, LBB0_3 ; test if bit 0 of w8 is 0, and if not exit loop (jump to LBB0_3)
// b LBB0_2
// BODY: LBB0_2: ; in Loop: Header=BB0_1 Depth=1
// ldr w8, [sp, #4] ; load s into w8
// add w8, w8, #1 ; increment s
// str w8, [sp, #4] ; store s into sp#4
// ldr w8, [sp, #8] ; load a into w8
// add w8, w8, #1 ; increment a
// str w8, [sp, #8] ; store a into sp#8
// b LBB0_1 ; branch to checking of loop
// END: LBB0_3:

comparLabel := fmt.Sprintf("LBBwhile%dcompar", g.ConditionCounter)
bodyLabel := fmt.Sprintf("LBBwhile%dbody", g.ConditionCounter)
endLabel := fmt.Sprintf("LBBwhile%dend", g.ConditionCounter)
g.out.WriteString("b " + comparLabel + "\n")
g.out.WriteString(comparLabel + ":\n")
g.GenerateComparisonCheck(w.Condition.(*ast.InfixExpression), bodyLabel, endLabel)
g.out.WriteString(bodyLabel + ":\n")
g.GenerateBlock(w.Body)
g.out.WriteString("b " + comparLabel + "\n")
g.out.WriteString(endLabel + ":\n")
}

func (g *AARCH64Generator) GenerateComparisonCheck(c *ast.InfixExpression, trueLab, falseLab string) {
defer tracer.Untrace(tracer.Trace("GenerateComparisonCheck"))
leftS, rightS, _ := g.GetInfixOperands(c)
g.out.WriteString("cmp " + leftS + ", " + rightS + "\n")
g.out.WriteString("cset x8, ")
switch c.Operator {
case "==":
g.out.WriteString("eq\n")
case "!=":
g.out.WriteString("ne\n")
case "<":
g.out.WriteString("lt\n")
case ">":
g.out.WriteString("gt\n")
case "<=":
g.out.WriteString("le\n")
case ">=":
g.out.WriteString("ge\n")
}
g.out.WriteString("tbnz x8, #0, " + trueLab + "\n")
g.out.WriteString("b " + falseLab + "\n")
g.VirtualRegisters = map[StorageLoc]string{}
}

// TODO: nested ifs!!
Expand All @@ -461,7 +512,7 @@ func (g *AARCH64Generator) GenerateVarReassignment(v *ast.VarReassignmentStateme
g.out.WriteString("str " + "x0" + ", " + fmt.Sprintf("[sp, #%d]", offset) + "\n")
case *ast.InfixExpression:
sloc := g.GenerateInfix(v.Value.(*ast.InfixExpression))
g.out.WriteString("str " + StorageLocs[sloc] + ", " + fmt.Sprintf("-%d(%%rbp)", offset) + "\n")
g.out.WriteString("str " + StorageLocs[sloc] + ", " + fmt.Sprintf("[sp, #%d]", offset) + "\n")
}
// remove the old value from any registers
sloc, _ := g.GetVarStorageLoc(v.Name.Value)
Expand Down
6 changes: 3 additions & 3 deletions dormouse/ssa.dor
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
int main() {
int a = 6
int b = 7
int test(int x, int y) {
int a = 7 + x
int b = 8 + y
int c = a + b
return c
}
31 changes: 18 additions & 13 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package main

/*
TODO:
bring new label improvements to x86_64 for nested loops
*/

import (
"flag"
"fmt"
Expand All @@ -13,12 +19,11 @@ import (
"github.com/westsi/dormouse/codegen/x86_64_as"
"github.com/westsi/dormouse/lex"
"github.com/westsi/dormouse/parse"
"github.com/westsi/dormouse/ssa"
"github.com/westsi/dormouse/tracer"
)

var globalDefines = make(map[string]string)
var labelcnt int = 0
var condcnt int = 0

func main() {
opts := Options{}
Expand Down Expand Up @@ -82,27 +87,27 @@ func Compile(opts *Options, lexer *lex.Lexer) {
fname := strings.Split((strings.Split(lexer.GetRdrFname(), "/")[len(strings.Split(lexer.GetRdrFname(), "/"))-1]), ".")[0]
p := parse.New(tokens)
ast := p.Parse()
fmt.Println("Errors:")
for _, err := range p.Errors() {
fmt.Println(err)
}
if len(p.Errors()) > 0 {
fmt.Println("Errors:")
for _, err := range p.Errors() {
fmt.Println(err)
}
os.Exit(1)
}

ssag := ssa.New(fname+".dssa", ast, globalDefines)
ssag.Generate()
ssag.Write()
os.Exit(0)
// ssag := ssa.New(fname+".dssa", ast, globalDefines)
// ssag.Generate()
// ssag.Write()
// os.Exit(0)
// fmt.Println(ast.String())
var cg codegen.CodeGenerator
switch opts.TargetArch {
case "x86_64":
cg = x86_64_as.New(fname+".s", ast, globalDefines, labelcnt)
cg = x86_64_as.New(fname+".s", ast, globalDefines, condcnt)
case "aarch64":
cg = aarch64_clang.New(fname+".s", ast, globalDefines, labelcnt)
cg = aarch64_clang.New(fname+".s", ast, globalDefines, condcnt)
}
labelcnt = cg.Generate()
condcnt = cg.Generate()
cg.Write()

}
Expand Down
68 changes: 14 additions & 54 deletions ssa/ssa.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package ssa

import (
"fmt"
"os"
"strconv"
"strings"

"github.com/westsi/dormouse/ast"
)

type SSAGen struct {
fpath string
out strings.Builder
AST ast.Program
Gdefs map[string]string
fpath string
out strings.Builder
AST ast.Program
Gdefs map[string]string
vcounts map[string]int
}

type DType int
Expand All @@ -28,12 +27,17 @@ type Dependent struct {
Value string // holds int val itoa or var name
}

func (s *SSAGen) ow(st string) {
s.out.WriteString(st)
}

func New(fpath string, ast *ast.Program, defs map[string]string) *SSAGen {
generator := &SSAGen{
fpath: fpath,
out: strings.Builder{},
AST: *ast,
Gdefs: defs,
fpath: fpath,
out: strings.Builder{},
AST: *ast,
Gdefs: defs,
vcounts: make(map[string]int),
}
os.MkdirAll("out/ssa", os.ModePerm)
return generator
Expand All @@ -60,50 +64,6 @@ func (s *SSAGen) Write() {
}
}

func (s *SSAGen) ProcessFunction(f *ast.FunctionDefinition) {
s.ProcessBlock(f.Body)
}

func (s *SSAGen) ProcessBlock(b *ast.BlockStatement) {
for _, stmt := range b.Statements {
switch stmt := stmt.(type) {
case *ast.VarStatement:
s.ProcessVarStatement(stmt)
}
}
}

func (s *SSAGen) ProcessVarStatement(v *ast.VarStatement) {
dependents := s.GetDefinitionDependents(v.Value)
fmt.Println(dependents)
}

func (s *SSAGen) GetDefinitionDependents(e ast.Expression) []Dependent {
var deps []Dependent
switch et := e.(type) {
case *ast.InfixExpression:
deps = append(deps, s.GetDefinitionDependents(et.Left)...)
deps = append(deps, s.GetDefinitionDependents(et.Right)...)
case *ast.IntegerLiteral:
dep := Dependent{
Type: CONST,
Value: strconv.Itoa(int(et.Value)),
}
deps = append(deps, dep)
case *ast.Identifier:
dep := Dependent{
Type: VAR,
Value: et.Value,
}
deps = append(deps, dep)
case *ast.ExpressionStatement:
deps = append(deps, s.GetDefinitionDependents(et.Expression)...)
default:
fmt.Printf("%T\n", e)
}
return deps
}

/*
Var statement steps
- check what its set to
Expand Down

0 comments on commit af053f4

Please sign in to comment.