diff --git a/src/handshake.rs b/src/handshake.rs index 62815a3..47349d2 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -2,7 +2,7 @@ use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use crate::mc_types; +use crate::mc_types::{self, Result}; pub struct Handshake { pub protocol_version: i32, @@ -11,19 +11,19 @@ pub struct Handshake { pub next_state: i32, } -pub async fn read_handshake(stream: &mut OwnedReadHalf) -> Handshake { - let mut data = mc_types::read_packet(stream).await; +pub async fn read_handshake(stream: &mut OwnedReadHalf) -> Result { + let mut data = mc_types::read_packet(stream).await?; let _packet_id = mc_types::get_var_int(&mut data); - get_handshake(&mut data) + Ok(get_handshake(&mut data)?) } -pub fn get_handshake(data: &mut Vec) -> Handshake { - Handshake { - protocol_version: mc_types::get_var_int(data), - server_address: mc_types::get_string(data), +pub fn get_handshake(data: &mut Vec) -> Result { + Ok(Handshake { + protocol_version: mc_types::get_var_int(data)?, + server_address: mc_types::get_string(data)?, server_port: mc_types::get_u16(data), - next_state: mc_types::get_var_int(data), - } + next_state: mc_types::get_var_int(data)?, + }) } pub fn convert_handshake(handshake: Handshake) -> Vec { @@ -39,6 +39,7 @@ pub fn convert_handshake(handshake: Handshake) -> Vec { pub async fn write_handshake( stream: &mut OwnedWriteHalf, handshake: Handshake, -) { - mc_types::write_packet(stream, &mut convert_handshake(handshake)).await; +) -> Result<()> { + mc_types::write_packet(stream, &mut convert_handshake(handshake)).await?; + Ok(()) } diff --git a/src/main.rs b/src/main.rs index 53282f2..5f841b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,55 +26,56 @@ async fn handle_client(client_socket: TcpStream) { let backend_addr = "127.0.0.1:25566"; let (mut client_reader, mut client_writer) = client_socket.into_split(); - if let Ok(backend_socket) = TcpStream::connect(backend_addr).await { - let (mut server_reader, mut server_writer) = - backend_socket.into_split(); - let mut buffer: [u8; 1] = [0; 1]; - client_reader.peek(&mut buffer) - .await.expect("Error reading from stream"); - let packet_id: u8 = buffer[0]; - if packet_id == 0xFE { - status::respond_legacy_status(&mut client_writer).await; - return; - } else { - let handshake_packet = handshake::read_handshake(&mut client_reader) - .await; - println!("Next state: {}", handshake_packet.next_state); - if handshake_packet.next_state == 1 { - println!("Receiving Status Request"); - status::respond_status( - &mut client_reader, - &mut client_writer, - &mut server_reader, - &mut server_writer, - ).await; - return; - } else if handshake_packet.next_state == 2 { - handshake::write_handshake(&mut server_writer, handshake::Handshake{ - protocol_version: mc_types::VERSION_PROTOCOL, - server_address: "localhost".to_string(), - server_port: 25565, - next_state: 2, - }).await; - } else { - return; - } - } + let backend_socket = TcpStream::connect(backend_addr) + .await.expect("Failed to connect to the backend server"); - // Forward from client to backend - tokio::spawn(async move { - io::copy(&mut client_reader, &mut server_writer) - .await.expect("Error copying from client to backend"); - }); + let (mut server_reader, mut server_writer) = backend_socket.into_split(); + let mut buffer: [u8; 1] = [0; 1]; + client_reader.peek(&mut buffer) + .await.expect("Failed to peek at first byte from stream"); + let packet_id: u8 = buffer[0]; - // Forward from backend to client - tokio::spawn(async move { - io::copy(&mut server_reader, &mut client_writer) - .await.expect("Error copying from backend to client"); - }); + if packet_id == 0xFE { + status::respond_legacy_status(&mut client_writer) + .await.expect("Error handling legacy status request"); + return; } else { - eprintln!("Failed to connect to the backend server"); + let handshake_packet = handshake::read_handshake(&mut client_reader) + .await.expect("Error reading handshake packet"); + println!("Next state: {}", handshake_packet.next_state); + if handshake_packet.next_state == 1 { + println!("Receiving Status Request"); + status::respond_status( + &mut client_reader, + &mut client_writer, + &mut server_reader, + &mut server_writer, + ).await.expect("Error handling status request"); + return; + } else if handshake_packet.next_state == 2 { + handshake::write_handshake(&mut server_writer, handshake::Handshake{ + protocol_version: mc_types::VERSION_PROTOCOL, + server_address: "localhost".to_string(), + server_port: 25565, + next_state: 2, + }).await.expect("Error logging into backend server"); + } else { + return; + } } + + // Forward from client to backend + tokio::spawn(async move { + io::copy(&mut client_reader, &mut server_writer) + .await.expect("Error copying from client to backend"); + }); + + // Forward from backend to client + tokio::spawn(async move { + io::copy(&mut server_reader, &mut client_writer) + .await.expect("Error copying from backend to client"); + }); + println!("Connection Closed"); } diff --git a/src/mc_types.rs b/src/mc_types.rs index 08ea135..0c21620 100644 --- a/src/mc_types.rs +++ b/src/mc_types.rs @@ -1,39 +1,68 @@ // Yeahbut December 2023 +use std::error::Error; +use std::fmt; + use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use serde::{Serialize, Deserialize}; +pub type Result = std::result::Result>; + pub const VERSION_NAME: &str = "1.19.4"; pub const VERSION_PROTOCOL: i32 = 762; const SEGMENT_BITS: u8 = 0x7F; const CONTINUE_BIT: u8 = 0x80; +#[derive(Debug)] +pub enum VarIntError { + ValueTooLarge, + RanOutOfBytes, +} + +impl fmt::Display for VarIntError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VarIntError::ValueTooLarge => + write!(f, "VarInt value is too large"), + VarIntError::RanOutOfBytes => + write!(f, "Ran out of bytes while reading VarInt"), + } + } +} + +impl Error for VarIntError {} + #[derive(Serialize, Deserialize)] pub struct Chat { pub text: String, } -pub async fn read_packet(stream: &mut OwnedReadHalf) -> Vec { - let length = read_var_int_stream(stream).await; - let mut buffer: Vec = vec![0; length as usize]; - stream.read_exact(&mut buffer) - .await.expect("Error reading string from stream"); - buffer +pub async fn read_packet(stream: &mut OwnedReadHalf) -> Result> { + let length = read_var_int_stream(stream).await? as usize; + + let mut buffer: Vec = vec![0; length]; + stream.read_exact(&mut buffer).await?; + + Ok(buffer) } -pub async fn write_packet(stream: &mut OwnedWriteHalf, data: &mut Vec) { +pub async fn write_packet( + stream: &mut OwnedWriteHalf, + data: &mut Vec, +) -> Result<()> { let mut out_data = convert_var_int(data.len() as i32); out_data.append(data); - stream.write_all(&out_data) - .await.expect("Error writing to stream"); + + stream.write_all(&out_data).await?; + + Ok(()) } -async fn read_var_int_stream(stream: &mut OwnedReadHalf) -> i32 { +async fn read_var_int_stream(stream: &mut OwnedReadHalf) -> Result { let mut data: Vec = vec![]; loop { - let current_byte = stream.read_u8() - .await.expect("Error reading from stream"); + let current_byte = stream.read_u8().await?; data.append(&mut vec![current_byte]); @@ -42,7 +71,9 @@ async fn read_var_int_stream(stream: &mut OwnedReadHalf) -> i32 { } } - get_var_int(&mut data) + let varint = get_var_int(&mut data)?; + + Ok(varint) } pub fn get_bool(data: &mut Vec) -> bool { @@ -150,13 +181,31 @@ pub fn convert_f64(value: f64) -> Vec { convert_u64(value as u64) } -pub fn get_var_int(data: &mut Vec) -> i32 { - let mut value: i32 = 0; - let mut position: u32 = 0; +pub fn get_var_int(data: &mut Vec) -> Result { + Ok(get_var(data, 32)? as i32) +} +pub fn convert_var_int(value: i32) -> Vec { + convert_var(value as i64) +} + +pub fn get_var_long(data: &mut Vec) -> Result { + get_var(data, 64) +} +pub fn convert_var_long(value: i64) -> Vec { + convert_var(value) +} + +fn get_var(data: &mut Vec, size: u8) -> Result { + let mut value: i64 = 0; + let mut position: u8 = 0; loop { + if data.is_empty() { + return Err(Box::new(VarIntError::RanOutOfBytes)); + } + let current_byte = data.remove(0); - value |= ((current_byte & SEGMENT_BITS) as i32) << position; + value |= ((current_byte & SEGMENT_BITS) as i64) << position; if (current_byte & CONTINUE_BIT) == 0 { break; @@ -164,31 +213,31 @@ pub fn get_var_int(data: &mut Vec) -> i32 { position += 7; - if position >= 32 { - eprintln!("VarInt is too big"); + if position >= size { + return Err(Box::new(VarIntError::ValueTooLarge)); } } - value + Ok(value) } -pub fn convert_var_int(mut value: i32) -> Vec { +fn convert_var(mut value: i64) -> Vec { let mut data: Vec = vec![]; loop { - if (value & !(SEGMENT_BITS as i32)) == 0 { + if (value & !(SEGMENT_BITS as i64)) == 0 { data.append(&mut vec![value as u8]); return data; } data.append( - &mut vec![(value & (SEGMENT_BITS as i32)) as u8 | CONTINUE_BIT]); + &mut vec![(value & (SEGMENT_BITS as i64)) as u8 | CONTINUE_BIT]); value >>= 7; } } -pub fn get_string(data: &mut Vec) -> String { - let length = get_var_int(data) as usize; +pub fn get_string(data: &mut Vec) -> Result { + let length = get_var_int(data)? as usize; let buffer = data[..length].to_vec(); for _ in 0..length { data.remove(0); } - String::from_utf8_lossy(&buffer).to_string() + Ok(String::from_utf8_lossy(&buffer).to_string()) } pub fn convert_string(s: &str) -> Vec { let length = s.len() as i32; diff --git a/src/status.rs b/src/status.rs index c18b2af..6e203f2 100644 --- a/src/status.rs +++ b/src/status.rs @@ -10,7 +10,7 @@ use serde_json::Value; use base64::{Engine as _, engine::general_purpose}; use rand::Rng; -use crate::mc_types; +use crate::mc_types::{self, Result}; use crate::handshake; #[derive(Serialize, Deserialize)] @@ -49,8 +49,8 @@ pub struct StatusResponseData { async fn online_players( server_reader: &mut OwnedReadHalf, server_writer: &mut OwnedWriteHalf, -) -> StatusPlayers { - get_upstream_status(server_reader, server_writer).await.players +) -> Result { + Ok(get_upstream_status(server_reader, server_writer).await?.players) } fn motd() -> String { @@ -83,6 +83,8 @@ fn motd() -> String { None => return default, }; + // TODO: Birthdays, Holidays, and Announcements + let line2: &str = match motd_data["line2"][rand2].as_str() { Some(s) => s, None => return default, @@ -117,17 +119,18 @@ pub async fn respond_status( client_writer: &mut OwnedWriteHalf, server_reader: &mut OwnedReadHalf, server_writer: &mut OwnedWriteHalf, -) { +)-> Result<()> { loop { println!("Status Handling"); - let mut data = mc_types::read_packet(client_reader).await; - let packet_id = mc_types::get_var_int(&mut data); + let mut data = mc_types::read_packet(client_reader).await?; + let packet_id = mc_types::get_var_int(&mut data)?; println!("Status Packet ID: {}", packet_id); if packet_id == 0x00 { println!("Handling Status"); - let online_players = online_players(server_reader, server_writer).await; + let online_players = + online_players(server_reader, server_writer).await?; let status_response = StatusResponseData { version: StatusVersion { name: mc_types::VERSION_NAME.to_string(), @@ -148,70 +151,53 @@ pub async fn respond_status( // previewsChat: Some(false), }; - let json_result = serde_json::to_string(&status_response); + let json = serde_json::to_string(&status_response)?; - match json_result { - Ok(json) => { - let mut out_data: Vec = vec![0]; - out_data.append(&mut mc_types::convert_string(&json)); - mc_types::write_packet(client_writer, &mut out_data).await; - }, - Err(err) => { - eprintln!("Error serializing to JSON: {}", err); - break; - }, - } + let mut out_data: Vec = vec![0]; + out_data.append(&mut mc_types::convert_string(&json)); + mc_types::write_packet(client_writer, &mut out_data).await?; } else if packet_id == 0x01 { println!("Handling Ping"); let mut out_data: Vec = vec![1]; out_data.append(&mut data); - mc_types::write_packet(client_writer, &mut out_data).await; + mc_types::write_packet(client_writer, &mut out_data).await?; break; } else { break; } } + Ok(()) } pub async fn get_upstream_status( server_reader: &mut OwnedReadHalf, server_writer: &mut OwnedWriteHalf, -) -> StatusResponseData { +) -> Result { handshake::write_handshake(server_writer, handshake::Handshake{ protocol_version: mc_types::VERSION_PROTOCOL, server_address: "localhost".to_string(), server_port: 25565, next_state: 1, - }).await; - mc_types::write_packet(server_writer, &mut vec![0]).await; - let mut data = mc_types::read_packet(server_reader).await; + }).await?; + mc_types::write_packet(server_writer, &mut vec![0]).await?; + let mut data = mc_types::read_packet(server_reader).await?; mc_types::get_u8(&mut data); - let json = mc_types::get_string(&mut data); - let status_response: StatusResponseData = serde_json::from_str(&json) - .expect("Error parsing JSON"); + let json = mc_types::get_string(&mut data)?; + let status_response: StatusResponseData = serde_json::from_str(&json)?; // let mut out_data: Vec = vec![1]; // out_data.append(&mut mc_types::convert_i64(0)); - // mc_types::write_packet(server_writer, &mut out_data).await; + // mc_types::write_packet(server_writer, &mut out_data).await?; - status_response + Ok(status_response) } -pub async fn respond_legacy_status(client_writer: &mut OwnedWriteHalf) { +pub async fn respond_legacy_status( + client_writer: &mut OwnedWriteHalf, +) -> Result<()> { println!("Old Style Status"); - client_writer.write_u8(0xFF) - .await.expect("Error writing to stream"); - - // let s = "§1\0127\01.12.2\0YTD Proxy§0§10"; - // println!("String length: {}", s.len()); - // client_writer.write_u16(s.len() as u16) - // .await.expect("Error writing to stream"); - // let utf16_bytes: Vec = s.encode_utf16().collect(); - // for utf16_char in utf16_bytes { - // client_writer.write_u16(utf16_char) - // .await.expect("Error writing to stream"); - // } + client_writer.write_u8(0xFF).await?; let s = "§1\0127\0".to_string() + mc_types::VERSION_NAME + @@ -220,22 +206,11 @@ pub async fn respond_legacy_status(client_writer: &mut OwnedWriteHalf) { .encode_utf16() .flat_map(|c| std::iter::once(c).chain(std::iter::once(0))) .collect(); - println!("String length: {}", (utf16_vec.len() / 2)); - client_writer.write_u16((utf16_vec.len() / 2) as u16) - .await.expect("Error writing to stream"); + + client_writer.write_u16((utf16_vec.len() / 2) as u16).await?; for utf16_char in utf16_vec { - client_writer.write_u16(utf16_char) - .await.expect("Error writing to stream"); + client_writer.write_u16(utf16_char).await?; } - // let s = b"\x00\xa7\x001\x00\x00\x001\x002\x007\x00\x00\x001\x00.\x001\x0 - // 02\x00.\x002\x00\x00\x00Y\x00T\x00D\x00 \x00P\x00r\x00o\x00x\x00y\x00 - // \xa7\x000\xa7\x001\x000"; - // println!("String length: {}", s.len()); - // client_writer.write_u16(25) - // .await.expect("Error writing to stream"); - // for b in s { - // client_writer.write_u8(b) - // .await.expect("Error writing to stream"); - // } + Ok(()) }