diff --git a/src/handshake.rs b/src/handshake.rs index cbdff5d..8a1ad5e 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -9,8 +9,8 @@ pub mod serverbound { } impl HandshakeEnum { - pub async fn read( - conn: &mut mc_types::ProtocolConnection<'_>, + pub async fn read( + conn: &mut T, ) -> Result { let mut data = conn.read_data().await?; let packet_id = mc_types::get_var_int(&mut data)?; diff --git a/src/login.rs b/src/login.rs index 5784eb4..9c88d6b 100644 --- a/src/login.rs +++ b/src/login.rs @@ -13,8 +13,8 @@ pub mod clientbound { } impl Login { - pub async fn read( - conn: &mut mc_types::ProtocolConnection<'_>, + pub async fn read( + conn: &mut T, ) -> Result { let mut data = conn.read_data().await?; let packet_id = mc_types::get_var_int(&mut data)?; @@ -241,8 +241,8 @@ pub mod serverbound { } impl Login { - pub async fn read( - conn: &mut mc_types::ProtocolConnection<'_>, + pub async fn read( + conn: &mut T, ) -> Result { let mut data = conn.read_data().await?; let packet_id = mc_types::get_var_int(&mut data)?; diff --git a/src/mc_types.rs b/src/mc_types.rs index e07541b..d6a0f51 100644 --- a/src/mc_types.rs +++ b/src/mc_types.rs @@ -53,6 +53,16 @@ pub struct Chat { pub text: String, } +#[async_trait] +pub trait ProtocolRead { + async fn read_data(&mut self) -> Result>; +} + +#[async_trait] +pub trait ProtocolWrite { + async fn write_data(&mut self, data: &mut Vec) -> Result<()>; +} + pub struct ProtocolConnection<'a> { pub stream_read: &'a mut OwnedReadHalf, pub stream_write: &'a mut OwnedWriteHalf, @@ -77,74 +87,6 @@ impl<'a> ProtocolConnection<'a> { } } - pub async fn read_data(&mut self) -> Result> { - match self.aes_encryption_key { - Some(aes_key) => { - let mut buffer: Vec = vec![0; 16]; - self.stream_read.read_exact(&mut buffer).await?; - buffer = encrypt::decrypt_aes( - &aes_key, buffer[0..16].try_into().unwrap()); - let raw_length = read_var_int_vec(&mut buffer)?; - let length = - if (raw_length - buffer.len() as i32) % 16 == 0 { - (raw_length - buffer.len() as i32) / 16 - } else { - ((raw_length - buffer.len() as i32) / 16) + 1 - }; - - for _ in 0..length { - let mut block: Vec = vec![0; 16]; - self.stream_read.read_exact(&mut block).await?; - buffer.append(&mut block); - } - - Ok(buffer) - }, - None => { - let length = read_var_int_stream( - self.stream_read).await? as usize; - - let mut buffer: Vec = vec![0; length]; - self.stream_read.read_exact(&mut buffer).await?; - - Ok(buffer) - } - } - } - - pub async fn write_data( - &mut self, - data: &mut Vec, - ) -> Result<()> { - let mut out_data = convert_var_int(data.len() as i32); - out_data.append(data); - match self.aes_encryption_key { - Some(aes_key) => { - let length = - if (data.len() as i32) % 16 == 0 { - (data.len() as i32) / 16 - } else { - ((data.len() as i32) / 16) + 1 - }; - - for _ in 0..length { - let mut block: Vec = out_data[0..16].to_vec(); - block = encrypt::encrypt_aes( - &aes_key, block[0..16].try_into().unwrap()); - self.stream_write.write_all(&block).await?; - } - - - Ok(()) - }, - None => { - self.stream_write.write_all(&out_data).await?; - - Ok(()) - } - } - } - pub fn create_encryption_request( &mut self, private_key: RsaPrivateKey, @@ -248,6 +190,213 @@ impl<'a> ProtocolConnection<'a> { }; } } + + pub fn split_conn( + &mut self + ) -> (WriteHaftProtocolConnection, ReadHaftProtocolConnection) { + (WriteHaftProtocolConnection { + stream_write: self.stream_write, + aes_encryption_key: self.aes_encryption_key, + }, + ReadHaftProtocolConnection { + stream_read: self.stream_read, + aes_encryption_key: self.aes_encryption_key, + }) + } +} + +unsafe impl<'a> Send for ProtocolConnection<'a> {} + +#[async_trait] +impl<'a> ProtocolRead for ProtocolConnection<'a> { + async fn read_data(&mut self) -> Result> { + match self.aes_encryption_key { + Some(aes_key) => { + let mut buffer: Vec = vec![0; 16]; + self.stream_read.read_exact(&mut buffer).await?; + buffer = encrypt::decrypt_aes( + &aes_key, buffer[0..16].try_into().unwrap()); + let raw_length = read_var_int_vec(&mut buffer)?; + let length = + if (raw_length - buffer.len() as i32) % 16 == 0 { + (raw_length - buffer.len() as i32) / 16 + } else { + ((raw_length - buffer.len() as i32) / 16) + 1 + }; + + for _ in 0..length { + let mut block: Vec = vec![0; 16]; + self.stream_read.read_exact(&mut block).await?; + buffer.append(&mut block); + } + + Ok(buffer) + }, + None => { + let length = read_var_int_stream( + self.stream_read).await? as usize; + + let mut buffer: Vec = vec![0; length]; + self.stream_read.read_exact(&mut buffer).await?; + + Ok(buffer) + } + } + } +} + +#[async_trait] +impl<'a> ProtocolWrite for ProtocolConnection<'a> { + async fn write_data(&mut self, data: &mut Vec) -> Result<()> { + let mut out_data = convert_var_int(data.len() as i32); + out_data.append(data); + match self.aes_encryption_key { + Some(aes_key) => { + let length = + if (data.len() as i32) % 16 == 0 { + (data.len() as i32) / 16 + } else { + ((data.len() as i32) / 16) + 1 + }; + + for _ in 0..length { + let mut block: Vec = out_data[0..16].to_vec(); + block = encrypt::encrypt_aes( + &aes_key, block[0..16].try_into().unwrap()); + self.stream_write.write_all(&block).await?; + } + + Ok(()) + }, + None => { + self.stream_write.write_all(&out_data).await?; + + Ok(()) + } + } + } +} + +pub struct WriteHaftProtocolConnection<'a> { + pub stream_write: &'a mut OwnedWriteHalf, + aes_encryption_key: Option<[u8; 16]>, +} + +impl<'a> WriteHaftProtocolConnection<'a> { + pub fn new( + stream_write: &'a mut OwnedWriteHalf, + ) -> Self { + WriteHaftProtocolConnection { + stream_write, + aes_encryption_key: None, + } + } +} + +unsafe impl<'a> Send for WriteHaftProtocolConnection<'a> {} + +#[async_trait] +impl<'a> ProtocolWrite for WriteHaftProtocolConnection<'a> { + async fn write_data( + &mut self, + data: &mut Vec, + ) -> Result<()> { + let mut out_data = convert_var_int(data.len() as i32); + out_data.append(data); + match self.aes_encryption_key { + Some(aes_key) => { + let length = + if (data.len() as i32) % 16 == 0 { + (data.len() as i32) / 16 + } else { + ((data.len() as i32) / 16) + 1 + }; + + for _ in 0..length { + let mut block: Vec = out_data[0..16].to_vec(); + block = encrypt::encrypt_aes( + &aes_key, block[0..16].try_into().unwrap()); + self.stream_write.write_all(&block).await?; + } + + + Ok(()) + }, + None => { + self.stream_write.write_all(&out_data).await?; + + Ok(()) + } + } + } +} + +pub struct ReadHaftProtocolConnection<'a> { + pub stream_read: &'a mut OwnedReadHalf, + aes_encryption_key: Option<[u8; 16]>, +} + +impl<'a> ReadHaftProtocolConnection<'a> { + pub fn new( + stream_read: &'a mut OwnedReadHalf, + ) -> Self { + ReadHaftProtocolConnection { + stream_read, + aes_encryption_key: None, + } + } + + pub async fn forward_play( + &mut self, + other: &mut T, + ) -> Result<()> { + loop { + let packet = Play::read(self).await?; + match packet { + Play::PlayPacket(packet) => packet.write(other).await?, + }; + } + } +} + +unsafe impl<'a> Send for ReadHaftProtocolConnection<'a> {} + +#[async_trait] +impl<'a> ProtocolRead for ReadHaftProtocolConnection<'a> { + async fn read_data(&mut self) -> Result> { + match self.aes_encryption_key { + Some(aes_key) => { + let mut buffer: Vec = vec![0; 16]; + self.stream_read.read_exact(&mut buffer).await?; + buffer = encrypt::decrypt_aes( + &aes_key, buffer[0..16].try_into().unwrap()); + let raw_length = read_var_int_vec(&mut buffer)?; + let length = + if (raw_length - buffer.len() as i32) % 16 == 0 { + (raw_length - buffer.len() as i32) / 16 + } else { + ((raw_length - buffer.len() as i32) / 16) + 1 + }; + + for _ in 0..length { + let mut block: Vec = vec![0; 16]; + self.stream_read.read_exact(&mut block).await?; + buffer.append(&mut block); + } + + Ok(buffer) + }, + None => { + let length = read_var_int_stream( + self.stream_read).await? as usize; + + let mut buffer: Vec = vec![0; length]; + self.stream_read.read_exact(&mut buffer).await?; + + Ok(buffer) + } + } + } } #[async_trait] @@ -256,7 +405,7 @@ pub trait Packet: Sized { fn get(data: &mut Vec) -> Result; fn convert(&self) -> Vec; - async fn read(conn: &mut ProtocolConnection<'_>) -> Result { + async fn read(conn: &mut T) -> Result { let mut data = conn.read_data().await?; let packet_id = get_var_int(&mut data)?; if packet_id == Self::packet_id() { @@ -266,7 +415,7 @@ pub trait Packet: Sized { } } - async fn write(&self, conn: &mut ProtocolConnection<'_>) -> Result<()> { + async fn write(&self, conn: &mut T) -> Result<()> { conn.write_data(&mut self.convert()).await } } diff --git a/src/play.rs b/src/play.rs index 52ad9a3..7313ab1 100644 --- a/src/play.rs +++ b/src/play.rs @@ -7,8 +7,8 @@ pub enum Play { } impl Play { - pub async fn read( - conn: &mut mc_types::ProtocolConnection<'_>, + pub async fn read( + conn: &mut T, ) -> Result { let mut data = conn.read_data().await?; Ok(Self::PlayPacket(PlayPacket::get(&mut data)?)) diff --git a/src/status.rs b/src/status.rs index e20f18e..fe2c86c 100644 --- a/src/status.rs +++ b/src/status.rs @@ -45,8 +45,8 @@ pub mod clientbound { } impl StatusPackets { - pub async fn read( - conn: &mut mc_types::ProtocolConnection<'_>, + pub async fn read( + conn: &mut T, ) -> Result { let mut data = conn.read_data().await?; let packet_id = mc_types::get_var_int(&mut data)?; @@ -134,8 +134,8 @@ pub mod serverbound { } impl StatusPackets { - pub async fn read( - conn: &mut mc_types::ProtocolConnection<'_>, + pub async fn read( + conn: &mut T, ) -> Result { let mut data = conn.read_data().await?; let packet_id = mc_types::get_var_int(&mut data)?;