296 lines
7.8 KiB
Go
296 lines
7.8 KiB
Go
|
//+build linux
|
||
|
|
||
|
package wglinux
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"unsafe"
|
||
|
|
||
|
"github.com/mdlayher/netlink"
|
||
|
"github.com/mdlayher/netlink/nlenc"
|
||
|
"golang.org/x/sys/unix"
|
||
|
"golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
|
||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||
|
)
|
||
|
|
||
|
// TODO(mdlayher): netlink message chunking with large configurations.
|
||
|
|
||
|
// configAttrs creates the required encoded netlink attributes to configure
|
||
|
// the device specified by name using the non-nil fields in cfg.
|
||
|
func configAttrs(name string, cfg wgtypes.Config) ([]byte, error) {
|
||
|
ae := netlink.NewAttributeEncoder()
|
||
|
ae.String(wgh.DeviceAIfname, name)
|
||
|
|
||
|
if cfg.PrivateKey != nil {
|
||
|
ae.Bytes(wgh.DeviceAPrivateKey, (*cfg.PrivateKey)[:])
|
||
|
}
|
||
|
|
||
|
if cfg.ListenPort != nil {
|
||
|
ae.Uint16(wgh.DeviceAListenPort, uint16(*cfg.ListenPort))
|
||
|
}
|
||
|
|
||
|
if cfg.FirewallMark != nil {
|
||
|
ae.Uint32(wgh.DeviceAFwmark, uint32(*cfg.FirewallMark))
|
||
|
}
|
||
|
|
||
|
if cfg.ReplacePeers {
|
||
|
ae.Uint32(wgh.DeviceAFlags, wgh.DeviceFReplacePeers)
|
||
|
}
|
||
|
|
||
|
// Only apply peer attributes if necessary.
|
||
|
if len(cfg.Peers) > 0 {
|
||
|
ae.Nested(wgh.DeviceAPeers, func(nae *netlink.AttributeEncoder) error {
|
||
|
// Netlink arrays use type as an array index.
|
||
|
for i, p := range cfg.Peers {
|
||
|
nae.Nested(uint16(i), func(nnae *netlink.AttributeEncoder) error {
|
||
|
return encodePeer(nnae, p)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return ae.Encode()
|
||
|
}
|
||
|
|
||
|
// ipBatchChunk is a tunable allowed IP batch limit per peer.
|
||
|
//
|
||
|
// Because we don't necessarily know how much space a given peer will occupy,
|
||
|
// we play it safe and use a reasonably small value. Note that this constant
|
||
|
// is used both in this package and tests, so be aware when making changes.
|
||
|
const ipBatchChunk = 256
|
||
|
|
||
|
// peerBatchChunk specifies the number of peers that can appear in a
|
||
|
// configuration before we start splitting it into chunks.
|
||
|
const peerBatchChunk = 32
|
||
|
|
||
|
// shouldBatch determines if a configuration is sufficiently complex that it
|
||
|
// should be split into batches.
|
||
|
func shouldBatch(cfg wgtypes.Config) bool {
|
||
|
if len(cfg.Peers) > peerBatchChunk {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
var ips int
|
||
|
for _, p := range cfg.Peers {
|
||
|
ips += len(p.AllowedIPs)
|
||
|
}
|
||
|
|
||
|
return ips > ipBatchChunk
|
||
|
}
|
||
|
|
||
|
// buildBatches produces a batch of configs from a single config, if needed.
|
||
|
func buildBatches(cfg wgtypes.Config) []wgtypes.Config {
|
||
|
// Is this a small configuration; no need to batch?
|
||
|
if !shouldBatch(cfg) {
|
||
|
return []wgtypes.Config{cfg}
|
||
|
}
|
||
|
|
||
|
// Use most fields of cfg for our "base" configuration, and only differ
|
||
|
// peers in each batch.
|
||
|
base := cfg
|
||
|
base.Peers = nil
|
||
|
|
||
|
// Track the known peers so that peer IPs are not replaced if a single
|
||
|
// peer has its allowed IPs split into multiple batches.
|
||
|
knownPeers := make(map[wgtypes.Key]struct{})
|
||
|
|
||
|
batches := make([]wgtypes.Config, 0)
|
||
|
for _, p := range cfg.Peers {
|
||
|
batch := base
|
||
|
|
||
|
// Iterate until no more allowed IPs.
|
||
|
var done bool
|
||
|
for !done {
|
||
|
var tmp []net.IPNet
|
||
|
if len(p.AllowedIPs) < ipBatchChunk {
|
||
|
// IPs all fit within a batch; we are done.
|
||
|
tmp = make([]net.IPNet, len(p.AllowedIPs))
|
||
|
copy(tmp, p.AllowedIPs)
|
||
|
done = true
|
||
|
} else {
|
||
|
// IPs are larger than a single batch, copy a batch out and
|
||
|
// advance the cursor.
|
||
|
tmp = make([]net.IPNet, ipBatchChunk)
|
||
|
copy(tmp, p.AllowedIPs[:ipBatchChunk])
|
||
|
|
||
|
p.AllowedIPs = p.AllowedIPs[ipBatchChunk:]
|
||
|
|
||
|
if len(p.AllowedIPs) == 0 {
|
||
|
// IPs ended on a batch boundary; no more IPs left so end
|
||
|
// iteration after this loop.
|
||
|
done = true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pcfg := wgtypes.PeerConfig{
|
||
|
// PublicKey denotes the peer and must be present.
|
||
|
PublicKey: p.PublicKey,
|
||
|
|
||
|
// Apply the update only flag to every chunk to ensure
|
||
|
// consistency between batches when the kernel module processes
|
||
|
// them.
|
||
|
UpdateOnly: p.UpdateOnly,
|
||
|
|
||
|
// It'd be a bit weird to have a remove peer message with many
|
||
|
// IPs, but just in case, add this to every peer's message.
|
||
|
Remove: p.Remove,
|
||
|
|
||
|
// The IPs for this chunk.
|
||
|
AllowedIPs: tmp,
|
||
|
}
|
||
|
|
||
|
// Only pass certain fields on the first occurrence of a peer, so
|
||
|
// that subsequent IPs won't be wiped out and space isn't wasted.
|
||
|
if _, ok := knownPeers[p.PublicKey]; !ok {
|
||
|
knownPeers[p.PublicKey] = struct{}{}
|
||
|
|
||
|
pcfg.PresharedKey = p.PresharedKey
|
||
|
pcfg.Endpoint = p.Endpoint
|
||
|
pcfg.PersistentKeepaliveInterval = p.PersistentKeepaliveInterval
|
||
|
|
||
|
// Important: do not move or appending peers won't work.
|
||
|
pcfg.ReplaceAllowedIPs = p.ReplaceAllowedIPs
|
||
|
}
|
||
|
|
||
|
// Add a peer configuration to this batch and keep going.
|
||
|
batch.Peers = []wgtypes.PeerConfig{pcfg}
|
||
|
batches = append(batches, batch)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Do not allow peer replacement beyond the first message in a batch,
|
||
|
// so we don't overwrite our previous batch work.
|
||
|
for i := range batches {
|
||
|
if i > 0 {
|
||
|
batches[i].ReplacePeers = false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return batches
|
||
|
}
|
||
|
|
||
|
// encodePeer converts a PeerConfig into netlink attribute encoder bytes.
|
||
|
func encodePeer(ae *netlink.AttributeEncoder, p wgtypes.PeerConfig) error {
|
||
|
ae.Bytes(wgh.PeerAPublicKey, p.PublicKey[:])
|
||
|
|
||
|
// Flags are stored in a single attribute.
|
||
|
var flags uint32
|
||
|
if p.Remove {
|
||
|
flags |= wgh.PeerFRemoveMe
|
||
|
}
|
||
|
if p.ReplaceAllowedIPs {
|
||
|
flags |= wgh.PeerFReplaceAllowedips
|
||
|
}
|
||
|
if p.UpdateOnly {
|
||
|
flags |= wgh.PeerFUpdateOnly
|
||
|
}
|
||
|
if flags != 0 {
|
||
|
ae.Uint32(wgh.PeerAFlags, flags)
|
||
|
}
|
||
|
|
||
|
if p.PresharedKey != nil {
|
||
|
ae.Bytes(wgh.PeerAPresharedKey, (*p.PresharedKey)[:])
|
||
|
}
|
||
|
|
||
|
if p.Endpoint != nil {
|
||
|
ae.Do(wgh.PeerAEndpoint, func() ([]byte, error) {
|
||
|
return sockaddrBytes(*p.Endpoint)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
if p.PersistentKeepaliveInterval != nil {
|
||
|
ae.Uint16(wgh.PeerAPersistentKeepaliveInterval, uint16(p.PersistentKeepaliveInterval.Seconds()))
|
||
|
}
|
||
|
|
||
|
// Only apply allowed IPs if necessary.
|
||
|
if len(p.AllowedIPs) > 0 {
|
||
|
ae.Nested(wgh.PeerAAllowedips, func(nae *netlink.AttributeEncoder) error {
|
||
|
return encodeAllowedIPs(nae, p.AllowedIPs)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// sockaddrBytes converts a net.UDPAddr to raw sockaddr_in or sockaddr_in6 bytes.
|
||
|
func sockaddrBytes(endpoint net.UDPAddr) ([]byte, error) {
|
||
|
if !isValidIP(endpoint.IP) {
|
||
|
return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String())
|
||
|
}
|
||
|
|
||
|
// Is this an IPv6 address?
|
||
|
if isIPv6(endpoint.IP) {
|
||
|
var addr [16]byte
|
||
|
copy(addr[:], endpoint.IP.To16())
|
||
|
|
||
|
sa := unix.RawSockaddrInet6{
|
||
|
Family: unix.AF_INET6,
|
||
|
Port: sockaddrPort(endpoint.Port),
|
||
|
Addr: addr,
|
||
|
}
|
||
|
|
||
|
return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil
|
||
|
}
|
||
|
|
||
|
// IPv4 address handling.
|
||
|
var addr [4]byte
|
||
|
copy(addr[:], endpoint.IP.To4())
|
||
|
|
||
|
sa := unix.RawSockaddrInet4{
|
||
|
Family: unix.AF_INET,
|
||
|
Port: sockaddrPort(endpoint.Port),
|
||
|
Addr: addr,
|
||
|
}
|
||
|
|
||
|
return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil
|
||
|
}
|
||
|
|
||
|
// encodeAllowedIPs converts a slice net.IPNets into netlink attribute encoder
|
||
|
// bytes.
|
||
|
func encodeAllowedIPs(ae *netlink.AttributeEncoder, ipns []net.IPNet) error {
|
||
|
for i, ipn := range ipns {
|
||
|
if !isValidIP(ipn.IP) {
|
||
|
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String())
|
||
|
}
|
||
|
|
||
|
family := uint16(unix.AF_INET6)
|
||
|
if !isIPv6(ipn.IP) {
|
||
|
// Make sure address is 4 bytes if IPv4.
|
||
|
family = unix.AF_INET
|
||
|
ipn.IP = ipn.IP.To4()
|
||
|
}
|
||
|
|
||
|
// Netlink arrays use type as an array index.
|
||
|
ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error {
|
||
|
nae.Uint16(wgh.AllowedipAFamily, family)
|
||
|
nae.Bytes(wgh.AllowedipAIpaddr, ipn.IP)
|
||
|
|
||
|
ones, _ := ipn.Mask.Size()
|
||
|
nae.Uint8(wgh.AllowedipACidrMask, uint8(ones))
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// isValidIP determines if IP is a valid IPv4 or IPv6 address.
|
||
|
func isValidIP(ip net.IP) bool {
|
||
|
return ip.To16() != nil
|
||
|
}
|
||
|
|
||
|
// isIPv6 determines if IP is a valid IPv6 address.
|
||
|
func isIPv6(ip net.IP) bool {
|
||
|
return isValidIP(ip) && ip.To4() == nil
|
||
|
}
|
||
|
|
||
|
// sockaddrPort interprets port as a big endian uint16 for use passing sockaddr
|
||
|
// structures to the kernel.
|
||
|
func sockaddrPort(port int) uint16 {
|
||
|
return binary.BigEndian.Uint16(nlenc.Uint16Bytes(uint16(port)))
|
||
|
}
|