workgroups/vendor/github.com/DisgoOrg/restclient/rest_client.go

156 lines
4.1 KiB
Go
Raw Normal View History

2021-09-24 17:34:17 +02:00
package restclient
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"github.com/DisgoOrg/log"
)
// all rest errors
var (
ErrBadGateway = errors.New("bad gateway could not reach destination")
ErrUnauthorized = errors.New("not authorized for this endpoint")
ErrBadRequest = errors.New("bad request please check your request")
ErrRatelimited = errors.New("too many requests")
)
// NewRestClient constructs a new RestClient with the given http.Client, log.Logger & useragent
//goland:noinspection GoUnusedExportedFunction
func NewRestClient(httpClient *http.Client, logger log.Logger, userAgent string, customHeader http.Header) RestClient {
if httpClient == nil {
httpClient = http.DefaultClient
}
if logger == nil {
logger = log.Default()
}
return &restClientImpl{userAgent: userAgent, httpClient: httpClient, logger: logger, customHeader: customHeader}
}
// RestClient allows doing requests to different endpoints
type RestClient interface {
Close()
UserAgent() string
HTTPClient() *http.Client
Logger() log.Logger
Do(route *CompiledAPIRoute, rqBody interface{}, rsBody interface{}) RestError
DoWithHeaders(route *CompiledAPIRoute, rqBody interface{}, rsBody interface{}, customHeader http.Header) RestError
}
type restClientImpl struct {
userAgent string
httpClient *http.Client
logger log.Logger
customHeader http.Header
}
func (r *restClientImpl) Close() {
r.httpClient.CloseIdleConnections()
}
func (r *restClientImpl) UserAgent() string {
return r.userAgent
}
func (r *restClientImpl) HTTPClient() *http.Client {
return r.httpClient
}
func (r *restClientImpl) Logger() log.Logger {
return r.logger
}
func (r *restClientImpl) Do(route *CompiledAPIRoute, rqBody interface{}, rsBody interface{}) RestError {
return r.DoWithHeaders(route, rqBody, rsBody, r.customHeader)
}
func (r *restClientImpl) DoWithHeaders(route *CompiledAPIRoute, rqBody interface{}, rsBody interface{}, customHeader http.Header) RestError {
rqBuffer := &bytes.Buffer{}
var contentType string
if rqBody != nil {
var buffer *bytes.Buffer
switch v := rqBody.(type) {
case *MultipartBuffer:
contentType = v.ContentType
buffer = v.Buffer
case url.Values:
contentType = "application/x-www-form-urlencoded"
buffer = bytes.NewBufferString(v.Encode())
default:
contentType = "application/json"
buffer = &bytes.Buffer{}
err := json.NewEncoder(buffer).Encode(rqBody)
if err != nil {
return NewError(nil, err)
}
}
body, _ := ioutil.ReadAll(io.TeeReader(buffer, rqBuffer))
r.Logger().Debugf("request to %s, body: %s", route.URL(), string(body))
}
rq, err := http.NewRequest(route.Method().String(), route.URL(), rqBuffer)
if err != nil {
return NewError(nil, err)
}
if customHeader != nil {
rq.Header = customHeader
}
rq.Header.Set("User-Agent", r.UserAgent())
if contentType != "" {
rq.Header.Set("Content-Type", contentType)
}
rs, err := r.httpClient.Do(rq)
if err != nil {
return NewError(rs, err)
}
if rs.Body != nil {
buffer := &bytes.Buffer{}
body, _ := ioutil.ReadAll(io.TeeReader(rs.Body, buffer))
rs.Body = ioutil.NopCloser(buffer)
r.Logger().Debugf("response from %s, code %d, body: %s", route.URL(), rs.StatusCode, string(body))
}
switch rs.StatusCode {
case http.StatusOK, http.StatusCreated, http.StatusNoContent:
if rsBody != nil && rs.Body != nil {
if err = json.NewDecoder(rs.Body).Decode(rsBody); err != nil {
r.Logger().Errorf("error unmarshalling response. error: %s", err)
return NewError(rs, err)
}
}
return nil
case http.StatusTooManyRequests:
r.Logger().Error(ErrRatelimited)
return NewError(rs, ErrRatelimited)
case http.StatusBadGateway:
r.Logger().Error(ErrBadGateway)
return NewError(rs, ErrBadGateway)
case http.StatusBadRequest:
r.Logger().Error(ErrBadRequest)
return NewError(rs, ErrBadRequest)
case http.StatusUnauthorized:
r.Logger().Error(ErrUnauthorized)
return NewError(rs, ErrUnauthorized)
default:
body, _ := ioutil.ReadAll(rq.Body)
return NewError(rs, fmt.Errorf("request to %s failed. statuscode: %d, body: %s", rq.URL, rs.StatusCode, body))
}
}