mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Initial implementation of source caching
Yet untested.
This commit is contained in:
parent
a72b0f7ae5
commit
e86d03dca2
19
src/conn.go
19
src/conn.go
@ -34,16 +34,21 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||
return addr, err
|
||||
}
|
||||
|
||||
func ListeningUpdate(device *Device) error {
|
||||
func UpdateUDPListener(device *Device) error {
|
||||
device.mutex.Lock()
|
||||
defer device.mutex.Unlock()
|
||||
|
||||
netc := &device.net
|
||||
netc.mutex.Lock()
|
||||
defer netc.mutex.Unlock()
|
||||
|
||||
// close existing sockets
|
||||
|
||||
if err := device.net.bind.Close(); err != nil {
|
||||
if netc.bind != nil {
|
||||
if err := netc.bind.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// open new sockets
|
||||
|
||||
@ -64,13 +69,19 @@ func ListeningUpdate(device *Device) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: clear endpoint (src) caches
|
||||
// clear cached source addresses
|
||||
|
||||
for _, peer := range device.peers {
|
||||
peer.mutex.Lock()
|
||||
peer.endpoint.value.ClearSrc()
|
||||
peer.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListeningClose(device *Device) error {
|
||||
func CloseUDPListener(device *Device) error {
|
||||
netc := &device.net
|
||||
netc.mutex.Lock()
|
||||
defer netc.mutex.Unlock()
|
||||
|
@ -133,7 +133,7 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string {
|
||||
}
|
||||
}
|
||||
|
||||
func (end *Endpoint) DestinationIP() net.IP {
|
||||
func (end *Endpoint) DstIP() net.IP {
|
||||
switch end.dst.Family {
|
||||
case unix.AF_INET6:
|
||||
return end.dst.Addr[:]
|
||||
@ -150,20 +150,24 @@ func (end *Endpoint) DestinationIP() net.IP {
|
||||
}
|
||||
}
|
||||
|
||||
func (end *Endpoint) SourceToBytes() []byte {
|
||||
func (end *Endpoint) SrcToBytes() []byte {
|
||||
ptr := unsafe.Pointer(&end.src)
|
||||
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
|
||||
return arr[:]
|
||||
}
|
||||
|
||||
func (end *Endpoint) SourceToString() string {
|
||||
func (end *Endpoint) SrcToString() string {
|
||||
return sockaddrToString(end.src)
|
||||
}
|
||||
|
||||
func (end *Endpoint) DestinationToString() string {
|
||||
func (end *Endpoint) DstToString() string {
|
||||
return sockaddrToString(end.dst)
|
||||
}
|
||||
|
||||
func (end *Endpoint) ClearDst() {
|
||||
end.dst = unix.RawSockaddrInet6{}
|
||||
}
|
||||
|
||||
func (end *Endpoint) ClearSrc() {
|
||||
end.src = unix.RawSockaddrInet6{}
|
||||
}
|
||||
|
@ -205,7 +205,7 @@ func (device *Device) RemoveAllPeers() {
|
||||
func (device *Device) Close() {
|
||||
device.RemoveAllPeers()
|
||||
close(device.signal.stop)
|
||||
ListeningClose(device)
|
||||
CloseUDPListener(device)
|
||||
}
|
||||
|
||||
func (device *Device) WaitChannel() chan struct{} {
|
||||
|
@ -14,8 +14,6 @@ func printUsage() {
|
||||
}
|
||||
|
||||
func main() {
|
||||
test()
|
||||
|
||||
// parse arguments
|
||||
|
||||
var foreground bool
|
||||
|
22
src/peer.go
22
src/peer.go
@ -14,8 +14,11 @@ type Peer struct {
|
||||
persistentKeepaliveInterval uint64
|
||||
keyPairs KeyPairs
|
||||
handshake Handshake
|
||||
endpoint Endpoint
|
||||
device *Device
|
||||
endpoint struct {
|
||||
set bool // has a known endpoint been discovered
|
||||
value Endpoint // source / destination cache
|
||||
}
|
||||
stats struct {
|
||||
txBytes uint64 // bytes send to peer (endpoint)
|
||||
rxBytes uint64 // bytes received from peer
|
||||
@ -105,6 +108,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
// reset endpoint
|
||||
|
||||
peer.endpoint.set = false
|
||||
peer.endpoint.value.ClearDst()
|
||||
peer.endpoint.value.ClearSrc()
|
||||
|
||||
// prepare queuing
|
||||
|
||||
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||
@ -129,11 +138,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
/* Returns a short string identification for logging
|
||||
*/
|
||||
func (peer *Peer) String() string {
|
||||
if !peer.endpoint.set {
|
||||
return fmt.Sprintf(
|
||||
"peer(%d unknown %s)",
|
||||
peer.id,
|
||||
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"peer(%d %s %s)",
|
||||
peer.id,
|
||||
peer.endpoint.DestinationToString(),
|
||||
peer.endpoint.value.DstToString(),
|
||||
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
||||
)
|
||||
}
|
||||
|
@ -331,7 +331,7 @@ func (device *Device) RoutineHandshake() {
|
||||
return
|
||||
}
|
||||
|
||||
srcBytes := elem.endpoint.SourceToBytes()
|
||||
srcBytes := elem.endpoint.SrcToBytes()
|
||||
if device.IsUnderLoad() {
|
||||
|
||||
// verify MAC2 field
|
||||
@ -340,8 +340,7 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
// construct cookie reply
|
||||
|
||||
logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString())
|
||||
|
||||
logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString())
|
||||
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
|
||||
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
|
||||
if err != nil {
|
||||
@ -365,9 +364,7 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
// check ratelimiter
|
||||
|
||||
if !device.ratelimiter.Allow(
|
||||
elem.endpoint.DestinationIP(),
|
||||
) {
|
||||
if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@ -398,7 +395,7 @@ func (device *Device) RoutineHandshake() {
|
||||
if peer == nil {
|
||||
logInfo.Println(
|
||||
"Recieved invalid initiation message from",
|
||||
elem.endpoint.DestinationToString(),
|
||||
elem.endpoint.DstToString(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
@ -412,7 +409,8 @@ func (device *Device) RoutineHandshake() {
|
||||
// TODO: Discover destination address also, only update on change
|
||||
|
||||
peer.mutex.Lock()
|
||||
peer.endpoint = elem.endpoint
|
||||
peer.endpoint.set = true
|
||||
peer.endpoint.value = elem.endpoint
|
||||
peer.mutex.Unlock()
|
||||
|
||||
// create response
|
||||
@ -435,7 +433,7 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
// send response
|
||||
|
||||
_, err = peer.SendBuffer(packet)
|
||||
err = peer.SendBuffer(packet)
|
||||
if err == nil {
|
||||
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||
}
|
||||
@ -458,7 +456,7 @@ func (device *Device) RoutineHandshake() {
|
||||
if peer == nil {
|
||||
logInfo.Println(
|
||||
"Recieved invalid response message from",
|
||||
elem.endpoint.DestinationToString(),
|
||||
elem.endpoint.DstToString(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
19
src/send.go
19
src/send.go
@ -105,24 +105,15 @@ func addToEncryptionQueue(
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
|
||||
func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||
peer.device.net.mutex.RLock()
|
||||
defer peer.device.net.mutex.RUnlock()
|
||||
|
||||
peer.mutex.RLock()
|
||||
defer peer.mutex.RUnlock()
|
||||
|
||||
endpoint := peer.endpoint
|
||||
if endpoint == nil {
|
||||
return 0, errors.New("No known endpoint for peer")
|
||||
if !peer.endpoint.set {
|
||||
return errors.New("No known endpoint for peer")
|
||||
}
|
||||
|
||||
conn := peer.device.net.conn
|
||||
if conn == nil {
|
||||
return 0, errors.New("No UDP socket for device")
|
||||
}
|
||||
|
||||
return conn.WriteToUDP(buffer, endpoint)
|
||||
return peer.device.net.bind.Send(buffer, &peer.endpoint.value)
|
||||
}
|
||||
|
||||
/* Reads packets from the TUN and inserts
|
||||
@ -343,7 +334,7 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||
// send message and return buffer to pool
|
||||
|
||||
length := uint64(len(elem.packet))
|
||||
_, err := peer.SendBuffer(elem.packet)
|
||||
err := peer.SendBuffer(elem.packet)
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
if err != nil {
|
||||
logDebug.Println("Failed to send authenticated packet to peer", peer.String())
|
||||
|
@ -288,7 +288,7 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||
packet := writer.Bytes()
|
||||
peer.mac.AddMacs(packet)
|
||||
|
||||
_, err = peer.SendBuffer(packet)
|
||||
err = peer.SendBuffer(packet)
|
||||
if err != nil {
|
||||
logError.Println(
|
||||
"Failed to send handshake initiation message to",
|
||||
|
@ -47,7 +47,7 @@ func (device *Device) RoutineTUNEventReader() {
|
||||
if !device.tun.isUp.Get() {
|
||||
logInfo.Println("Interface set up")
|
||||
device.tun.isUp.Set(true)
|
||||
updateUDPConn(device)
|
||||
UpdateUDPListener(device)
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ func (device *Device) RoutineTUNEventReader() {
|
||||
if device.tun.isUp.Get() {
|
||||
logInfo.Println("Interface set down")
|
||||
device.tun.isUp.Set(false)
|
||||
closeUDPConn(device)
|
||||
CloseUDPListener(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
61
src/uapi.go
61
src/uapi.go
@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
send("private_key=" + device.privateKey.ToHex())
|
||||
}
|
||||
|
||||
if device.net.addr != nil {
|
||||
send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
|
||||
if device.net.port != 0 {
|
||||
send(fmt.Sprintf("listen_port=%d", device.net.port))
|
||||
}
|
||||
|
||||
if device.net.fwmark != 0 {
|
||||
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
|
||||
}
|
||||
@ -52,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
defer peer.mutex.RUnlock()
|
||||
send("public_key=" + peer.handshake.remoteStatic.ToHex())
|
||||
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
|
||||
if peer.endpoint != nil {
|
||||
send("endpoint=" + peer.endpoint.String())
|
||||
if peer.endpoint.set {
|
||||
send("endpoint=" + peer.endpoint.value.DstToString())
|
||||
}
|
||||
|
||||
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
|
||||
@ -137,53 +138,24 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
logError.Println("Failed to set listen_port:", err)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
logError.Println("Failed to set listen_port:", err)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
}
|
||||
|
||||
device.net.mutex.Lock()
|
||||
device.net.addr = addr
|
||||
device.net.mutex.Unlock()
|
||||
|
||||
err = updateUDPConn(device)
|
||||
if err != nil {
|
||||
device.net.port = uint16(port)
|
||||
if err := UpdateUDPListener(device); err != nil {
|
||||
logError.Println("Failed to set listen_port:", err)
|
||||
return &IPCError{Code: ipcErrorPortInUse}
|
||||
}
|
||||
|
||||
// TODO: Clear source address of all peers
|
||||
|
||||
case "fwmark":
|
||||
fwmark, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
logError.Println("Invalid fwmark", err)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
}
|
||||
|
||||
device.net.mutex.Lock()
|
||||
if fwmark > 0 || device.net.fwmark > 0 {
|
||||
device.net.fwmark = uint32(fwmark)
|
||||
err := SetMark(
|
||||
device.net.conn,
|
||||
device.net.fwmark,
|
||||
)
|
||||
if err != nil {
|
||||
logError.Println("Failed to set fwmark:", err)
|
||||
device.net.mutex.Unlock()
|
||||
return &IPCError{Code: ipcErrorIO}
|
||||
}
|
||||
|
||||
// TODO: Clear source address of all peers
|
||||
}
|
||||
device.net.mutex.Unlock()
|
||||
|
||||
case "public_key":
|
||||
|
||||
// switch to peer configuration
|
||||
|
||||
deviceConfig = false
|
||||
|
||||
case "replace_peers":
|
||||
@ -218,7 +190,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
device.mutex.RLock()
|
||||
if device.publicKey.Equals(pubKey) {
|
||||
|
||||
// create dummy instance
|
||||
// create dummy instance (not added to device)
|
||||
|
||||
peer = &Peer{}
|
||||
dummy = true
|
||||
@ -244,6 +216,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
}
|
||||
|
||||
case "remove":
|
||||
|
||||
// remove currently selected peer from device
|
||||
|
||||
if value != "true" {
|
||||
logError.Println("Failed to set remove, invalid value:", value)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
@ -256,6 +231,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
dummy = true
|
||||
|
||||
case "preshared_key":
|
||||
|
||||
// update PSK
|
||||
|
||||
peer.mutex.Lock()
|
||||
err := peer.handshake.presharedKey.FromHex(value)
|
||||
peer.mutex.Unlock()
|
||||
@ -265,14 +243,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
}
|
||||
|
||||
case "endpoint":
|
||||
addr, err := parseEndpoint(value)
|
||||
|
||||
// set endpoint destination and reset handshake timer
|
||||
|
||||
peer.mutex.Lock()
|
||||
err := peer.endpoint.value.Set(value)
|
||||
peer.endpoint.set = (err == nil)
|
||||
peer.mutex.Unlock()
|
||||
if err != nil {
|
||||
logError.Println("Failed to set endpoint:", value)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
}
|
||||
peer.mutex.Lock()
|
||||
peer.endpoint = addr
|
||||
peer.mutex.Unlock()
|
||||
signalSend(peer.signal.handshakeReset)
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
|
Loading…
Reference in New Issue
Block a user