// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "bufio" "fmt" "io" "net/http" ) func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) { var hs serverHandshaker = &hybiServerHandshaker{Config: config} code, err := hs.ReadHandshake(buf.Reader, req) if err == ErrBadWebSocketVersion { fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) buf.WriteString("\r\n") buf.WriteString(err.Error()) buf.Flush() return } if err != nil { fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) buf.WriteString("\r\n") buf.WriteString(err.Error()) buf.Flush() return } if handshake != nil { err = handshake(config, req) if err != nil { code = http.StatusForbidden fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) buf.WriteString("\r\n") buf.Flush() return } } err = hs.AcceptHandshake(buf.Writer) if err != nil { code = http.StatusBadRequest fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) buf.WriteString("\r\n") buf.Flush() return } conn = hs.NewServerConn(buf, rwc, req) return } // Server represents a server of a WebSocket. type Server struct { // Config is a WebSocket configuration for new WebSocket connection. Config // Handshake is an optional function in WebSocket handshake. // For example, you can check, or don't check Origin header. // Another example, you can select config.Protocol. Handshake func(*Config, *http.Request) error // Handler handles a WebSocket connection. Handler } // ServeHTTP implements the http.Handler interface for a WebSocket func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { s.serveWebSocket(w, req) } func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) { rwc, buf, err := w.(http.Hijacker).Hijack() if err != nil { panic("Hijack failed: " + err.Error()) } // The server should abort the WebSocket connection if it finds // the client did not send a handshake that matches with protocol // specification. defer rwc.Close() conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake) if err != nil { return } if conn == nil { panic("unexpected nil conn") } s.Handler(conn) } // Handler is a simple interface to a WebSocket browser client. // It checks if Origin header is valid URL by default. // You might want to verify websocket.Conn.Config().Origin in the func. // If you use Server instead of Handler, you could call websocket.Origin and // check the origin in your Handshake func. So, if you want to accept // non-browser clients, which do not send an Origin header, set a // Server.Handshake that does not check the origin. type Handler func(*Conn) func checkOrigin(config *Config, req *http.Request) (err error) { config.Origin, err = Origin(config, req) if err == nil && config.Origin == nil { return fmt.Errorf("null origin") } return err } // ServeHTTP implements the http.Handler interface for a WebSocket func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { s := Server{Handler: h, Handshake: checkOrigin} s.serveWebSocket(w, req) }