mirror of
https://github.com/owncast/owncast.git
synced 2024-10-10 19:16:02 +00:00
Prune expired auth requests + add global max limit. Closes #2490
This commit is contained in:
parent
a5f6f49280
commit
87eeeffa1c
@ -2,9 +2,13 @@ package fediverse
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OTPRegistration represents a single OTP request.
|
// OTPRegistration represents a single OTP request.
|
||||||
@ -18,19 +22,53 @@ type OTPRegistration struct {
|
|||||||
|
|
||||||
// Key by access token to limit one OTP request for a person
|
// Key by access token to limit one OTP request for a person
|
||||||
// to be active at a time.
|
// to be active at a time.
|
||||||
var pendingAuthRequests = make(map[string]OTPRegistration)
|
var (
|
||||||
|
pendingAuthRequests = make(map[string]OTPRegistration)
|
||||||
|
lock = sync.Mutex{}
|
||||||
|
)
|
||||||
|
|
||||||
const registrationTimeout = time.Minute * 10
|
const (
|
||||||
|
registrationTimeout = time.Minute * 10
|
||||||
|
maxPendingRequests = 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
go setupExpiredRequestPruner()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear out any pending requests that have been pending for greater than
|
||||||
|
// the specified timeout value.
|
||||||
|
func setupExpiredRequestPruner() {
|
||||||
|
pruneExpiredRequestsTimer := time.NewTicker(registrationTimeout)
|
||||||
|
|
||||||
|
for range pruneExpiredRequestsTimer.C {
|
||||||
|
lock.Lock()
|
||||||
|
log.Debugln("Pruning expired OTP requests.")
|
||||||
|
for k, v := range pendingAuthRequests {
|
||||||
|
if time.Since(v.Timestamp) > registrationTimeout {
|
||||||
|
delete(pendingAuthRequests, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterFediverseOTP will start the OTP flow for a user, creating a new
|
// RegisterFediverseOTP will start the OTP flow for a user, creating a new
|
||||||
// code and returning it to be sent to a destination.
|
// code and returning it to be sent to a destination.
|
||||||
func RegisterFediverseOTP(accessToken, userID, userDisplayName, account string) (OTPRegistration, bool) {
|
func RegisterFediverseOTP(accessToken, userID, userDisplayName, account string) (OTPRegistration, bool, error) {
|
||||||
request, requestExists := pendingAuthRequests[accessToken]
|
request, requestExists := pendingAuthRequests[accessToken]
|
||||||
|
|
||||||
// If a request is already registered and has not expired then return that
|
// If a request is already registered and has not expired then return that
|
||||||
// existing request.
|
// existing request.
|
||||||
if requestExists && time.Since(request.Timestamp) < registrationTimeout {
|
if requestExists && time.Since(request.Timestamp) < registrationTimeout {
|
||||||
return request, false
|
return request, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
if len(pendingAuthRequests)+1 > maxPendingRequests {
|
||||||
|
return request, false, errors.New("Please try again later. Too many pending requests.")
|
||||||
}
|
}
|
||||||
|
|
||||||
code, _ := createCode()
|
code, _ := createCode()
|
||||||
@ -43,7 +81,7 @@ func RegisterFediverseOTP(accessToken, userID, userDisplayName, account string)
|
|||||||
}
|
}
|
||||||
pendingAuthRequests[accessToken] = r
|
pendingAuthRequests[accessToken] = r
|
||||||
|
|
||||||
return r, true
|
return r, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateFediverseOTP will verify a OTP code for a auth request.
|
// ValidateFediverseOTP will verify a OTP code for a auth request.
|
||||||
@ -54,6 +92,9 @@ func ValidateFediverseOTP(accessToken, code string) (bool, *OTPRegistration) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
delete(pendingAuthRequests, accessToken)
|
delete(pendingAuthRequests, accessToken)
|
||||||
return true, &request
|
return true, &request
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,8 @@ package fediverse
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/owncast/owncast/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -13,7 +15,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestOTPFlowValidation(t *testing.T) {
|
func TestOTPFlowValidation(t *testing.T) {
|
||||||
r, success := RegisterFediverseOTP(accessToken, userID, userDisplayName, account)
|
r, success, err := RegisterFediverseOTP(accessToken, userID, userDisplayName, account)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
if !success {
|
if !success {
|
||||||
t.Error("Registration should be permitted.")
|
t.Error("Registration should be permitted.")
|
||||||
@ -50,8 +55,8 @@ func TestOTPFlowValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSingleOTPFlowRequest(t *testing.T) {
|
func TestSingleOTPFlowRequest(t *testing.T) {
|
||||||
r1, _ := RegisterFediverseOTP(accessToken, userID, userDisplayName, account)
|
r1, _, _ := RegisterFediverseOTP(accessToken, userID, userDisplayName, account)
|
||||||
r2, s2 := RegisterFediverseOTP(accessToken, userID, userDisplayName, account)
|
r2, s2, _ := RegisterFediverseOTP(accessToken, userID, userDisplayName, account)
|
||||||
|
|
||||||
if r1.Code != r2.Code {
|
if r1.Code != r2.Code {
|
||||||
t.Error("Only one registration should be permitted.")
|
t.Error("Only one registration should be permitted.")
|
||||||
@ -65,14 +70,42 @@ func TestSingleOTPFlowRequest(t *testing.T) {
|
|||||||
func TestAccountCaseInsensitive(t *testing.T) {
|
func TestAccountCaseInsensitive(t *testing.T) {
|
||||||
account := "Account"
|
account := "Account"
|
||||||
accessToken := "another-fake-access-token"
|
accessToken := "another-fake-access-token"
|
||||||
r1, _ := RegisterFediverseOTP(accessToken, userID, userDisplayName, account)
|
r1, _, _ := RegisterFediverseOTP(accessToken, userID, userDisplayName, account)
|
||||||
_, reg1 := ValidateFediverseOTP(accessToken, r1.Code)
|
_, reg1 := ValidateFediverseOTP(accessToken, r1.Code)
|
||||||
|
|
||||||
// Simulate second auth with account in different case
|
// Simulate second auth with account in different case
|
||||||
r2, _ := RegisterFediverseOTP(accessToken, userID, userDisplayName, strings.ToUpper(account))
|
r2, _, _ := RegisterFediverseOTP(accessToken, userID, userDisplayName, strings.ToUpper(account))
|
||||||
_, reg2 := ValidateFediverseOTP(accessToken, r2.Code)
|
_, reg2 := ValidateFediverseOTP(accessToken, r2.Code)
|
||||||
|
|
||||||
if reg1.Account != reg2.Account {
|
if reg1.Account != reg2.Account {
|
||||||
t.Errorf("Account names should be case-insensitive: %s %s", reg1.Account, reg2.Account)
|
t.Errorf("Account names should be case-insensitive: %s %s", reg1.Account, reg2.Account)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLimitGlobalPendingRequests(t *testing.T) {
|
||||||
|
for i := 0; i < maxPendingRequests-1; i++ {
|
||||||
|
at, _ := utils.GenerateRandomString(10)
|
||||||
|
uid, _ := utils.GenerateRandomString(10)
|
||||||
|
account, _ := utils.GenerateRandomString(10)
|
||||||
|
|
||||||
|
_, success, error := RegisterFediverseOTP(at, uid, "userDisplayName", account)
|
||||||
|
if !success {
|
||||||
|
t.Error("Registration should be permitted.", i, " of ", len(pendingAuthRequests))
|
||||||
|
}
|
||||||
|
if error != nil {
|
||||||
|
t.Error(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This one should fail
|
||||||
|
at, _ := utils.GenerateRandomString(10)
|
||||||
|
uid, _ := utils.GenerateRandomString(10)
|
||||||
|
account, _ := utils.GenerateRandomString(10)
|
||||||
|
_, success, error := RegisterFediverseOTP(at, uid, "userDisplayName", account)
|
||||||
|
if success {
|
||||||
|
t.Error("Registration should not be permitted.")
|
||||||
|
}
|
||||||
|
if error == nil {
|
||||||
|
t.Error("Error should be returned.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -8,16 +8,48 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/owncast/owncast/core/data"
|
"github.com/owncast/owncast/core/data"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var pendingAuthRequests = make(map[string]*Request)
|
var (
|
||||||
|
pendingAuthRequests = make(map[string]*Request)
|
||||||
|
lock = sync.Mutex{}
|
||||||
|
)
|
||||||
|
|
||||||
|
const registrationTimeout = time.Minute * 10
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
go setupExpiredRequestPruner()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear out any pending requests that have been pending for greater than
|
||||||
|
// the specified timeout value.
|
||||||
|
func setupExpiredRequestPruner() {
|
||||||
|
pruneExpiredRequestsTimer := time.NewTicker(registrationTimeout)
|
||||||
|
|
||||||
|
for range pruneExpiredRequestsTimer.C {
|
||||||
|
lock.Lock()
|
||||||
|
log.Debugln("Pruning expired IndieAuth requests.")
|
||||||
|
for k, v := range pendingAuthRequests {
|
||||||
|
if time.Since(v.Timestamp) > registrationTimeout {
|
||||||
|
delete(pendingAuthRequests, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// StartAuthFlow will begin the IndieAuth flow by generating an auth request.
|
// StartAuthFlow will begin the IndieAuth flow by generating an auth request.
|
||||||
func StartAuthFlow(authHost, userID, accessToken, displayName string) (*url.URL, error) {
|
func StartAuthFlow(authHost, userID, accessToken, displayName string) (*url.URL, error) {
|
||||||
|
if len(pendingAuthRequests) >= maxPendingRequests {
|
||||||
|
return nil, errors.New("Please try again later. Too many pending requests.")
|
||||||
|
}
|
||||||
|
|
||||||
serverURL := data.GetServerURL()
|
serverURL := data.GetServerURL()
|
||||||
if serverURL == "" {
|
if serverURL == "" {
|
||||||
return nil, errors.New("Owncast server URL must be set when using auth")
|
return nil, errors.New("Owncast server URL must be set when using auth")
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/andybalholm/cascadia"
|
"github.com/andybalholm/cascadia"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -63,6 +64,7 @@ func createAuthRequest(authDestination, userID, displayName, accessToken, baseSe
|
|||||||
State: state,
|
State: state,
|
||||||
Redirect: &redirect,
|
Redirect: &redirect,
|
||||||
Callback: &callbackURL,
|
Callback: &callbackURL,
|
||||||
|
Timestamp: time.Now(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
35
auth/indieauth/indieauth_test.go
Normal file
35
auth/indieauth/indieauth_test.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package indieauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/owncast/owncast/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLimitGlobalPendingRequests(t *testing.T) {
|
||||||
|
// Simulate 10 pending requests
|
||||||
|
for i := 0; i < maxPendingRequests-1; i++ {
|
||||||
|
cid, _ := utils.GenerateRandomString(10)
|
||||||
|
redirectURL, _ := utils.GenerateRandomString(10)
|
||||||
|
cc, _ := utils.GenerateRandomString(10)
|
||||||
|
state, _ := utils.GenerateRandomString(10)
|
||||||
|
me, _ := utils.GenerateRandomString(10)
|
||||||
|
|
||||||
|
_, err := StartServerAuth(cid, redirectURL, cc, state, me)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Registration should be permitted.", i, " of ", len(pendingAuthRequests), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This should throw an error
|
||||||
|
cid, _ := utils.GenerateRandomString(10)
|
||||||
|
redirectURL, _ := utils.GenerateRandomString(10)
|
||||||
|
cc, _ := utils.GenerateRandomString(10)
|
||||||
|
state, _ := utils.GenerateRandomString(10)
|
||||||
|
me, _ := utils.GenerateRandomString(10)
|
||||||
|
|
||||||
|
_, err := StartServerAuth(cid, redirectURL, cc, state, me)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Registration should not be permitted.")
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,9 @@
|
|||||||
package indieauth
|
package indieauth
|
||||||
|
|
||||||
import "net/url"
|
import (
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// Request represents a single in-flight IndieAuth request.
|
// Request represents a single in-flight IndieAuth request.
|
||||||
type Request struct {
|
type Request struct {
|
||||||
@ -15,4 +18,5 @@ type Request struct {
|
|||||||
CodeChallenge string
|
CodeChallenge string
|
||||||
State string
|
State string
|
||||||
Me *url.URL
|
Me *url.URL
|
||||||
|
Timestamp time.Time
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package indieauth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/owncast/owncast/core/data"
|
"github.com/owncast/owncast/core/data"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -17,6 +18,7 @@ type ServerAuthRequest struct {
|
|||||||
State string
|
State string
|
||||||
Me string
|
Me string
|
||||||
Code string
|
Code string
|
||||||
|
Timestamp time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerProfile represents basic user-provided data about this Owncast instance.
|
// ServerProfile represents basic user-provided data about this Owncast instance.
|
||||||
@ -38,10 +40,16 @@ type ServerProfileResponse struct {
|
|||||||
|
|
||||||
var pendingServerAuthRequests = map[string]ServerAuthRequest{}
|
var pendingServerAuthRequests = map[string]ServerAuthRequest{}
|
||||||
|
|
||||||
|
const maxPendingRequests = 1000
|
||||||
|
|
||||||
// StartServerAuth will handle the authentication for the admin user of this
|
// StartServerAuth will handle the authentication for the admin user of this
|
||||||
// Owncast server. Initiated via a GET of the auth endpoint.
|
// Owncast server. Initiated via a GET of the auth endpoint.
|
||||||
// https://indieweb.org/authorization-endpoint
|
// https://indieweb.org/authorization-endpoint
|
||||||
func StartServerAuth(clientID, redirectURI, codeChallenge, state, me string) (*ServerAuthRequest, error) {
|
func StartServerAuth(clientID, redirectURI, codeChallenge, state, me string) (*ServerAuthRequest, error) {
|
||||||
|
if len(pendingServerAuthRequests)+1 >= maxPendingRequests {
|
||||||
|
return nil, errors.New("Please try again later. Too many pending requests.")
|
||||||
|
}
|
||||||
|
|
||||||
code := shortid.MustGenerate()
|
code := shortid.MustGenerate()
|
||||||
|
|
||||||
r := ServerAuthRequest{
|
r := ServerAuthRequest{
|
||||||
@ -51,6 +59,7 @@ func StartServerAuth(clientID, redirectURI, codeChallenge, state, me string) (*S
|
|||||||
State: state,
|
State: state,
|
||||||
Me: me,
|
Me: me,
|
||||||
Code: code,
|
Code: code,
|
||||||
|
Timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
pendingServerAuthRequests[code] = r
|
pendingServerAuthRequests[code] = r
|
||||||
|
@ -28,7 +28,12 @@ func RegisterFediverseOTPRequest(u user.User, w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
|
|
||||||
accessToken := r.URL.Query().Get("accessToken")
|
accessToken := r.URL.Query().Get("accessToken")
|
||||||
reg, success := fediverseauth.RegisterFediverseOTP(accessToken, u.ID, u.DisplayName, req.FediverseAccount)
|
reg, success, err := fediverseauth.RegisterFediverseOTP(accessToken, u.ID, u.DisplayName, req.FediverseAccount)
|
||||||
|
if err != nil {
|
||||||
|
controllers.WriteSimpleResponse(w, false, "Could not register auth request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !success {
|
if !success {
|
||||||
controllers.WriteSimpleResponse(w, false, "Could not register auth request. One may already be pending. Try again later.")
|
controllers.WriteSimpleResponse(w, false, "Could not register auth request. One may already be pending. Try again later.")
|
||||||
return
|
return
|
||||||
|
@ -33,7 +33,7 @@ func handleAuthEndpointGet(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
request, err := ia.StartServerAuth(clientID, redirectURI, codeChallenge, state, me)
|
request, err := ia.StartServerAuth(clientID, redirectURI, codeChallenge, state, me)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Return a human readable, HTML page as an error. JSON is no use here.
|
_ = controllers.WriteString(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user