maj/src/server/mod.rs

194 lines
5.3 KiB
Rust

use crate::{Response, StatusCode};
use async_std::{
io::prelude::*,
net::{TcpListener, TcpStream},
stream::StreamExt,
task,
};
use async_tls::TlsAcceptor;
use async_trait::async_trait;
use rustls::Certificate;
use std::{error::Error as StdError, net::SocketAddr, sync::Arc};
use url::Url;
/// A Gemini request and its associated metadata.
#[allow(dead_code)]
pub struct Request {
pub url: Url,
pub certs: Option<Vec<Certificate>>,
}
pub type Error = Box<dyn StdError + Sync + Send>;
type Result<T = ()> = std::result::Result<T, Error>;
#[derive(thiserror::Error, Debug)]
enum RequestParsingError {
#[error("invalid scheme {0}")]
InvalidScheme(String),
#[error("unexpected end of request")]
UnexpectedEnd,
}
#[allow(dead_code, unused_assignments, unused_mut, unused_variables)]
mod routes;
pub use routes::*;
pub mod cgi;
pub mod files;
#[async_trait]
pub trait Handler {
async fn handle(&self, r: Request) -> Result<Response>;
}
pub async fn serve(
h: Arc<dyn Handler + Send + Sync>,
cfg: rustls::ServerConfig,
host: String,
port: u16,
) -> Result
where
{
let cfg = Arc::new(cfg);
let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
let mut incoming = listener.incoming();
let acceptor = Arc::new(TlsAcceptor::from(cfg.clone()));
while let Some(Ok(stream)) = incoming.next().await {
let h = h.clone();
let acceptor = acceptor.clone();
let addr = stream.peer_addr().unwrap();
let port = port.clone();
task::spawn(handle_request(h, stream, acceptor, addr, port));
}
Ok(())
}
/// Handle a single client session (request + response).
async fn handle_request(
h: Arc<(dyn Handler + Send + Sync)>,
stream: TcpStream,
acceptor: Arc<TlsAcceptor>,
addr: SocketAddr,
port: u16,
) -> Result {
// Perform handshake.
let mut stream = acceptor.clone().accept(stream).await?;
match parse_request(&mut stream).await {
Ok(url) => {
if let Some(u_port) = url.port() {
if port != u_port {
let _ = write_header(
&mut stream,
StatusCode::ProxyRequestRefused,
"Cannot proxy to that URL",
)
.await;
return Ok(());
}
}
if url.scheme() != "gemini" {
let _ = write_header(
&mut stream,
StatusCode::ProxyRequestRefused,
"Cannot proxy to that URL",
)
.await;
Err(RequestParsingError::InvalidScheme(url.scheme().to_string()))?
}
let req = Request {
url: url,
certs: None,
};
handle(h, req, &mut stream, addr).await;
}
Err(e) => {
let _ = write_header(&mut stream, StatusCode::BadRequest, "Invalid request").await;
log::error!("error from {}: {}", addr, e);
}
}
Ok(())
}
pub async fn write_header<W: Write + Unpin>(
mut stream: W,
status: StatusCode,
meta: &str,
) -> Result {
stream
.write(format!("{} {}\r\n", status as u8, meta).as_bytes())
.await?;
Ok(())
}
/// Return the URL requested by the client.
async fn parse_request<R: Read + Unpin>(mut stream: R) -> Result<Url> {
// Because requests are limited to 1024 bytes (plus 2 bytes for CRLF), we
// can use a fixed-sized buffer on the stack, avoiding allocations and
// copying, and stopping bad clients from making us use too much memory.
let mut request = [0; 1026];
let mut buf = &mut request[..];
let mut len = 0;
// Read until CRLF, end-of-stream, or there's no buffer space left.
loop {
let bytes_read = stream.read(buf).await?;
len += bytes_read;
if request[..len].ends_with(b"\r\n") {
break;
} else if bytes_read == 0 {
Err(RequestParsingError::UnexpectedEnd)?
}
buf = &mut request[len..];
}
let request = std::str::from_utf8(&request[..len - 2])?;
// Handle scheme-relative URLs.
let url = if request.starts_with("//") {
Url::parse(&format!("gemini:{}", request))?
} else {
Url::parse(request)?
};
// Validate the URL. TODO: Check the hostname and port.
Ok(url)
}
async fn handle<T>(
h: Arc<(dyn Handler + Send + Sync)>,
req: Request,
stream: &mut T,
addr: SocketAddr,
) where
T: Write + Unpin,
{
let u = req.url.clone();
match h.handle(req).await {
Ok(resp) => {
let _ = stream
.write(format!("{} {}\r\n", resp.status as u8, resp.meta).as_bytes())
.await;
let _ = stream.write(&resp.body).await;
log::info!("{}: {} {:?} {}", addr, u, resp.status, resp.meta);
}
Err(why) => {
let _ = stream
.write(
format!(
"{} {}\r\n",
StatusCode::PermanentFailure as u8,
why.to_string()
)
.as_bytes(),
)
.await;
log::error!("{}: {}: {:?}", addr, u, why);
}
};
}