From bd1d37548ab41871e257a97b2da2976ac457fece Mon Sep 17 00:00:00 2001 From: Christine Dodrill Date: Sat, 25 Jul 2020 20:16:07 -0400 Subject: [PATCH] add majsite server example --- Cargo.toml | 20 +++++++-- src/bin/majsite.rs | 78 ++++++++++++++++++++++++++++++++ src/lib.rs | 4 +- src/server/mod.rs | 108 ++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 198 insertions(+), 12 deletions(-) create mode 100644 src/bin/majsite.rs diff --git a/Cargo.toml b/Cargo.toml index effbeed..471f322 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,24 +7,38 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = { version = "0", optional = true } num = "0.2" num-derive = "0.3" num-traits = "0.2" rustls = { version = "0.18", optional = true, features = ["dangerous_configuration"] } webpki = { version = "0.21.0", optional = true } webpki-roots = { version = "0.20", optional = true } +tokio-rustls = { version = "0.14", features = ["dangerous_configuration"], optional = true } log = "0.4" url = "2" thiserror = "1" - -[dev-dependencies] +structopt = "0.3" pretty_env_logger = "0.4" +[dependencies.tokio] +version = "0.2.0" +features = [ + "macros", + "net", + "tcp", + "io-util", + "rt-core", + "time", + "stream" +] +optional = true + [features] default = ["client", "server"] client = ["rustls", "webpki", "webpki-roots"] -server = ["rustls", "webpki", "webpki-roots"] +server = ["rustls", "webpki", "webpki-roots", "tokio", "async-trait", "tokio-rustls"] [workspace] members = [ diff --git a/src/bin/majsite.rs b/src/bin/majsite.rs new file mode 100644 index 0000000..88562e7 --- /dev/null +++ b/src/bin/majsite.rs @@ -0,0 +1,78 @@ +use std::fs::File; +use std::io::{self, BufReader}; +use std::path::{Path, PathBuf}; +use structopt::StructOpt; +use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys}; +use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig}; + +#[derive(StructOpt, Debug)] +struct Options { + /// host to listen on + #[structopt(short = "H", long, env = "HOST", default_value = "0.0.0.0")] + host: String, + + /// port to listen on + #[structopt(short = "p", long, env = "PORT", default_value = "1965")] + port: u16, + + /// cert file + #[structopt(short = "c", long = "cert", env = "CERT_FILE")] + cert: PathBuf, + + /// key file + #[structopt(short = "k", long = "key", env = "KEY_FILE")] + key: PathBuf, +} + +fn load_certs(path: &Path) -> io::Result> { + certs(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) +} + +fn load_keys(path: &Path) -> io::Result> { + rsa_private_keys(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) +} + +#[tokio::main] +async fn main() -> Result<(), maj::server::Error> { + pretty_env_logger::init(); + let opts = Options::from_args(); + let certs = load_certs(&opts.cert)?; + let mut keys = load_keys(&opts.key)?; + + let mut config = ServerConfig::new(NoClientAuth::new()); + config + .set_single_cert(certs, keys.remove(0)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; + + maj::server::serve(Handler{}, config, opts.host, opts.port).await?; + + Ok(()) +} + +struct Handler {} + +fn index() -> Result { + let msg = include_bytes!("../../majc/src/help.gmi"); + + Ok(maj::Response { + status: maj::StatusCode::Success, + meta: "text/gemini".to_string(), + body: msg.to_vec(), + }) +} + +#[async_trait::async_trait] +impl maj::server::Handler for Handler { + async fn handle(&self, r: maj::server::Request) -> Result { + match r.url.path() { + "/" | "" => index(), + _ => Ok(maj::Response { + status: maj::StatusCode::NotFound, + meta: format!("{} not found", r.url), + body: vec![], + }), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index ad11036..3a5bacd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,4 @@ mod client; pub use client::{get, Error as ClientError}; #[cfg(feature = "server")] -mod server; -#[cfg(feature = "server")] -pub use server::{serve, serve_plain, Error as AnyError, Handler}; +pub mod server; diff --git a/src/server/mod.rs b/src/server/mod.rs index a412f22..cbb1394 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,10 +1,17 @@ -use crate::Response; -use std::{error::Error as StdError, io}; +use crate::{Response, StatusCode}; +use async_trait::async_trait; +use rustls::{Certificate, Session}; +use std::{error::Error as StdError, sync::Arc}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::{net::TcpListener, stream::StreamExt}; +use tokio_rustls::TlsAcceptor; use url::Url; /// A Gemini request and its associated metadata. +#[allow(dead_code)] pub struct Request { - url: Url, + pub url: Url, + pub certs: Option>, } pub type Error = Box; @@ -13,14 +20,103 @@ pub type Error = Box; mod routes; pub use routes::*; +#[async_trait] pub trait Handler { - fn handle(r: Request) -> Result; + async fn handle(&self, r: Request) -> Result; } -pub fn serve(_h: impl Handler, _port: u16) -> io::Result<()> { +pub async fn serve(h: T, cfg: rustls::ServerConfig, host: String, port: u16) -> Result<(), Error> +where + T: Handler, +{ + let cfg = Arc::new(cfg); + let mut listener = TcpListener::bind(&format!("{}:{}", host, port)).await?; + let mut incoming = listener.incoming(); + let acceptor = TlsAcceptor::from(cfg.clone()); + + while let Some(stream) = incoming.next().await { + let stream = stream?; + let addr = stream.peer_addr().unwrap(); + let acceptor = acceptor.clone(); + let mut stream = acceptor.accept(stream).await?; + let mut rd = BufReader::new(&mut stream); + let mut u = String::new(); + rd.read_line(&mut u).await?; + if u.len() > 1025 { + stream + .write(format!("{} URL too long", StatusCode::BadRequest as u8).as_bytes()) + .await?; + continue; + } + + let u = Url::parse(&u)?; + match h + .handle(Request { + url: u.clone(), + certs: stream.get_ref().1.get_peer_certificates(), + }) + .await + { + Ok(resp) => { + stream + .write(format!("{} {}\r\n", resp.status as u8, resp.meta).as_bytes()) + .await?; + stream.write(&resp.body).await?; + log::info!("{}: {} {:?}", addr, u, resp.status); + } + Err(why) => { + stream + .write(format!("{} {:?}\r\n", StatusCode::PermanentFailure as u8, why).as_bytes()) + .await?; + log::error!("{}: {}: {:?}", addr, u, why); + } + }; + } + Ok(()) } -pub fn serve_plain(_h: impl Handler, _port: u16) -> io::Result<()> { +pub async fn serve_plain(h: T, host: String, port: u16) -> Result<(), Error> +where + T: Handler, +{ + let mut listener = TcpListener::bind(&format!("{}:{}", host, port)).await?; + let mut incoming = listener.incoming(); + while let Some(stream) = incoming.next().await { + let mut stream = stream?; + let mut rd = BufReader::new(&mut stream); + let mut u = String::new(); + rd.read_line(&mut u).await?; + if u.len() > 1025 { + stream + .write(format!("{} URL too long", StatusCode::BadRequest as u8).as_bytes()) + .await?; + continue; + } + + let u = Url::parse(&u)?; + match h + .handle(Request { + url: u.clone(), + certs: None, + }) + .await + { + Ok(resp) => { + stream + .write(format!("{} {}", resp.status as u8, resp.meta).as_bytes()) + .await?; + stream.write(&resp.body).await?; + log::info!("{}: {} {:?}", stream.peer_addr().unwrap(), u, resp.status); + } + Err(why) => { + stream + .write(format!("{} {:?}", StatusCode::PermanentFailure as u8, why).as_bytes()) + .await?; + log::error!("{}: {}: {:?}", stream.peer_addr().unwrap(), u, why); + } + }; + } + Ok(()) }