3-websocket协议实现

golang 2019-04-20

websocket链接的实现

编写基本的httpserver

启动一个基本的httpserver,提供两个接口,一个index返回主页,另一个是就是我们自定义的websocket协议接口

@main.go

package main

import (
    "html/template"
    "log"
    "net/http"
    "http/websocket"
)
//websocket处理器
func echo(w http.ResponseWriter, r *http.Request){
    //协议升级
    c,err := websocket.Upgrade(w,r)
    if err != nil {
        log.Print("Upgrade error:",err)
        return
    }
    defer c.Close()
    //循环处理数据,接受数据,然后返回
    for {
        message,err := c.ReadData()
        if err != nil {
            log.Println("read:",err)
            break
        }
        log.Printf("recv:%s",message)
        c.SendData(message)
    }
}

//index
func home(w http.ResponseWriter, r *http.Request){
    homeTemplate.Execute(w,"ws://"+r.Host+"/echo")
}

func main(){
    log.SetFlags(0)

    http.HandleFunc("/echo",echo)
    http.HandleFunc("/",home)
    log.Fatal(http.ListenAndServe("0.0.0.0:8000",nil))
}
var homeTemplate = template.Must(template.New("").Parse(`
<!DOCTYPE html>
<head>
<meta charset="utf-8">
<script>
window.addEventListener("load", function(evt) {

    var output = document.getElementById("output");
    var input = document.getElementById("input");
    var ws;

    var print = function(message) {
        var d = document.createElement("div");
        d.innerHTML = message;
        output.appendChild(d);
    };

    document.getElementById("open").onclick = function(evt) {
        if (ws) {
            return false;
        }
        ws = new WebSocket("{{.}}");
        ws.onopen = function(evt) {
            print("OPEN");
        }
        ws.onclose = function(evt) {
            print("CLOSE");
            ws = null;
        }
        ws.onmessage = function(evt) {
            print("RESPONSE: " + evt.data);
        }
        ws.onerror = function(evt) {
            print("ERROR: " + evt.data);
        }
        return false;
    };

    document.getElementById("send").onclick = function(evt) {
        if (!ws) {
            return false;
        }
        print("SEND: " + input.value);
        ws.send(input.value);
        return false;
    };

    document.getElementById("close").onclick = function(evt) {
        if (!ws) {
            return false;
        }
        ws.close();
        return false;
    };

});
</script>
</head>
<body>
<table>
<tr><td valign="top" width="50%">
<p>
点击 "Open" 开始一个新的WebSocket链接,
“Send" 将内容发送到服务器,
"Close" 将关闭链接。
<p>
<form>
<button id="open">Open</button>
<button id="close">Close</button>
<p><input id="input" type="text" value="hello world!">
<button id="send">Send</button>
</form>
</td><td valign="top" width="50%">
<div id="output"></div>
</td></tr></table>
</body>
</html>
`))

index 返回我们自定义的一段html代码,

echo 接口进行升级我们的websocket

页面如下

index

自定义的webscoket upgrade进行升级

根据之前的协议分析,我知道握手的过程其实就是检查 HTTP 请求头部字段的过程,值得注意的一点就是需要针对客户端发送的 Sec-WebSocket-Key 生成一个正确的 Sec-WebSocket-Accept 只。关于生成的 Sec-WebSocket-Accpet 的实现,可以参考之前的分析。握手过程的具体代码如下:

@upgrade.go

package websocket

import(
    "net/http"
    "net"
    "errors"
    "log"
    "bufio"
)

