1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 09:15:14 +01:00

Rework index hashtable

This commit is contained in:
Jason A. Donenfeld 2018-05-13 18:23:40 +02:00
parent 233f079a94
commit 2c27ab205c
8 changed files with 78 additions and 88 deletions

View File

@ -56,8 +56,8 @@ type Device struct {
// unprotected / "self-synchronising resources" // unprotected / "self-synchronising resources"
indices IndexTable indexTable IndexTable
mac CookieChecker mac CookieChecker
rate struct { rate struct {
underLoadUntil atomic.Value underLoadUntil atomic.Value
@ -283,7 +283,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// initialize noise & crypt-key routine // initialize noise & crypt-key routine
device.indices.Init() device.indexTable.Init()
device.routing.table.Reset() device.routing.table.Reset()
// setup buffer pool // setup buffer pool

View File

@ -7,18 +7,14 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary"
"sync" "sync"
"unsafe"
) )
/* Index=0 is reserved for unset indecies
*
*/
type IndexTableEntry struct { type IndexTableEntry struct {
peer *Peer peer *Peer
handshake *Handshake handshake *Handshake
keyPair *Keypair keypair *Keypair
} }
type IndexTable struct { type IndexTable struct {
@ -27,34 +23,38 @@ type IndexTable struct {
} }
func randUint32() (uint32, error) { func randUint32() (uint32, error) {
var buff [4]byte var integer [4]byte
_, err := rand.Read(buff[:]) _, err := rand.Read(integer[:])
value := binary.LittleEndian.Uint32(buff[:]) return *(*uint32)(unsafe.Pointer(&integer[0])), err
return value, err
} }
func (table *IndexTable) Init() { func (table *IndexTable) Init() {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock()
table.table = make(map[uint32]IndexTableEntry) table.table = make(map[uint32]IndexTableEntry)
table.mutex.Unlock()
} }
func (table *IndexTable) Delete(index uint32) { func (table *IndexTable) Delete(index uint32) {
if index == 0 { table.mutex.Lock()
defer table.mutex.Unlock()
delete(table.table, index)
}
func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
table.mutex.Lock()
defer table.mutex.Unlock()
entry, ok := table.table[index]
if !ok {
return return
} }
table.mutex.Lock() table.table[index] = IndexTableEntry{
delete(table.table, index) peer: entry.peer,
table.mutex.Unlock() keypair: keypair,
handshake: nil,
}
} }
func (table *IndexTable) Insert(key uint32, value IndexTableEntry) { func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) {
table.mutex.Lock()
table.table[key] = value
table.mutex.Unlock()
}
func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
for { for {
// generate random index // generate random index
@ -62,9 +62,6 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
if err != nil { if err != nil {
return index, err return index, err
} }
if index == 0 {
continue
}
// check if index used // check if index used
@ -75,7 +72,7 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
continue continue
} }
// map index to handshake // check again while locked
table.mutex.Lock() table.mutex.Lock()
_, found := table.table[index] _, found := table.table[index]
@ -85,8 +82,8 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
} }
table.table[index] = IndexTableEntry{ table.table[index] = IndexTableEntry{
peer: peer, peer: peer,
handshake: &peer.handshake, handshake: handshake,
keyPair: nil, keypair: nil,
} }
table.mutex.Unlock() table.mutex.Unlock()
return index, nil return index, nil

View File

