Refactored all but status.rs to use traits

This commit is contained in:
Kyler 2023-12-20 22:26:15 -07:00
parent 48a471909d
commit 1d717ab8da
6 changed files with 166 additions and 81 deletions

View File

@ -13,3 +13,4 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
base64 = "0.21.5" base64 = "0.21.5"
rand = "0.8.5" rand = "0.8.5"
async-trait = "0.1.75"

View File

@ -1,45 +1,57 @@
// Yeahbut December 2023 // Yeahbut December 2023
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; pub mod serverbound {
use crate::mc_types::{self, Result}; use tokio::net::tcp::OwnedReadHalf;
pub struct Handshake { use crate::mc_types::{self, Result, Packet, PacketError};
enum HandshakeEnum {
Handshake(Handshake),
}
impl HandshakeEnum {
pub async fn read(stream: &mut OwnedReadHalf) -> Result<Self> {
let mut data = mc_types::read_data(stream).await?;
let packet_id = mc_types::get_var_int(&mut data)?;
if packet_id == Handshake::packet_id() {
return Ok(Self::Handshake(Handshake::get(&mut data)?))
} else {
return Err(Box::new(PacketError::InvalidPacketId))
}
}
}
pub struct Handshake {
pub protocol_version: i32, pub protocol_version: i32,
pub server_address: String, pub server_address: String,
pub server_port: u16, pub server_port: u16,
pub next_state: i32, pub next_state: i32,
} }
pub async fn read_handshake(stream: &mut OwnedReadHalf) -> Result<Handshake> { impl Packet for Handshake {
let mut data = mc_types::read_packet(stream).await?;
let _packet_id = mc_types::get_var_int(&mut data);
Ok(get_handshake(&mut data)?)
}
pub fn get_handshake(data: &mut Vec<u8>) -> Result<Handshake> { fn packet_id() -> i32 {0}
Ok(Handshake {
fn get(data: &mut Vec<u8>) -> Result<Self> {
Ok(Self {
protocol_version: mc_types::get_var_int(data)?, protocol_version: mc_types::get_var_int(data)?,
server_address: mc_types::get_string(data)?, server_address: mc_types::get_string(data)?,
server_port: mc_types::get_u16(data), server_port: mc_types::get_u16(data),
next_state: mc_types::get_var_int(data)?, next_state: mc_types::get_var_int(data)?,
}) })
} }
pub fn convert_handshake(handshake: Handshake) -> Vec<u8> { fn convert(&self) -> Vec<u8> {
let mut data: Vec<u8> = vec![0]; let mut data: Vec<u8> = vec![];
data.append(&mut mc_types::convert_var_int(handshake.protocol_version)); data.append(&mut mc_types::convert_var_int(Self::packet_id()));
data.append(&mut mc_types::convert_string(&handshake.server_address)); data.append(&mut mc_types::convert_var_int(self.protocol_version));
data.append(&mut mc_types::convert_u16(handshake.server_port)); data.append(&mut mc_types::convert_string(&self.server_address));
data.append(&mut mc_types::convert_var_int(handshake.next_state)); data.append(&mut mc_types::convert_u16(self.server_port));
data.append(&mut mc_types::convert_var_int(self.next_state));
data data
} }
pub async fn write_handshake( }
stream: &mut OwnedWriteHalf,
handshake: Handshake,
) -> Result<()> {
mc_types::write_packet(stream, &mut convert_handshake(handshake)).await?;
Ok(())
} }

View File

@ -1,22 +1,49 @@
// Yeahbut December 2023 // Yeahbut December 2023
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; pub mod clientbound {
use crate::mc_types::{self, Result}; use tokio::net::tcp::OwnedReadHalf;
pub fn convert_clientbound_disconnect(reason: String) -> Vec<u8> { use crate::mc_types::{self, Result, Packet, PacketError};
let mut data: Vec<u8> = vec![0];
data.append(&mut &mut mc_types::convert_string(&reason)); enum Login {
Disconnect(Disconnect),
}
impl Login {
pub async fn read(stream: &mut OwnedReadHalf) -> Result<Self> {
let mut data = mc_types::read_data(stream).await?;
let packet_id = mc_types::get_var_int(&mut data)?;
if packet_id == Disconnect::packet_id() {
return Ok(Self::Disconnect(Disconnect::get(&mut data)?))
} else {
return Err(Box::new(PacketError::InvalidPacketId))
}
}
}
pub struct Disconnect {
pub reason: String
}
impl Packet for Disconnect {
fn packet_id() -> i32 {0}
fn get(mut data: &mut Vec<u8>) -> Result<Self> {
Ok(Self {
reason: mc_types::get_string(&mut data)?
})
}
fn convert(&self) -> Vec<u8> {
let mut data: Vec<u8> = vec![];
data.append(&mut mc_types::convert_var_int(Self::packet_id()));
data.append(&mut mc_types::convert_string(&self.reason));
data data
} }
pub async fn write_clientbound_disconnect(
stream: &mut OwnedWriteHalf, }
reason: String,
) -> Result<()> {
mc_types::write_packet(
stream,
&mut convert_clientbound_disconnect(reason),
).await?;
Ok(())
} }

