From d241d9d023fd2646b7c58ed9a50470e80ceda71a Mon Sep 17 00:00:00 2001 From: naotaka nakane Date: Sat, 20 Jan 2024 00:14:57 +0900 Subject: [PATCH] Fix server/client. --- src/client.rs | 33 +++---- src/server.rs | 263 +++++++++++++++++++++++++++++--------------------- 2 files changed, 167 insertions(+), 129 deletions(-) diff --git a/src/client.rs b/src/client.rs index ec01509..d8ea734 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,19 +1,25 @@ use anyhow::Result; use dialoguer::{theme::ColorfulTheme, BasicHistory, Input}; -use std::{ - io::{Read, Write}, - net::TcpStream, - process, -}; +use std::{net::TcpStream, process}; + +use crate::server::{read_from_stream, write_to_stream}; pub fn client_start() -> Result<()> { println!("connecting to junkdb server..."); let mut stream = TcpStream::connect("127.0.0.1:7878")?; println!("connected!"); + let ascii = r#" + ██╗██╗ ██╗███╗ ██╗██╗ ██╗██████╗ ██████╗ + ██║██║ ██║████╗ ██║██║ ██╔╝██╔══██╗██╔══██╗ + ██║██║ ██║██╔██╗ ██║█████╔╝ ██║ ██║██████╔╝ +██ ██║██║ ██║██║╚██╗██║██╔═██╗ ██║ ██║██╔══██╗ +╚█████╔╝╚██████╔╝██║ ╚████║██║ ██╗██████╔╝██████╔╝ + ╚════╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═╝╚═════╝ ╚═════╝ + "#; + println!("{}", ascii); println!("Welcome to junkdb!"); println!("Type \"exit\" or \"quit\" to exit."); - - let mut history = BasicHistory::new().max_entries(8).no_duplicates(true); + let mut history = BasicHistory::new().max_entries(100).no_duplicates(true); loop { if let Ok(cmd) = Input::::with_theme(&ColorfulTheme::default()) .with_prompt("Query") @@ -21,18 +27,11 @@ pub fn client_start() -> Result<()> { .interact_text() { if cmd == "exit" || cmd == "quit" { + println!("Bye!"); process::exit(0); } - - stream.write(&(cmd.len() as u32).to_be_bytes())?; - stream.write(cmd.as_bytes())?; - stream.flush()?; - - let mut size_buffer = [0u8; 4]; - stream.read_exact(&mut size_buffer)?; - let mut buffer = vec![0u8; u32::from_be_bytes(size_buffer) as usize]; - stream.read_exact(&mut buffer)?; - let response = String::from_utf8(buffer)?; + write_to_stream(&mut stream, &cmd)?; + let response = read_from_stream(&mut stream)?; println!("{}", response); } } diff --git a/src/server.rs b/src/server.rs index 8a35676..99388c8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -16,12 +16,15 @@ use crate::{ parser::{Parser, StatementAST}, }; +const SERVER_DEFAULT_PORT: u16 = 7878; + pub fn server_start(init: bool) -> Result<()> { println!("junkdb server started"); // init let instance = Arc::new(RwLock::new(Instance::new("data", init)?)); + // trap signals let instance_clone = instance.clone(); let mut signals = Signals::new(TERM_SIGNALS)?; thread::spawn(move || { @@ -41,7 +44,7 @@ pub fn server_start(init: bool) -> Result<()> { }); // listen - let listener = TcpListener::bind("127.0.0.1:7878")?; + let listener = TcpListener::bind(format!("127.0.0.1:{}", SERVER_DEFAULT_PORT))?; for stream in listener.incoming() { let stream = stream?; println!("connection established: {}", stream.peer_addr()?); @@ -68,132 +71,168 @@ impl Session { } } fn start(&mut self) -> Result<()> { - let result = self.internal(); - match result { - Ok(_) => Ok(()), - Err(e) => { - if let Some(txn_id) = self.current_txn_id { - self.instance - .write() - .map_err(|_| anyhow!("lock error"))? - .rollback(txn_id)?; + loop { + let request = self.read()?; + match self.execute(&request) { + Ok(response) => { + self.write(&response)?; + } + Err(e) => { + self.write(&format!("error: {}", e))?; + self.rollback()?; } - let response = format!("{}", e); - self.stream.write(&(response.len() as u32).to_be_bytes())?; - self.stream.write_all(response.as_bytes())?; - self.stream.flush()?; - Err(e) } } } - fn internal(&mut self) -> Result<()> { - loop { - // read request - let mut size_buffer = [0u8; 4]; - match self.stream.read_exact(&mut size_buffer) { - Ok(_) => {} - Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => { - println!("Client disconnected."); - return Ok(()); + fn read(&mut self) -> Result { + match read_from_stream(&mut self.stream) { + Ok(request) => { + return Ok(request); + } + Err(e) => { + self.rollback()?; + if let Some(io_err) = e.downcast_ref::() { + if io_err.kind() == io::ErrorKind::UnexpectedEof { + println!("connection closed: {}", self.stream.peer_addr()?); + return Err(e); + } } - Err(e) => { - return Err(e.into()); + println!("read error: {}", e); + return Err(e); + } + } + } + fn write(&mut self, response: &str) -> Result<()> { + match write_to_stream(&mut self.stream, &response) { + Ok(_) => { + return Ok(()); + } + Err(e) => { + self.rollback()?; + if let Some(io_err) = e.downcast_ref::() { + if io_err.kind() == io::ErrorKind::BrokenPipe { + println!("connection closed: {}", self.stream.peer_addr()?); + return Err(e); + } } + println!("write error: {}", e); + return Err(e); } - let mut buffer = vec![0u8; u32::from_be_bytes(size_buffer) as usize]; - self.stream.read_exact(&mut buffer)?; - let request = String::from_utf8(buffer)?; - - // parse - let mut iter = request.chars().peekable(); - let tokens = tokenize(&mut iter)?; - let mut parser = Parser::new(tokens); - let statement = parser.parse()?; + } + } + fn rollback(&mut self) -> Result<()> { + if let Some(txn_id) = self.current_txn_id { + self.instance + .write() + .map_err(|_| anyhow!("lock error"))? + .rollback(txn_id)?; + self.current_txn_id = None; + } + Ok(()) + } + fn execute(&mut self, query: &str) -> Result { + // parse + let mut iter = query.chars().peekable(); + let tokens = tokenize(&mut iter)?; + let mut parser = Parser::new(tokens); + let statement = parser.parse()?; - let response = match statement { - StatementAST::Begin => { - let txn_id = self - .instance - .read() - .map_err(|_| anyhow!("lock error"))? - .begin(self.current_txn_id)?; - self.current_txn_id = Some(txn_id); - format!("transaction started.") + let response = match statement { + StatementAST::Begin => { + let txn_id = self + .instance + .read() + .map_err(|_| anyhow!("lock error"))? + .begin(self.current_txn_id)?; + self.current_txn_id = Some(txn_id); + format!("transaction started.") + } + _ => { + let txn_id_existed = self.current_txn_id.is_some(); + if !txn_id_existed { + let txn_id = Some( + self.instance + .read() + .map_err(|_| anyhow!("lock error"))? + .begin(None)?, + ); + self.current_txn_id = txn_id; } - _ => { - let txn_id_existed = self.current_txn_id.is_some(); - if !txn_id_existed { - let txn_id = Some( - self.instance - .read() - .map_err(|_| anyhow!("lock error"))? - .begin(None)?, - ); - self.current_txn_id = txn_id; - } - let txn_id = self.current_txn_id.unwrap(); - let response = match statement { - StatementAST::Commit => { - self.instance - .write() - .map_err(|_| anyhow!("lock error"))? - .commit(txn_id)?; - self.current_txn_id = None; - format!("transaction committed.") - } - StatementAST::Rollback => { - self.instance - .write() - .map_err(|_| anyhow!("lock error"))? - .rollback(txn_id)?; - self.current_txn_id = None; - format!("transaction rolled back.") - } - StatementAST::CreateTable(ast) => { - self.instance - .write() - .map_err(|_| anyhow!("lock error"))? - .create_table(&ast, txn_id)?; - format!("table {} created", ast.table_name) - } - _ => { - let (rows, schema) = self - .instance - .write() - .map_err(|_| anyhow!("lock error"))? - .execute(&statement, txn_id)?; - - // TODO: move to client - let mut table_view = Table::new(); - let mut header = vec![]; - for column in schema.columns { - header.push(Cell::new(&column.name)); - } - table_view.set_titles(Row::new(header)); - for row in rows { - let cells = row - .iter() - .map(|v| Cell::new(&v.to_string())) - .collect::>(); - table_view.add_row(Row::new(cells)); - } - format!("{}", table_view) - } - }; - if !txn_id_existed { + let txn_id = self.current_txn_id.unwrap(); + let response = match statement { + StatementAST::Commit => { self.instance .write() .map_err(|_| anyhow!("lock error"))? .commit(txn_id)?; self.current_txn_id = None; + format!("transaction committed.") } - response - } - }; + StatementAST::Rollback => { + self.instance + .write() + .map_err(|_| anyhow!("lock error"))? + .rollback(txn_id)?; + self.current_txn_id = None; + format!("transaction rolled back.") + } + StatementAST::CreateTable(ast) => { + self.instance + .write() + .map_err(|_| anyhow!("lock error"))? + .create_table(&ast, txn_id)?; + format!("table {} created", ast.table_name) + } + _ => { + let (rows, schema) = self + .instance + .write() + .map_err(|_| anyhow!("lock error"))? + .execute(&statement, txn_id)?; - self.stream.write(&(response.len() as u32).to_be_bytes())?; - self.stream.write_all(response.as_bytes())?; - self.stream.flush()?; - } + // TODO: move to client + let mut table_view = Table::new(); + let mut header = vec![]; + for column in schema.columns { + header.push(Cell::new(&column.name)); + } + table_view.set_titles(Row::new(header)); + for row in rows { + let cells = row + .iter() + .map(|v| Cell::new(&v.to_string())) + .collect::>(); + table_view.add_row(Row::new(cells)); + } + format!("{}", table_view) + } + }; + if !txn_id_existed { + self.instance + .write() + .map_err(|_| anyhow!("lock error"))? + .commit(txn_id)?; + self.current_txn_id = None; + } + response + } + }; + Ok(response) } } + +pub fn write_to_stream(stream: &mut TcpStream, response: &str) -> Result<()> { + stream.write(&(response.len() as u32).to_be_bytes())?; + stream.write_all(response.as_bytes())?; + stream.flush()?; + Ok(()) +} + +pub fn read_from_stream(stream: &mut TcpStream) -> Result { + let mut size_buffer = [0u8; 4]; + stream.read_exact(&mut size_buffer)?; + let mut buffer = vec![0u8; u32::from_be_bytes(size_buffer) as usize]; + stream.read_exact(&mut buffer)?; + let response = String::from_utf8(buffer)?; + Ok(response) +}