func Upgrade(w http.ResponseWriter,r *http.Request)(c *Conn,err error){
    //是否是Get方法
    if r.Method != "GET" {
        http.Error(w,http.StatusText(http.StatusMethodNotAllowed),http.StatusMethodNotAllowed)
        return nil,errors.New("websocket:method not GET")
    }
    //检查 Sec-WebSocket-Version 版本
    if values := r.Header["Sec-Websocket-Version"];len(values) == 0 || values[0] != "13" {
        http.Error(w,http.StatusText(http.StatusBadRequest),http.StatusBadRequest)
        return nil,errors.New("websocket:version != 13")
    }

    //检查Connection 和  Upgrade
    if !tokenListContainsValue(r.Header,"Connection","upgrade") {
        http.Error(w,http.StatusText(http.StatusBadRequest),http.StatusBadRequest)
        return nil,errors.New("websocket:could not find connection header with token 'upgrade'")
    }
    if !tokenListContainsValue(r.Header,"Upgrade","websocket") {
        http.Error(w,http.StatusText(http.StatusBadRequest),http.StatusBadRequest)
        return nil,errors.New("websocket:could not find connection header with token 'websocket'")
    }


    //计算Sec-Websocket-Accept的值
    challengeKey := r.Header.Get("Sec-Websocket-Key")
    if challengeKey == "" {
        http.Error(w,http.StatusText(http.StatusBadRequest),http.StatusBadRequest)
        return nil,errors.New("websocket:key missing or blank")
    }

    var (
        netConn net.Conn
        br *bufio.Reader
    )
    h,ok := w.(http.Hijacker)
    if  !ok {
        http.Error(w,http.StatusText(http.StatusInternalServerError),http.StatusInternalServerError)
        return nil,errors.New("websocket:response dose not implement http.Hijacker")
    }
    var rw *bufio.ReadWriter
    //接管当前tcp连接,阻止内置http接管连接
    netConn,rw,err = h.Hijack()
    if err != nil {
        http.Error(w,http.StatusText(http.StatusInternalServerError),http.StatusInternalServerError)
        return nil,err
    }

    br = rw.Reader
    if br.Buffered() > 0 {
        netConn.Close()
        return nil,errors.New("websocket:client send data before hanshake is complete")
    }
    // 构造握手成功后返回的 response
    p := []byte{}
    p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
    p = append(p, computeAcceptKey(challengeKey)...)
    p = append(p, "\r\n\r\n"...)
    //返回repson 但不关闭连接
    if _,err = netConn.Write(p);err != nil {
        netConn.Close()
        return nil,err
    }
    //升级为websocket
    log.Println("Upgrade http to websocket successfully")
    conn := newConn(netConn)
    return conn,nil
}

握手过程的代码比较直观,就不多做解释了。到这里 WebSocket 的实现就基本完成了,可以看到有了之前的各种约定,我们实现 WebSocket 协议也是比较简单的。

封装的websocket结构体和对应的方法

@conn.go

package websocket

import (
    "fmt"
    "encoding/binary"
    "log"
    "errors"
    "net"
)

const (
    /*
    * 是否是最后一个数据帧
    * Fin Rsv1 Rsv2 Rsv3 Opcode
    *  1  0    0    0    0 0 0 0  => 128
    */
    finalBit = 1 << 7

    /*
    * 是否需要掩码处理
    *  Mask payload-len 第一位mask表示是否需要进行掩码处理 后面
    *  7位表示数据包长度 1.0-125 表示长度 2.126 后面需要扩展2 字节 16bit
    *  3.127则扩展8bit
    *  1    0 0 0 0 0 0 0  => 128
    */
    maskBit = 1 << 7

    /*
    * 文本帧类型
    * 0 0 0 0 0 0 0 1
    */
    TextMessage = 1
    /*
    * 关闭数据帧类型
    * 0 0 0 0 1 0 0 0
    */
    CloseMessage = 8
)

//websocket 连接
type Conn struct {
    writeBuf []byte
    maskKey [4]byte
    conn net.Conn
}
func newConn(conn net.Conn)*Conn{
    return &Conn{conn:conn}
}
func (c *Conn)Close(){
    c.conn.Close()
}

//发送数据
func (c *Conn)SendData(data []byte){
    length := len(data)
    c.writeBuf = make([]byte,10 + length)

    //数据开始和结束位置
    payloadStart := 2
    /**
    *数据帧的第一个字节,不支持且只能发送文本类型数据
    *finalBit 1 0 0 0 0 0 0 0
    *                |
    *Text     0 0 0 0 0 0 0 1
    * =>      1 0 0 0 0 0 0 1
    */
    c.writeBuf[0] = byte(TextMessage) | finalBit
    fmt.Printf("1 bit:%b\n",c.writeBuf[0])

    //数据帧第二个字节,服务器发送的数据不需要进行掩码处理
    switch{
    //大于2字节的长度
    case length >= 1 << 16 ://65536
        //c.writeBuf[1] = byte(0x00) | 127 // 127
        c.writeBuf[1] = byte(127) // 127
        //大端写入64位
        binary.BigEndian.PutUint64(c.writeBuf[payloadStart:],uint64(length))
        //需要8byte来存储数据长度
        payloadStart += 8
    case length > 125:
        //c.writeBuf[1] = byte(0x00) | 126
        c.writeBuf[1] = byte(126)
        binary.BigEndian.PutUint16(c.writeBuf[payloadStart:],uint16(length))
        payloadStart += 2
    default:
        //c.writeBuf[1] = byte(0x00) | byte(length)
        c.writeBuf[1] = byte(length)
    }
    fmt.Printf("2 bit:%b\n",c.writeBuf[1])

    copy(c.writeBuf[payloadStart:],data[:])
    c.conn.Write(c.writeBuf[:payloadStart+length])
}

