Added ProtocolRead and ProtocolWrite traits

This commit is contained in:
Kyler 2024-05-31 18:51:07 -06:00
parent 505adfb92c
commit c63ff903cb
5 changed files with 231 additions and 82 deletions

View File

@ -9,8 +9,8 @@ pub mod serverbound {
}
impl HandshakeEnum {
pub async fn read(
conn: &mut mc_types::ProtocolConnection<'_>,
pub async fn read<T: mc_types::ProtocolRead>(
conn: &mut T,
) -> Result<Self> {
let mut data = conn.read_data().await?;
let packet_id = mc_types::get_var_int(&mut data)?;

View File

@ -13,8 +13,8 @@ pub mod clientbound {
}
impl Login {
pub async fn read(
conn: &mut mc_types::ProtocolConnection<'_>,
pub async fn read<T: mc_types::ProtocolRead>(
conn: &mut T,
) -> Result<Self> {
let mut data = conn.read_data().await?;
let packet_id = mc_types::get_var_int(&mut data)?;
@ -241,8 +241,8 @@ pub mod serverbound {
}
impl Login {
pub async fn read(
conn: &mut mc_types::ProtocolConnection<'_>,
pub async fn read<T: mc_types::ProtocolRead>(
conn: &mut T,
) -> Result<Self> {
let mut data = conn.read_data().await?;
let packet_id = mc_types::get_var_int(&mut data)?;

View File

@ -53,6 +53,16 @@ pub struct Chat {
pub text: String,
}
#[async_trait]
pub trait ProtocolRead {
async fn read_data(&mut self) -> Result<Vec<u8>>;
}
#[async_trait]
pub trait ProtocolWrite {
async fn write_data(&mut self, data: &mut Vec<u8>) -> Result<()>;
}
pub struct ProtocolConnection<'a> {
pub stream_read: &'a mut OwnedReadHalf,
pub stream_write: &'a mut OwnedWriteHalf,
@ -77,74 +87,6 @@ impl<'a> ProtocolConnection<'a> {
}
}
pub async fn read_data(&mut self) -> Result<Vec<u8>> {
match self.aes_encryption_key {
Some(aes_key) => {
let mut buffer: Vec<u8> = vec![0; 16];
self.stream_read.read_exact(&mut buffer).await?;
buffer = encrypt::decrypt_aes(
&aes_key, buffer[0..16].try_into().unwrap());
let raw_length = read_var_int_vec(&mut buffer)?;
let length =
if (raw_length - buffer.len() as i32) % 16 == 0 {
(raw_length - buffer.len() as i32) / 16
} else {
((raw_length - buffer.len() as i32) / 16) + 1
};
for _ in 0..length {
let mut block: Vec<u8> = vec![0; 16];
self.stream_read.read_exact(&mut block).await?;
buffer.append(&mut block);
}
Ok(buffer)
},
None => {
let length = read_var_int_stream(
self.stream_read).await? as usize;
let mut buffer: Vec<u8> = vec![0; length];
self.stream_read.read_exact(&mut buffer).await?;
Ok(buffer)
}
}
}
pub async fn write_data(
&mut self,
data: &mut Vec<u8>,
) -> Result<()> {
let mut out_data = convert_var_int(data.len() as i32);
out_data.append(data);
match self.aes_encryption_key {
Some(aes_key) => {
let length =
if (data.len() as i32) % 16 == 0 {
(data.len() as i32) / 16
} else {
((data.len() as i32) / 16) + 1
};
for _ in 0..length {
let mut block: Vec<u8> = out_data[0..16].to_vec();
block = encrypt::encrypt_aes(
&aes_key, block[0..16].try_into().unwrap());
self.stream_write.write_all(&block).await?;
}
Ok(())
},
None => {
self.stream_write.write_all(&out_data).await?;
Ok(())
}
}
}
pub fn create_encryption_request(
&mut self,
private_key: RsaPrivateKey,
@ -248,6 +190,213 @@ impl<'a> ProtocolConnection<'a> {
};
}
}
pub fn split_conn(
&mut self
) -> (WriteHaftProtocolConnection, ReadHaftProtocolConnection) {
(WriteHaftProtocolConnection {
stream_write: self.stream_write,
aes_encryption_key: self.aes_encryption_key,
},
ReadHaftProtocolConnection {
stream_read: self.stream_read,
aes_encryption_key: self.aes_encryption_key,
})
}
}
unsafe impl<'a> Send for ProtocolConnection<'a> {}
#[async_trait]
impl<'a> ProtocolRead for ProtocolConnection<'a> {
async fn read_data(&mut self) -> Result<Vec<u8>> {
match self.aes_encryption_key {
Some(aes_key) => {
let mut buffer: Vec<u8> = vec![0; 16];
self.stream_read.read_exact(&mut buffer).await?;
buffer = encrypt::decrypt_aes(
&aes_key, buffer[0..16].try_into().unwrap());
let raw_length = read_var_int_vec(&mut buffer)?;
let length =
if (raw_length - buffer.len() as i32) % 16 == 0 {
(raw_length - buffer.len() as i32) / 16
} else {
((raw_length - buffer.len() as i32) / 16) + 1
};
for _ in 0..length {
let mut block: Vec<u8> = vec![0; 16];
self.stream_read.read_exact(&mut block).await?;
buffer.append(&mut block);
}
Ok(buffer)
},
None => {
let length = read_var_int_stream(
self.stream_read).await? as usize;
let mut buffer: Vec<u8> = vec![0; length];
self.stream_read.read_exact(&mut buffer).await?;
Ok(buffer)
}
}
}
}
#[async_trait]
impl<'a> ProtocolWrite for ProtocolConnection<'a> {
async fn write_data(&mut self, data: &mut Vec<u8>) -> Result<()> {
let mut out_data = convert_var_int(data.len() as i32);
out_data.append(data);
match self.aes_encryption_key {
Some(aes_key) => {
let length =
if (data.len() as i32) % 16 == 0 {
(data.len() as i32) / 16
} else {
((data.len() as i32) / 16) + 1
};
for _ in 0..length {
let mut block: Vec<u8> = out_data[0..16].to_vec();
block = encrypt::encrypt_aes(
&aes_key, block[0..16].try_into().unwrap());
self.stream_write.write_all(&block).await?;
}
Ok(())
},
None => {
self.stream_write.write_all(&out_data).await?;
Ok(())
}
}
}
}
pub struct WriteHaftProtocolConnection<'a> {
pub stream_write: &'a mut OwnedWriteHalf,
aes_encryption_key: Option<[u8; 16]>,
}
impl<'a> WriteHaftProtocolConnection<'a> {
pub fn new(
stream_write: &'a mut OwnedWriteHalf,
) -> Self {
WriteHaftProtocolConnection {
stream_write,
aes_encryption_key: None,
}
}
}
unsafe impl<'a> Send for WriteHaftProtocolConnection<'a> {}
#[async_trait]
impl<'a> ProtocolWrite for WriteHaftProtocolConnection<'a> {
async fn write_data(
&mut self,
data: &mut Vec<u8>,
) -> Result<()> {
let mut out_data = convert_var_int(data.len() as i32);
out_data.append(data);
match self.aes_encryption_key {
Some(aes_key) => {
let length =
if (data.len() as i32) % 16 == 0 {
(data.len() as i32) / 16
} else {
((data.len() as i32) / 16) + 1
};
for _ in 0..length {
let mut block: Vec<u8> = out_data[0..16].to_vec();
block = encrypt::encrypt_aes(
&aes_key, block[0..16].try_into().unwrap());
self.stream_write.write_all(&block).await?;
}
Ok(())
},
None => {
self.stream_write.write_all(&out_data).await?;
Ok(())
}
}
}
}
pub struct ReadHaftProtocolConnection<'a> {
pub stream_read: &'a mut OwnedReadHalf,
aes_encryption_key: Option<[u8; 16]>,
}
impl<'a> ReadHaftProtocolConnection<'a> {
pub fn new(
stream_read: &'a mut OwnedReadHalf,
) -> Self {
ReadHaftProtocolConnection {
stream_read,
aes_encryption_key: None,
}
}
pub async fn forward_play<T: ProtocolWrite + Send>(
&mut self,
other: &mut T,
) -> Result<()> {
loop {
let packet = Play::read(self).await?;
match packet {
Play::PlayPacket(packet) => packet.write(other).await?,
};
}
}
}
unsafe impl<'a> Send for ReadHaftProtocolConnection<'a> {}
#[async_trait]
impl<'a> ProtocolRead for ReadHaftProtocolConnection<'a> {
async fn read_data(&mut self) -> Result<Vec<u8>> {
match self.aes_encryption_key {
Some(aes_key) => {
let mut buffer: Vec<u8> = vec![0; 16];
self.stream_read.read_exact(&mut buffer).await?;
buffer = encrypt::decrypt_aes(
&aes_key, buffer[0..16].try_into().unwrap());
let raw_length = read_var_int_vec(&mut buffer)?;
let length =
if (raw_length - buffer.len() as i32) % 16 == 0 {
(raw_length - buffer.len() as i32) / 16
} else {
((raw_length - buffer.len() as i32) / 16) + 1
};
for _ in 0..length {
let mut block: Vec<u8> = vec![0; 16];
self.stream_read.read_exact(&mut block).await?;
buffer.append(&mut block);
}
Ok(buffer)
},
None => {
let length = read_var_int_stream(
self.stream_read).await? as usize;
let mut buffer: Vec<u8> = vec![0; length];
self.stream_read.read_exact(&mut buffer).await?;
Ok(buffer)
}
}
}
}
#[async_trait]
@ -256,7 +405,7 @@ pub trait Packet: Sized {
fn get(data: &mut Vec<u8>) -> Result<Self>;
fn convert(&self) -> Vec<u8>;
async fn read(conn: &mut ProtocolConnection<'_>) -> Result<Self> {
async fn read<T: ProtocolRead + Send>(conn: &mut T) -> Result<Self> {
let mut data = conn.read_data().await?;
let packet_id = get_var_int(&mut data)?;
if packet_id == Self::packet_id() {
@ -266,7 +415,7 @@ pub trait Packet: Sized {
}
}
async fn write(&self, conn: &mut ProtocolConnection<'_>) -> Result<()> {
async fn write<T: ProtocolWrite + Send>(&self, conn: &mut T) -> Result<()> {
conn.write_data(&mut self.convert()).await
}
}

View File

@ -7,8 +7,8 @@ pub enum Play {
}
impl Play {
pub async fn read(
conn: &mut mc_types::ProtocolConnection<'_>,
pub async fn read<T: mc_types::ProtocolRead>(
conn: &mut T,
) -> Result<Self> {
let mut data = conn.read_data().await?;
Ok(Self::PlayPacket(PlayPacket::get(&mut data)?))

View File

@ -45,8 +45,8 @@ pub mod clientbound {
}
impl StatusPackets {
pub async fn read(
conn: &mut mc_types::ProtocolConnection<'_>,
pub async fn read<T: mc_types::ProtocolRead>(
conn: &mut T,
) -> Result<Self> {
let mut data = conn.read_data().await?;
let packet_id = mc_types::get_var_int(&mut data)?;
@ -134,8 +134,8 @@ pub mod serverbound {
}
impl StatusPackets {
pub async fn read(
conn: &mut mc_types::ProtocolConnection<'_>,
pub async fn read<T: mc_types::ProtocolRead>(
conn: &mut T,
) -> Result<Self> {
let mut data = conn.read_data().await?;
let packet_id = mc_types::get_var_int(&mut data)?;