@ -44,6 +44,6 @@ func (kp *Keypairs) Current() *Keypair {
func (device *Device) DeleteKeypair(key *Keypair) { func (device *Device) DeleteKeypair(key *Keypair) {
if key != nil { if key != nil {
device.indices.Delete(key.localIndex) device.indexTable.Delete(key.localIndex)
} }
} }

View File

@ -161,7 +161,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) { if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("Static shared secret is zero") return nil, errors.New("static shared secret is zero")
} }
// create ephemeral key // create ephemeral key
@ -176,8 +176,8 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// assign index // assign index
device.indices.Delete(handshake.localIndex) device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer) handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil { if err != nil {
return nil, err return nil, err
@ -328,14 +328,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationConsumed { if handshake.state != HandshakeInitiationConsumed {
return nil, errors.New("handshake initation must be consumed first") return nil, errors.New("handshake initiation must be consumed first")
} }
// assign index // assign index
var err error var err error
device.indices.Delete(handshake.localIndex) device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer) handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -393,9 +393,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return nil return nil
} }
// lookup handshake by reciever // lookup handshake by receiver
lookup := device.indices.Lookup(msg.Receiver) lookup := device.indexTable.Lookup(msg.Receiver)
handshake := lookup.handshake handshake := lookup.handshake
if handshake == nil { if handshake == nil {
return nil return nil
@ -528,35 +528,28 @@ func (peer *Peer) NewKeypair() *Keypair {
// create AEAD instances // create AEAD instances
keyPair := new(Keypair) keypair := new(Keypair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keypair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) keypair.receive, _ = chacha20poly1305.New(recvKey[:])
setZero(sendKey[:]) setZero(sendKey[:])
setZero(recvKey[:]) setZero(recvKey[:])
keyPair.created = time.Now() keypair.created = time.Now()
keyPair.sendNonce = 0 keypair.sendNonce = 0
keyPair.replayFilter.Init() keypair.replayFilter.Init()
keyPair.isInitiator = isInitiator keypair.isInitiator = isInitiator
keyPair.localIndex = peer.handshake.localIndex keypair.localIndex = peer.handshake.localIndex
keyPair.remoteIndex = peer.handshake.remoteIndex keypair.remoteIndex = peer.handshake.remoteIndex
// remap index // remap index
device.indices.Insert( device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
handshake.localIndex,
IndexTableEntry{
peer: peer,
keyPair: keyPair,
handshake: nil,
},
)
handshake.localIndex = 0 handshake.localIndex = 0
// rotate key pairs // rotate key pairs
kp := &peer.keyPairs kp := &peer.keypairs
kp.mutex.Lock() kp.mutex.Lock()
peer.timersSessionDerived() peer.timersSessionDerived()
@ -574,14 +567,14 @@ func (peer *Peer) NewKeypair() *Keypair {
kp.previous = current kp.previous = current
} }
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
kp.current = keyPair kp.current = keypair
} else { } else {
kp.next = keyPair kp.next = keypair
device.DeleteKeypair(next) device.DeleteKeypair(next)
kp.previous = nil kp.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
} }
kp.mutex.Unlock() kp.mutex.Unlock()
return keyPair return keypair
} }

View File

