Refactored all but status.rs to use traits

This commit is contained in:
Kyler 2023-12-20 22:26:15 -07:00
parent 48a471909d
commit 1d717ab8da
6 changed files with 166 additions and 81 deletions

View File

@ -13,3 +13,4 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
base64 = "0.21.5"
rand = "0.8.5"
async-trait = "0.1.75"

View File

@ -1,45 +1,57 @@
// Yeahbut December 2023
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
pub mod serverbound {
use crate::mc_types::{self, Result};
use tokio::net::tcp::OwnedReadHalf;
pub struct Handshake {
pub protocol_version: i32,
pub server_address: String,
pub server_port: u16,
pub next_state: i32,
}
pub async fn read_handshake(stream: &mut OwnedReadHalf) -> Result<Handshake> {
let mut data = mc_types::read_packet(stream).await?;
let _packet_id = mc_types::get_var_int(&mut data);
Ok(get_handshake(&mut data)?)
}
pub fn get_handshake(data: &mut Vec<u8>) -> Result<Handshake> {
Ok(Handshake {
protocol_version: mc_types::get_var_int(data)?,
server_address: mc_types::get_string(data)?,
server_port: mc_types::get_u16(data),
next_state: mc_types::get_var_int(data)?,
})
}
pub fn convert_handshake(handshake: Handshake) -> Vec<u8> {
let mut data: Vec<u8> = vec![0];
data.append(&mut mc_types::convert_var_int(handshake.protocol_version));
data.append(&mut mc_types::convert_string(&handshake.server_address));
data.append(&mut mc_types::convert_u16(handshake.server_port));
data.append(&mut mc_types::convert_var_int(handshake.next_state));
data
}
pub async fn write_handshake(
stream: &mut OwnedWriteHalf,
handshake: Handshake,
) -> Result<()> {
mc_types::write_packet(stream, &mut convert_handshake(handshake)).await?;
Ok(())
use crate::mc_types::{self, Result, Packet, PacketError};
enum HandshakeEnum {
Handshake(Handshake),
}
impl HandshakeEnum {
pub async fn read(stream: &mut OwnedReadHalf) -> Result<Self> {
let mut data = mc_types::read_data(stream).await?;
let packet_id = mc_types::get_var_int(&mut data)?;
if packet_id == Handshake::packet_id() {
return Ok(Self::Handshake(Handshake::get(&mut data)?))
} else {
return Err(Box::new(PacketError::InvalidPacketId))
}
}
}
pub struct Handshake {
pub protocol_version: i32,
pub server_address: String,
pub server_port: u16,
pub next_state: i32,
}
impl Packet for Handshake {
fn packet_id() -> i32 {0}
fn get(data: &mut Vec<u8>) -> Result<Self> {
Ok(Self {
protocol_version: mc_types::get_var_int(data)?,
server_address: mc_types::get_string(data)?,
server_port: mc_types::get_u16(data),
next_state: mc_types::get_var_int(data)?,
})
}
fn convert(&self) -> Vec<u8> {
let mut data: Vec<u8> = vec![];
data.append(&mut mc_types::convert_var_int(Self::packet_id()));
data.append(&mut mc_types::convert_var_int(self.protocol_version));
data.append(&mut mc_types::convert_string(&self.server_address));
data.append(&mut mc_types::convert_u16(self.server_port));
data.append(&mut mc_types::convert_var_int(self.next_state));
data
}
}
}

View File

@ -1,22 +1,49 @@
// Yeahbut December 2023
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
pub mod clientbound {
use crate::mc_types::{self, Result};
use tokio::net::tcp::OwnedReadHalf;
pub fn convert_clientbound_disconnect(reason: String) -> Vec<u8> {
let mut data: Vec<u8> = vec![0];
data.append(&mut &mut mc_types::convert_string(&reason));
use crate::mc_types::{self, Result, Packet, PacketError};
enum Login {
Disconnect(Disconnect),
}
impl Login {
pub async fn read(stream: &mut OwnedReadHalf) -> Result<Self> {
let mut data = mc_types::read_data(stream).await?;
let packet_id = mc_types::get_var_int(&mut data)?;
if packet_id == Disconnect::packet_id() {
return Ok(Self::Disconnect(Disconnect::get(&mut data)?))
} else {
return Err(Box::new(PacketError::InvalidPacketId))
}
}
}
pub struct Disconnect {
pub reason: String
}
impl Packet for Disconnect {
fn packet_id() -> i32 {0}
fn get(mut data: &mut Vec<u8>) -> Result<Self> {
Ok(Self {
reason: mc_types::get_string(&mut data)?
})
}
fn convert(&self) -> Vec<u8> {
let mut data: Vec<u8> = vec![];
data.append(&mut mc_types::convert_var_int(Self::packet_id()));
data.append(&mut mc_types::convert_string(&self.reason));
data
}
}
data
}
pub async fn write_clientbound_disconnect(
stream: &mut OwnedWriteHalf,
reason: String,
) -> Result<()> {
mc_types::write_packet(
stream,
&mut convert_clientbound_disconnect(reason),
).await?;
Ok(())
}

View File

@ -9,6 +9,8 @@ mod handshake;
mod status;
mod login;
use mc_types::Packet;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind("127.0.0.1:25565").await?;
@ -52,8 +54,9 @@ async fn handle_client(client_socket: TcpStream) {
.await.expect("Error handling legacy status request");
return;
} else {
let handshake_packet = handshake::read_handshake(&mut client_reader)
.await.expect("Error reading handshake packet");
let handshake_packet =
handshake::serverbound::Handshake::read(&mut client_reader)
.await.expect("Error reading handshake packet");
println!("Next state: {}", handshake_packet.next_state);
if handshake_packet.next_state == 1 {
println!("Receiving Status Request");
@ -67,15 +70,15 @@ async fn handle_client(client_socket: TcpStream) {
} else if handshake_packet.next_state == 2 {
match server_writer {
Some(mut server_writer) => {
handshake::write_handshake(
&mut server_writer,
handshake::Handshake {
protocol_version: mc_types::VERSION_PROTOCOL,
server_address: "localhost".to_string(),
server_port: 25565,
next_state: 2,
},
).await.expect("Error logging into backend server");
handshake::serverbound::Handshake {
protocol_version: mc_types::VERSION_PROTOCOL,
server_address: "localhost".to_string(),
server_port: 25565,
next_state: 2,
}
.write(&mut server_writer)
.await
.expect("Error logging into backend server");
// Forward from client to backend
tokio::spawn(async move {
@ -98,10 +101,13 @@ async fn handle_client(client_socket: TcpStream) {
};
},
None => {
login::write_clientbound_disconnect(
&mut client_writer,
"\"Server Error (Server may be starting)\"".to_string(),
).await.expect("Error sending disconnect on: \
login::clientbound::Disconnect {
reason: "\"Server Error (Server may be starting)\""
.to_string()
}
.write(&mut client_writer)
.await
.expect("Error sending disconnect on: \
Failed to connect to the backend server");
}
};

View File

@ -6,6 +6,7 @@ use std::fmt;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use serde::{Serialize, Deserialize};
use async_trait::async_trait;
pub type Result<T> = std::result::Result<T, Box<dyn Error>>;
@ -15,6 +16,22 @@ pub const VERSION_PROTOCOL: i32 = 762;
const SEGMENT_BITS: u8 = 0x7F;
const CONTINUE_BIT: u8 = 0x80;
#[derive(Debug)]
pub enum PacketError {
InvalidPacketId,
}
impl fmt::Display for PacketError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PacketError::InvalidPacketId =>
write!(f, "Invalid packet id"),
}
}
}
impl Error for PacketError {}
#[derive(Debug)]
pub enum VarIntError {
ValueTooLarge,
@ -39,7 +56,29 @@ pub struct Chat {
pub text: String,
}
pub async fn read_packet(stream: &mut OwnedReadHalf) -> Result<Vec<u8>> {
#[async_trait]
pub trait Packet: Sized {
fn packet_id() -> i32;
fn get(data: &mut Vec<u8>) -> Result<Self>;
fn convert(&self) -> Vec<u8>;
async fn read(stream: &mut OwnedReadHalf) -> Result<Self> {
let mut data = read_data(stream).await?;
let packet_id = get_var_int(&mut data)?;
if packet_id == Self::packet_id() {
return Ok(Self::get(&mut data)?)
} else {
return Err(Box::new(PacketError::InvalidPacketId))
}
}
async fn write(&self, stream: &mut OwnedWriteHalf) -> Result<()> {
write_data(stream, &mut self.convert()).await
}
}
pub async fn read_data(stream: &mut OwnedReadHalf) -> Result<Vec<u8>> {
let length = read_var_int_stream(stream).await? as usize;
let mut buffer: Vec<u8> = vec![0; length];
@ -47,7 +86,7 @@ pub async fn read_packet(stream: &mut OwnedReadHalf) -> Result<Vec<u8>> {
Ok(buffer)
}
pub async fn write_packet(
pub async fn write_data(
stream: &mut OwnedWriteHalf,
data: &mut Vec<u8>,
) -> Result<()> {

View File

@ -10,7 +10,7 @@ use serde_json::Value;
use base64::{Engine as _, engine::general_purpose};
use rand::Rng;
use crate::mc_types::{self, Result};
use crate::mc_types::{self, Result, Packet};
use crate::handshake;
#[derive(Serialize, Deserialize)]
@ -122,7 +122,7 @@ pub async fn respond_status(
)-> Result<()> {
loop {
println!("Status Handling");
let mut data = mc_types::read_packet(client_reader).await?;
let mut data = mc_types::read_data(client_reader).await?;
let packet_id = mc_types::get_var_int(&mut data)?;
println!("Status Packet ID: {}", packet_id);
@ -188,12 +188,12 @@ pub async fn respond_status(
let mut out_data: Vec<u8> = vec![0];
out_data.append(&mut mc_types::convert_string(&json));
mc_types::write_packet(client_writer, &mut out_data).await?;
mc_types::write_data(client_writer, &mut out_data).await?;
} else if packet_id == 0x01 {
println!("Handling Ping");
let mut out_data: Vec<u8> = vec![1];
out_data.append(&mut data);
mc_types::write_packet(client_writer, &mut out_data).await?;
mc_types::write_data(client_writer, &mut out_data).await?;
break;
} else {
break;
@ -206,14 +206,14 @@ pub async fn get_upstream_status(
server_reader: &mut OwnedReadHalf,
server_writer: &mut OwnedWriteHalf,
) -> Result<StatusResponseData> {
handshake::write_handshake(server_writer, handshake::Handshake{
handshake::serverbound::Handshake{
protocol_version: mc_types::VERSION_PROTOCOL,
server_address: "localhost".to_string(),
server_port: 25565,
next_state: 1,
}).await?;
mc_types::write_packet(server_writer, &mut vec![0]).await?;
let mut data = mc_types::read_packet(server_reader).await?;
}.write(server_writer).await?;
mc_types::write_data(server_writer, &mut vec![0]).await?;
let mut data = mc_types::read_data(server_reader).await?;
mc_types::get_u8(&mut data);
let json = mc_types::get_string(&mut data)?;