diff --git a/Cargo.toml b/Cargo.toml index 8642dd1..abe48b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,4 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" base64 = "0.21.5" rand = "0.8.5" +async-trait = "0.1.75" diff --git a/src/handshake.rs b/src/handshake.rs index 47349d2..78726e3 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -1,45 +1,57 @@ // Yeahbut December 2023 -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +pub mod serverbound { -use crate::mc_types::{self, Result}; + use tokio::net::tcp::OwnedReadHalf; -pub struct Handshake { - pub protocol_version: i32, - pub server_address: String, - pub server_port: u16, - pub next_state: i32, -} - -pub async fn read_handshake(stream: &mut OwnedReadHalf) -> Result { - let mut data = mc_types::read_packet(stream).await?; - let _packet_id = mc_types::get_var_int(&mut data); - Ok(get_handshake(&mut data)?) -} - -pub fn get_handshake(data: &mut Vec) -> Result { - Ok(Handshake { - protocol_version: mc_types::get_var_int(data)?, - server_address: mc_types::get_string(data)?, - server_port: mc_types::get_u16(data), - next_state: mc_types::get_var_int(data)?, - }) -} - -pub fn convert_handshake(handshake: Handshake) -> Vec { - let mut data: Vec = vec![0]; - data.append(&mut mc_types::convert_var_int(handshake.protocol_version)); - data.append(&mut mc_types::convert_string(&handshake.server_address)); - data.append(&mut mc_types::convert_u16(handshake.server_port)); - data.append(&mut mc_types::convert_var_int(handshake.next_state)); - - data -} - -pub async fn write_handshake( - stream: &mut OwnedWriteHalf, - handshake: Handshake, -) -> Result<()> { - mc_types::write_packet(stream, &mut convert_handshake(handshake)).await?; - Ok(()) + use crate::mc_types::{self, Result, Packet, PacketError}; + + enum HandshakeEnum { + Handshake(Handshake), + } + + impl HandshakeEnum { + pub async fn read(stream: &mut OwnedReadHalf) -> Result { + let mut data = mc_types::read_data(stream).await?; + let packet_id = mc_types::get_var_int(&mut data)?; + if packet_id == Handshake::packet_id() { + return Ok(Self::Handshake(Handshake::get(&mut data)?)) + } else { + return Err(Box::new(PacketError::InvalidPacketId)) + } + } + } + + pub struct Handshake { + pub protocol_version: i32, + pub server_address: String, + pub server_port: u16, + pub next_state: i32, + } + + impl Packet for Handshake { + + fn packet_id() -> i32 {0} + + fn get(data: &mut Vec) -> Result { + Ok(Self { + protocol_version: mc_types::get_var_int(data)?, + server_address: mc_types::get_string(data)?, + server_port: mc_types::get_u16(data), + next_state: mc_types::get_var_int(data)?, + }) + } + + fn convert(&self) -> Vec { + let mut data: Vec = vec![]; + data.append(&mut mc_types::convert_var_int(Self::packet_id())); + data.append(&mut mc_types::convert_var_int(self.protocol_version)); + data.append(&mut mc_types::convert_string(&self.server_address)); + data.append(&mut mc_types::convert_u16(self.server_port)); + data.append(&mut mc_types::convert_var_int(self.next_state)); + + data + } + + } } diff --git a/src/login.rs b/src/login.rs index 7f7bc1a..7ea06f3 100644 --- a/src/login.rs +++ b/src/login.rs @@ -1,22 +1,49 @@ // Yeahbut December 2023 -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +pub mod clientbound { -use crate::mc_types::{self, Result}; + use tokio::net::tcp::OwnedReadHalf; -pub fn convert_clientbound_disconnect(reason: String) -> Vec { - let mut data: Vec = vec![0]; - data.append(&mut &mut mc_types::convert_string(&reason)); + use crate::mc_types::{self, Result, Packet, PacketError}; + + enum Login { + Disconnect(Disconnect), + } + + impl Login { + pub async fn read(stream: &mut OwnedReadHalf) -> Result { + let mut data = mc_types::read_data(stream).await?; + let packet_id = mc_types::get_var_int(&mut data)?; + if packet_id == Disconnect::packet_id() { + return Ok(Self::Disconnect(Disconnect::get(&mut data)?)) + } else { + return Err(Box::new(PacketError::InvalidPacketId)) + } + } + } + + pub struct Disconnect { + pub reason: String + } + + impl Packet for Disconnect { + + fn packet_id() -> i32 {0} + + fn get(mut data: &mut Vec) -> Result { + Ok(Self { + reason: mc_types::get_string(&mut data)? + }) + } + + fn convert(&self) -> Vec { + let mut data: Vec = vec![]; + data.append(&mut mc_types::convert_var_int(Self::packet_id())); + data.append(&mut mc_types::convert_string(&self.reason)); + + data + } + + } - data -} -pub async fn write_clientbound_disconnect( - stream: &mut OwnedWriteHalf, - reason: String, -) -> Result<()> { - mc_types::write_packet( - stream, - &mut convert_clientbound_disconnect(reason), - ).await?; - Ok(()) } diff --git a/src/main.rs b/src/main.rs index b89eff6..eda0438 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,8 @@ mod handshake; mod status; mod login; +use mc_types::Packet; + #[tokio::main] async fn main() -> Result<(), Box> { let listener = TcpListener::bind("127.0.0.1:25565").await?; @@ -52,8 +54,9 @@ async fn handle_client(client_socket: TcpStream) { .await.expect("Error handling legacy status request"); return; } else { - let handshake_packet = handshake::read_handshake(&mut client_reader) - .await.expect("Error reading handshake packet"); + let handshake_packet = + handshake::serverbound::Handshake::read(&mut client_reader) + .await.expect("Error reading handshake packet"); println!("Next state: {}", handshake_packet.next_state); if handshake_packet.next_state == 1 { println!("Receiving Status Request"); @@ -67,15 +70,15 @@ async fn handle_client(client_socket: TcpStream) { } else if handshake_packet.next_state == 2 { match server_writer { Some(mut server_writer) => { - handshake::write_handshake( - &mut server_writer, - handshake::Handshake { - protocol_version: mc_types::VERSION_PROTOCOL, - server_address: "localhost".to_string(), - server_port: 25565, - next_state: 2, - }, - ).await.expect("Error logging into backend server"); + handshake::serverbound::Handshake { + protocol_version: mc_types::VERSION_PROTOCOL, + server_address: "localhost".to_string(), + server_port: 25565, + next_state: 2, + } + .write(&mut server_writer) + .await + .expect("Error logging into backend server"); // Forward from client to backend tokio::spawn(async move { @@ -98,10 +101,13 @@ async fn handle_client(client_socket: TcpStream) { }; }, None => { - login::write_clientbound_disconnect( - &mut client_writer, - "\"Server Error (Server may be starting)\"".to_string(), - ).await.expect("Error sending disconnect on: \ + login::clientbound::Disconnect { + reason: "\"Server Error (Server may be starting)\"" + .to_string() + } + .write(&mut client_writer) + .await + .expect("Error sending disconnect on: \ Failed to connect to the backend server"); } }; diff --git a/src/mc_types.rs b/src/mc_types.rs index 0c21620..621d8db 100644 --- a/src/mc_types.rs +++ b/src/mc_types.rs @@ -6,6 +6,7 @@ use std::fmt; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use serde::{Serialize, Deserialize}; +use async_trait::async_trait; pub type Result = std::result::Result>; @@ -15,6 +16,22 @@ pub const VERSION_PROTOCOL: i32 = 762; const SEGMENT_BITS: u8 = 0x7F; const CONTINUE_BIT: u8 = 0x80; +#[derive(Debug)] +pub enum PacketError { + InvalidPacketId, +} + +impl fmt::Display for PacketError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PacketError::InvalidPacketId => + write!(f, "Invalid packet id"), + } + } +} + +impl Error for PacketError {} + #[derive(Debug)] pub enum VarIntError { ValueTooLarge, @@ -39,7 +56,29 @@ pub struct Chat { pub text: String, } -pub async fn read_packet(stream: &mut OwnedReadHalf) -> Result> { + +#[async_trait] +pub trait Packet: Sized { + fn packet_id() -> i32; + fn get(data: &mut Vec) -> Result; + fn convert(&self) -> Vec; + + async fn read(stream: &mut OwnedReadHalf) -> Result { + let mut data = read_data(stream).await?; + let packet_id = get_var_int(&mut data)?; + if packet_id == Self::packet_id() { + return Ok(Self::get(&mut data)?) + } else { + return Err(Box::new(PacketError::InvalidPacketId)) + } + } + + async fn write(&self, stream: &mut OwnedWriteHalf) -> Result<()> { + write_data(stream, &mut self.convert()).await + } +} + +pub async fn read_data(stream: &mut OwnedReadHalf) -> Result> { let length = read_var_int_stream(stream).await? as usize; let mut buffer: Vec = vec![0; length]; @@ -47,7 +86,7 @@ pub async fn read_packet(stream: &mut OwnedReadHalf) -> Result> { Ok(buffer) } -pub async fn write_packet( +pub async fn write_data( stream: &mut OwnedWriteHalf, data: &mut Vec, ) -> Result<()> { diff --git a/src/status.rs b/src/status.rs index 787b633..52f1e75 100644 --- a/src/status.rs +++ b/src/status.rs @@ -10,7 +10,7 @@ use serde_json::Value; use base64::{Engine as _, engine::general_purpose}; use rand::Rng; -use crate::mc_types::{self, Result}; +use crate::mc_types::{self, Result, Packet}; use crate::handshake; #[derive(Serialize, Deserialize)] @@ -122,7 +122,7 @@ pub async fn respond_status( )-> Result<()> { loop { println!("Status Handling"); - let mut data = mc_types::read_packet(client_reader).await?; + let mut data = mc_types::read_data(client_reader).await?; let packet_id = mc_types::get_var_int(&mut data)?; println!("Status Packet ID: {}", packet_id); @@ -188,12 +188,12 @@ pub async fn respond_status( let mut out_data: Vec = vec![0]; out_data.append(&mut mc_types::convert_string(&json)); - mc_types::write_packet(client_writer, &mut out_data).await?; + mc_types::write_data(client_writer, &mut out_data).await?; } else if packet_id == 0x01 { println!("Handling Ping"); let mut out_data: Vec = vec![1]; out_data.append(&mut data); - mc_types::write_packet(client_writer, &mut out_data).await?; + mc_types::write_data(client_writer, &mut out_data).await?; break; } else { break; @@ -206,14 +206,14 @@ pub async fn get_upstream_status( server_reader: &mut OwnedReadHalf, server_writer: &mut OwnedWriteHalf, ) -> Result { - handshake::write_handshake(server_writer, handshake::Handshake{ + handshake::serverbound::Handshake{ protocol_version: mc_types::VERSION_PROTOCOL, server_address: "localhost".to_string(), server_port: 25565, next_state: 1, - }).await?; - mc_types::write_packet(server_writer, &mut vec![0]).await?; - let mut data = mc_types::read_packet(server_reader).await?; + }.write(server_writer).await?; + mc_types::write_data(server_writer, &mut vec![0]).await?; + let mut data = mc_types::read_data(server_reader).await?; mc_types::get_u8(&mut data); let json = mc_types::get_string(&mut data)?;