mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Added source verification
This commit is contained in:
parent
ed31e75739
commit
5c1ccbddf0
@ -61,8 +61,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
|
||||
if peer.endpoint != nil {
|
||||
send("endpoint=" + peer.endpoint.String())
|
||||
}
|
||||
send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
|
||||
send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
|
||||
send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
|
||||
send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
|
||||
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
|
||||
for _, ip := range device.routingTable.AllowedIPs(peer) {
|
||||
send("allowed_ip=" + ip.String())
|
||||
@ -73,7 +73,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
|
||||
// send lines
|
||||
|
||||
for _, line := range lines {
|
||||
device.log.Debug.Println("Response:", line)
|
||||
_, err := socket.WriteString(line + "\n")
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -31,10 +31,16 @@ type Device struct {
|
||||
signal struct {
|
||||
stop chan struct{}
|
||||
}
|
||||
peers map[NoisePublicKey]*Peer
|
||||
mac MACStateDevice
|
||||
congestionState int32 // used as an atomic bool
|
||||
peers map[NoisePublicKey]*Peer
|
||||
mac MACStateDevice
|
||||
}
|
||||
|
||||
const (
|
||||
CongestionStateUnderLoad = iota
|
||||
CongestionStateOkay
|
||||
)
|
||||
|
||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
|
||||
device.mutex.Lock()
|
||||
defer device.mutex.Unlock()
|
||||
@ -93,6 +99,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||
go device.RoutineDecryption()
|
||||
go device.RoutineHandshake()
|
||||
}
|
||||
go device.RoutineBusyMonitor()
|
||||
go device.RoutineReadFromTUN(tun)
|
||||
go device.RoutineReceiveIncomming()
|
||||
go device.RoutineWriteToTUN(tun)
|
||||
|
@ -17,8 +17,8 @@ type Peer struct {
|
||||
keyPairs KeyPairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
tx_bytes uint64
|
||||
rx_bytes uint64
|
||||
txBytes uint64
|
||||
rxBytes uint64
|
||||
time struct {
|
||||
lastSend time.Time // last send message
|
||||
lastHandshake time.Time // last completed handshake
|
||||
|
137
src/receive.go
137
src/receive.go
@ -72,12 +72,48 @@ func addToHandshakeQueue(
|
||||
}
|
||||
}
|
||||
|
||||
/* Routine determining the busy state of the interface
|
||||
*
|
||||
* TODO: prehaps nicer to do this in response to events
|
||||
* TODO: more well reasoned definition of "busy"
|
||||
*/
|
||||
func (device *Device) RoutineBusyMonitor() {
|
||||
samples := 0
|
||||
interval := time.Second
|
||||
for timer := time.NewTimer(interval); ; {
|
||||
|
||||
select {
|
||||
case <-device.signal.stop:
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
// compute busy heuristic
|
||||
|
||||
if len(device.queue.handshake) > QueueHandshakeBusySize {
|
||||
samples += 1
|
||||
} else if samples > 0 {
|
||||
samples -= 1
|
||||
}
|
||||
samples %= 30
|
||||
busy := samples > 5
|
||||
|
||||
// update busy state
|
||||
|
||||
if busy {
|
||||
atomic.StoreInt32(&device.congestionState, CongestionStateUnderLoad)
|
||||
} else {
|
||||
atomic.StoreInt32(&device.congestionState, CongestionStateOkay)
|
||||
}
|
||||
|
||||
timer.Reset(interval)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineReceiveIncomming() {
|
||||
|
||||
debugLog := device.log.Debug
|
||||
debugLog.Println("Routine, receive incomming, started")
|
||||
|
||||
errorLog := device.log.Error
|
||||
logDebug := device.log.Debug
|
||||
logDebug.Println("Routine, receive incomming, started")
|
||||
|
||||
var buffer []byte
|
||||
|
||||
@ -122,33 +158,6 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
|
||||
case MessageInitiationType, MessageResponseType:
|
||||
|
||||
// verify mac1
|
||||
|
||||
if !device.mac.CheckMAC1(packet) {
|
||||
debugLog.Println("Received packet with invalid mac1")
|
||||
return
|
||||
}
|
||||
|
||||
// check if busy, TODO: refine definition of "busy"
|
||||
|
||||
busy := len(device.queue.handshake) > QueueHandshakeBusySize
|
||||
if busy && !device.mac.CheckMAC2(packet, raddr) {
|
||||
sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" always follows "type"
|
||||
reply, err := device.CreateMessageCookieReply(packet, sender, raddr)
|
||||
if err != nil {
|
||||
errorLog.Println("Failed to create cookie reply:", err)
|
||||
return
|
||||
}
|
||||
writer := bytes.NewBuffer(packet[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
packet = writer.Bytes()
|
||||
_, err = device.net.conn.WriteToUDP(packet, raddr)
|
||||
if err != nil {
|
||||
debugLog.Println("Failed to send cookie reply:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// add to handshake queue
|
||||
|
||||
addToHandshakeQueue(
|
||||
@ -173,7 +182,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
reader := bytes.NewReader(packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||
if err != nil {
|
||||
debugLog.Println("Failed to decode cookie reply")
|
||||
logDebug.Println("Failed to decode cookie reply")
|
||||
return
|
||||
}
|
||||
device.ConsumeMessageCookieReply(&reply)
|
||||
@ -218,7 +227,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
|
||||
default:
|
||||
// unknown message type
|
||||
debugLog.Println("Got unknown message from:", raddr)
|
||||
logDebug.Println("Got unknown message from:", raddr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -285,6 +294,38 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
func() {
|
||||
|
||||
// verify mac1
|
||||
|
||||
if !device.mac.CheckMAC1(elem.packet) {
|
||||
logDebug.Println("Received packet with invalid mac1")
|
||||
return
|
||||
}
|
||||
|
||||
// verify mac2
|
||||
|
||||
busy := atomic.LoadInt32(&device.congestionState) == CongestionStateUnderLoad
|
||||
|
||||
if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
|
||||
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
|
||||
reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
|
||||
if err != nil {
|
||||
logError.Println("Failed to create cookie reply:", err)
|
||||
return
|
||||
}
|
||||
writer := bytes.NewBuffer(elem.packet[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
elem.packet = writer.Bytes()
|
||||
_, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
|
||||
if err != nil {
|
||||
logDebug.Println("Failed to send cookie reply:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ratelimit
|
||||
|
||||
// handle messages
|
||||
|
||||
switch elem.msgType {
|
||||
case MessageInitiationType:
|
||||
|
||||
@ -321,12 +362,12 @@ func (device *Device) RoutineHandshake() {
|
||||
logError.Println("Failed to create response message:", err)
|
||||
return
|
||||
}
|
||||
|
||||
outElem := device.NewOutboundElement()
|
||||
writer := bytes.NewBuffer(outElem.data[:0])
|
||||
binary.Write(writer, binary.LittleEndian, response)
|
||||
elem.packet = writer.Bytes()
|
||||
peer.mac.AddMacs(elem.packet)
|
||||
device.log.Debug.Println(elem.packet)
|
||||
addToOutboundQueue(peer.queue.outbound, outElem)
|
||||
|
||||
case MessageResponseType:
|
||||
@ -388,7 +429,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
}
|
||||
elem.mutex.Lock()
|
||||
|
||||
// process IP packet
|
||||
// process packet
|
||||
|
||||
func() {
|
||||
if elem.IsDropped() {
|
||||
@ -407,30 +448,54 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
return
|
||||
}
|
||||
|
||||
// strip padding
|
||||
// verify source and strip padding
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case IPv4version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < IPv4headerSize {
|
||||
return
|
||||
}
|
||||
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv4 source
|
||||
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
if device.routingTable.LookupIPv4(dst) != peer {
|
||||
return
|
||||
}
|
||||
|
||||
case IPv6version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < IPv6headerSize {
|
||||
return
|
||||
}
|
||||
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += IPv6headerSize
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv6 source
|
||||
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
if device.routingTable.LookupIPv6(dst) != peer {
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Debug.Println("Receieved packet with unknown IP version")
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddUint64(&peer.rxBytes, uint64(len(elem.packet)))
|
||||
addToInboundQueue(device.queue.inbound, elem)
|
||||
}()
|
||||
}
|
||||
|
@ -329,7 +329,7 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet)))
|
||||
atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
|
||||
|
||||
// shift keep-alive timer
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user