@ -20,7 +20,7 @@ const (
type Peer struct { type Peer struct {
isRunning AtomicBool isRunning AtomicBool
mutex sync.RWMutex mutex sync.RWMutex
keyPairs Keypairs keypairs Keypairs
handshake Handshake handshake Handshake
device *Device device *Device
endpoint Endpoint endpoint Endpoint
@ -234,7 +234,7 @@ func (peer *Peer) Stop() {
// clear key pairs // clear key pairs
kp := &peer.keyPairs kp := &peer.keypairs
kp.mutex.Lock() kp.mutex.Lock()
device.DeleteKeypair(kp.previous) device.DeleteKeypair(kp.previous)
@ -250,7 +250,7 @@ func (peer *Peer) Stop() {
hs := &peer.handshake hs := &peer.handshake
hs.mutex.Lock() hs.mutex.Lock()
device.indices.Delete(hs.localIndex) device.indexTable.Delete(hs.localIndex)
hs.Clear() hs.Clear()
hs.mutex.Unlock() hs.mutex.Unlock()

View File

@ -31,7 +31,7 @@ type QueueInboundElement struct {
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
packet []byte packet []byte
counter uint64 counter uint64
keyPair *Keypair keypair *Keypair
endpoint Endpoint endpoint Endpoint
} }
@ -107,7 +107,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake { if peer.timers.sentLastMinuteHandshake {
return return
} }
kp := peer.keyPairs.Current() kp := peer.keypairs.Current()
if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake = true peer.timers.sentLastMinuteHandshake = true
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
@ -183,15 +183,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
receiver := binary.LittleEndian.Uint32( receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
) )
value := device.indices.Lookup(receiver) value := device.indexTable.Lookup(receiver)
keyPair := value.keyPair keypair := value.keypair
if keyPair == nil { if keypair == nil {
continue continue
} }
// check key-pair expiry // check key-pair expiry
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue continue
} }
@ -201,7 +201,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
elem := &QueueInboundElement{ elem := &QueueInboundElement{
packet: packet, packet: packet,
buffer: buffer, buffer: buffer,
keyPair: keyPair, keypair: keypair,
dropped: AtomicFalse, dropped: AtomicFalse,
endpoint: endpoint, endpoint: endpoint,
} }
@ -296,7 +296,7 @@ func (device *Device) RoutineDecryption() {
var err error var err error
elem.counter = binary.LittleEndian.Uint64(counter) elem.counter = binary.LittleEndian.Uint64(counter)
elem.packet, err = elem.keyPair.receive.Open( elem.packet, err = elem.keypair.receive.Open(
content[:0], content[:0],
nonce[:], nonce[:],
content, content,
@ -358,7 +358,7 @@ func (device *Device) RoutineHandshake() {
// lookup peer from index // lookup peer from index
entry := device.indices.Lookup(reply.Receiver) entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil { if entry.peer == nil {
continue continue
@ -587,7 +587,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for replay // check for replay
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { if !elem.keypair.replayFilter.ValidateCounter(elem.counter) {
continue continue
} }
@ -599,9 +599,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check if using new key-pair // check if using new key-pair
kp := &peer.keyPairs kp := &peer.keypairs
kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
if kp.next == elem.keyPair { if kp.next == elem.keypair {
old := kp.previous old := kp.previous
kp.previous = kp.current kp.previous = kp.current
device.DeleteKeypair(old) device.DeleteKeypair(old)

20
send.go
View File

@ -47,7 +47,7 @@ type QueueOutboundElement struct {
buffer *[MaxMessageSize]byte // slice holding the packet data buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!) packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption nonce uint64 // nonce for encryption
keyPair *Keypair // key-pair for encryption keypair *Keypair // key-pair for encryption
peer *Peer // related peer peer *Peer // related peer
} }
@ -161,7 +161,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
* *
*/ */
func (peer *Peer) keepKeyFreshSending() { func (peer *Peer) keepKeyFreshSending() {
kp := peer.keyPairs.Current() kp := peer.keypairs.Current()
if kp == nil { if kp == nil {
return return
} }
@ -260,7 +260,7 @@ func (peer *Peer) FlushNonceQueue() {
* Obs. A single instance per peer * Obs. A single instance per peer
*/ */
func (peer *Peer) RoutineNonce() { func (peer *Peer) RoutineNonce() {
var keyPair *Keypair var keypair *Keypair
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
@ -291,9 +291,9 @@ func (peer *Peer) RoutineNonce() {
// wait for key pair // wait for key pair
for { for {
keyPair = peer.keyPairs.Current() keypair = peer.keypairs.Current()
if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { if keypair != nil && keypair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keyPair.created) < RejectAfterTime { if time.Now().Sub(keypair.created) < RejectAfterTime {
break break
} }
} }
@ -328,12 +328,12 @@ func (peer *Peer) RoutineNonce() {
// populate work element // populate work element
elem.peer = peer elem.peer = peer
elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
// double check in case of race condition added by future code // double check in case of race condition added by future code
if elem.nonce >= RejectAfterMessages { if elem.nonce >= RejectAfterMessages {
goto NextPacket goto NextPacket
} }
elem.keyPair = keyPair elem.keypair = keypair
elem.dropped = AtomicFalse elem.dropped = AtomicFalse
elem.mutex.Lock() elem.mutex.Lock()
@ -392,7 +392,7 @@ func (device *Device) RoutineEncryption() {
fieldNonce := header[8:16] fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType) binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16 // pad content to multiple of 16
@ -408,7 +408,7 @@ func (device *Device) RoutineEncryption() {
// encrypt content and release to consumer // encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keyPair.send.Seal( elem.packet = elem.keypair.send.Seal(
header, header,
nonce[:], nonce[:],
elem.packet, elem.packet,

View File

@ -108,7 +108,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
hs := &peer.handshake hs := &peer.handshake
hs.mutex.Lock() hs.mutex.Lock()
kp := &peer.keyPairs kp := &peer.keypairs
kp.mutex.Lock() kp.mutex.Lock()
if kp.previous != nil { if kp.previous != nil {
@ -125,7 +125,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
} }
kp.mutex.Unlock() kp.mutex.Unlock()
peer.device.indices.Delete(hs.localIndex) peer.device.indexTable.Delete(hs.localIndex)
hs.Clear() hs.Clear()
hs.mutex.Unlock() hs.mutex.Unlock()
} }