156 lines
4.1 KiB
Go
156 lines
4.1 KiB
Go
|
package restclient
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
|
||
|
"github.com/DisgoOrg/log"
|
||
|
)
|
||
|
|
||
|
// all rest errors
|
||
|
var (
|
||
|
ErrBadGateway = errors.New("bad gateway could not reach destination")
|
||
|
ErrUnauthorized = errors.New("not authorized for this endpoint")
|
||
|
ErrBadRequest = errors.New("bad request please check your request")
|
||
|
ErrRatelimited = errors.New("too many requests")
|
||
|
)
|
||
|
|
||
|
// NewRestClient constructs a new RestClient with the given http.Client, log.Logger & useragent
|
||
|
//goland:noinspection GoUnusedExportedFunction
|
||
|
func NewRestClient(httpClient *http.Client, logger log.Logger, userAgent string, customHeader http.Header) RestClient {
|
||
|
if httpClient == nil {
|
||
|
httpClient = http.DefaultClient
|
||
|
}
|
||
|
if logger == nil {
|
||
|
logger = log.Default()
|
||
|
}
|
||
|
return &restClientImpl{userAgent: userAgent, httpClient: httpClient, logger: logger, customHeader: customHeader}
|
||
|
}
|
||
|
|
||
|
// RestClient allows doing requests to different endpoints
|
||
|
type RestClient interface {
|
||
|
Close()
|
||
|
UserAgent() string
|
||
|
HTTPClient() *http.Client
|
||
|
Logger() log.Logger
|
||
|
Do(route *CompiledAPIRoute, rqBody interface{}, rsBody interface{}) RestError
|
||
|
DoWithHeaders(route *CompiledAPIRoute, rqBody interface{}, rsBody interface{}, customHeader http.Header) RestError
|
||
|
}
|
||
|
|
||
|
type restClientImpl struct {
|
||
|
userAgent string
|
||
|
httpClient *http.Client
|
||
|
logger log.Logger
|
||
|
customHeader http.Header
|
||
|
}
|
||
|
|
||
|
func (r *restClientImpl) Close() {
|
||
|
r.httpClient.CloseIdleConnections()
|
||
|
}
|
||
|
|
||
|
func (r *restClientImpl) UserAgent() string {
|
||
|
return r.userAgent
|
||
|
}
|
||
|
|
||
|
func (r *restClientImpl) HTTPClient() *http.Client {
|
||
|
return r.httpClient
|
||
|
}
|
||
|
|
||
|
func (r *restClientImpl) Logger() log.Logger {
|
||
|
return r.logger
|
||
|
}
|
||
|
|
||
|
func (r *restClientImpl) Do(route *CompiledAPIRoute, rqBody interface{}, rsBody interface{}) RestError {
|
||
|
return r.DoWithHeaders(route, rqBody, rsBody, r.customHeader)
|
||
|
}
|
||
|
|
||
|
func (r *restClientImpl) DoWithHeaders(route *CompiledAPIRoute, rqBody interface{}, rsBody interface{}, customHeader http.Header) RestError {
|
||
|
rqBuffer := &bytes.Buffer{}
|
||
|
var contentType string
|
||
|
|
||
|
if rqBody != nil {
|
||
|
var buffer *bytes.Buffer
|
||
|
switch v := rqBody.(type) {
|
||
|
case *MultipartBuffer:
|
||
|
contentType = v.ContentType
|
||
|
buffer = v.Buffer
|
||
|
|
||
|
case url.Values:
|
||
|
contentType = "application/x-www-form-urlencoded"
|
||
|
buffer = bytes.NewBufferString(v.Encode())
|
||
|
|
||
|
default:
|
||
|
contentType = "application/json"
|
||
|
buffer = &bytes.Buffer{}
|
||
|
err := json.NewEncoder(buffer).Encode(rqBody)
|
||
|
if err != nil {
|
||
|
return NewError(nil, err)
|
||
|
}
|
||
|
}
|
||
|
body, _ := ioutil.ReadAll(io.TeeReader(buffer, rqBuffer))
|
||
|
r.Logger().Debugf("request to %s, body: %s", route.URL(), string(body))
|
||
|
}
|
||
|
|
||
|
rq, err := http.NewRequest(route.Method().String(), route.URL(), rqBuffer)
|
||
|
if err != nil {
|
||
|
return NewError(nil, err)
|
||
|
}
|
||
|
|
||
|
if customHeader != nil {
|
||
|
rq.Header = customHeader
|
||
|
}
|
||
|
rq.Header.Set("User-Agent", r.UserAgent())
|
||
|
if contentType != "" {
|
||
|
rq.Header.Set("Content-Type", contentType)
|
||
|
}
|
||
|
|
||
|
rs, err := r.httpClient.Do(rq)
|
||
|
if err != nil {
|
||
|
return NewError(rs, err)
|
||
|
}
|
||
|
|
||
|
if rs.Body != nil {
|
||
|
buffer := &bytes.Buffer{}
|
||
|
body, _ := ioutil.ReadAll(io.TeeReader(rs.Body, buffer))
|
||
|
rs.Body = ioutil.NopCloser(buffer)
|
||
|
r.Logger().Debugf("response from %s, code %d, body: %s", route.URL(), rs.StatusCode, string(body))
|
||
|
}
|
||
|
|
||
|
switch rs.StatusCode {
|
||
|
case http.StatusOK, http.StatusCreated, http.StatusNoContent:
|
||
|
if rsBody != nil && rs.Body != nil {
|
||
|
if err = json.NewDecoder(rs.Body).Decode(rsBody); err != nil {
|
||
|
r.Logger().Errorf("error unmarshalling response. error: %s", err)
|
||
|
return NewError(rs, err)
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
|
||
|
case http.StatusTooManyRequests:
|
||
|
r.Logger().Error(ErrRatelimited)
|
||
|
return NewError(rs, ErrRatelimited)
|
||
|
|
||
|
case http.StatusBadGateway:
|
||
|
r.Logger().Error(ErrBadGateway)
|
||
|
return NewError(rs, ErrBadGateway)
|
||
|
|
||
|
case http.StatusBadRequest:
|
||
|
r.Logger().Error(ErrBadRequest)
|
||
|
return NewError(rs, ErrBadRequest)
|
||
|
|
||
|
case http.StatusUnauthorized:
|
||
|
r.Logger().Error(ErrUnauthorized)
|
||
|
return NewError(rs, ErrUnauthorized)
|
||
|
|
||
|
default:
|
||
|
body, _ := ioutil.ReadAll(rq.Body)
|
||
|
return NewError(rs, fmt.Errorf("request to %s failed. statuscode: %d, body: %s", rq.URL, rs.StatusCode, body))
|
||
|
}
|
||
|
}
|