//读取数据
func (c *Conn)ReadData()(data []byte,err error){
    var b [8]byte
    //读取数据帧的前两个字节
    if _,err := c.conn.Read(b[:2]); err != nil {
        return nil,err
    }
    //开始解析第一个字节 是否还有后续数据帧
    final := b[0] & finalBit != 0
    fmt.Printf("read data 1 bit :%b\n",b[0])
    //不支持数据分片
    if !final {
        log.Println("Recived fragemented frame,not support")
        return nil,errors.New("not suppeort fragmented message")
    }

    //数据帧类型
    /*
    *1 0 0 0  0 0 0 1
    *        &
    *0 0 0 0  1 1 1 1
    *0 0 0 0  0 0 0 1
    * => 1 这样就可以直接获取到类型了
    */
    frameType := int(b[0] & 0xf)
    //如果 关闭类型,则关闭连接
    if frameType == CloseMessage {
        c.conn.Close()
        log.Println("Recived closed message,connection will be closed")
        return nil,errors.New("recived closed message")
    }
    //只实现了文本格式的传输,编码utf-8
    if frameType != TextMessage {
        return nil,errors.New("only support text message")
    }
    //检查数据帧是否被掩码处理
    //maskBit => 1 0 0 0 0 0 0 0 任何与他 要么为0 要么为 128
    mask := b[1] & maskBit != 0
    //数据长度
    payloadLen := int64(b[1] & 0x7F)//0 1 1 1 1 1 1 1 1 127
    dataLen := int64(payloadLen)
    //根据payload length 判断数据的真实长度
    switch payloadLen {
    case 126://扩展2字节
        if _,err := c.conn.Read(b[:2]);err != nil {
            return nil,err
        }
        //获取扩展二字节的真实数据长度
        dataLen = int64(binary.BigEndian.Uint16(b[:2]))
    case 127 :
        if _,err := c.conn.Read(b[:8]);err != nil {
            return nil,err
        }
        dataLen = int64(binary.BigEndian.Uint64(b[:8]))
    }

    log.Printf("Read data length :%d,payload length %d",payloadLen,dataLen)
    //读取mask key
    if mask {//如果需要掩码处理的话 需要取出key
        //maskKey 是 4 字节  32位
        if _,err := c.conn.Read(c.maskKey[:]);err != nil {
        return nil ,err
        }
    }
    //读取数据内容
    p := make([]byte,dataLen)
    if _,err := c.conn.Read(p);err != nil {
        return nil,err
    }
    if mask {
        maskBytes(c.maskKey,p)//进行解码
    }
    return p,nil
}

http 头部检查

import (
    "crypto/sha1"
    "encoding/base64"
    "strings"
    "net/http"
)


var KeyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
//握手阶段使用 加密key返回 进行握手
func computeAcceptKey(challengeKey string)string{
    h := sha1.New()
    h.Write([]byte(challengeKey))
    h.Write(KeyGUID)
    return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

//解码
func maskBytes(key [4]byte,b []byte){
    pos := 0
    for i := range b {
        b[i] ^= key[pos & 3]
        pos ++
    }
}

// 检查http 头部字段中是否包含指定的值
func tokenListContainsValue(header http.Header, name string, value string)bool{
    for _,v := range header[name] {
        for _, s := range strings.Split(v,","){
            if strings.EqualFold(value,strings.TrimSpace(s)) {
                return true
            }
        }
    }
    return false
}

本文由 小东@xiaodo 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。