From a029b942ae25be8ea31a4945198617c76c31abcd Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Fri, 15 Jan 2021 13:24:38 -0800 Subject: [PATCH] device: expand IPCError Expand IPCError to contain a wrapped error, and add a helper to make constructing such errors easier. Add a defer-based "log on returned error" to IpcSetOperation. This lets us simplify all of the error return paths. Signed-off-by: Josh Bleecher Snyder --- device/uapi.go | 94 +++++++++++++++++++++++--------------------------- 1 file changed, 43 insertions(+), 51 deletions(-) diff --git a/device/uapi.go b/device/uapi.go index 4436e72..7f50869 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -21,15 +21,24 @@ import ( ) type IPCError struct { - int64 + code int64 // error code + err error // underlying/wrapped error } func (s IPCError) Error() string { - return fmt.Sprintf("IPC error: %d", s.int64) + return fmt.Sprintf("IPC error %d: %v", s.code, s.err) +} + +func (s IPCError) Unwrap() error { + return s.err } func (s IPCError) ErrorCode() int64 { - return s.int64 + return s.code +} + +func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError { + return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } func (device *Device) IpcGetOperation(w io.Writer) error { @@ -100,24 +109,28 @@ func (device *Device) IpcGetOperation(w io.Writer) error { for _, line := range lines { _, err := io.WriteString(w, line+"\n") if err != nil { - return &IPCError{ipc.IpcErrorIO} + return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err) } } return nil } -func (device *Device) IpcSetOperation(r io.Reader) error { - scanner := bufio.NewScanner(r) - logError := device.log.Error +func (device *Device) IpcSetOperation(r io.Reader) (err error) { + defer func() { + if err != nil { + device.log.Error.Println(err) + } + }() + logDebug := device.log.Debug var peer *Peer - dummy := false createdNewPeer := false deviceConfig := true + scanner := bufio.NewScanner(r) for scanner.Scan() { // parse line @@ -128,7 +141,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { } parts := strings.Split(line, "=") if len(parts) != 2 { - return &IPCError{ipc.IpcErrorProtocol} + return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts)) } key := parts[0] value := parts[1] @@ -142,8 +155,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { var sk NoisePrivateKey err := sk.FromMaybeZeroHex(value) if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) } logDebug.Println("UAPI: Updating private key") device.SetPrivateKey(sk) @@ -154,8 +166,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { port, err := strconv.ParseUint(value, 10, 16) if err != nil { - logError.Println("Failed to parse listen_port:", err) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) } // update port and rebind @@ -167,8 +178,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { device.net.Unlock() if err := device.BindUpdate(); err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{ipc.IpcErrorPortInUse} + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) } case "fwmark": @@ -184,15 +194,13 @@ func (device *Device) IpcSetOperation(r io.Reader) error { }() if err != nil { - logError.Println("Invalid fwmark", err) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) } logDebug.Println("UAPI: Updating fwmark") if err := device.BindSetMark(uint32(fwmark)); err != nil { - logError.Println("Failed to update fwmark:", err) - return &IPCError{ipc.IpcErrorPortInUse} + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) } case "public_key": @@ -202,15 +210,13 @@ func (device *Device) IpcSetOperation(r io.Reader) error { case "replace_peers": if value != "true" { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) } logDebug.Println("UAPI: Removing all peers") device.RemoveAllPeers() default: - logError.Println("Invalid UAPI device key:", key) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } } @@ -224,8 +230,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { var publicKey NoisePublicKey err := publicKey.FromHex(value) if err != nil { - logError.Println("Failed to get peer by public key:", err) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) } // ignore peer with public key of device @@ -244,8 +249,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { if createdNewPeer { peer, err = device.NewPeer(publicKey) if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) } logDebug.Println(peer, "- UAPI: Created") } @@ -255,8 +259,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { // allow disabling of creation if value != "true" { - logError.Println("Failed to set update only, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) } if createdNewPeer && !dummy { device.RemovePeer(peer.handshake.remoteStatic) @@ -269,8 +272,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { // remove currently selected peer from device if value != "true" { - logError.Println("Failed to set remove, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) } if !dummy { logDebug.Println(peer, "- UAPI: Removing") @@ -290,8 +292,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { peer.handshake.mutex.Unlock() if err != nil { - logError.Println("Failed to set preshared key:", err) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) } case "endpoint": @@ -312,8 +313,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { }() if err != nil { - logError.Println("Failed to set endpoint:", err, ":", value) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) } case "persistent_keepalive_interval": @@ -324,8 +324,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { secs, err := strconv.ParseUint(value, 10, 16) if err != nil { - logError.Println("Failed to set persistent keepalive interval:", err) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) } old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) @@ -334,8 +333,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { if old == 0 && secs != 0 { if err != nil { - logError.Println("Failed to get tun device status:", err) - return &IPCError{ipc.IpcErrorIO} + return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err) } if device.isUp.Get() && !dummy { peer.SendKeepalive() @@ -347,8 +345,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { logDebug.Println(peer, "- UAPI: Removing all allowedips") if value != "true" { - logError.Println("Failed to replace allowedips, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) } if dummy { @@ -363,8 +360,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { _, network, err := net.ParseCIDR(value) if err != nil { - logError.Println("Failed to set allowed ip:", err) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) } if dummy { @@ -377,13 +373,11 @@ func (device *Device) IpcSetOperation(r io.Reader) error { case "protocol_version": if value != "1" { - logError.Println("Invalid protocol version:", value) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) } default: - logError.Println("Invalid UAPI peer key:", key) - return &IPCError{ipc.IpcErrorInvalid} + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) } } } @@ -431,16 +425,14 @@ func (device *Device) IpcHandle(socket net.Conn) { err = device.IpcSetOperation(buffered.Reader) if err != nil && !errors.As(err, &status) { // should never happen - device.log.Error.Println("Invalid UAPI error:", err) - status = &IPCError{1} + status = ipcErrorf(1, "invalid UAPI error: %w", err) } case "get=1\n": err = device.IpcGetOperation(buffered.Writer) if err != nil && !errors.As(err, &status) { // should never happen - device.log.Error.Println("Invalid UAPI error:", err) - status = &IPCError{1} + status = ipcErrorf(1, "invalid UAPI error: %w", err) } default: