diff --git a/src/codegen/amd64_asm/mod.rs b/src/codegen/amd64_asm/mod.rs index c2fbf3d..b5383aa 100644 --- a/src/codegen/amd64_asm/mod.rs +++ b/src/codegen/amd64_asm/mod.rs @@ -4,10 +4,8 @@ mod register; use super::Codegen; use crate::{ - ir::{ - Block, Expr, ExprKind, ExprLit, Id, IntTy, Item, ItemFn, Node, Stmt, Ty, UintTy, Variable, - }, - parser::{BinOp, BitwiseOp, CmpOp, OpParseError, UnOp}, + ir::{Block, Expr, ExprKind, ExprLit, Id, Item, ItemFn, Node, Stmt, Ty, Variable}, + parser::{BinOp, BitwiseOp, CmpOp, IntTy, OpParseError, UintTy, UnOp}, Context, }; use allocator::RegisterAllocator; diff --git a/src/ir/mod.rs b/src/ir/mod.rs index d8c37cc..ee1fcca 100644 --- a/src/ir/mod.rs +++ b/src/ir/mod.rs @@ -5,7 +5,7 @@ use crate::parser::{BinOp, UnOp}; use bumpalo::Bump; pub use ordered_map::OrderedMap; -pub use types::{IntTy, Ty, TyArray, UintTy}; +pub use types::{Ty, TyArray}; #[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash)] pub struct Id { diff --git a/src/ir/types.rs b/src/ir/types.rs index f67f8b7..a1bd29d 100644 --- a/src/ir/types.rs +++ b/src/ir/types.rs @@ -1,4 +1,7 @@ -use crate::ty_problem; +use crate::{ + parser::{IntTy, UintTy}, + ty_problem, +}; #[derive(Debug, PartialEq)] pub struct TyArray<'ir> { @@ -6,15 +9,6 @@ pub struct TyArray<'ir> { pub len: usize, } -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] -pub enum IntTy { - I8, - I16, - I32, - I64, - Isize, -} - impl IntTy { fn size(&self) -> Option { Some(match self { @@ -27,27 +21,6 @@ impl IntTy { } } -impl std::fmt::Display for IntTy { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::I8 => write!(f, "i8"), - Self::I16 => write!(f, "i16"), - Self::I32 => write!(f, "i32"), - Self::I64 => write!(f, "i64"), - Self::Isize => write!(f, "isize"), - } - } -} - -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] -pub enum UintTy { - U8, - U16, - U32, - U64, - Usize, -} - impl UintTy { fn size(&self) -> Option { Some(match self { @@ -70,18 +43,6 @@ impl UintTy { } } -impl std::fmt::Display for UintTy { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::U8 => write!(f, "u8"), - Self::U16 => write!(f, "u16"), - Self::U32 => write!(f, "u32"), - Self::U64 => write!(f, "u64"), - Self::Usize => write!(f, "usize"), - } - } -} - #[derive(Debug, PartialEq)] pub enum Ty<'ir> { Void, diff --git a/src/lexer/error.rs b/src/lexer/error.rs index 5540a26..939b801 100644 --- a/src/lexer/error.rs +++ b/src/lexer/error.rs @@ -1,7 +1,7 @@ use thiserror::Error; #[derive(Error, Debug)] -pub enum LexerError { +pub enum Error { #[error("Failed to parse char {0}")] UnknownCharacter(char), } diff --git a/src/lexer/lexer.rs b/src/lexer/lexer.rs deleted file mode 100644 index 08e10cf..0000000 --- a/src/lexer/lexer.rs +++ /dev/null @@ -1,378 +0,0 @@ -use super::{LexerError, Token}; - -#[derive(Debug)] -pub struct Lexer { - input: String, - position: usize, - read_position: usize, - ch: char, -} - -impl Lexer { - pub fn new(input: String) -> Self { - let mut lexer = Self { - input, - ch: '\0', - position: 0, - read_position: 0, - }; - lexer.read_char(); - - lexer - } - - fn read_char(&mut self) { - match self.input[self.read_position..].chars().next() { - Some(ch) => { - self.ch = ch; - self.position = self.read_position; - self.read_position += ch.len_utf8(); - } - None => { - self.ch = '\0'; - } - } - } - - fn peek(&self) -> Option { - self.input[self.read_position..].chars().next() - } - - fn read_ident(&mut self) -> String { - let pos = self.position; - - while self.ch.is_alphanumeric() || self.ch == '_' { - self.read_char(); - } - - self.input[pos..self.position].to_string() - } - - fn read_int(&mut self) -> String { - let pos = self.position; - - while self.ch.is_ascii_digit() { - self.read_char(); - } - - self.input[pos..self.position].to_string() - } - - fn read_string(&mut self) -> String { - let pos = self.position + 1; - - loop { - self.read_char(); - - if self.ch == '"' || self.ch == '\0' { - break self.input[pos..self.position].to_string(); - } - } - } - - fn skip_whitespace(&mut self) { - while self.ch.is_ascii_whitespace() { - self.read_char(); - } - } - - fn skip_comment(&mut self) { - while self.ch != '\n' { - self.read_char(); - } - } -} - -impl Iterator for Lexer { - type Item = Result; - - fn next(&mut self) -> Option { - self.skip_whitespace(); - - let token = match self.ch { - '=' => { - if self.peek() == Some('=') { - self.read_char(); - Token::Equal - } else { - Token::Assign - } - } - '-' => { - if self.peek() == Some('>') { - self.read_char(); - Token::Arrow - } else { - Token::Minus - } - } - '+' => Token::Plus, - '/' => { - if self.peek() == Some('/') { - self.skip_comment(); - - return self.next(); - } else { - Token::Slash - } - } - '.' => Token::Period, - '~' => Token::Tilde, - '&' => { - if self.peek() == Some('&') { - self.read_char(); - Token::And - } else { - Token::Ampersand - } - } - '|' => { - if self.peek() == Some('|') { - self.read_char(); - Token::Or - } else { - Token::Bar - } - } - '!' => { - if self.peek() == Some('=') { - self.read_char(); - Token::NotEqual - } else { - Token::Bang - } - } - '*' => Token::Asterisk, - '<' => match self.peek() { - Some('=') => { - self.read_char(); - Token::LessEqual - } - Some('<') => { - self.read_char(); - Token::Shl - } - _ => Token::LessThan, - }, - '>' => match self.peek() { - Some('=') => { - self.read_char(); - Token::GreaterEqual - } - Some('>') => { - self.read_char(); - Token::Shr - } - _ => Token::GreaterThan, - }, - ';' => Token::Semicolon, - '(' => Token::LParen, - ')' => Token::RParen, - '{' => Token::LBrace, - '}' => Token::RBrace, - '[' => Token::LBracket, - ']' => Token::RBracket, - ',' => Token::Comma, - ':' => Token::Colon, - '"' => Token::String(self.read_string()), - '0'..='9' => { - return Some(Ok(Token::Integer(self.read_int()))); - } - '\0' => return None, - ch if ch.is_alphanumeric() || ch == '_' => { - let ident = self.read_ident(); - - return Some(Ok(match ident.as_str() { - "const" => Token::Const, - "true" => Token::True, - "let" => Token::Let, - "fn" => Token::Fn, - "enum" => Token::Enum, - "struct" => Token::Struct, - "false" => Token::False, - "if" => Token::If, - "while" => Token::While, - "for" => Token::For, - "else" => Token::Else, - "return" => Token::Return, - "as" => Token::As, - "continue" => Token::Continue, - "break" => Token::Break, - "u8" => Token::U8, - "u16" => Token::U16, - "u32" => Token::U32, - "u64" => Token::U64, - "i8" => Token::I8, - "i16" => Token::I16, - "i32" => Token::I32, - "i64" => Token::I64, - "usize" => Token::Usize, - "isize" => Token::Isize, - "bool" => Token::Bool, - "void" => Token::Void, - "NULL" => Token::Null, - _ => Token::Ident(ident), - })); - } - ch => { - return Some(Err(LexerError::UnknownCharacter(ch))); - } - }; - - self.read_char(); - - Some(Ok(token)) - } -} - -#[cfg(test)] -mod test { - use super::{Lexer, LexerError}; - use crate::lexer::Token; - - #[test] - fn source_into_tokens() -> Result<(), LexerError> { - let input = r#" - ident - 69 - "string" - - = - + - - - ! - * - / - -> - . - ~ - & - | - == - != - < - > - <= - >= - && - || - << - >> - , - ; - : - ( - ) - { - } - [ - ] - - // keywords - // heyo :D - const - true - false - let - fn - enum - struct - if - while - for - else - return - as - continue - break - - u8 - u16 - u32 - u64 - i8 - i16 - i32 - i64 - usize - isize - bool - void - NULL - "#; - - let tokens = vec![ - Token::Ident(String::from("ident")), - Token::Integer(String::from("69")), - Token::String(String::from("string")), - Token::Assign, - Token::Plus, - Token::Minus, - Token::Bang, - Token::Asterisk, - Token::Slash, - Token::Arrow, - Token::Period, - Token::Tilde, - Token::Ampersand, - Token::Bar, - Token::Equal, - Token::NotEqual, - Token::LessThan, - Token::GreaterThan, - Token::LessEqual, - Token::GreaterEqual, - Token::And, - Token::Or, - Token::Shl, - Token::Shr, - Token::Comma, - Token::Semicolon, - Token::Colon, - Token::LParen, - Token::RParen, - Token::LBrace, - Token::RBrace, - Token::LBracket, - Token::RBracket, - Token::Const, - Token::True, - Token::False, - Token::Let, - Token::Fn, - Token::Enum, - Token::Struct, - Token::If, - Token::While, - Token::For, - Token::Else, - Token::Return, - Token::As, - Token::Continue, - Token::Break, - Token::U8, - Token::U16, - Token::U32, - Token::U64, - Token::I8, - Token::I16, - Token::I32, - Token::I64, - Token::Usize, - Token::Isize, - Token::Bool, - Token::Void, - Token::Null, - ]; - - let mut lexer = Lexer::new(input.to_string()); - - for token in tokens { - let next_token = lexer.next().unwrap()?; - - assert_eq!(token, next_token); - } - - Ok(()) - } -} diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 2ad8cdc..aa4aaaa 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -1,7 +1,382 @@ mod error; -mod lexer; mod token; -pub use error::LexerError; -pub use lexer::Lexer; +pub use error::Error; pub use token::Token; + +#[derive(Debug)] +pub struct Lexer { + input: String, + position: usize, + read_position: usize, + ch: char, +} + +impl Lexer { + pub fn new(input: String) -> Self { + let mut lexer = Self { + input, + ch: '\0', + position: 0, + read_position: 0, + }; + lexer.read_char(); + + lexer + } + + fn read_char(&mut self) { + match self.input[self.read_position..].chars().next() { + Some(ch) => { + self.ch = ch; + self.position = self.read_position; + self.read_position += ch.len_utf8(); + } + None => { + self.ch = '\0'; + } + } + } + + fn peek(&self) -> Option { + self.input[self.read_position..].chars().next() + } + + fn read_ident(&mut self) -> String { + let pos = self.position; + + while self.ch.is_alphanumeric() || self.ch == '_' { + self.read_char(); + } + + self.input[pos..self.position].to_string() + } + + fn read_int(&mut self) -> String { + let pos = self.position; + + while self.ch.is_ascii_digit() { + self.read_char(); + } + + self.input[pos..self.position].to_string() + } + + fn read_string(&mut self) -> String { + let pos = self.position + 1; + + loop { + self.read_char(); + + if self.ch == '"' || self.ch == '\0' { + break self.input[pos..self.position].to_string(); + } + } + } + + fn skip_whitespace(&mut self) { + while self.ch.is_ascii_whitespace() { + self.read_char(); + } + } + + fn skip_comment(&mut self) { + while self.ch != '\n' { + self.read_char(); + } + } +} + +impl Iterator for Lexer { + type Item = Result; + + fn next(&mut self) -> Option { + self.skip_whitespace(); + + let token = match self.ch { + '=' => { + if self.peek() == Some('=') { + self.read_char(); + Token::Equal + } else { + Token::Assign + } + } + '-' => { + if self.peek() == Some('>') { + self.read_char(); + Token::Arrow + } else { + Token::Minus + } + } + '+' => Token::Plus, + '/' => { + if self.peek() == Some('/') { + self.skip_comment(); + + return self.next(); + } else { + Token::Slash + } + } + '.' => Token::Period, + '~' => Token::Tilde, + '&' => { + if self.peek() == Some('&') { + self.read_char(); + Token::And + } else { + Token::Ampersand + } + } + '|' => { + if self.peek() == Some('|') { + self.read_char(); + Token::Or + } else { + Token::Bar + } + } + '!' => { + if self.peek() == Some('=') { + self.read_char(); + Token::NotEqual + } else { + Token::Bang + } + } + '*' => Token::Asterisk, + '<' => match self.peek() { + Some('=') => { + self.read_char(); + Token::LessEqual + } + Some('<') => { + self.read_char(); + Token::Shl + } + _ => Token::LessThan, + }, + '>' => match self.peek() { + Some('=') => { + self.read_char(); + Token::GreaterEqual + } + Some('>') => { + self.read_char(); + Token::Shr + } + _ => Token::GreaterThan, + }, + ';' => Token::Semicolon, + '(' => Token::LParen, + ')' => Token::RParen, + '{' => Token::LBrace, + '}' => Token::RBrace, + '[' => Token::LBracket, + ']' => Token::RBracket, + ',' => Token::Comma, + ':' => Token::Colon, + '"' => Token::String(self.read_string()), + '0'..='9' => { + return Some(Ok(Token::Integer(self.read_int()))); + } + '\0' => return None, + ch if ch.is_alphanumeric() || ch == '_' => { + let ident = self.read_ident(); + + return Some(Ok(match ident.as_str() { + "const" => Token::Const, + "true" => Token::True, + "let" => Token::Let, + "fn" => Token::Fn, + "enum" => Token::Enum, + "struct" => Token::Struct, + "false" => Token::False, + "if" => Token::If, + "while" => Token::While, + "for" => Token::For, + "else" => Token::Else, + "return" => Token::Return, + "as" => Token::As, + "continue" => Token::Continue, + "break" => Token::Break, + "u8" => Token::U8, + "u16" => Token::U16, + "u32" => Token::U32, + "u64" => Token::U64, + "i8" => Token::I8, + "i16" => Token::I16, + "i32" => Token::I32, + "i64" => Token::I64, + "usize" => Token::Usize, + "isize" => Token::Isize, + "bool" => Token::Bool, + "void" => Token::Void, + "NULL" => Token::Null, + _ => Token::Ident(ident), + })); + } + ch => { + return Some(Err(Error::UnknownCharacter(ch))); + } + }; + + self.read_char(); + + Some(Ok(token)) + } +} + +#[cfg(test)] +mod test { + use super::{Error, Lexer}; + use crate::lexer::Token; + + #[test] + fn source_into_tokens() -> Result<(), Error> { + let input = r#" + ident + 69 + "string" + + = + + + - + ! + * + / + -> + . + ~ + & + | + == + != + < + > + <= + >= + && + || + << + >> + , + ; + : + ( + ) + { + } + [ + ] + + // keywords + // heyo :D + const + true + false + let + fn + enum + struct + if + while + for + else + return + as + continue + break + + u8 + u16 + u32 + u64 + i8 + i16 + i32 + i64 + usize + isize + bool + void + NULL + "#; + + let tokens = vec![ + Token::Ident(String::from("ident")), + Token::Integer(String::from("69")), + Token::String(String::from("string")), + Token::Assign, + Token::Plus, + Token::Minus, + Token::Bang, + Token::Asterisk, + Token::Slash, + Token::Arrow, + Token::Period, + Token::Tilde, + Token::Ampersand, + Token::Bar, + Token::Equal, + Token::NotEqual, + Token::LessThan, + Token::GreaterThan, + Token::LessEqual, + Token::GreaterEqual, + Token::And, + Token::Or, + Token::Shl, + Token::Shr, + Token::Comma, + Token::Semicolon, + Token::Colon, + Token::LParen, + Token::RParen, + Token::LBrace, + Token::RBrace, + Token::LBracket, + Token::RBracket, + Token::Const, + Token::True, + Token::False, + Token::Let, + Token::Fn, + Token::Enum, + Token::Struct, + Token::If, + Token::While, + Token::For, + Token::Else, + Token::Return, + Token::As, + Token::Continue, + Token::Break, + Token::U8, + Token::U16, + Token::U32, + Token::U64, + Token::I8, + Token::I16, + Token::I32, + Token::I64, + Token::Usize, + Token::Isize, + Token::Bool, + Token::Void, + Token::Null, + ]; + + let mut lexer = Lexer::new(input.to_string()); + + for token in tokens { + let next_token = lexer.next().unwrap()?; + + assert_eq!(token, next_token); + } + + Ok(()) + } +} diff --git a/src/lexer/token.rs b/src/lexer/token.rs index 4ee6d9b..ab29493 100644 --- a/src/lexer/token.rs +++ b/src/lexer/token.rs @@ -1,5 +1,3 @@ -use std::fmt::Display; - #[derive(Debug, PartialEq, Clone, Hash, Eq)] pub enum Token { Ident(String), @@ -69,7 +67,7 @@ pub enum Token { Null, } -impl Display for Token { +impl std::fmt::Display for Token { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use Token::*; diff --git a/src/lowering/mod.rs b/src/lowering/mod.rs index c59a742..3a1ac91 100644 --- a/src/lowering/mod.rs +++ b/src/lowering/mod.rs @@ -2,7 +2,7 @@ mod scopes; use crate::{ ir::{self, Id, OrderedMap, Stmt}, - parser::{self, BinOp, Item, UnOp, Variable}, + parser::{self, BinOp, IntTy, Item, UintTy, UnOp, Variable}, ty_problem, Context, }; use scopes::Scopes; @@ -374,36 +374,26 @@ impl<'a, 'ir> Lowering<'a, 'ir> { parser::Ty::Void => self.ctx.allocator.alloc(ir::Ty::Void), parser::Ty::Bool => self.ctx.allocator.alloc(ir::Ty::Bool), parser::Ty::Int(ty) => match ty { - parser::IntTy::I8 => self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::I8)), - parser::IntTy::I16 => self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::I16)), - parser::IntTy::I32 => self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::I32)), - parser::IntTy::I64 => self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::I64)), - parser::IntTy::Isize => { - self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::Isize)) - } + parser::IntTy::I8 => self.ctx.allocator.alloc(ir::Ty::Int(IntTy::I8)), + parser::IntTy::I16 => self.ctx.allocator.alloc(ir::Ty::Int(IntTy::I16)), + parser::IntTy::I32 => self.ctx.allocator.alloc(ir::Ty::Int(IntTy::I32)), + parser::IntTy::I64 => self.ctx.allocator.alloc(ir::Ty::Int(IntTy::I64)), + parser::IntTy::Isize => self.ctx.allocator.alloc(ir::Ty::Int(IntTy::Isize)), }, parser::Ty::UInt(ty) => match ty { - parser::UintTy::U8 => { - self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::U8)) - } - parser::UintTy::U16 => { - self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::U16)) - } - parser::UintTy::U32 => { - self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::U32)) - } - parser::UintTy::U64 => { - self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::U64)) - } + parser::UintTy::U8 => self.ctx.allocator.alloc(ir::Ty::UInt(UintTy::U8)), + parser::UintTy::U16 => self.ctx.allocator.alloc(ir::Ty::UInt(UintTy::U16)), + parser::UintTy::U32 => self.ctx.allocator.alloc(ir::Ty::UInt(UintTy::U32)), + parser::UintTy::U64 => self.ctx.allocator.alloc(ir::Ty::UInt(UintTy::U64)), parser::UintTy::Usize => { - self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::Usize)) + self.ctx.allocator.alloc(ir::Ty::UInt(UintTy::Usize)) } }, parser::Ty::Ptr(ref ty) => self .ctx .allocator .alloc(ir::Ty::Ptr(self.lower_ty(*ty.clone()))), - parser::Ty::Array(parser::TyArray { ref ty, len }) => { + parser::Ty::Array { ty, len } => { self.ctx.allocator.alloc(ir::Ty::Array(ir::TyArray { len: *len, ty: self.lower_ty(*ty.clone()), @@ -448,7 +438,7 @@ impl<'a, 'ir> Lowering<'a, 'ir> { .alloc(ir::Ty::Infer(self.ctx.ty_problem.new_infer_ty_var())), parser::Expr::Lit(lit) => match lit { parser::ExprLit::Bool(_) => &ir::Ty::Bool, - parser::ExprLit::String(_) => &ir::Ty::Ptr(&ir::Ty::UInt(ir::UintTy::U8)), + parser::ExprLit::String(_) => &ir::Ty::Ptr(&ir::Ty::UInt(UintTy::U8)), _ => self .ctx .allocator diff --git a/src/parser/error.rs b/src/parser/error.rs index d497be7..c5b5244 100644 --- a/src/parser/error.rs +++ b/src/parser/error.rs @@ -1,11 +1,11 @@ use super::{OpParseError, Ty}; -use crate::lexer::{LexerError, Token}; +use crate::lexer::{self, Token}; use thiserror::Error; #[derive(Error, Debug)] -pub enum ParserError { +pub enum Error { #[error(transparent)] - Lexer(#[from] LexerError), + Lexer(#[from] lexer::Error), #[error(transparent)] Type(#[from] TyError), #[error(transparent)] diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 92fa918..4a96633 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,21 +1,21 @@ mod error; mod item; mod op; -mod parser; mod precedence; mod stmt; mod types; pub mod expr; -pub use error::{ParserError, TyError}; +use crate::lexer::{self, Token}; +pub use error::{Error, TyError}; pub use expr::*; pub use item::{Item, ItemFn, ItemStruct}; pub use op::{BinOp, BitwiseOp, CmpOp, OpParseError, UnOp}; -pub use parser::Parser; pub use precedence::Precedence; +use std::collections::HashMap; pub use stmt::{Stmt, StmtFor, StmtIf, StmtReturn, StmtWhile}; -pub use types::{IntTy, Ty, TyArray, UintTy}; +pub use types::{IntTy, Ty, UintTy}; #[derive(Debug, Clone, PartialEq)] pub struct Variable { @@ -26,3 +26,933 @@ pub struct Variable { #[derive(Debug, Clone, PartialEq)] pub struct Block(pub Vec); + +type PrefixFn = fn(&mut Parser) -> Result; +type InfixFn = fn(&mut Parser, left: Expr) -> Result; + +pub struct Parser>> { + lexer: T, + cur_token: Option, + peek_token: Option, + prefix_fns: HashMap>, + infix_fns: HashMap>, +} + +impl>> Parser { + pub fn new(mut lexer: T) -> Result { + Ok(Self { + cur_token: lexer.next().transpose()?, + peek_token: lexer.next().transpose()?, + lexer, + prefix_fns: HashMap::from([ + (Token::Ident(Default::default()), Self::ident as PrefixFn), + (Token::String(Default::default()), Self::string_lit), + (Token::Integer(Default::default()), Self::int_lit), + (Token::Null, Self::null), + (Token::True, Self::bool), + (Token::False, Self::bool), + (Token::Minus, Self::unary_expr), + (Token::Bang, Self::unary_expr), + (Token::LParen, Self::grouped_expr), + (Token::Ampersand, Self::unary_expr), + (Token::Asterisk, Self::unary_expr), + (Token::Tilde, Self::unary_expr), + (Token::LBracket, Self::array_expr), + ]), + infix_fns: HashMap::from([ + (Token::Plus, Self::bin_expr as InfixFn), + (Token::Minus, Self::bin_expr), + (Token::Asterisk, Self::bin_expr), + (Token::Slash, Self::bin_expr), + (Token::Assign, Self::bin_expr), + (Token::LessThan, Self::bin_expr), + (Token::LessEqual, Self::bin_expr), + (Token::GreaterThan, Self::bin_expr), + (Token::GreaterEqual, Self::bin_expr), + (Token::Equal, Self::bin_expr), + (Token::NotEqual, Self::bin_expr), + (Token::And, Self::bin_expr), + (Token::Or, Self::bin_expr), + (Token::LParen, Self::bin_expr), + (Token::Ampersand, Self::bin_expr), + (Token::Bar, Self::bin_expr), + (Token::Shl, Self::bin_expr), + (Token::Shr, Self::bin_expr), + (Token::Arrow, Self::pointer_access), + (Token::Period, Self::struct_access), + (Token::LBracket, Self::array_access), + (Token::As, Self::cast_expr), + (Token::LParen, Self::func_call_expr), + (Token::Bang, Self::macro_call_expr), + ]), + }) + } + + fn next_token(&mut self) -> Result, Error> { + let mut token = self.lexer.next().transpose()?; + + std::mem::swap(&mut self.cur_token, &mut self.peek_token); + std::mem::swap(&mut token, &mut self.peek_token); + + Ok(token) + } + + fn cur_token_is(&self, token: &Token) -> bool { + self.cur_token.as_ref() == Some(token) + } + + fn peek_token_is(&self, token: &Token) -> bool { + self.peek_token.as_ref() == Some(token) + } + + fn expect(&mut self, token: &Token) -> Result<(), Error> { + match self.next_token()? { + Some(ref cur) if cur == token => Ok(()), + Some(cur) => Err(Error::UnexpectedToken(token.to_owned(), cur)), + None => Err(Error::Expected(token.to_owned())), + } + } + + pub fn parse(&mut self) -> Result, Error> { + let mut items = Vec::new(); + + while let Some(token) = &self.cur_token { + let item = match token { + Token::Struct => self.parse_struct()?, + Token::Let => self.global()?, + Token::Fn => self.function(true)?, + _ => unreachable!(), + }; + items.push(item); + } + + Ok(items) + } + + pub fn expr(&mut self, precedence: Precedence) -> Result { + let token = match self.cur_token.as_ref().unwrap() { + Token::Ident(_) => Token::Ident(Default::default()), + Token::Integer(_) => Token::Integer(Default::default()), + Token::String(_) => Token::String(Default::default()), + token => token.clone(), + }; + + let mut left = match self.prefix_fns.get(&token) { + Some(func) => func(self), + None => { + return Err(Error::Prefix(token)); + } + }; + + while !self.cur_token_is(&Token::Semicolon) + && self.cur_token.is_some() + && precedence < Precedence::from(self.cur_token.as_ref().unwrap()) + { + left = match self.infix_fns.get(self.cur_token.as_ref().unwrap()) { + Some(func) => func(self, left?), + None => { + return Err(Error::Infix(self.cur_token.clone().unwrap())); + } + }; + } + + left + } + + fn parse_struct(&mut self) -> Result { + self.expect(&Token::Struct)?; + + let name = match self + .next_token()? + .ok_or(Error::Expected(Token::Ident(Default::default())))? + { + Token::Ident(ident) => Ok(ident), + token => Err(Error::UnexpectedToken( + Token::Ident(Default::default()), + token, + )), + }?; + + self.expect(&Token::LBrace)?; + + let mut fields = Vec::new(); + + while !self.cur_token_is(&Token::RBrace) { + if self.cur_token_is(&Token::Fn) { + // Handle struct methods here + } else { + let name = match self.next_token()? { + Some(Token::Ident(ident)) => ident, + _ => todo!("Don't know what error to return yet"), + }; + self.expect(&Token::Colon)?; + let mut type_ = self.parse_type()?; + self.array_type(&mut type_)?; + + match fields.iter().find(|(field_name, _)| field_name == &name) { + Some(_) => todo!("Don't know yet what error to return"), + None => fields.push((name, type_)), + }; + + if !self.cur_token_is(&Token::RBrace) { + self.expect(&Token::Semicolon)?; + } + } + } + + self.expect(&Token::RBrace)?; + + Ok(Item::Struct(ItemStruct { name, fields })) + } + + fn stmt(&mut self) -> Result { + match self.cur_token.as_ref().unwrap() { + Token::Return => self.parse_return(), + Token::If => self.if_stmt(), + Token::While => self.while_stmt(), + Token::For => self.for_stmt(), + Token::Let => self.local(), + Token::Continue => { + self.expect(&Token::Continue)?; + self.expect(&Token::Semicolon)?; + + Ok(Stmt::Continue) + } + Token::Break => { + self.expect(&Token::Break)?; + self.expect(&Token::Semicolon)?; + + Ok(Stmt::Break) + } + Token::Fn => Ok(Stmt::Item(self.function(true)?)), + _ => { + let expr = Stmt::Expr(self.expr(Precedence::default())?); + + self.expect(&Token::Semicolon)?; + + Ok(expr) + } + } + } + + fn compound_statement(&mut self) -> Result { + let mut stmts = Vec::new(); + + self.expect(&Token::LBrace)?; + + while !self.cur_token_is(&Token::RBrace) { + stmts.push(self.stmt()?); + } + + self.expect(&Token::RBrace)?; + + Ok(Block(stmts)) + } + + // This function is used only by macro expansion + pub fn parse_stmts(&mut self) -> Result, Error> { + let mut stmts = Vec::new(); + + while self.cur_token.is_some() { + stmts.push(self.stmt()?); + } + + Ok(stmts) + } + + fn parse_type(&mut self) -> Result { + let mut n = 0; + while self.cur_token_is(&Token::Asterisk) { + self.expect(&Token::Asterisk)?; + n += 1; + } + + let mut base = match self.next_token()?.unwrap() { + Token::U8 => Ok(Ty::UInt(UintTy::U8)), + Token::U16 => Ok(Ty::UInt(UintTy::U16)), + Token::U32 => Ok(Ty::UInt(UintTy::U32)), + Token::U64 => Ok(Ty::UInt(UintTy::U64)), + Token::I8 => Ok(Ty::Int(IntTy::I8)), + Token::I16 => Ok(Ty::Int(IntTy::I16)), + Token::I32 => Ok(Ty::Int(IntTy::I32)), + Token::I64 => Ok(Ty::Int(IntTy::I64)), + Token::Usize => Ok(Ty::UInt(UintTy::Usize)), + Token::Isize => Ok(Ty::Int(IntTy::Isize)), + Token::Bool => Ok(Ty::Bool), + Token::Void => Ok(Ty::Void), + Token::Ident(ident) => Ok(Ty::Ident(ident)), + Token::Fn => { + self.expect(&Token::LParen)?; + + let mut params = Vec::new(); + + while !self.cur_token_is(&Token::RParen) { + params.push(self.parse_type()?); + + if !self.cur_token_is(&Token::RParen) { + self.expect(&Token::Comma)?; + } + } + + self.expect(&Token::RParen)?; + self.expect(&Token::Arrow)?; + + Ok(Ty::Fn(params, Box::new(self.parse_type()?))) + } + token => Err(Error::ParseType(token)), + }?; + + while n > 0 { + base = Ty::Ptr(Box::new(base)); + n -= 1; + } + + Ok(base) + } + + fn parse_return(&mut self) -> Result { + self.expect(&Token::Return)?; + + let expr = if !self.cur_token_is(&Token::Semicolon) { + Some(self.expr(Precedence::default())?) + } else { + None + }; + + self.expect(&Token::Semicolon)?; + + Ok(Stmt::Return(StmtReturn { expr })) + } + + fn if_stmt(&mut self) -> Result { + self.expect(&Token::If)?; + + let condition = self.expr(Precedence::default())?; + let consequence = self.compound_statement()?; + let alternative = if self.cur_token_is(&Token::Else) { + self.expect(&Token::Else)?; + + Some(self.compound_statement()?) + } else { + None + }; + + Ok(Stmt::If(StmtIf { + condition, + consequence, + alternative, + })) + } + + fn while_stmt(&mut self) -> Result { + self.expect(&Token::While)?; + + let condition = self.expr(Precedence::default())?; + let block = self.compound_statement()?; + + Ok(Stmt::While(StmtWhile { condition, block })) + } + + fn for_stmt(&mut self) -> Result { + self.expect(&Token::For)?; + + let initializer = if self.cur_token_is(&Token::Semicolon) { + None + } else { + let stmt = if self.cur_token_is(&Token::Let) { + self.local()? + } else { + Stmt::Expr(self.expr(Precedence::default())?) + }; + + Some(stmt) + }; + + let condition = if self.cur_token_is(&Token::Semicolon) { + None + } else { + Some(self.expr(Precedence::default())?) + }; + self.expect(&Token::Semicolon)?; + + let increment = if self.cur_token_is(&Token::LBrace) { + None + } else { + Some(self.expr(Precedence::default())?) + }; + + let block = self.compound_statement()?; + + Ok(Stmt::For(StmtFor { + initializer: initializer.map(|initializer| Box::new(initializer)), + condition, + increment, + block, + })) + } + + fn array_type(&mut self, type_: &mut Ty) -> Result<(), Error> { + if self.cur_token_is(&Token::LBracket) { + self.expect(&Token::LBracket)?; + + match self.next_token()?.unwrap() { + Token::Integer(int) => { + let length: usize = str::parse(&int).unwrap(); + self.expect(&Token::RBracket)?; + + *type_ = Ty::Array { + ty: Box::new(type_.clone()), + len: length, + }; + } + token => panic!("Expected integer, got {token}"), + } + } + + Ok(()) + } + + fn local(&mut self) -> Result { + self.expect(&Token::Let)?; + + let name = match self.next_token()?.unwrap() { + Token::Ident(ident) => ident, + token => { + return Err(Error::ParseType(token)); + } + }; + let ty = if self.cur_token_is(&Token::Colon) { + self.expect(&Token::Colon)?; + + let mut ty = self.parse_type()?; + self.array_type(&mut ty)?; + + ty + } else { + Ty::Infer + }; + + let expr = if self.cur_token_is(&Token::Assign) { + self.expect(&Token::Assign)?; + + Some(self.expr(Precedence::default())?) + } else { + None + }; + + self.expect(&Token::Semicolon)?; + + Ok(Stmt::Local(Variable { + name, + ty, + value: expr, + })) + } + + fn global(&mut self) -> Result { + self.expect(&Token::Let)?; + + let name = match self.next_token()?.unwrap() { + Token::Ident(ident) => ident, + token => { + return Err(Error::ParseType(token)); + } + }; + self.expect(&Token::Colon)?; + + let mut ty = self.parse_type()?; + self.array_type(&mut ty)?; + + let expr = if self.cur_token_is(&Token::Assign) { + self.expect(&Token::Assign)?; + + Some(self.expr(Precedence::default())?) + } else { + None + }; + + self.expect(&Token::Semicolon)?; + + Ok(Item::Global(Variable { + name, + ty, + value: expr, + })) + } + + fn function(&mut self, func_definition: bool) -> Result { + self.expect(&Token::Fn)?; + + let name = match self.next_token()?.unwrap() { + Token::Ident(ident) => ident, + token => { + return Err(Error::ParseType(token)); + } + }; + + self.expect(&Token::LParen)?; + + let params = self.params(Token::Comma, Token::RParen)?; + self.expect(&Token::Arrow)?; + + let type_ = self.parse_type()?; + let block = if self.cur_token_is(&Token::LBrace) { + Some(self.compound_statement()?) + } else { + None + }; + + if block.is_some() & !func_definition { + panic!("Function definition is not supported here"); + } + + if block.is_none() { + self.expect(&Token::Semicolon)?; + } + + Ok(Item::Fn(ItemFn { + ret_ty: type_, + name, + params, + block, + })) + } + + fn params(&mut self, delim: Token, end: Token) -> Result, Error> { + let mut params = Vec::new(); + + while !self.cur_token_is(&end) { + let name = match self.next_token()? { + Some(Token::Ident(ident)) => ident, + _ => todo!("Don't know what error to return yet"), + }; + self.expect(&Token::Colon)?; + let type_ = self.parse_type()?; + + match params.iter().find(|(field_name, _)| field_name == &name) { + Some(_) => todo!("Don't know yet what error to return"), + None => params.push((name, type_)), + }; + + if !self.cur_token_is(&end) { + self.expect(&delim)?; + } + } + + self.expect(&end)?; + + Ok(params) + } + + fn ident(&mut self) -> Result { + match self.peek_token { + Some(Token::LBrace) => self.struct_expr(), + _ => match self + .next_token()? + .ok_or(Error::Expected(Token::Ident(Default::default())))? + { + Token::Ident(ident) => Ok(Expr::Ident(ExprIdent(ident))), + token => Err(Error::ParseType(token)), + }, + } + } + + fn struct_expr(&mut self) -> Result { + let name = match self.next_token()? { + Some(Token::Ident(ident)) => ident, + _ => todo!("Don't know what error to return yet"), + }; + + self.expect(&Token::LBrace)?; + let mut fields = Vec::new(); + + while !self.cur_token_is(&Token::RBrace) { + match self.next_token()? { + Some(Token::Ident(field)) => { + self.expect(&Token::Colon)?; + let expr = self.expr(Precedence::Lowest)?; + + if !self.cur_token_is(&Token::RBrace) { + self.expect(&Token::Comma)?; + } + + fields.push((field, expr)); + } + _ => todo!("Don't know what error to return yet"), + } + } + + self.expect(&Token::RBrace)?; + + Ok(Expr::Struct(ExprStruct { name, fields })) + } + + fn string_lit(&mut self) -> Result { + match self.next_token()? { + Some(Token::String(literal)) => Ok(Expr::Lit(ExprLit::String(literal))), + Some(_) | None => unreachable!(), + } + } + + fn int_lit(&mut self) -> Result { + match self.next_token()? { + Some(Token::Integer(num_str)) => Ok(Expr::Lit(ExprLit::UInt(num_str.parse().unwrap()))), + Some(_) | None => unreachable!(), + } + } + + fn null(&mut self) -> Result { + self.expect(&Token::Null)?; + + Ok(Expr::Lit(ExprLit::Null)) + } + + fn bool(&mut self) -> Result { + match self.next_token()? { + Some(Token::True) => Ok(Expr::Lit(ExprLit::Bool(true))), + Some(Token::False) => Ok(Expr::Lit(ExprLit::Bool(false))), + Some(_) | None => unreachable!(), + } + } + + fn func_call_expr(&mut self, left: Expr) -> Result { + self.expect(&Token::LParen)?; + + let args = self.expr_list()?; + + Ok(Expr::FunctionCall(ExprFunctionCall { + expr: Box::new(left), + arguments: args, + })) + } + + fn macro_call_expr(&mut self, left: Expr) -> Result { + let name = match left { + Expr::Ident(expr) => expr.0, + _ => panic!("Macro name can only be a string"), + }; + + self.expect(&Token::Bang)?; + self.expect(&Token::LParen)?; + + let mut tokens = Vec::new(); + + while !self.cur_token_is(&Token::RParen) { + tokens.push(self.next_token()?.ok_or(Error::Expected(Token::RParen))?); + } + + self.expect(&Token::RParen)?; + + Ok(Expr::MacroCall(MacroCall { name, tokens })) + } + + fn bin_expr(&mut self, left: Expr) -> Result { + let token = self.next_token()?.unwrap(); + + // NOTE: assignment expression is right-associative + let precedence = if let &Token::Assign = &token { + Precedence::from(&token).lower() + } else { + Precedence::from(&token) + }; + let right = self.expr(precedence)?; + let op = BinOp::try_from(&token).map_err(|e| Error::Operator(e))?; + + Ok(Expr::Binary(ExprBinary { + op, + left: Box::new(left), + right: Box::new(right), + })) + } + + fn pointer_access(&mut self, left: Expr) -> Result { + self.expect(&Token::Arrow)?; + + let field = match self.next_token()? { + Some(Token::Ident(ident)) => ident, + _ => unreachable!(), + }; + + Ok(Expr::Field(ExprField { + expr: Box::new(Expr::Unary(ExprUnary { + op: UnOp::Deref, + expr: Box::new(left), + })), + field, + })) + } + + fn struct_access(&mut self, expr: Expr) -> Result { + self.expect(&Token::Period)?; + + if self.peek_token_is(&Token::LParen) { + let method = match self.next_token()? { + Some(Token::Ident(field)) => field, + _ => panic!("Struct field name should be of type string"), + }; + + self.expect(&Token::LParen)?; + let arguments = self.expr_list()?; + + Ok(Expr::StructMethod(ExprStructMethod { + expr: Box::new(expr), + method, + arguments, + })) + } else { + match self.next_token()? { + Some(Token::Ident(field)) => Ok(Expr::Field(ExprField { + expr: Box::new(expr), + field, + })), + _ => panic!("Struct field name should be of type string"), + } + } + } + + fn array_access(&mut self, expr: Expr) -> Result { + self.expect(&Token::LBracket)?; + let index = self.expr(Precedence::Access)?; + self.expect(&Token::RBracket)?; + + Ok(Expr::ArrayAccess(ExprArrayAccess { + expr: Box::new(expr), + index: Box::new(index), + })) + } + + fn cast_expr(&mut self, expr: Expr) -> Result { + self.expect(&Token::As)?; + + Ok(Expr::Cast(ExprCast { + expr: Box::new(expr), + ty: self.parse_type()?, + })) + } + + fn unary_expr(&mut self) -> Result { + let op = UnOp::try_from(&self.next_token()?.unwrap()).map_err(|e| Error::Operator(e))?; + let expr = self.expr(Precedence::Prefix)?; + + Ok(Expr::Unary(ExprUnary { + op, + expr: Box::new(expr), + })) + } + + fn expr_list(&mut self) -> Result, Error> { + let mut exprs = Vec::new(); + + while !self.cur_token_is(&Token::RParen) { + exprs.push(self.expr(Precedence::default())?); + if !self.cur_token_is(&Token::RParen) { + self.expect(&Token::Comma)?; + } + } + + self.expect(&Token::RParen)?; + + Ok(exprs) + } + + fn grouped_expr(&mut self) -> Result { + self.expect(&Token::LParen)?; + + let expr = self.expr(Precedence::default())?; + self.expect(&Token::RParen)?; + + Ok(expr) + } + + fn array_expr(&mut self) -> Result { + self.expect(&Token::LBracket)?; + let mut items = Vec::new(); + + while !self.cur_token_is(&Token::RBracket) { + items.push(self.expr(Precedence::default())?); + + if !self.cur_token_is(&Token::RBracket) { + self.expect(&Token::Comma)?; + } + } + + self.expect(&Token::RBracket)?; + + Ok(Expr::Array(ExprArray(items))) + } +} + +#[cfg(test)] +mod test { + use super::Parser; + use crate::{ + lexer::Lexer, + parser::{ + BinOp, Error, Expr, ExprBinary, ExprCast, ExprIdent, ExprLit, ExprUnary, IntTy, Stmt, + Ty, UintTy, UnOp, Variable, + }, + }; + + #[test] + fn parse_arithmetic_expression() -> Result<(), Error> { + let tests = [ + ( + " + { + 1 * 2 + 3 / (4 + 1 as u8); + } + ", + vec![Stmt::Expr(Expr::Binary(ExprBinary { + op: BinOp::Add, + left: Box::new(Expr::Binary(ExprBinary { + op: BinOp::Mul, + left: Box::new(Expr::Lit(ExprLit::UInt(1))), + right: Box::new(Expr::Lit(ExprLit::UInt(2))), + })), + right: Box::new(Expr::Binary(ExprBinary { + op: BinOp::Div, + left: Box::new(Expr::Lit(ExprLit::UInt(3))), + right: Box::new(Expr::Binary(ExprBinary { + op: BinOp::Add, + left: Box::new(Expr::Lit(ExprLit::UInt(4))), + right: Box::new(Expr::Cast(ExprCast { + ty: Ty::UInt(UintTy::U8), + expr: Box::new(Expr::Lit(ExprLit::UInt(1))), + })), + })), + })), + }))], + ), + ( + " + { + let foo: u8; + foo = -1 as u8 + 5; + } + ", + vec![ + Stmt::Local(Variable { + name: "foo".to_owned(), + ty: Ty::UInt(UintTy::U8), + value: None, + }), + Stmt::Expr(Expr::Binary(ExprBinary { + op: BinOp::Assign, + left: Box::new(Expr::Ident(ExprIdent("foo".to_owned()))), + right: Box::new(Expr::Binary(ExprBinary { + op: BinOp::Add, + left: Box::new(Expr::Cast(ExprCast { + ty: Ty::UInt(UintTy::U8), + expr: Box::new(Expr::Unary(ExprUnary { + op: UnOp::Negative, + expr: Box::new(Expr::Lit(ExprLit::UInt(1))), + })), + })), + right: Box::new(Expr::Lit(ExprLit::UInt(5))), + })), + })), + ], + ), + ( + " + { + let foo: u8; + let bar: i8; + bar = foo as i8 + 5 / 10; + } + ", + vec![ + Stmt::Local(Variable { + name: "foo".to_owned(), + ty: Ty::UInt(UintTy::U8), + value: None, + }), + Stmt::Local(Variable { + name: "bar".to_owned(), + ty: Ty::Int(IntTy::I8), + value: None, + }), + Stmt::Expr(Expr::Binary(ExprBinary { + op: BinOp::Assign, + left: Box::new(Expr::Ident(ExprIdent("bar".to_owned()))), + right: Box::new(Expr::Binary(ExprBinary { + op: BinOp::Add, + left: Box::new(Expr::Cast(ExprCast { + ty: Ty::Int(IntTy::I8), + expr: Box::new(Expr::Ident(ExprIdent("foo".to_owned()))), + })), + right: Box::new(Expr::Binary(ExprBinary { + op: BinOp::Div, + left: Box::new(Expr::Lit(ExprLit::UInt(5))), + right: Box::new(Expr::Lit(ExprLit::UInt(10))), + })), + })), + })), + ], + ), + ( + " + { + 1 as i8 + 2 / 3; + } + ", + vec![Stmt::Expr(Expr::Binary(ExprBinary { + op: BinOp::Add, + left: Box::new(Expr::Cast(ExprCast { + ty: Ty::Int(IntTy::I8), + expr: Box::new(Expr::Lit(ExprLit::UInt(1))), + })), + right: Box::new(Expr::Binary(ExprBinary { + op: BinOp::Div, + left: Box::new(Expr::Lit(ExprLit::UInt(2))), + right: Box::new(Expr::Lit(ExprLit::UInt(3))), + })), + }))], + ), + ( + " + { + let a: u8; + let b: u8; + + a = b = 69; + } + ", + vec![ + Stmt::Local(Variable { + name: "a".to_owned(), + ty: Ty::UInt(UintTy::U8), + value: None, + }), + Stmt::Local(Variable { + name: "b".to_owned(), + ty: Ty::UInt(UintTy::U8), + value: None, + }), + Stmt::Expr(Expr::Binary(ExprBinary { + op: BinOp::Assign, + left: Box::new(Expr::Ident(ExprIdent("a".to_owned()))), + right: Box::new(Expr::Binary(ExprBinary { + op: BinOp::Assign, + left: Box::new(Expr::Ident(ExprIdent("b".to_owned()))), + right: Box::new(Expr::Lit(ExprLit::UInt(69))), + })), + })), + ], + ), + ]; + + for (input, expected) in tests { + let mut parser = Parser::new(Lexer::new(input.to_string())).unwrap(); + let ast = parser.compound_statement().unwrap(); + + assert_eq!( + &ast.0, &expected, + "expected: {:?}, got: {:?}", + expected, ast + ); + } + + Ok(()) + } +} diff --git a/src/parser/parser.rs b/src/parser/parser.rs deleted file mode 100644 index 1d70f17..0000000 --- a/src/parser/parser.rs +++ /dev/null @@ -1,948 +0,0 @@ -use super::{ - expr::{ExprBinary, ExprLit, ExprUnary}, - item::Item, - precedence::Precedence, - stmt::{StmtFor, StmtIf, StmtReturn, StmtWhile}, - BinOp, Block, Expr, ExprArray, ExprArrayAccess, ExprCast, ExprIdent, ExprStruct, - ExprStructMethod, IntTy, ItemFn, ItemStruct, MacroCall, ParserError, Stmt, Ty, TyArray, UintTy, - UnOp, Variable, -}; -use crate::{ - lexer::{LexerError, Token}, - parser::{ExprField, ExprFunctionCall}, -}; -use std::collections::HashMap; - -type PrefixFn = fn(&mut Parser) -> Result; -type InfixFn = fn(&mut Parser, left: Expr) -> Result; - -pub struct Parser>> { - lexer: T, - cur_token: Option, - peek_token: Option, - prefix_fns: HashMap>, - infix_fns: HashMap>, -} - -impl>> Parser { - pub fn new(mut lexer: T) -> Result { - Ok(Self { - cur_token: lexer.next().transpose()?, - peek_token: lexer.next().transpose()?, - lexer, - prefix_fns: HashMap::from([ - (Token::Ident(Default::default()), Self::ident as PrefixFn), - (Token::String(Default::default()), Self::string_lit), - (Token::Integer(Default::default()), Self::int_lit), - (Token::Null, Self::null), - (Token::True, Self::bool), - (Token::False, Self::bool), - (Token::Minus, Self::unary_expr), - (Token::Bang, Self::unary_expr), - (Token::LParen, Self::grouped_expr), - (Token::Ampersand, Self::unary_expr), - (Token::Asterisk, Self::unary_expr), - (Token::Tilde, Self::unary_expr), - (Token::LBracket, Self::array_expr), - ]), - infix_fns: HashMap::from([ - (Token::Plus, Self::bin_expr as InfixFn), - (Token::Minus, Self::bin_expr), - (Token::Asterisk, Self::bin_expr), - (Token::Slash, Self::bin_expr), - (Token::Assign, Self::bin_expr), - (Token::LessThan, Self::bin_expr), - (Token::LessEqual, Self::bin_expr), - (Token::GreaterThan, Self::bin_expr), - (Token::GreaterEqual, Self::bin_expr), - (Token::Equal, Self::bin_expr), - (Token::NotEqual, Self::bin_expr), - (Token::And, Self::bin_expr), - (Token::Or, Self::bin_expr), - (Token::LParen, Self::bin_expr), - (Token::Ampersand, Self::bin_expr), - (Token::Bar, Self::bin_expr), - (Token::Shl, Self::bin_expr), - (Token::Shr, Self::bin_expr), - (Token::Arrow, Self::pointer_access), - (Token::Period, Self::struct_access), - (Token::LBracket, Self::array_access), - (Token::As, Self::cast_expr), - (Token::LParen, Self::func_call_expr), - (Token::Bang, Self::macro_call_expr), - ]), - }) - } - - fn next_token(&mut self) -> Result, ParserError> { - let mut token = self.lexer.next().transpose()?; - - std::mem::swap(&mut self.cur_token, &mut self.peek_token); - std::mem::swap(&mut token, &mut self.peek_token); - - Ok(token) - } - - fn cur_token_is(&self, token: &Token) -> bool { - self.cur_token.as_ref() == Some(token) - } - - fn peek_token_is(&self, token: &Token) -> bool { - self.peek_token.as_ref() == Some(token) - } - - fn expect(&mut self, token: &Token) -> Result<(), ParserError> { - match self.next_token()? { - Some(ref cur) if cur == token => Ok(()), - Some(cur) => Err(ParserError::UnexpectedToken(token.to_owned(), cur)), - None => Err(ParserError::Expected(token.to_owned())), - } - } - - pub fn parse(&mut self) -> Result, ParserError> { - let mut items = Vec::new(); - - while let Some(token) = &self.cur_token { - let item = match token { - Token::Struct => self.parse_struct()?, - Token::Let => self.global()?, - Token::Fn => self.function(true)?, - _ => unreachable!(), - }; - items.push(item); - } - - Ok(items) - } - - pub fn expr(&mut self, precedence: Precedence) -> Result { - let token = match self.cur_token.as_ref().unwrap() { - Token::Ident(_) => Token::Ident(Default::default()), - Token::Integer(_) => Token::Integer(Default::default()), - Token::String(_) => Token::String(Default::default()), - token => token.clone(), - }; - - let mut left = match self.prefix_fns.get(&token) { - Some(func) => func(self), - None => { - return Err(ParserError::Prefix(token)); - } - }; - - while !self.cur_token_is(&Token::Semicolon) - && self.cur_token.is_some() - && precedence < Precedence::from(self.cur_token.as_ref().unwrap()) - { - left = match self.infix_fns.get(self.cur_token.as_ref().unwrap()) { - Some(func) => func(self, left?), - None => { - return Err(ParserError::Infix(self.cur_token.clone().unwrap())); - } - }; - } - - left - } - - fn parse_struct(&mut self) -> Result { - self.expect(&Token::Struct)?; - - let name = match self - .next_token()? - .ok_or(ParserError::Expected(Token::Ident(Default::default())))? - { - Token::Ident(ident) => Ok(ident), - token => Err(ParserError::UnexpectedToken( - Token::Ident(Default::default()), - token, - )), - }?; - - self.expect(&Token::LBrace)?; - - let mut fields = Vec::new(); - - while !self.cur_token_is(&Token::RBrace) { - if self.cur_token_is(&Token::Fn) { - // Handle struct methods here - } else { - let name = match self.next_token()? { - Some(Token::Ident(ident)) => ident, - _ => todo!("Don't know what error to return yet"), - }; - self.expect(&Token::Colon)?; - let mut type_ = self.parse_type()?; - self.array_type(&mut type_)?; - - match fields.iter().find(|(field_name, _)| field_name == &name) { - Some(_) => todo!("Don't know yet what error to return"), - None => fields.push((name, type_)), - }; - - if !self.cur_token_is(&Token::RBrace) { - self.expect(&Token::Semicolon)?; - } - } - } - - self.expect(&Token::RBrace)?; - - Ok(Item::Struct(ItemStruct { name, fields })) - } - - fn stmt(&mut self) -> Result { - match self.cur_token.as_ref().unwrap() { - Token::Return => self.parse_return(), - Token::If => self.if_stmt(), - Token::While => self.while_stmt(), - Token::For => self.for_stmt(), - Token::Let => self.local(), - Token::Continue => { - self.expect(&Token::Continue)?; - self.expect(&Token::Semicolon)?; - - Ok(Stmt::Continue) - } - Token::Break => { - self.expect(&Token::Break)?; - self.expect(&Token::Semicolon)?; - - Ok(Stmt::Break) - } - Token::Fn => Ok(Stmt::Item(self.function(true)?)), - _ => { - let expr = Stmt::Expr(self.expr(Precedence::default())?); - - self.expect(&Token::Semicolon)?; - - Ok(expr) - } - } - } - - fn compound_statement(&mut self) -> Result { - let mut stmts = Vec::new(); - - self.expect(&Token::LBrace)?; - - while !self.cur_token_is(&Token::RBrace) { - stmts.push(self.stmt()?); - } - - self.expect(&Token::RBrace)?; - - Ok(Block(stmts)) - } - - // This function is used only by macro expansion - pub fn parse_stmts(&mut self) -> Result, ParserError> { - let mut stmts = Vec::new(); - - while self.cur_token.is_some() { - stmts.push(self.stmt()?); - } - - Ok(stmts) - } - - fn parse_type(&mut self) -> Result { - let mut n = 0; - while self.cur_token_is(&Token::Asterisk) { - self.expect(&Token::Asterisk)?; - n += 1; - } - - let mut base = match self.next_token()?.unwrap() { - Token::U8 => Ok(Ty::UInt(UintTy::U8)), - Token::U16 => Ok(Ty::UInt(UintTy::U16)), - Token::U32 => Ok(Ty::UInt(UintTy::U32)), - Token::U64 => Ok(Ty::UInt(UintTy::U64)), - Token::I8 => Ok(Ty::Int(IntTy::I8)), - Token::I16 => Ok(Ty::Int(IntTy::I16)), - Token::I32 => Ok(Ty::Int(IntTy::I32)), - Token::I64 => Ok(Ty::Int(IntTy::I64)), - Token::Usize => Ok(Ty::UInt(UintTy::Usize)), - Token::Isize => Ok(Ty::Int(IntTy::Isize)), - Token::Bool => Ok(Ty::Bool), - Token::Void => Ok(Ty::Void), - Token::Ident(ident) => Ok(Ty::Ident(ident)), - Token::Fn => { - self.expect(&Token::LParen)?; - - let mut params = Vec::new(); - - while !self.cur_token_is(&Token::RParen) { - params.push(self.parse_type()?); - - if !self.cur_token_is(&Token::RParen) { - self.expect(&Token::Comma)?; - } - } - - self.expect(&Token::RParen)?; - self.expect(&Token::Arrow)?; - - Ok(Ty::Fn(params, Box::new(self.parse_type()?))) - } - token => Err(ParserError::ParseType(token)), - }?; - - while n > 0 { - base = Ty::Ptr(Box::new(base)); - n -= 1; - } - - Ok(base) - } - - fn parse_return(&mut self) -> Result { - self.expect(&Token::Return)?; - - let expr = if !self.cur_token_is(&Token::Semicolon) { - Some(self.expr(Precedence::default())?) - } else { - None - }; - - self.expect(&Token::Semicolon)?; - - Ok(Stmt::Return(StmtReturn { expr })) - } - - fn if_stmt(&mut self) -> Result { - self.expect(&Token::If)?; - - let condition = self.expr(Precedence::default())?; - let consequence = self.compound_statement()?; - let alternative = if self.cur_token_is(&Token::Else) { - self.expect(&Token::Else)?; - - Some(self.compound_statement()?) - } else { - None - }; - - Ok(Stmt::If(StmtIf { - condition, - consequence, - alternative, - })) - } - - fn while_stmt(&mut self) -> Result { - self.expect(&Token::While)?; - - let condition = self.expr(Precedence::default())?; - let block = self.compound_statement()?; - - Ok(Stmt::While(StmtWhile { condition, block })) - } - - fn for_stmt(&mut self) -> Result { - self.expect(&Token::For)?; - - let initializer = if self.cur_token_is(&Token::Semicolon) { - None - } else { - let stmt = if self.cur_token_is(&Token::Let) { - self.local()? - } else { - Stmt::Expr(self.expr(Precedence::default())?) - }; - - Some(stmt) - }; - - let condition = if self.cur_token_is(&Token::Semicolon) { - None - } else { - Some(self.expr(Precedence::default())?) - }; - self.expect(&Token::Semicolon)?; - - let increment = if self.cur_token_is(&Token::LBrace) { - None - } else { - Some(self.expr(Precedence::default())?) - }; - - let block = self.compound_statement()?; - - Ok(Stmt::For(StmtFor { - initializer: initializer.map(|initializer| Box::new(initializer)), - condition, - increment, - block, - })) - } - - fn array_type(&mut self, type_: &mut Ty) -> Result<(), ParserError> { - if self.cur_token_is(&Token::LBracket) { - self.expect(&Token::LBracket)?; - - match self.next_token()?.unwrap() { - Token::Integer(int) => { - let length: usize = str::parse(&int).unwrap(); - self.expect(&Token::RBracket)?; - - *type_ = Ty::Array(TyArray { - ty: Box::new(type_.clone()), - len: length, - }); - } - token => panic!("Expected integer, got {token}"), - } - } - - Ok(()) - } - - fn local(&mut self) -> Result { - self.expect(&Token::Let)?; - - let name = match self.next_token()?.unwrap() { - Token::Ident(ident) => ident, - token => { - return Err(ParserError::ParseType(token)); - } - }; - let ty = if self.cur_token_is(&Token::Colon) { - self.expect(&Token::Colon)?; - - let mut ty = self.parse_type()?; - self.array_type(&mut ty)?; - - ty - } else { - Ty::Infer - }; - - let expr = if self.cur_token_is(&Token::Assign) { - self.expect(&Token::Assign)?; - - Some(self.expr(Precedence::default())?) - } else { - None - }; - - self.expect(&Token::Semicolon)?; - - Ok(Stmt::Local(Variable { - name, - ty, - value: expr, - })) - } - - fn global(&mut self) -> Result { - self.expect(&Token::Let)?; - - let name = match self.next_token()?.unwrap() { - Token::Ident(ident) => ident, - token => { - return Err(ParserError::ParseType(token)); - } - }; - self.expect(&Token::Colon)?; - - let mut ty = self.parse_type()?; - self.array_type(&mut ty)?; - - let expr = if self.cur_token_is(&Token::Assign) { - self.expect(&Token::Assign)?; - - Some(self.expr(Precedence::default())?) - } else { - None - }; - - self.expect(&Token::Semicolon)?; - - Ok(Item::Global(Variable { - name, - ty, - value: expr, - })) - } - - fn function(&mut self, func_definition: bool) -> Result { - self.expect(&Token::Fn)?; - - let name = match self.next_token()?.unwrap() { - Token::Ident(ident) => ident, - token => { - return Err(ParserError::ParseType(token)); - } - }; - - self.expect(&Token::LParen)?; - - let params = self.params(Token::Comma, Token::RParen)?; - self.expect(&Token::Arrow)?; - - let type_ = self.parse_type()?; - let block = if self.cur_token_is(&Token::LBrace) { - Some(self.compound_statement()?) - } else { - None - }; - - if block.is_some() & !func_definition { - panic!("Function definition is not supported here"); - } - - if block.is_none() { - self.expect(&Token::Semicolon)?; - } - - Ok(Item::Fn(ItemFn { - ret_ty: type_, - name, - params, - block, - })) - } - - fn params(&mut self, delim: Token, end: Token) -> Result, ParserError> { - let mut params = Vec::new(); - - while !self.cur_token_is(&end) { - let name = match self.next_token()? { - Some(Token::Ident(ident)) => ident, - _ => todo!("Don't know what error to return yet"), - }; - self.expect(&Token::Colon)?; - let type_ = self.parse_type()?; - - match params.iter().find(|(field_name, _)| field_name == &name) { - Some(_) => todo!("Don't know yet what error to return"), - None => params.push((name, type_)), - }; - - if !self.cur_token_is(&end) { - self.expect(&delim)?; - } - } - - self.expect(&end)?; - - Ok(params) - } - - fn ident(&mut self) -> Result { - match self.peek_token { - Some(Token::LBrace) => self.struct_expr(), - _ => match self - .next_token()? - .ok_or(ParserError::Expected(Token::Ident(Default::default())))? - { - Token::Ident(ident) => Ok(Expr::Ident(ExprIdent(ident))), - token => Err(ParserError::ParseType(token)), - }, - } - } - - fn struct_expr(&mut self) -> Result { - let name = match self.next_token()? { - Some(Token::Ident(ident)) => ident, - _ => todo!("Don't know what error to return yet"), - }; - - self.expect(&Token::LBrace)?; - let mut fields = Vec::new(); - - while !self.cur_token_is(&Token::RBrace) { - match self.next_token()? { - Some(Token::Ident(field)) => { - self.expect(&Token::Colon)?; - let expr = self.expr(Precedence::Lowest)?; - - if !self.cur_token_is(&Token::RBrace) { - self.expect(&Token::Comma)?; - } - - fields.push((field, expr)); - } - _ => todo!("Don't know what error to return yet"), - } - } - - self.expect(&Token::RBrace)?; - - Ok(Expr::Struct(ExprStruct { name, fields })) - } - - fn string_lit(&mut self) -> Result { - match self.next_token()? { - Some(Token::String(literal)) => Ok(Expr::Lit(ExprLit::String(literal))), - Some(_) | None => unreachable!(), - } - } - - fn int_lit(&mut self) -> Result { - match self.next_token()? { - Some(Token::Integer(num_str)) => Ok(Expr::Lit(ExprLit::UInt(num_str.parse().unwrap()))), - Some(_) | None => unreachable!(), - } - } - - fn null(&mut self) -> Result { - self.expect(&Token::Null)?; - - Ok(Expr::Lit(ExprLit::Null)) - } - - fn bool(&mut self) -> Result { - match self.next_token()? { - Some(Token::True) => Ok(Expr::Lit(ExprLit::Bool(true))), - Some(Token::False) => Ok(Expr::Lit(ExprLit::Bool(false))), - Some(_) | None => unreachable!(), - } - } - - fn func_call_expr(&mut self, left: Expr) -> Result { - self.expect(&Token::LParen)?; - - let args = self.expr_list()?; - - Ok(Expr::FunctionCall(ExprFunctionCall { - expr: Box::new(left), - arguments: args, - })) - } - - fn macro_call_expr(&mut self, left: Expr) -> Result { - let name = match left { - Expr::Ident(expr) => expr.0, - _ => panic!("Macro name can only be a string"), - }; - - self.expect(&Token::Bang)?; - self.expect(&Token::LParen)?; - - let mut tokens = Vec::new(); - - while !self.cur_token_is(&Token::RParen) { - tokens.push( - self.next_token()? - .ok_or(ParserError::Expected(Token::RParen))?, - ); - } - - self.expect(&Token::RParen)?; - - Ok(Expr::MacroCall(MacroCall { name, tokens })) - } - - fn bin_expr(&mut self, left: Expr) -> Result { - let token = self.next_token()?.unwrap(); - - // NOTE: assignment expression is right-associative - let precedence = if let &Token::Assign = &token { - Precedence::from(&token).lower() - } else { - Precedence::from(&token) - }; - let right = self.expr(precedence)?; - let op = BinOp::try_from(&token).map_err(|e| ParserError::Operator(e))?; - - Ok(Expr::Binary(ExprBinary { - op, - left: Box::new(left), - right: Box::new(right), - })) - } - - fn pointer_access(&mut self, left: Expr) -> Result { - self.expect(&Token::Arrow)?; - - let field = match self.next_token()? { - Some(Token::Ident(ident)) => ident, - _ => unreachable!(), - }; - - Ok(Expr::Field(ExprField { - expr: Box::new(Expr::Unary(ExprUnary { - op: UnOp::Deref, - expr: Box::new(left), - })), - field, - })) - } - - fn struct_access(&mut self, expr: Expr) -> Result { - self.expect(&Token::Period)?; - - if self.peek_token_is(&Token::LParen) { - let method = match self.next_token()? { - Some(Token::Ident(field)) => field, - _ => panic!("Struct field name should be of type string"), - }; - - self.expect(&Token::LParen)?; - let arguments = self.expr_list()?; - - Ok(Expr::StructMethod(ExprStructMethod { - expr: Box::new(expr), - method, - arguments, - })) - } else { - match self.next_token()? { - Some(Token::Ident(field)) => Ok(Expr::Field(ExprField { - expr: Box::new(expr), - field, - })), - _ => panic!("Struct field name should be of type string"), - } - } - } - - fn array_access(&mut self, expr: Expr) -> Result { - self.expect(&Token::LBracket)?; - let index = self.expr(Precedence::Access)?; - self.expect(&Token::RBracket)?; - - Ok(Expr::ArrayAccess(ExprArrayAccess { - expr: Box::new(expr), - index: Box::new(index), - })) - } - - fn cast_expr(&mut self, expr: Expr) -> Result { - self.expect(&Token::As)?; - - Ok(Expr::Cast(ExprCast { - expr: Box::new(expr), - ty: self.parse_type()?, - })) - } - - fn unary_expr(&mut self) -> Result { - let op = - UnOp::try_from(&self.next_token()?.unwrap()).map_err(|e| ParserError::Operator(e))?; - let expr = self.expr(Precedence::Prefix)?; - - Ok(Expr::Unary(ExprUnary { - op, - expr: Box::new(expr), - })) - } - - fn expr_list(&mut self) -> Result, ParserError> { - let mut exprs = Vec::new(); - - while !self.cur_token_is(&Token::RParen) { - exprs.push(self.expr(Precedence::default())?); - if !self.cur_token_is(&Token::RParen) { - self.expect(&Token::Comma)?; - } - } - - self.expect(&Token::RParen)?; - - Ok(exprs) - } - - fn grouped_expr(&mut self) -> Result { - self.expect(&Token::LParen)?; - - let expr = self.expr(Precedence::default())?; - self.expect(&Token::RParen)?; - - Ok(expr) - } - - fn array_expr(&mut self) -> Result { - self.expect(&Token::LBracket)?; - let mut items = Vec::new(); - - while !self.cur_token_is(&Token::RBracket) { - items.push(self.expr(Precedence::default())?); - - if !self.cur_token_is(&Token::RBracket) { - self.expect(&Token::Comma)?; - } - } - - self.expect(&Token::RBracket)?; - - Ok(Expr::Array(ExprArray(items))) - } -} - -#[cfg(test)] -mod test { - use super::Parser; - use crate::{ - lexer::Lexer, - parser::{ - BinOp, Expr, ExprBinary, ExprCast, ExprIdent, ExprLit, ExprUnary, IntTy, ParserError, - Stmt, Ty, UintTy, UnOp, Variable, - }, - }; - - #[test] - fn parse_arithmetic_expression() -> Result<(), ParserError> { - let tests = [ - ( - " - { - 1 * 2 + 3 / (4 + 1 as u8); - } - ", - vec![Stmt::Expr(Expr::Binary(ExprBinary { - op: BinOp::Add, - left: Box::new(Expr::Binary(ExprBinary { - op: BinOp::Mul, - left: Box::new(Expr::Lit(ExprLit::UInt(1))), - right: Box::new(Expr::Lit(ExprLit::UInt(2))), - })), - right: Box::new(Expr::Binary(ExprBinary { - op: BinOp::Div, - left: Box::new(Expr::Lit(ExprLit::UInt(3))), - right: Box::new(Expr::Binary(ExprBinary { - op: BinOp::Add, - left: Box::new(Expr::Lit(ExprLit::UInt(4))), - right: Box::new(Expr::Cast(ExprCast { - ty: Ty::UInt(UintTy::U8), - expr: Box::new(Expr::Lit(ExprLit::UInt(1))), - })), - })), - })), - }))], - ), - ( - " - { - let foo: u8; - foo = -1 as u8 + 5; - } - ", - vec![ - Stmt::Local(Variable { - name: "foo".to_owned(), - ty: Ty::UInt(UintTy::U8), - value: None, - }), - Stmt::Expr(Expr::Binary(ExprBinary { - op: BinOp::Assign, - left: Box::new(Expr::Ident(ExprIdent("foo".to_owned()))), - right: Box::new(Expr::Binary(ExprBinary { - op: BinOp::Add, - left: Box::new(Expr::Cast(ExprCast { - ty: Ty::UInt(UintTy::U8), - expr: Box::new(Expr::Unary(ExprUnary { - op: UnOp::Negative, - expr: Box::new(Expr::Lit(ExprLit::UInt(1))), - })), - })), - right: Box::new(Expr::Lit(ExprLit::UInt(5))), - })), - })), - ], - ), - ( - " - { - let foo: u8; - let bar: i8; - bar = foo as i8 + 5 / 10; - } - ", - vec![ - Stmt::Local(Variable { - name: "foo".to_owned(), - ty: Ty::UInt(UintTy::U8), - value: None, - }), - Stmt::Local(Variable { - name: "bar".to_owned(), - ty: Ty::Int(IntTy::I8), - value: None, - }), - Stmt::Expr(Expr::Binary(ExprBinary { - op: BinOp::Assign, - left: Box::new(Expr::Ident(ExprIdent("bar".to_owned()))), - right: Box::new(Expr::Binary(ExprBinary { - op: BinOp::Add, - left: Box::new(Expr::Cast(ExprCast { - ty: Ty::Int(IntTy::I8), - expr: Box::new(Expr::Ident(ExprIdent("foo".to_owned()))), - })), - right: Box::new(Expr::Binary(ExprBinary { - op: BinOp::Div, - left: Box::new(Expr::Lit(ExprLit::UInt(5))), - right: Box::new(Expr::Lit(ExprLit::UInt(10))), - })), - })), - })), - ], - ), - ( - " - { - 1 as i8 + 2 / 3; - } - ", - vec![Stmt::Expr(Expr::Binary(ExprBinary { - op: BinOp::Add, - left: Box::new(Expr::Cast(ExprCast { - ty: Ty::Int(IntTy::I8), - expr: Box::new(Expr::Lit(ExprLit::UInt(1))), - })), - right: Box::new(Expr::Binary(ExprBinary { - op: BinOp::Div, - left: Box::new(Expr::Lit(ExprLit::UInt(2))), - right: Box::new(Expr::Lit(ExprLit::UInt(3))), - })), - }))], - ), - ( - " - { - let a: u8; - let b: u8; - - a = b = 69; - } - ", - vec![ - Stmt::Local(Variable { - name: "a".to_owned(), - ty: Ty::UInt(UintTy::U8), - value: None, - }), - Stmt::Local(Variable { - name: "b".to_owned(), - ty: Ty::UInt(UintTy::U8), - value: None, - }), - Stmt::Expr(Expr::Binary(ExprBinary { - op: BinOp::Assign, - left: Box::new(Expr::Ident(ExprIdent("a".to_owned()))), - right: Box::new(Expr::Binary(ExprBinary { - op: BinOp::Assign, - left: Box::new(Expr::Ident(ExprIdent("b".to_owned()))), - right: Box::new(Expr::Lit(ExprLit::UInt(69))), - })), - })), - ], - ), - ]; - - for (input, expected) in tests { - let mut parser = Parser::new(Lexer::new(input.to_string())).unwrap(); - let ast = parser.compound_statement().unwrap(); - - assert_eq!( - &ast.0, &expected, - "expected: {:?}, got: {:?}", - expected, ast - ); - } - - Ok(()) - } -} diff --git a/src/parser/types.rs b/src/parser/types.rs index 8dabb45..9903ed4 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -1,207 +1,58 @@ -use super::error::TyError; +use derive_more::derive::Display; -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] -pub struct TyArray { - pub ty: Box, - pub len: usize, -} - -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash, Display)] pub enum IntTy { + #[display("i8")] I8, + #[display("i16")] I16, + #[display("i32")] I32, + #[display("i64")] I64, + #[display("isize")] Isize, } -impl IntTy { - fn size(&self) -> Option { - Some(match self { - Self::I8 => 1, - Self::I16 => 2, - Self::I32 => 4, - Self::I64 => 8, - Self::Isize => return None, - }) - } -} - -impl std::fmt::Display for IntTy { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::I8 => write!(f, "i8"), - Self::I16 => write!(f, "i16"), - Self::I32 => write!(f, "i32"), - Self::I64 => write!(f, "i64"), - Self::Isize => write!(f, "isize"), - } - } -} - -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash, Display)] pub enum UintTy { + #[display("u8")] U8, + #[display("u16")] U16, + #[display("u32")] U32, + #[display("u64")] U64, + #[display("usize")] Usize, } -impl UintTy { - fn size(&self) -> Option { - Some(match self { - Self::U8 => 1, - Self::U16 => 2, - Self::U32 => 4, - Self::U64 => 8, - Self::Usize => return None, - }) - } - - pub fn to_signed(self) -> IntTy { - match self { - Self::U8 => IntTy::I8, - Self::U16 => IntTy::I16, - Self::U32 => IntTy::I32, - Self::U64 => IntTy::I64, - Self::Usize => IntTy::Isize, - } - } -} - -impl std::fmt::Display for UintTy { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::U8 => write!(f, "u8"), - Self::U16 => write!(f, "u16"), - Self::U32 => write!(f, "u32"), - Self::U64 => write!(f, "u64"), - Self::Usize => write!(f, "usize"), - } - } -} - -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash, Display)] pub enum Ty { + #[display("null")] Null, + #[display("void")] Void, + #[display("bool")] Bool, Int(IntTy), UInt(UintTy), Ident(String), + #[display("*{_0}")] Ptr(Box), - Array(TyArray), + #[display("{ty}[{len}]")] + Array { + ty: Box, + len: usize, + }, + #[display("fn ({}) -> {_1}", + _0 + .iter() + .map(|type_| type_.to_string()) + .collect::() + )] Fn(Vec, Box), + #[display("infer")] Infer, } - -impl std::fmt::Display for Ty { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Int(int) => int.fmt(f), - Self::UInt(uint) => uint.fmt(f), - Self::Bool => write!(f, "bool"), - Self::Void => write!(f, "void"), - Self::Ptr(type_) => write!(f, "*{type_}"), - Self::Ident(name) => write!(f, "{name}"), - Self::Array(array) => write!(f, "{}[{}]", array.ty, array.len), - Self::Fn(params, return_type) => write!( - f, - "fn ({}) -> {return_type}", - params - .iter() - .map(|type_| type_.to_string()) - .collect::() - ), - Self::Null => write!(f, "NULL"), - Self::Infer => unreachable!(), - } - } -} - -impl Ty { - pub fn ptr(&self) -> bool { - matches!(self, Self::Ptr(..)) - } - - pub fn arr(&self) -> bool { - matches!(self, Self::Array(..)) - } - - pub fn signed(&self) -> bool { - matches!(self, Self::Int(..)) - } - - pub fn int(&self) -> bool { - matches!(self, Ty::UInt(_) | Ty::Int(_)) - } - - pub fn cast(from: Self, to: Self) -> Result { - match (from, to) { - (from, to) if from.int() && to.int() => Ok(to), - (from, to) if from == Self::Bool && to.int() || from.int() && to == Ty::Bool => Ok(to), - (from, to) - if from.arr() && to.ptr() && from.inner().unwrap() == to.inner().unwrap() => - { - Ok(to) - } - (Ty::Array(_), Ty::Ptr(pointee)) if pointee.as_ref() == &Ty::Void => { - Ok(Ty::Ptr(pointee)) - } - (from, to) if from.ptr() && to.ptr() => Ok(to), - (from, to) if from.ptr() && to.int() => Ok(to), - (from, to) => Err(TyError::Cast(from, to)), - } - } - - pub fn size(&self) -> Option { - match self { - Ty::Void => Some(0), - Ty::Bool => Some(1), - Ty::Int(int) => int.size(), - Ty::UInt(uint) => uint.size(), - _ => None, - } - } - - pub fn inner(&self) -> Result { - match self { - Self::Ptr(type_) => Ok(type_.as_ref().to_owned()), - Self::Array(array) => Ok(*array.ty.clone()), - type_ => Err(TyError::Deref(type_.clone())), - } - } - - pub fn common_type(lhs: Ty, rhs: Ty) -> Ty { - match (lhs, rhs) { - (lhs, rhs) if lhs == rhs => lhs, - (type_ @ Ty::Ptr(_), int) | (int, type_ @ Ty::Ptr(_)) if int.int() => type_, - (type_ @ Ty::Ptr(_), Ty::Null) | (Ty::Null, type_ @ Ty::Ptr(_)) => type_, - (Ty::UInt(lhs), Ty::UInt(rhs)) => { - if lhs > rhs { - Ty::UInt(lhs) - } else { - Ty::UInt(rhs) - } - } - (Ty::Int(lhs), Ty::Int(rhs)) => { - if lhs > rhs { - Ty::Int(lhs) - } else { - Ty::Int(rhs) - } - } - (Ty::UInt(uint), Ty::Int(int)) | (Ty::Int(int), Ty::UInt(uint)) => { - let uint_int = uint.to_signed(); - - if uint_int <= int { - Ty::Int(int) - } else { - Ty::Int(uint_int) - } - } - (lhs, rhs) => unreachable!("Failed to get common type for {lhs} and {rhs}"), - } - } -} diff --git a/src/ty_problem/mod.rs b/src/ty_problem/mod.rs index a0e8db4..ed7701f 100644 --- a/src/ty_problem/mod.rs +++ b/src/ty_problem/mod.rs @@ -1,5 +1,6 @@ use crate::{ ir::{self, Ir, Item, Node, OrderedMap, Ty}, + parser::IntTy, Context, }; @@ -154,8 +155,8 @@ impl<'ir> TyProblem<'ir> { if let Some(ty) = self.get_ty_var(*lhs).ty() { match ty { Ty::Ptr(_) => { - *self.get_ty_var_mut(*rhs) = TyVar::Typed(&Ty::Int(ir::IntTy::Isize)); - *self.get_ty_var_mut(*expr) = TyVar::Typed(&Ty::Int(ir::IntTy::Isize)); + *self.get_ty_var_mut(*rhs) = TyVar::Typed(&Ty::Int(IntTy::Isize)); + *self.get_ty_var_mut(*expr) = TyVar::Typed(&Ty::Int(IntTy::Isize)); progress |= true; } Ty::Int(_) | Ty::UInt(_) => { @@ -182,7 +183,7 @@ impl<'ir> TyProblem<'ir> { match (self.get_ty_var(*lhs).ty(), self.get_ty_var(*rhs).ty()) { (Some(Ty::Ptr(_)), Some(Ty::Ptr(_))) => { - *self.get_ty_var_mut(*expr) = TyVar::Typed(&Ty::Int(ir::IntTy::Isize)); + *self.get_ty_var_mut(*expr) = TyVar::Typed(&Ty::Int(IntTy::Isize)); progress |= true; false