1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

device: expand IPCError

Expand IPCError to contain a wrapped error,
and add a helper to make constructing such errors easier.

Add a defer-based "log on returned error" to IpcSetOperation.
This lets us simplify all of the error return paths.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2021-01-15 13:24:38 -08:00
parent db3fa1409c
commit a029b942ae

View File

@ -21,15 +21,24 @@ import (
) )
type IPCError struct { type IPCError struct {
int64 code int64 // error code
err error // underlying/wrapped error
} }
func (s IPCError) Error() string { func (s IPCError) Error() string {
return fmt.Sprintf("IPC error: %d", s.int64) return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
}
func (s IPCError) Unwrap() error {
return s.err
} }
func (s IPCError) ErrorCode() int64 { func (s IPCError) ErrorCode() int64 {
return s.int64 return s.code
}
func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
} }
func (device *Device) IpcGetOperation(w io.Writer) error { func (device *Device) IpcGetOperation(w io.Writer) error {
@ -100,24 +109,28 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
for _, line := range lines { for _, line := range lines {
_, err := io.WriteString(w, line+"\n") _, err := io.WriteString(w, line+"\n")
if err != nil { if err != nil {
return &IPCError{ipc.IpcErrorIO} return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
} }
} }
return nil return nil
} }
func (device *Device) IpcSetOperation(r io.Reader) error { func (device *Device) IpcSetOperation(r io.Reader) (err error) {
scanner := bufio.NewScanner(r) defer func() {
logError := device.log.Error if err != nil {
device.log.Error.Println(err)
}
}()
logDebug := device.log.Debug logDebug := device.log.Debug
var peer *Peer var peer *Peer
dummy := false dummy := false
createdNewPeer := false createdNewPeer := false
deviceConfig := true deviceConfig := true
scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
// parse line // parse line
@ -128,7 +141,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
} }
parts := strings.Split(line, "=") parts := strings.Split(line, "=")
if len(parts) != 2 { if len(parts) != 2 {
return &IPCError{ipc.IpcErrorProtocol} return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
} }
key := parts[0] key := parts[0]
value := parts[1] value := parts[1]
@ -142,8 +155,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
var sk NoisePrivateKey var sk NoisePrivateKey
err := sk.FromMaybeZeroHex(value) err := sk.FromMaybeZeroHex(value)
if err != nil { if err != nil {
logError.Println("Failed to set private_key:", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
return &IPCError{ipc.IpcErrorInvalid}
} }
logDebug.Println("UAPI: Updating private key") logDebug.Println("UAPI: Updating private key")
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
@ -154,8 +166,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
port, err := strconv.ParseUint(value, 10, 16) port, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
logError.Println("Failed to parse listen_port:", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
return &IPCError{ipc.IpcErrorInvalid}
} }
// update port and rebind // update port and rebind
@ -167,8 +178,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
device.net.Unlock() device.net.Unlock()
if err := device.BindUpdate(); err != nil { if err := device.BindUpdate(); err != nil {
logError.Println("Failed to set listen_port:", err) return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
return &IPCError{ipc.IpcErrorPortInUse}
} }
case "fwmark": case "fwmark":
@ -184,15 +194,13 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
}() }()
if err != nil { if err != nil {
logError.Println("Invalid fwmark", err) return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
return &IPCError{ipc.IpcErrorInvalid}
} }
logDebug.Println("UAPI: Updating fwmark") logDebug.Println("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(fwmark)); err != nil { if err := device.BindSetMark(uint32(fwmark)); err != nil {
logError.Println("Failed to update fwmark:", err) return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
return &IPCError{ipc.IpcErrorPortInUse}
} }
case "public_key": case "public_key":
@ -202,15 +210,13 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
case "replace_peers": case "replace_peers":
if value != "true" { if value != "true" {
logError.Println("Failed to set replace_peers, invalid value:", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
return &IPCError{ipc.IpcErrorInvalid}
} }
logDebug.Println("UAPI: Removing all peers") logDebug.Println("UAPI: Removing all peers")
device.RemoveAllPeers() device.RemoveAllPeers()
default: default:
logError.Println("Invalid UAPI device key:", key) return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
return &IPCError{ipc.IpcErrorInvalid}
} }
} }
@ -224,8 +230,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
var publicKey NoisePublicKey var publicKey NoisePublicKey
err := publicKey.FromHex(value) err := publicKey.FromHex(value)
if err != nil { if err != nil {
logError.Println("Failed to get peer by public key:", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
return &IPCError{ipc.IpcErrorInvalid}
} }
// ignore peer with public key of device // ignore peer with public key of device
@ -244,8 +249,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
if createdNewPeer { if createdNewPeer {
peer, err = device.NewPeer(publicKey) peer, err = device.NewPeer(publicKey)
if err != nil { if err != nil {
logError.Println("Failed to create new peer:", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
return &IPCError{ipc.IpcErrorInvalid}
} }
logDebug.Println(peer, "- UAPI: Created") logDebug.Println(peer, "- UAPI: Created")
} }
@ -255,8 +259,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
// allow disabling of creation // allow disabling of creation
if value != "true" { if value != "true" {
logError.Println("Failed to set update only, invalid value:", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
return &IPCError{ipc.IpcErrorInvalid}
} }
if createdNewPeer && !dummy { if createdNewPeer && !dummy {
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
@ -269,8 +272,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
// remove currently selected peer from device // remove currently selected peer from device
if value != "true" { if value != "true" {
logError.Println("Failed to set remove, invalid value:", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
return &IPCError{ipc.IpcErrorInvalid}
} }
if !dummy { if !dummy {
logDebug.Println(peer, "- UAPI: Removing") logDebug.Println(peer, "- UAPI: Removing")
@ -290,8 +292,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
if err != nil { if err != nil {
logError.Println("Failed to set preshared key:", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
return &IPCError{ipc.IpcErrorInvalid}
} }
case "endpoint": case "endpoint":
@ -312,8 +313,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
}() }()
if err != nil { if err != nil {
logError.Println("Failed to set endpoint:", err, ":", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
return &IPCError{ipc.IpcErrorInvalid}
} }
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
@ -324,8 +324,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
secs, err := strconv.ParseUint(value, 10, 16) secs, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
logError.Println("Failed to set persistent keepalive interval:", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
return &IPCError{ipc.IpcErrorInvalid}
} }
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
@ -334,8 +333,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
if old == 0 && secs != 0 { if old == 0 && secs != 0 {
if err != nil { if err != nil {
logError.Println("Failed to get tun device status:", err) return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
return &IPCError{ipc.IpcErrorIO}
} }
if device.isUp.Get() && !dummy { if device.isUp.Get() && !dummy {
peer.SendKeepalive() peer.SendKeepalive()
@ -347,8 +345,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
logDebug.Println(peer, "- UAPI: Removing all allowedips") logDebug.Println(peer, "- UAPI: Removing all allowedips")
if value != "true" { if value != "true" {
logError.Println("Failed to replace allowedips, invalid value:", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
return &IPCError{ipc.IpcErrorInvalid}
} }
if dummy { if dummy {
@ -363,8 +360,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
logError.Println("Failed to set allowed ip:", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
return &IPCError{ipc.IpcErrorInvalid}
} }
if dummy { if dummy {
@ -377,13 +373,11 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
case "protocol_version": case "protocol_version":
if value != "1" { if value != "1" {
logError.Println("Invalid protocol version:", value) return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
return &IPCError{ipc.IpcErrorInvalid}
} }
default: default:
logError.Println("Invalid UAPI peer key:", key) return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
return &IPCError{ipc.IpcErrorInvalid}
} }
} }
} }
@ -431,16 +425,14 @@ func (device *Device) IpcHandle(socket net.Conn) {
err = device.IpcSetOperation(buffered.Reader) err = device.IpcSetOperation(buffered.Reader)
if err != nil && !errors.As(err, &status) { if err != nil && !errors.As(err, &status) {
// should never happen // should never happen
device.log.Error.Println("Invalid UAPI error:", err) status = ipcErrorf(1, "invalid UAPI error: %w", err)
status = &IPCError{1}
} }
case "get=1\n": case "get=1\n":
err = device.IpcGetOperation(buffered.Writer) err = device.IpcGetOperation(buffered.Writer)
if err != nil && !errors.As(err, &status) { if err != nil && !errors.As(err, &status) {
// should never happen // should never happen
device.log.Error.Println("Invalid UAPI error:", err) status = ipcErrorf(1, "invalid UAPI error: %w", err)
status = &IPCError{1}
} }
default: default: