1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

device: allow pipelining UAPI requests

The original spec ends with \n\n especially for this reason.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-01-25 19:00:43 +01:00
parent a29767dda6
commit 18e47795e5

View File

@ -380,9 +380,6 @@ func (device *Device) IpcSet(uapiConf string) error {
} }
func (device *Device) IpcHandle(socket net.Conn) { func (device *Device) IpcHandle(socket net.Conn) {
// create buffered read/writer
defer socket.Close() defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter { buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@ -391,34 +388,43 @@ func (device *Device) IpcHandle(socket net.Conn) {
return bufio.NewReadWriter(reader, writer) return bufio.NewReadWriter(reader, writer)
}(socket) }(socket)
defer buffered.Flush() for {
op, err := buffered.ReadString('\n')
if err != nil {
return
}
op, err := buffered.ReadString('\n') // handle operation
if err != nil { switch op {
return case "set=1\n":
} err = device.IpcSetOperation(buffered.Reader)
case "get=1\n":
nextByte, err := buffered.ReadByte()
if err != nil {
return
}
if nextByte != '\n' {
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %c", nextByte, err)
break
}
err = device.IpcGetOperation(buffered.Writer)
default:
device.log.Error.Println("invalid UAPI operation:", op)
return
}
// handle operation // write status
switch op { var status *IPCError
case "set=1\n": if err != nil && !errors.As(err, &status) {
err = device.IpcSetOperation(buffered.Reader) // shouldn't happen
case "get=1\n": status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
err = device.IpcGetOperation(buffered.Writer) }
default: if status != nil {
device.log.Error.Println("invalid UAPI operation:", op) device.log.Error.Println(status)
return fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
} } else {
fmt.Fprintf(buffered, "errno=0\n\n")
// write status }
var status *IPCError buffered.Flush()
if err != nil && !errors.As(err, &status) {
// shouldn't happen
status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
}
if status != nil {
device.log.Error.Println(status)
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
} }
} }