From fcc625f4fbd225e0c0cfb202801a4fcd8e8d3d67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=8Ddiohabara=E5=8D=8D?= Date: Wed, 5 Jul 2023 19:35:43 -0500 Subject: [PATCH] feat: add oprators --- ccc.h | 17 ++++++ codegen.c | 145 +++++++++++++++++++++++++++++++++++++++++++++- examples/nqueen.c | 29 ++++------ parse.c | 112 +++++++++++++++++++++++++++++++++-- tests | 57 ++++++++++++++++++ tokenize.c | 5 +- type.c | 21 +++++++ 7 files changed, 359 insertions(+), 27 deletions(-) diff --git a/ccc.h b/ccc.h index e3b62ff..7b58b4d 100644 --- a/ccc.h +++ b/ccc.h @@ -78,6 +78,7 @@ typedef enum { ND_LT, // < ND_LE, // <= ND_ASSIGN, // = + ND_COMMA, // , ND_RETURN, // "return" ND_IF, // "if" ND_WHILE, // "while" @@ -94,6 +95,22 @@ typedef enum { ND_SIZEOF, // "sizeof" ND_MEMBER, // . (struct member access) ND_CAST, // Type cast + ND_PRE_INC, // pre ++ + ND_PRE_DEC, // pre -- + ND_POST_INC, // post ++ + ND_POST_DEC, // post -- + ND_A_ADD, // += + ND_A_SUB, // -= + ND_A_MUL, // *= + ND_A_DIV, // /= + ND_A_MOD, // %= + ND_NOT, // ! + ND_BITNOT, // ~ + ND_BITAND, // & + ND_BITOR, // | + ND_BITXOR, // ^ + ND_LOGAND, // && + ND_LOGOR, // || } NodeKind; // AST node type typedef struct Node Node; diff --git a/codegen.c b/codegen.c index e64a89b..a8b7ccd 100644 --- a/codegen.c +++ b/codegen.c @@ -102,6 +102,18 @@ void store(Type *ty) { printf(" push rdi\n"); } +void inc(Type *ty) { + printf(" pop rax\n"); + printf(" add rax, %d\n", ty->base ? size_of(ty->base) : 1); + printf(" push rax\n"); +} + +void dec(Type *ty) { + printf(" pop rax\n"); + printf(" sub rax, %d\n", ty->base ? size_of(ty->base) : 1); + printf(" push rax\n"); +} + // Generate code for a given node. void gen(Node *node) { switch (node->kind) { @@ -131,6 +143,10 @@ void gen(Node *node) { gen(node->rhs); store(node->ty); return; + case ND_COMMA: + gen(node->lhs); + gen(node->rhs); + return; case ND_ADDR: gen_addr(node->lhs); return; @@ -140,6 +156,40 @@ void gen(Node *node) { load(node->ty); } return; + case ND_LOGAND: { + int seq = labelseq++; + gen(node->lhs); + printf(" pop rax\n"); + printf(" cmp rax, 0\n"); + printf(" je .L.false.%d\n", seq); + gen(node->rhs); + printf(" pop rax\n"); + printf(" cmp rax, 0\n"); + printf(" je .L.false.%d\n", seq); + printf(" push 1\n"); + printf(" jmp .L.end.%d\n", seq); + printf(".L.false.%d:\n", seq); + printf(" push 0\n"); + printf(".L.end.%d:\n", seq); + return; + } + case ND_LOGOR: { + int seq = labelseq++; + gen(node->lhs); + printf(" pop rax\n"); + printf(" cmp rax, 0\n"); + printf(" jne .L.true.%d\n", seq); + gen(node->rhs); + printf(" pop rax\n"); + printf(" cmp rax, 0\n"); + printf(" jne .L.true.%d\n", seq); + printf(" push 0\n"); + printf(" jmp .L.end.%d\n", seq); + printf(".L.true.%d:\n", seq); + printf(" push 1\n"); + printf(".L.end.%d:\n", seq); + return; + } case ND_IF: { int seq = labelseq++; if (node->els) { @@ -233,6 +283,90 @@ void gen(Node *node) { gen(node->lhs); truncate(node->ty); return; + case ND_PRE_INC: + gen_lval(node->lhs); + printf(" push [rsp]\n"); + load(node->ty); + inc(node->ty); + store(node->ty); + return; + case ND_PRE_DEC: + gen_lval(node->lhs); + printf(" push [rsp]\n"); + load(node->ty); + dec(node->ty); + store(node->ty); + return; + case ND_POST_INC: + gen_lval(node->lhs); + printf(" push [rsp]\n"); + load(node->ty); + inc(node->ty); + store(node->ty); + dec(node->ty); + return; + case ND_POST_DEC: + gen_lval(node->lhs); + printf(" push [rsp]\n"); + load(node->ty); + dec(node->ty); + store(node->ty); + inc(node->ty); + return; + case ND_A_ADD: + case ND_A_SUB: + case ND_A_MUL: + case ND_A_DIV: + case ND_A_MOD: + gen_lval(node->lhs); + printf(" push [rsp]\n"); + load(node->ty); + gen(node->rhs); + printf(" pop rdi\n"); + printf(" pop rax\n"); + switch (node->kind) { + case ND_A_ADD: + if (node->ty->base) { + printf(" imul rdi, %d\n", size_of(node->ty->base)); + } + printf(" add rax, rdi\n"); + break; + case ND_A_SUB: + if (node->ty->base) { + printf(" imul rdi, %d\n", size_of(node->ty->base)); + } + printf(" sub rax, rdi\n"); + break; + case ND_A_MUL: + printf(" imul rax, rdi\n"); + break; + case ND_A_DIV: + printf(" cqo\n"); + printf(" idiv rdi\n"); + break; + case ND_A_MOD: + printf(" cqo\n"); + printf(" idiv rdi\n"); + printf(" mov rax, rdx\n"); + break; + } + printf(" push rax\n"); + store(node->ty); + return; + case ND_NOT: + gen(node->lhs); + printf(" pop rax\n"); + printf(" cmp rax, 0\n"); + printf(" sete al\n"); + printf(" movzb rax, al\n"); + printf(" push rax\n"); + return; + case ND_BITNOT: + gen(node->lhs); + printf(" pop rax\n"); + printf(" not rax\n"); + printf(" push rax\n"); + return; } gen(node->lhs); @@ -279,6 +413,15 @@ void gen(Node *node) { printf(" setle al\n"); printf(" movzb rax, al\n"); break; + case ND_BITAND: + printf(" and rax, rdi\n"); + break; + case ND_BITOR: + printf(" or rax, rdi\n"); + break; + case ND_BITXOR: + printf(" xor rax, rdi\n"); + break; } printf(" push rax\n"); } @@ -345,4 +488,4 @@ void codegen(Program *prog) { printf(".intel_syntax noprefix\n"); emit_data(prog); emit_text(prog); -} \ No newline at end of file +} diff --git a/examples/nqueen.c b/examples/nqueen.c index 902b450..a2d687d 100644 --- a/examples/nqueen.c +++ b/examples/nqueen.c @@ -6,28 +6,23 @@ // $ ./tmp int print_board(int (*board)[10]) { - for (int i = 0; i < 10; i=i+1) { - for (int j = 0; j < 10; j=j+1) + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) if (board[i][j]) - printf("Q "); + printf("Q "); else - printf(". "); + printf(". "); printf("\n"); } printf("\n\n"); } int conflict(int (*board)[10], int row, int col) { - for (int i = 0; i < row; i=i+1) { - if (board[i][col]) - return 1; + for (int i = 0; i < row; i++) { + if (board[i][col]) return 1; int j = row - i; - if (0 < col - j + 1) - if (board[i][col - j]) - return 1; - if (col + j < 10) - if (board[i][col + j]) - return 1; + if (0 < col - j + 1 && board[i][col - j]) return 1; + if (col + j < 10 && board[i][col + j]) return 1; } return 0; } @@ -37,9 +32,8 @@ int solve(int (*board)[10], int row) { print_board(board); return 0; } - for (int i = 0; i < 10; i=i+1) { - if (conflict(board, row, i)) { - } else { + for (int i = 0; i < 10; i++) { + if (!conflict(board, row, i)) { board[row][i] = 1; solve(board, row + 1); board[row][i] = 0; @@ -49,8 +43,7 @@ int solve(int (*board)[10], int row) { int main() { int board[100]; - for (int i = 0; i < 100; i=i+1) - board[i] = 0; + for (int i = 0; i < 100; i++) board[i] = 0; solve(board, 0); return 0; } diff --git a/parse.c b/parse.c index 895b38b..e1888df 100644 --- a/parse.c +++ b/parse.c @@ -134,6 +134,11 @@ bool is_typename(); Node *stmt(); Node *expr(); Node *assign(); +Node *bitand(); +Node * bitor (); +Node *bitxor(); +Node *logand(); +Node *logor(); Node *equality(); Node *relational(); Node *add(); @@ -676,15 +681,89 @@ Node *stmt() { return node; } -// expr = assign -Node *expr() { return assign(); } -// assign = equality ("=" assign)? +// expr = assign ("," expr)? +Node *expr() { + Node *node = assign(); + Token *tok; + while (tok = consume(",")) { + node = new_unary(ND_EXPR_STMT, node, node->tok); + node = new_binary(ND_COMMA, node, assign(), tok); + } + return node; +} +// assign = logor (assign-op assign)? +// assign-op = "=" | "+=" | "-=" | "*=" | "/=" | "%=" Node *assign() { - Node *node = equality(); + Node *node = logor(); Token *tok; if (tok = consume("=")) { node = new_binary(ND_ASSIGN, node, assign(), tok); } + if (tok = consume("+=")) { + node = new_binary(ND_A_ADD, node, assign(), tok); + } + if (tok = consume("-=")) { + node = new_binary(ND_A_SUB, node, assign(), tok); + } + if (tok = consume("*=")) { + node = new_binary(ND_A_MUL, node, assign(), tok); + } + if (tok = consume("/=")) { + node = new_binary(ND_A_DIV, node, assign(), tok); + } + if (tok = consume("%=")) { + node = new_binary(ND_A_MOD, node, assign(), tok); + } + return node; +} + +// logor = logand ("||" logand)* +Node *logor() { + Node *node = logand(); + Token *tok; + while (tok = consume("||")) { + node = new_binary(ND_LOGOR, node, logand(), tok); + } + return node; +} + +// logand = bitor ("&&" bitor)* +Node *logand() { + Node *node = bitor (); + Token *tok; + while (tok = consume("&&")) { + node = new_binary(ND_LOGAND, node, bitor (), tok); + } + return node; +} + +// bitor = bitxor ("|" bitxor)* +Node * bitor () { + Node *node = bitxor(); + Token *tok; + while (tok = consume("|")) { + node = new_binary(ND_BITOR, node, bitxor(), tok); + } + return node; +} + +// bitxor = bitand ("^" bitand)* +Node *bitxor() { + Node *node = bitand(); + Token *tok; + while (tok = consume("^")) { + node = new_binary(ND_BITXOR, node, bitand(), tok); + } + return node; +} + +// bitand = equality ("&" equality)* +Node *bitand() { + Node *node = equality(); + Token *tok; + while (tok = consume("&")) { + node = new_binary(ND_BITAND, node, equality(), tok); + } return node; } @@ -764,7 +843,8 @@ Node *cast() { return unary(); } -// unary = ("+" | "-" | "*" | "&")? cast +// unary = ("+" | "-" | "*" | "&" | "!" | "~")? cast +// | ("++" | "--") unary // | postfix Node *unary() { Token *tok; @@ -780,10 +860,22 @@ Node *unary() { if (tok = consume("*")) { return new_unary(ND_DEREF, cast(), tok); } + if (tok = consume("!")) { + return new_unary(ND_NOT, cast(), tok); + } + if (tok = consume("~")) { + return new_unary(ND_BITNOT, cast(), tok); + } + if (tok = consume("++")) { + return new_unary(ND_PRE_INC, unary(), tok); + } + if (tok = consume("--")) { + return new_unary(ND_PRE_DEC, unary(), tok); + } return postfix(); } -// postfix = primary ("[" expr "]" | "." ident | "->" ident)* +// postfix = primary ("[" expr "]" | "." ident | "->" ident | "++" | "--")* Node *postfix() { Node *node = primary(); Token *tok; @@ -808,6 +900,14 @@ Node *postfix() { node->member_name = expect_ident(); continue; } + if (tok = consume("++")) { + node = new_unary(ND_POST_INC, node, tok); + continue; + } + if (tok = consume("--")) { + node = new_unary(ND_POST_DEC, node, tok); + continue; + } return node; } } diff --git a/tests b/tests index 6492a4a..692bcd2 100644 --- a/tests +++ b/tests @@ -720,6 +720,63 @@ int main() { assert(55, ({ int j=0; for (int i=0; i<=10; i=i+1) j=j+i; j; }), "int j=0; for (int i=0; i<=10; i=i+1) j=j+i; j;"); assert(3, ({ int i=3; int j=0; for (int i=0; i<=10; i=i+1) j=j+i; i; }), "int i=3; int j=0; for (int i=0; i<=10; i=i+1) j=j+i; i;"); + assert(3, (1, 2, 3), "(1, 2, 3)"); + + assert(3, ({ int i=2; ++i; }), "int i=2; ++i;"); + assert(1, ({ int i=2; --i; }), "int i=2; --i;"); + assert(2, ({ int i=2; i++; }), "int i=2; i++;"); + assert(2, ({ int i=2; i--; }), "int i=2; i--;"); + assert(3, ({ int i=2; i++; i; }), "int i=2; i++; i;"); + assert(1, ({ int i=2; i--; i; }), "int i=2; i--; i;"); + assert(1, ({ int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; *p++; }), "int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; *p++;"); + assert(2, ({ int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; ++*p; }), "int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; ++*p;"); + assert(1, ({ int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; *p--; }), "int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; *p--;"); + assert(0, ({ int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; --*p; }), "int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; --*p;"); + + assert(0, ({ int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; (*p++)--; a[0]; }), "int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; (*p++); a[0];"); + assert(0, ({ int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; (*p++)--; a[1]; }), "int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; (*p++); a[0];"); + assert(2, ({ int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; (*p++)--; a[2]; }), "int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; (*p++); a[0];"); + assert(2, ({ int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; (*p++)--; *p; }), "int a[3]; a[0]=0; a[1]=1; a[2]=2; int *p=a+1; (*p++); a[0];"); + + assert(7, ({ int i=2; i+=5; i; }), "int i=2; i+=5; i;"); + assert(7, ({ int i=2; i+=5; }), "int i=2; i+=5;"); + assert(3, ({ int i=5; i-=2; i; }), "int i=5; i-=2; i;"); + assert(3, ({ int i=5; i-=2; }), "int i=5; i-=2;"); + assert(6, ({ int i=3; i*=2; i; }), "int i=3; i*=2; i;"); + assert(6, ({ int i=3; i*=2; }), "int i=3; i*=2;"); + assert(3, ({ int i=6; i/=2; i; }), "int i=6; i/=2; i;"); + assert(3, ({ int i=6; i/=2; }), "int i=6; i/=2;"); + + assert(0, !1, "!1"); + assert(0, !2, "!2"); + assert(1, !0, "!0"); + + assert(-1, ~0, "~0"); + assert(0, ~-1, "~-1"); + + assert(0, 0&1, "0&1"); + assert(1, 3&1, "3&1"); + assert(3, 7&3, "7&3"); + assert(10, -1&10, " -1&10"); + + assert(1, 0|1, "0|1"); + assert(3, 2|1, "2|1"); + assert(3, 1|3, "1|3"); + + assert(0, 0^0, "0^0"); + assert(0, 8^8, "8^8"); + assert(4, 7^3, "7^3"); + assert(2, 7^5, "7^5"); + + assert(1, 0||1, "0||1"); + assert(1, 0||(2-2)||5, "0||(2-2)||5"); + assert(0, 0||0, "0||0"); + assert(0, 0||(2-2), "0||(2-2)"); + + assert(0, 0&&1, "0&&1"); + assert(0, (2-2)&&5, "(2-2)&&5"); + assert(1, 1&&5, "1&&5"); + printf("OK\n"); return 0; } diff --git a/tokenize.c b/tokenize.c index 0ded082..43887f7 100644 --- a/tokenize.c +++ b/tokenize.c @@ -159,7 +159,8 @@ char *starts_with_reserved(char *p) { } } // Multi-letter punctuator - static char *ops[] = {"==", "!=", "<=", ">=", "->"}; + static char *ops[] = {"==", "!=", "<=", ">=", "->", "++", "--", + "+=", "-=", "*=", "/=", "&=", "&&", "||"}; for (int i = 0; i < sizeof(ops) / sizeof(*ops); i++) if (startswith(p, ops[i])) { return ops[i]; @@ -282,7 +283,7 @@ Token *tokenize() { continue; } // Single-letter punctuator - if (strchr("+-*/()<>;={},&[].", *p)) { + if (strchr("+-*/()<>;={},&[].,!~|^", *p)) { cur = new_token(TK_RESERVED, cur, p++, 1); continue; } diff --git a/type.c b/type.c index 02d6e73..861fa06 100644 --- a/type.c +++ b/type.c @@ -104,6 +104,14 @@ void visit(Node* node) { case ND_NE: case ND_LT: case ND_LE: + case ND_NOT: + case ND_BITAND: + case ND_BITOR: + case ND_BITXOR: + case ND_LOGAND: + case ND_LOGOR: + node->ty = int_type(); + return; case ND_NUM: if (node->val == (int)node->val) { node->ty = int_type(); @@ -132,6 +140,19 @@ void visit(Node* node) { node->ty = node->lhs->ty; return; case ND_ASSIGN: + case ND_PRE_INC: + case ND_PRE_DEC: + case ND_POST_INC: + case ND_POST_DEC: + case ND_A_ADD: + case ND_A_SUB: + case ND_A_MUL: + case ND_A_DIV: + case ND_A_MOD: + case ND_BITNOT: + node->ty = node->lhs->ty; + return; + case ND_COMMA: node->ty = node->lhs->ty; return; case ND_MEMBER: {