1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Fixed cookie reply processing bug

This commit is contained in:
Mathias Hall-Andersen 2017-07-07 13:47:09 +02:00
parent 70179f8c8c
commit ed31e75739
10 changed files with 173 additions and 95 deletions

View File

@ -12,7 +12,7 @@ const (
RejectAfterTime = time.Second * 180 RejectAfterTime = time.Second * 180
RejectAfterMessages = (1 << 64) - (1 << 4) - 1 RejectAfterMessages = (1 << 64) - (1 << 4) - 1
KeepaliveTimeout = time.Second * 10 KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 2 CookieRefreshTime = time.Minute * 2
MaxHandshakeAttemptTime = time.Second * 90 MaxHandshakeAttemptTime = time.Second * 90
) )

View File

@ -25,8 +25,8 @@ type Device struct {
queue struct { queue struct {
encryption chan *QueueOutboundElement encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement decryption chan *QueueInboundElement
inbound chan *QueueInboundElement
handshake chan QueueHandshakeElement handshake chan QueueHandshakeElement
inbound chan []byte // inbound queue for TUN
} }
signal struct { signal struct {
stop chan struct{} stop chan struct{}
@ -77,10 +77,10 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
// create queues // create queues
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
device.queue.inbound = make(chan []byte, QueueInboundSize) device.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals // prepare signals

View File

@ -112,7 +112,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
binary.Write(writer, binary.LittleEndian, msg) binary.Write(writer, binary.LittleEndian, msg)
elem.packet = writer.Bytes() elem.packet = writer.Bytes()
peer.mac.AddMacs(elem.packet) peer.mac.AddMacs(elem.packet)
println(elem)
addToOutboundQueue(peer.queue.outbound, elem) addToOutboundQueue(peer.queue.outbound, elem)
if attempts == 0 { if attempts == 0 {

View File

@ -5,14 +5,17 @@ import (
) )
const ( const (
IPv4version = 4 IPv4version = 4
IPv4offsetSrc = 12 IPv4offsetTotalLength = 2
IPv4offsetDst = IPv4offsetSrc + net.IPv4len IPv4offsetSrc = 12
IPv4headerSize = 20 IPv4offsetDst = IPv4offsetSrc + net.IPv4len
IPv4headerSize = 20
) )
const ( const (
IPv6version = 6 IPv6version = 6
IPv6offsetSrc = 8 IPv6offsetPayloadLength = 4
IPv6offsetDst = IPv6offsetSrc + net.IPv6len IPv6offsetSrc = 8
IPv6offsetDst = IPv6offsetSrc + net.IPv6len
IPv6headerSize = 40
) )

View File

@ -15,21 +15,31 @@ type MACStateDevice struct {
mutex sync.RWMutex mutex sync.RWMutex
refreshed time.Time refreshed time.Time
secret [blake2s.Size]byte secret [blake2s.Size]byte
keyMac1 [blake2s.Size]byte keyMAC1 [blake2s.Size]byte
keyMAC2 [blake2s.Size]byte
xaead cipher.AEAD xaead cipher.AEAD
} }
func (state *MACStateDevice) Init(pk NoisePublicKey) { func (state *MACStateDevice) Init(pk NoisePublicKey) {
state.mutex.Lock() state.mutex.Lock()
defer state.mutex.Unlock() defer state.mutex.Unlock()
func() { func() {
hsh, _ := blake2s.New256(nil) hsh, _ := blake2s.New256(nil)
hsh.Write([]byte(WGLabelMAC1)) hsh.Write([]byte(WGLabelMAC1))
hsh.Write(pk[:]) hsh.Write(pk[:])
hsh.Sum(state.keyMac1[:0]) hsh.Sum(state.keyMAC1[:0])
}() }()
state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMac1[:])
state.refreshed = time.Time{} // never func() {
hsh, _ := blake2s.New256(nil)
hsh.Write([]byte(WGLabelCookie))
hsh.Write(pk[:])
hsh.Sum(state.keyMAC2[:0])
}()
state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMAC2[:])
state.refreshed = time.Time{}
} }
func (state *MACStateDevice) CheckMAC1(msg []byte) bool { func (state *MACStateDevice) CheckMAC1(msg []byte) bool {
@ -39,7 +49,7 @@ func (state *MACStateDevice) CheckMAC1(msg []byte) bool {
var mac1 [blake2s.Size128]byte var mac1 [blake2s.Size128]byte
func() { func() {
mac, _ := blake2s.New128(state.keyMac1[:]) mac, _ := blake2s.New128(state.keyMAC1[:])
mac.Write(msg[:startMac1]) mac.Write(msg[:startMac1])
mac.Sum(mac1[:0]) mac.Sum(mac1[:0])
}() }()
@ -117,7 +127,7 @@ func (device *Device) CreateMessageCookieReply(msg []byte, receiver uint32, addr
startMac1 := size - (blake2s.Size128 * 2) startMac1 := size - (blake2s.Size128 * 2)
startMac2 := size - blake2s.Size128 startMac2 := size - blake2s.Size128
M := msg[startMac1:startMac2] mac1 := msg[startMac1:startMac2]
reply := new(MessageCookieReply) reply := new(MessageCookieReply)
reply.Type = MessageCookieReplyType reply.Type = MessageCookieReplyType
@ -127,7 +137,7 @@ func (device *Device) CreateMessageCookieReply(msg []byte, receiver uint32, addr
state.mutex.RUnlock() state.mutex.RUnlock()
return nil, err return nil, err
} }
state.xaead.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], M) state.xaead.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], mac1)
state.mutex.RUnlock() state.mutex.RUnlock()
return reply, nil return reply, nil
} }
@ -149,9 +159,11 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
var cookie [blake2s.Size128]byte var cookie [blake2s.Size128]byte
state := &lookup.peer.mac state := &lookup.peer.mac
state.mutex.Lock() state.mutex.Lock()
defer state.mutex.Unlock() defer state.mutex.Unlock()
_, err := state.xaead.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], state.lastMac1[:])
_, err := state.xaead.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], state.lastMAC1[:])
if err != nil { if err != nil {
return false return false
} }