View File

@ -9,6 +9,8 @@ mod handshake;
mod status; mod status;
mod login; mod login;
use mc_types::Packet;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> { async fn main() -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind("127.0.0.1:25565").await?; let listener = TcpListener::bind("127.0.0.1:25565").await?;
@ -52,7 +54,8 @@ async fn handle_client(client_socket: TcpStream) {
.await.expect("Error handling legacy status request"); .await.expect("Error handling legacy status request");
return; return;
} else { } else {
let handshake_packet = handshake::read_handshake(&mut client_reader) let handshake_packet =
handshake::serverbound::Handshake::read(&mut client_reader)
.await.expect("Error reading handshake packet"); .await.expect("Error reading handshake packet");
println!("Next state: {}", handshake_packet.next_state); println!("Next state: {}", handshake_packet.next_state);
if handshake_packet.next_state == 1 { if handshake_packet.next_state == 1 {
@ -67,15 +70,15 @@ async fn handle_client(client_socket: TcpStream) {
} else if handshake_packet.next_state == 2 { } else if handshake_packet.next_state == 2 {
match server_writer { match server_writer {
Some(mut server_writer) => { Some(mut server_writer) => {
handshake::write_handshake( handshake::serverbound::Handshake {
&mut server_writer,
handshake::Handshake {
protocol_version: mc_types::VERSION_PROTOCOL, protocol_version: mc_types::VERSION_PROTOCOL,
server_address: "localhost".to_string(), server_address: "localhost".to_string(),
server_port: 25565, server_port: 25565,
next_state: 2, next_state: 2,
}, }
).await.expect("Error logging into backend server"); .write(&mut server_writer)
.await
.expect("Error logging into backend server");
// Forward from client to backend // Forward from client to backend
tokio::spawn(async move { tokio::spawn(async move {
@ -98,10 +101,13 @@ async fn handle_client(client_socket: TcpStream) {
}; };
}, },
None => { None => {
login::write_clientbound_disconnect( login::clientbound::Disconnect {
&mut client_writer, reason: "\"Server Error (Server may be starting)\""
"\"Server Error (Server may be starting)\"".to_string(), .to_string()
).await.expect("Error sending disconnect on: \ }
.write(&mut client_writer)
.await
.expect("Error sending disconnect on: \
Failed to connect to the backend server"); Failed to connect to the backend server");
} }
}; };

View File

@ -6,6 +6,7 @@ use std::fmt;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use async_trait::async_trait;
pub type Result<T> = std::result::Result<T, Box<dyn Error>>; pub type Result<T> = std::result::Result<T, Box<dyn Error>>;
@ -15,6 +16,22 @@ pub const VERSION_PROTOCOL: i32 = 762;
const SEGMENT_BITS: u8 = 0x7F; const SEGMENT_BITS: u8 = 0x7F;
const CONTINUE_BIT: u8 = 0x80; const CONTINUE_BIT: u8 = 0x80;
#[derive(Debug)]
pub enum PacketError {
InvalidPacketId,
}
impl fmt::Display for PacketError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PacketError::InvalidPacketId =>
write!(f, "Invalid packet id"),
}
}
}
impl Error for PacketError {}
#[derive(Debug)] #[derive(Debug)]
pub enum VarIntError { pub enum VarIntError {
ValueTooLarge, ValueTooLarge,
@ -39,7 +56,29 @@ pub struct Chat {
pub text: String, pub text: String,
} }
pub async fn read_packet(stream: &mut OwnedReadHalf) -> Result<Vec<u8>> {
#[async_trait]
pub trait Packet: Sized {
fn packet_id() -> i32;
fn get(data: &mut Vec<u8>) -> Result<Self>;
fn convert(&self) -> Vec<u8>;
async fn read(stream: &mut OwnedReadHalf) -> Result<Self> {
let mut data = read_data(stream).await?;
let packet_id = get_var_int(&mut data)?;
if packet_id == Self::packet_id() {
return Ok(Self::get(&mut data)?)
} else {
return Err(Box::new(PacketError::InvalidPacketId))
}
}
async fn write(&self, stream: &mut OwnedWriteHalf) -> Result<()> {
write_data(stream, &mut self.convert()).await
}
}
pub async fn read_data(stream: &mut OwnedReadHalf) -> Result<Vec<u8>> {
let length = read_var_int_stream(stream).await? as usize; let length = read_var_int_stream(stream).await? as usize;
let mut buffer: Vec<u8> = vec![0; length]; let mut buffer: Vec<u8> = vec![0; length];
@ -47,7 +86,7 @@ pub async fn read_packet(stream: &mut OwnedReadHalf) -> Result<Vec<u8>> {
Ok(buffer) Ok(buffer)
} }
pub async fn write_packet( pub async fn write_data(
stream: &mut OwnedWriteHalf, stream: &mut OwnedWriteHalf,
data: &mut Vec<u8>, data: &mut Vec<u8>,
) -> Result<()> { ) -> Result<()> {

View File

@ -10,7 +10,7 @@ use serde_json::Value;
use base64::{Engine as _, engine::general_purpose}; use base64::{Engine as _, engine::general_purpose};
use rand::Rng; use rand::Rng;
use crate::mc_types::{self, Result}; use crate::mc_types::{self, Result, Packet};
use crate::handshake; use crate::handshake;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -122,7 +122,7 @@ pub async fn respond_status(
)-> Result<()> { )-> Result<()> {
loop { loop {
println!("Status Handling"); println!("Status Handling");
let mut data = mc_types::read_packet(client_reader).await?; let mut data = mc_types::read_data(client_reader).await?;
let packet_id = mc_types::get_var_int(&mut data)?; let packet_id = mc_types::get_var_int(&mut data)?;
println!("Status Packet ID: {}", packet_id); println!("Status Packet ID: {}", packet_id);
@ -188,12 +188,12 @@ pub async fn respond_status(
let mut out_data: Vec<u8> = vec![0]; let mut out_data: Vec<u8> = vec![0];
out_data.append(&mut mc_types::convert_string(&json)); out_data.append(&mut mc_types::convert_string(&json));
mc_types::write_packet(client_writer, &mut out_data).await?; mc_types::write_data(client_writer, &mut out_data).await?;
} else if packet_id == 0x01 { } else if packet_id == 0x01 {
println!("Handling Ping"); println!("Handling Ping");
let mut out_data: Vec<u8> = vec![1]; let mut out_data: Vec<u8> = vec![1];
out_data.append(&mut data); out_data.append(&mut data);
mc_types::write_packet(client_writer, &mut out_data).await?; mc_types::write_data(client_writer, &mut out_data).await?;
break; break;
} else { } else {
break; break;
@ -206,14 +206,14 @@ pub async fn get_upstream_status(
server_reader: &mut OwnedReadHalf, server_reader: &mut OwnedReadHalf,
server_writer: &mut OwnedWriteHalf, server_writer: &mut OwnedWriteHalf,
) -> Result<StatusResponseData> { ) -> Result<StatusResponseData> {
handshake::write_handshake(server_writer, handshake::Handshake{ handshake::serverbound::Handshake{
protocol_version: mc_types::VERSION_PROTOCOL, protocol_version: mc_types::VERSION_PROTOCOL,
server_address: "localhost".to_string(), server_address: "localhost".to_string(),
server_port: 25565, server_port: 25565,
next_state: 1, next_state: 1,
}).await?; }.write(server_writer).await?;
mc_types::write_packet(server_writer, &mut vec![0]).await?; mc_types::write_data(server_writer, &mut vec![0]).await?;
let mut data = mc_types::read_packet(server_reader).await?; let mut data = mc_types::read_data(server_reader).await?;
mc_types::get_u8(&mut data); mc_types::get_u8(&mut data);
let json = mc_types::get_string(&mut data)?; let json = mc_types::get_string(&mut data)?;