View File

@ -13,21 +13,31 @@ type MACStatePeer struct {
mutex sync.RWMutex mutex sync.RWMutex
cookieSet time.Time cookieSet time.Time
cookie [blake2s.Size128]byte cookie [blake2s.Size128]byte
lastMac1 [blake2s.Size128]byte lastMAC1 [blake2s.Size128]byte
keyMac1 [blake2s.Size]byte keyMAC1 [blake2s.Size]byte
keyMAC2 [blake2s.Size]byte
xaead cipher.AEAD xaead cipher.AEAD
} }
func (state *MACStatePeer) Init(pk NoisePublicKey) { func (state *MACStatePeer) Init(pk NoisePublicKey) {
state.mutex.Lock() state.mutex.Lock()
defer state.mutex.Unlock() defer state.mutex.Unlock()
func() { func() {
hsh, _ := blake2s.New256(nil) hsh, _ := blake2s.New256(nil)
hsh.Write([]byte(WGLabelMAC1)) hsh.Write([]byte(WGLabelMAC1))
hsh.Write(pk[:]) hsh.Write(pk[:])
hsh.Sum(state.keyMac1[:0]) hsh.Sum(state.keyMAC1[:0])
}() }()
state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMac1[:])
func() {
hsh, _ := blake2s.New256(nil)
hsh.Write([]byte(WGLabelCookie))
hsh.Write(pk[:])
hsh.Sum(state.keyMAC2[:0])
}()
state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMAC2[:])
state.cookieSet = time.Time{} // never state.cookieSet = time.Time{} // never
} }
@ -50,11 +60,11 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
// set mac1 // set mac1
func() { func() {
mac, _ := blake2s.New128(state.keyMac1[:]) mac, _ := blake2s.New128(state.keyMAC1[:])
mac.Write(msg[:startMac1]) mac.Write(msg[:startMac1])
mac.Sum(state.lastMac1[:0]) mac.Sum(state.lastMAC1[:0])
}() }()
copy(mac1, state.lastMac1[:]) copy(mac1, state.lastMAC1[:])
// set mac2 // set mac2

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"net" "net"
"testing" "testing"
"testing/quick" "testing/quick"
@ -17,8 +16,8 @@ func TestMAC1(t *testing.T) {
peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
assertEqual(t, peer1.mac.keyMac1[:], dev1.mac.keyMac1[:]) assertEqual(t, peer1.mac.keyMAC1[:], dev1.mac.keyMAC1[:])
assertEqual(t, peer2.mac.keyMac1[:], dev2.mac.keyMac1[:]) assertEqual(t, peer2.mac.keyMAC1[:], dev2.mac.keyMAC1[:])
msg1 := make([]byte, 256) msg1 := make([]byte, 256)
copy(msg1, []byte("some content")) copy(msg1, []byte("some content"))
@ -52,17 +51,15 @@ func TestMACs(t *testing.T) {
if addr.Port < 0 { if addr.Port < 0 {
return true return true
} }
addr.Port &= 0xffff addr.Port &= 0xffff
if len(msg) < 32 { if len(msg) < 32 {
return true return true
} }
if bytes.Compare(peer1.mac.keyMac1[:], device1.mac.keyMac1[:]) != 0 {
return false assertEqual(t, peer1.mac.keyMAC1[:], device1.mac.keyMAC1[:])
} assertEqual(t, peer2.mac.keyMAC1[:], device2.mac.keyMAC1[:])
if bytes.Compare(peer2.mac.keyMac1[:], device2.mac.keyMac1[:]) != 0 {
return false
}
device2.indices.Insert(receiver, IndexTableEntry{ device2.indices.Insert(receiver, IndexTableEntry{
peer: peer1, peer: peer1,
@ -83,17 +80,17 @@ func TestMACs(t *testing.T) {
return false return false
} }
if device2.ConsumeMessageCookieReply(cr) == false { if !device2.ConsumeMessageCookieReply(cr) {
return false return false
} }
// test MAC1 + MAC2 // test MAC1 + MAC2
peer1.mac.AddMacs(msg) peer1.mac.AddMacs(msg)
if device1.mac.CheckMAC1(msg) == false { if !device1.mac.CheckMAC1(msg) {
return false return false
} }
if device1.mac.CheckMAC2(msg, &addr) == false { if !device1.mac.CheckMAC2(msg, &addr) {
return false return false
} }
@ -107,6 +104,8 @@ func TestMACs(t *testing.T) {
return false return false
} }
t.Log("Passed")
return true return true
} }

View File

@ -55,6 +55,23 @@ func addToInboundQueue(
} }
} }
func addToHandshakeQueue(
queue chan QueueHandshakeElement,
element QueueHandshakeElement,
) {
for {
select {
case queue <- element:
return
default:
select {
case <-queue:
default:
}
}
}
}
func (device *Device) RoutineReceiveIncomming() { func (device *Device) RoutineReceiveIncomming() {
debugLog := device.log.Debug debugLog := device.log.Debug
@ -62,7 +79,7 @@ func (device *Device) RoutineReceiveIncomming() {
errorLog := device.log.Error errorLog := device.log.Error
var buffer []byte // unsliced buffer var buffer []byte
for { for {
@ -116,7 +133,7 @@ func (device *Device) RoutineReceiveIncomming() {
busy := len(device.queue.handshake) > QueueHandshakeBusySize busy := len(device.queue.handshake) > QueueHandshakeBusySize
if busy && !device.mac.CheckMAC2(packet, raddr) { if busy && !device.mac.CheckMAC2(packet, raddr) {
sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" follows "type" sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" always follows "type"
reply, err := device.CreateMessageCookieReply(packet, sender, raddr) reply, err := device.CreateMessageCookieReply(packet, sender, raddr)
if err != nil { if err != nil {
errorLog.Println("Failed to create cookie reply:", err) errorLog.Println("Failed to create cookie reply:", err)
@ -134,12 +151,15 @@ func (device *Device) RoutineReceiveIncomming() {
// add to handshake queue // add to handshake queue
addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
packet: packet,
source: raddr,
},
)
buffer = nil buffer = nil
device.queue.handshake <- QueueHandshakeElement{
msgType: msgType,
packet: packet,
source: raddr,
}
case MessageCookieReplyType: case MessageCookieReplyType:
@ -293,7 +313,21 @@ func (device *Device) RoutineHandshake() {
) )
return return
} }
logDebug.Println("Recieved valid initiation message for peer", peer.id)
// create response
response, err := device.CreateMessageResponse(peer)
if err != nil {
logError.Println("Failed to create response message:", err)
return
}
outElem := device.NewOutboundElement()
writer := bytes.NewBuffer(outElem.data[:0])
binary.Write(writer, binary.LittleEndian, response)
elem.packet = writer.Bytes()
peer.mac.AddMacs(elem.packet)
device.log.Debug.Println(elem.packet)
addToOutboundQueue(peer.queue.outbound, outElem)
case MessageResponseType: case MessageResponseType:
@ -352,29 +386,53 @@ func (peer *Peer) RoutineSequentialReceiver() {
return return
case elem = <-peer.queue.inbound: case elem = <-peer.queue.inbound:
} }
elem.mutex.Lock() elem.mutex.Lock()
if elem.IsDropped() {
continue
}
// check for replay // process IP packet
// update timers func() {
if elem.IsDropped() {
return
}
// check for keep-alive // check for replay
if len(elem.packet) == 0 { // update timers
continue
}
// strip padding // refresh key material
// insert into inbound TUN queue // check for keep-alive
device.queue.inbound <- elem.packet if len(elem.packet) == 0 {
return
}
// update key material // strip padding
switch elem.packet[0] >> 4 {
case IPv4version:
if len(elem.packet) < IPv4headerSize {
return
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
elem.packet = elem.packet[:length]
case IPv6version:
if len(elem.packet) < IPv6headerSize {
return
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += IPv6headerSize
elem.packet = elem.packet[:length]
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
return
}
addToInboundQueue(device.queue.inbound, elem)
}()
} }
} }
@ -387,8 +445,8 @@ func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
select { select {
case <-device.signal.stop: case <-device.signal.stop:
return return
case packet := <-device.queue.inbound: case elem := <-device.queue.inbound:
_, err := tun.Write(packet) _, err := tun.Write(elem.packet)
if err != nil { if err != nil {
logError.Println("Failed to write packet to TUN device:", err) logError.Println("Failed to write packet to TUN device:", err)
} }

View File

@ -28,13 +28,13 @@ import (
* *
* If the element is inserted into the "encryption queue", * If the element is inserted into the "encryption queue",
* the content is preceeded by enough "junk" to contain the header * the content is preceeded by enough "junk" to contain the header
* (to allow the constuction of transport messages in-place) * (to allow the construction of transport messages in-place)
*/ */
type QueueOutboundElement struct { type QueueOutboundElement struct {
state uint32 state uint32
mutex sync.Mutex mutex sync.Mutex
data [MaxMessageSize]byte data [MaxMessageSize]byte
packet []byte // slice of packet (sending) packet []byte // slice of "data" (always!)
nonce uint64 // nonce for encryption nonce uint64 // nonce for encryption
keyPair *KeyPair // key-pair for encryption keyPair *KeyPair // key-pair for encryption
peer *Peer // related peer peer *Peer // related peer
@ -51,8 +51,12 @@ func (peer *Peer) FlushNonceQueue() {
} }
} }
/*
* Assumption: The mutex of the returned element is released
*/
func (device *Device) NewOutboundElement() *QueueOutboundElement { func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := new(QueueOutboundElement) // TODO: profile, consider sync.Pool // TODO: profile, consider sync.Pool
elem := new(QueueOutboundElement)
return elem return elem
} }
@ -160,9 +164,8 @@ func (peer *Peer) RoutineNonce() {
var elem *QueueOutboundElement var elem *QueueOutboundElement
device := peer.device device := peer.device
logger := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, nonce worker, started for peer", peer.id)
logger.Println("Routine, nonce worker, started for peer", peer.id)
func() { func() {
@ -193,18 +196,18 @@ func (peer *Peer) RoutineNonce() {
break break
} }
} }
logger.Println("Key pair:", keyPair) logDebug.Println("Key pair:", keyPair)
sendSignal(peer.signal.handshakeBegin) sendSignal(peer.signal.handshakeBegin)
logger.Println("Waiting for key-pair, peer", peer.id) logDebug.Println("Waiting for key-pair, peer", peer.id)
select { select {
case <-peer.signal.newKeyPair: case <-peer.signal.newKeyPair:
logger.Println("Key-pair negotiated for peer", peer.id) logDebug.Println("Key-pair negotiated for peer", peer.id)
goto NextPacket goto NextPacket
case <-peer.signal.flushNonceQueue: case <-peer.signal.flushNonceQueue:
logger.Println("Clearing queue for peer", peer.id) logDebug.Println("Clearing queue for peer", peer.id)
peer.FlushNonceQueue() peer.FlushNonceQueue()
elem = nil elem = nil
goto NextPacket goto NextPacket
@ -233,8 +236,6 @@ func (peer *Peer) RoutineNonce() {
} }
} }
}() }()
logger.Println("Routine, nonce worker, stopped for peer", peer.id)
} }
/* Encrypts the elements in the queue /* Encrypts the elements in the queue
@ -265,20 +266,16 @@ func (device *Device) RoutineEncryption() {
// encrypt content // encrypt content
func() { binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
binary.LittleEndian.PutUint64(nonce[4:], work.nonce) work.packet = work.keyPair.send.Seal(
work.packet = work.keyPair.send.Seal( work.packet[:0],
work.packet[:0], nonce[:],
nonce[:], work.packet,
work.packet, nil,
nil, )
) length := MessageTransportHeaderSize + len(work.packet)
work.mutex.Unlock() work.packet = work.data[:length]
}() work.mutex.Unlock()
// reslice to include header
work.packet = work.data[:MessageTransportHeaderSize+len(work.packet)]
// refresh key if necessary // refresh key if necessary
@ -292,15 +289,15 @@ func (device *Device) RoutineEncryption() {
* The routine terminates then the outbound queue is closed. * The routine terminates then the outbound queue is closed.
*/ */
func (peer *Peer) RoutineSequentialSender() { func (peer *Peer) RoutineSequentialSender() {
logger := peer.device.log.Debug logDebug := peer.device.log.Debug
logger.Println("Routine, sequential sender, started for peer", peer.id) logDebug.Println("Routine, sequential sender, started for peer", peer.id)
device := peer.device device := peer.device
for { for {
select { select {
case <-peer.signal.stop: case <-peer.signal.stop:
logger.Println("Routine, sequential sender, stopped for peer", peer.id) logDebug.Println("Routine, sequential sender, stopped for peer", peer.id)
return return
case work := <-peer.queue.outbound: case work := <-peer.queue.outbound:
if work.IsDropped() { if work.IsDropped() {
@ -316,7 +313,7 @@ func (peer *Peer) RoutineSequentialSender() {
defer peer.mutex.RUnlock() defer peer.mutex.RUnlock()
if peer.endpoint == nil { if peer.endpoint == nil {
logger.Println("No endpoint for peer:", peer.id) logDebug.Println("No endpoint for peer:", peer.id)
return return
} }
@ -324,7 +321,7 @@ func (peer *Peer) RoutineSequentialSender() {
defer device.net.mutex.RUnlock() defer device.net.mutex.RUnlock()
if device.net.conn == nil { if device.net.conn == nil {
logger.Println("No source for device") logDebug.Println("No source for device")
return return
} }

View File

@ -74,6 +74,6 @@ func CreateTUN(name string) (TUNDevice, error) {
return &NativeTun{ return &NativeTun{
fd: fd, fd: fd,
name: newName, name: newName,
mtu: 1500, // TODO: FIX mtu: 0,
}, nil }, nil
} }