[feature] Refactor tokens, allow multiple app redirect_uris

This commit is contained in:
tobi 2025-03-02 16:46:44 +01:00
parent 8488ac9286
commit 6013a71ba4
77 changed files with 860 additions and 553 deletions

View file

@ -260,7 +260,7 @@
// Build handlers used in later initializations. // Build handlers used in later initializations.
mediaManager := media.NewManager(state) mediaManager := media.NewManager(state)
oauthServer := oauth.New(ctx, dbService) oauthServer := oauth.New(ctx, state, apiutil.GetClientScopeHandler(ctx, state))
typeConverter := typeutils.NewConverter(state) typeConverter := typeutils.NewConverter(state)
visFilter := visibility.NewFilter(state) visFilter := visibility.NewFilter(state)
intFilter := interaction.NewFilter(state) intFilter := interaction.NewFilter(state)

View file

@ -843,6 +843,19 @@ definitions:
example: https://example.org/callback?some=query example: https://example.org/callback?some=query
type: string type: string
x-go-name: RedirectURI x-go-name: RedirectURI
redirect_uris:
description: Post-authorization redirect URIs for the application (OAuth2).
example: '[https://example.org/callback?some=query]'
items:
type: string
type: array
x-go-name: RedirectURIs
scopes:
description: OAuth scopes for this application.
items:
type: string
type: array
x-go-name: Scopes
vapid_key: vapid_key:
description: Push API key for this application. description: Push API key for this application.
type: string type: string
@ -7442,16 +7455,17 @@ paths:
type: string type: string
x-go-name: ClientName x-go-name: ClientName
- description: |- - description: |-
Where the user should be redirected after authorization. Single redirect URI or newline-separated list of redirect URIs (optional).
To display the authorization code to the user instead of redirecting to a web page, use `urn:ietf:wg:oauth:2.0:oob` in this parameter. To display the authorization code to the user instead of redirecting to a web page, use `urn:ietf:wg:oauth:2.0:oob` in this parameter.
If no redirect URIs are provided, defaults to `urn:ietf:wg:oauth:2.0:oob`.
in: formData in: formData
name: redirect_uris name: redirect_uris
required: true
type: string type: string
x-go-name: RedirectURIs x-go-name: RedirectURIs
- description: |- - description: |-
Space separated list of scopes. Space separated list of scopes (optional).
If no scopes are provided, defaults to `read`. If no scopes are provided, defaults to `read`.
in: formData in: formData

View file

@ -50,7 +50,6 @@ type UserStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -67,7 +66,6 @@ type UserStandardTestSuite struct {
func (suite *UserStandardTestSuite) SetupSuite() { func (suite *UserStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -55,7 +55,6 @@ type AuthStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -71,7 +70,6 @@ type AuthStandardTestSuite struct {
func (suite *AuthStandardTestSuite) SetupSuite() { func (suite *AuthStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -20,7 +20,7 @@
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"io/ioutil" "io"
"net/http" "net/http"
"testing" "testing"
"time" "time"
@ -47,21 +47,21 @@ func (suite *TokenTestSuite) TestPOSTTokenEmptyForm() {
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
suite.NoError(err) suite.NoError(err)
suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: grant_type was not set in the token request form, but must be set to authorization_code or client_credentials: client_id was not set in the token request form: client_secret was not set in the token request form: redirect_uri was not set in the token request form"}`, string(b)) suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: grant_type was not set in the token request form, but must be set to authorization_code or client_credentials: client_id was not set in the token request form: client_secret was not set in the token request form: redirect_uri was not set in the token request form"}`, string(b))
} }
func (suite *TokenTestSuite) TestRetrieveClientCredentialsOK() { func (suite *TokenTestSuite) TestRetrieveClientCredentialsOK() {
testClient := suite.testClients["local_account_1"] testApp := suite.testApplications["application_1"]
requestBody, w, err := testrig.CreateMultipartFormData( requestBody, w, err := testrig.CreateMultipartFormData(
nil, nil,
map[string][]string{ map[string][]string{
"grant_type": {"client_credentials"}, "grant_type": {"client_credentials"},
"client_id": {testClient.ID}, "client_id": {testApp.ClientID},
"client_secret": {testClient.Secret}, "client_secret": {testApp.ClientSecret},
"redirect_uri": {"http://localhost:8080"}, "redirect_uri": {"http://localhost:8080"},
}) })
if err != nil { if err != nil {
@ -79,7 +79,7 @@ func (suite *TokenTestSuite) TestRetrieveClientCredentialsOK() {
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
suite.NoError(err) suite.NoError(err)
t := &apimodel.Token{} t := &apimodel.Token{}
@ -98,16 +98,81 @@ func (suite *TokenTestSuite) TestRetrieveClientCredentialsOK() {
suite.NotNil(dbToken) suite.NotNil(dbToken)
} }
func (suite *TokenTestSuite) TestRetrieveClientCredentialsBadScope() {
testApp := suite.testApplications["application_1"]
requestBody, w, err := testrig.CreateMultipartFormData(
nil,
map[string][]string{
"grant_type": {"client_credentials"},
"client_id": {testApp.ClientID},
"client_secret": {testApp.ClientSecret},
"redirect_uri": {"http://localhost:8080"},
"scope": {"admin"},
})
if err != nil {
panic(err)
}
bodyBytes := requestBody.Bytes()
ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType())
ctx.Request.Header.Set("accept", "application/json")
suite.authModule.TokenPOSTHandler(ctx)
suite.Equal(http.StatusForbidden, recorder.Code)
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
suite.NoError(err)
suite.Equal(`{"error":"invalid_scope","error_description":"Forbidden: requested scope admin was not covered by client scope: If you arrived at this error during a sign in/oauth flow, please try clearing your session cookies and signing in again; if problems persist, make sure you're using the correct credentials"}`, string(b))
}
func (suite *TokenTestSuite) TestRetrieveClientCredentialsDifferentRedirectURI() {
testApp := suite.testApplications["application_1"]
requestBody, w, err := testrig.CreateMultipartFormData(
nil,
map[string][]string{
"grant_type": {"client_credentials"},
"client_id": {testApp.ClientID},
"client_secret": {testApp.ClientSecret},
"redirect_uri": {"http://somewhere.else.example.org"},
})
if err != nil {
panic(err)
}
bodyBytes := requestBody.Bytes()
ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType())
ctx.Request.Header.Set("accept", "application/json")
suite.authModule.TokenPOSTHandler(ctx)
suite.Equal(http.StatusForbidden, recorder.Code)
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
suite.NoError(err)
suite.Equal(`{"error":"invalid redirect uri","error_description":"Forbidden: requested redirect URI http://somewhere.else.example.org was not covered by client redirect URIs: If you arrived at this error during a sign in/oauth flow, please try clearing your session cookies and signing in again; if problems persist, make sure you're using the correct credentials"}`, string(b))
}
func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeOK() { func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeOK() {
testClient := suite.testClients["local_account_1"] testApp := suite.testApplications["application_1"]
testUserAuthorizationToken := suite.testTokens["local_account_1_user_authorization_token"] testUserAuthorizationToken := suite.testTokens["local_account_1_user_authorization_token"]
requestBody, w, err := testrig.CreateMultipartFormData( requestBody, w, err := testrig.CreateMultipartFormData(
nil, nil,
map[string][]string{ map[string][]string{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
"client_id": {testClient.ID}, "client_id": {testApp.ClientID},
"client_secret": {testClient.Secret}, "client_secret": {testApp.ClientSecret},
"redirect_uri": {"http://localhost:8080"}, "redirect_uri": {"http://localhost:8080"},
"code": {testUserAuthorizationToken.Code}, "code": {testUserAuthorizationToken.Code},
}) })
@ -126,7 +191,7 @@ func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeOK() {
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
suite.NoError(err) suite.NoError(err)
t := &apimodel.Token{} t := &apimodel.Token{}
@ -145,14 +210,14 @@ func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeOK() {
} }
func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeNoCode() { func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeNoCode() {
testClient := suite.testClients["local_account_1"] testApp := suite.testApplications["application_1"]
requestBody, w, err := testrig.CreateMultipartFormData( requestBody, w, err := testrig.CreateMultipartFormData(
nil, nil,
map[string][]string{ map[string][]string{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
"client_id": {testClient.ID}, "client_id": {testApp.ClientID},
"client_secret": {testClient.Secret}, "client_secret": {testApp.ClientSecret},
"redirect_uri": {"http://localhost:8080"}, "redirect_uri": {"http://localhost:8080"},
}) })
if err != nil { if err != nil {
@ -170,21 +235,21 @@ func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeNoCode() {
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
suite.NoError(err) suite.NoError(err)
suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: code was not set in the token request form, but must be set since grant_type is authorization_code"}`, string(b)) suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: code was not set in the token request form, but must be set since grant_type is authorization_code"}`, string(b))
} }
func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeWrongGrantType() { func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeWrongGrantType() {
testClient := suite.testClients["local_account_1"] testApplication := suite.testApplications["application_1"]
requestBody, w, err := testrig.CreateMultipartFormData( requestBody, w, err := testrig.CreateMultipartFormData(
nil, nil,
map[string][]string{ map[string][]string{
"grant_type": {"client_credentials"}, "grant_type": {"client_credentials"},
"client_id": {testClient.ID}, "client_id": {testApplication.ClientID},
"client_secret": {testClient.Secret}, "client_secret": {testApplication.ClientSecret},
"redirect_uri": {"http://localhost:8080"}, "redirect_uri": {"http://localhost:8080"},
"code": {"peepeepoopoo"}, "code": {"peepeepoopoo"},
}) })
@ -203,7 +268,7 @@ func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeWrongGrantType() {
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
suite.NoError(err) suite.NoError(err)
suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: a code was provided in the token request form, but grant_type was not set to authorization_code"}`, string(b)) suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: a code was provided in the token request form, but grant_type was not set to authorization_code"}`, string(b))

View file

@ -56,7 +56,6 @@ type AccountStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -69,7 +68,6 @@ type AccountStandardTestSuite struct {
func (suite *AccountStandardTestSuite) SetupSuite() { func (suite *AccountStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -56,7 +56,6 @@ type AdminStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -72,7 +71,6 @@ type AdminStandardTestSuite struct {
func (suite *AdminStandardTestSuite) SetupSuite() { func (suite *AdminStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -61,7 +61,6 @@ type BookmarkTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -77,7 +76,6 @@ type BookmarkTestSuite struct {
func (suite *BookmarkTestSuite) SetupSuite() { func (suite *BookmarkTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -44,7 +44,6 @@ type ExportsTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -55,7 +54,6 @@ type ExportsTestSuite struct {
func (suite *ExportsTestSuite) SetupSuite() { func (suite *ExportsTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -48,7 +48,6 @@ type FavouritesStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -62,7 +61,6 @@ type FavouritesStandardTestSuite struct {
func (suite *FavouritesStandardTestSuite) SetupSuite() { func (suite *FavouritesStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -53,7 +53,6 @@ type FiltersTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -68,7 +67,6 @@ type FiltersTestSuite struct {
func (suite *FiltersTestSuite) SetupSuite() { func (suite *FiltersTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -53,7 +53,6 @@ type FiltersTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -68,7 +67,6 @@ type FiltersTestSuite struct {
func (suite *FiltersTestSuite) SetupSuite() { func (suite *FiltersTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -48,7 +48,6 @@ type FollowedTagsTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -60,7 +59,6 @@ type FollowedTagsTestSuite struct {
func (suite *FollowedTagsTestSuite) SetupSuite() { func (suite *FollowedTagsTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -53,7 +53,6 @@ type FollowRequestStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -66,7 +65,6 @@ type FollowRequestStandardTestSuite struct {
func (suite *FollowRequestStandardTestSuite) SetupSuite() { func (suite *FollowRequestStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -43,7 +43,6 @@ type ImportTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -54,7 +53,6 @@ type ImportTestSuite struct {
func (suite *ImportTestSuite) SetupSuite() { func (suite *ImportTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -55,7 +55,6 @@ type InstanceStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -68,7 +67,6 @@ type InstanceStandardTestSuite struct {
func (suite *InstanceStandardTestSuite) SetupSuite() { func (suite *InstanceStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -47,7 +47,6 @@ type ListsStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -64,7 +63,6 @@ type ListsStandardTestSuite struct {
func (suite *ListsStandardTestSuite) SetupSuite() { func (suite *ListsStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -62,7 +62,6 @@ type MediaCreateTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -101,7 +100,7 @@ func (suite *MediaCreateTestSuite) SetupTest() {
) )
suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor( suite.processor = testrig.NewTestProcessor(
@ -117,7 +116,6 @@ func (suite *MediaCreateTestSuite) SetupTest() {
// setup test data // setup test data
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -60,7 +60,6 @@ type MediaUpdateTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -99,7 +98,7 @@ func (suite *MediaUpdateTestSuite) SetupTest() {
) )
suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor( suite.processor = testrig.NewTestProcessor(
@ -115,7 +114,6 @@ func (suite *MediaUpdateTestSuite) SetupTest() {
// setup test data // setup test data
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -56,7 +56,6 @@ type MutesTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -67,7 +66,6 @@ type MutesTestSuite struct {
func (suite *MutesTestSuite) SetupSuite() { func (suite *MutesTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -48,7 +48,6 @@ type NotificationsTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -63,7 +62,6 @@ type NotificationsTestSuite struct {
func (suite *NotificationsTestSuite) SetupSuite() { func (suite *NotificationsTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -48,7 +48,6 @@ type PollsStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -61,7 +60,6 @@ type PollsStandardTestSuite struct {
func (suite *PollsStandardTestSuite) SetupSuite() { func (suite *PollsStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -47,7 +47,6 @@ type PushTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -59,7 +58,6 @@ type PushTestSuite struct {
func (suite *PushTestSuite) SetupSuite() { func (suite *PushTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -47,7 +47,6 @@ type ReportsStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -60,7 +59,6 @@ type ReportsStandardTestSuite struct {
func (suite *ReportsStandardTestSuite) SetupSuite() { func (suite *ReportsStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -55,7 +55,6 @@ type SearchStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -66,7 +65,6 @@ type SearchStandardTestSuite struct {
func (suite *SearchStandardTestSuite) SetupSuite() { func (suite *SearchStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -55,7 +55,6 @@ type StatusStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -176,7 +175,6 @@ func (suite *StatusStandardTestSuite) determinateStatus(rawMap map[string]any) {
func (suite *StatusStandardTestSuite) SetupSuite() { func (suite *StatusStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -61,7 +61,6 @@ type StreamingTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -75,7 +74,6 @@ type StreamingTestSuite struct {
func (suite *StreamingTestSuite) SetupSuite() { func (suite *StreamingTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -56,7 +56,6 @@ type TagsTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -68,7 +67,6 @@ type TagsTestSuite struct {
func (suite *TagsTestSuite) SetupSuite() { func (suite *TagsTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -50,7 +50,6 @@ type UserStandardTestSuite struct {
state state.State state state.State
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -66,7 +65,6 @@ func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestLog() testrig.InitTestLog()
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -51,7 +51,6 @@ type FileserverTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -100,7 +99,7 @@ func (suite *FileserverTestSuite) SetupSuite() {
) )
suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
suite.fileServer = fileserver.New(suite.processor) suite.fileServer = fileserver.New(suite.processor)
@ -118,7 +117,6 @@ func (suite *FileserverTestSuite) SetupTest() {
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -33,12 +33,17 @@ type Application struct {
// Post-authorization redirect URI for the application (OAuth2). // Post-authorization redirect URI for the application (OAuth2).
// example: https://example.org/callback?some=query // example: https://example.org/callback?some=query
RedirectURI string `json:"redirect_uri,omitempty"` RedirectURI string `json:"redirect_uri,omitempty"`
// Post-authorization redirect URIs for the application (OAuth2).
// example: [https://example.org/callback?some=query]
RedirectURIs []string `json:"redirect_uris,omitempty"`
// Client ID associated with this application. // Client ID associated with this application.
ClientID string `json:"client_id,omitempty"` ClientID string `json:"client_id,omitempty"`
// Client secret associated with this application. // Client secret associated with this application.
ClientSecret string `json:"client_secret,omitempty"` ClientSecret string `json:"client_secret,omitempty"`
// Push API key for this application. // Push API key for this application.
VapidKey string `json:"vapid_key,omitempty"` VapidKey string `json:"vapid_key,omitempty"`
// OAuth scopes for this application.
Scopes []string `json:"scopes,omitempty"`
} }
// ApplicationCreateRequest models app create parameters. // ApplicationCreateRequest models app create parameters.
@ -50,14 +55,15 @@ type ApplicationCreateRequest struct {
// in: formData // in: formData
// required: true // required: true
ClientName string `form:"client_name" json:"client_name" xml:"client_name" binding:"required"` ClientName string `form:"client_name" json:"client_name" xml:"client_name" binding:"required"`
// Where the user should be redirected after authorization. // Single redirect URI or newline-separated list of redirect URIs (optional).
// //
// To display the authorization code to the user instead of redirecting to a web page, use `urn:ietf:wg:oauth:2.0:oob` in this parameter. // To display the authorization code to the user instead of redirecting to a web page, use `urn:ietf:wg:oauth:2.0:oob` in this parameter.
// //
// If no redirect URIs are provided, defaults to `urn:ietf:wg:oauth:2.0:oob`.
//
// in: formData // in: formData
// required: true RedirectURIs string `form:"redirect_uris" json:"redirect_uris" xml:"redirect_uris"`
RedirectURIs string `form:"redirect_uris" json:"redirect_uris" xml:"redirect_uris" binding:"required"` // Space separated list of scopes (optional).
// Space separated list of scopes.
// //
// If no scopes are provided, defaults to `read`. // If no scopes are provided, defaults to `read`.
// //

View file

@ -18,15 +18,21 @@
package util package util
import ( import (
"context"
"errors" "errors"
"slices" "slices"
"strings" "strings"
"codeberg.org/superseriousbusiness/oauth2/v4" "codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/server"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
) )
// Auth wraps an authorized token, application, user, and account. // Auth wraps an authorized token, application, user, and account.
@ -150,3 +156,51 @@ func(hasScope string) bool {
return a, nil return a, nil
} }
// GetClientScopeHandler returns a handler for testing scope on a TokenGenerateRequest.
func GetClientScopeHandler(ctx context.Context, state *state.State) server.ClientScopeHandler {
return func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) {
application, err := state.DB.GetApplicationByClientID(
gtscontext.SetBarebones(ctx),
tgr.ClientID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
log.Errorf(ctx, "database error getting application: %v", err)
return false, err
}
if application == nil {
err := gtserror.Newf("no application found with client id %s", tgr.ClientID)
return false, err
}
// Normalize scope.
if strings.TrimSpace(tgr.Scope) == "" {
tgr.Scope = "read"
}
// Make sure requested scopes are all
// within scopes permitted by application.
hasScopes := strings.Split(application.Scopes, " ")
wantsScopes := strings.Split(tgr.Scope, " ")
for _, wantsScope := range wantsScopes {
thisOK := slices.ContainsFunc(
hasScopes,
func(hasScope string) bool {
has := Scope(hasScope)
wants := Scope(wantsScope)
return has.Permits(wants)
},
)
if !thisOK {
// Requested unpermitted
// scope for this app.
return false, nil
}
}
// All OK.
return true, nil
}
}

View file

@ -93,11 +93,29 @@
// scope permits the wanted scope. // scope permits the wanted scope.
func (has Scope) Permits(wanted Scope) bool { func (has Scope) Permits(wanted Scope) bool {
if has == wanted { if has == wanted {
// Exact match. // Exact match on either a
// top-level or granular scope.
return true return true
} }
// Check if we have a parent scope of what's wanted, // Ensure we have a
// eg., we have scope "admin", we want "admin:read". // known top-level scope.
return strings.HasPrefix(string(wanted), string(has)) switch has {
case ScopeProfile,
ScopePush,
ScopeRead,
ScopeWrite,
ScopeAdmin,
ScopeAdminRead,
ScopeAdminWrite:
// Check if top-level includes wanted,
// eg., have "admin", want "admin:read".
return strings.HasPrefix(string(wanted), string(has)+":")
default:
// Unknown top-level scope,
// can't permit anything.
return false
}
} }

View file

@ -89,6 +89,16 @@ func TestScopes(t *testing.T) {
WantsScope: util.ScopeWrite, WantsScope: util.ScopeWrite,
Expect: false, Expect: false,
}, },
{
HasScope: util.ScopeProfile,
WantsScope: util.ScopePush,
Expect: false,
},
{
HasScope: util.Scope("p"),
WantsScope: util.ScopePush,
Expect: false,
},
} { } {
res := test.HasScope.Permits(test.WantsScope) res := test.HasScope.Permits(test.WantsScope)
if res != test.Expect { if res != test.Expect {

View file

@ -50,7 +50,6 @@ type WebfingerStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -63,7 +62,6 @@ type WebfingerStandardTestSuite struct {
func (suite *WebfingerStandardTestSuite) SetupSuite() { func (suite *WebfingerStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()
@ -102,7 +100,7 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
suite.mediaManager, suite.mediaManager,
) )
suite.webfingerModule = webfinger.New(suite.processor) suite.webfingerModule = webfinger.New(suite.processor)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
} }

View file

@ -94,7 +94,7 @@ func (suite *WebfingerGetTestSuite) funkifyAccountDomain(host string, accountDom
subscriptions.New(&suite.state, suite.federator.TransportController(), suite.tc), subscriptions.New(&suite.state, suite.federator.TransportController(), suite.tc),
suite.tc, suite.tc,
suite.federator, suite.federator,
testrig.NewTestOauthServer(suite.db), testrig.NewTestOauthServer(&suite.state),
testrig.NewTestMediaManager(&suite.state), testrig.NewTestMediaManager(&suite.state),
&suite.state, &suite.state,
suite.emailSender, suite.emailSender,

View file

@ -69,7 +69,6 @@ func (c *Caches) Init() {
c.initBlock() c.initBlock()
c.initBlockIDs() c.initBlockIDs()
c.initBoostOfIDs() c.initBoostOfIDs()
c.initClient()
c.initConversation() c.initConversation()
c.initConversationLastStatusIDs() c.initConversationLastStatusIDs()
c.initDomainAllow() c.initDomainAllow()
@ -161,7 +160,6 @@ func (c *Caches) Sweep(threshold float64) {
c.DB.Block.Trim(threshold) c.DB.Block.Trim(threshold)
c.DB.BlockIDs.Trim(threshold) c.DB.BlockIDs.Trim(threshold)
c.DB.BoostOfIDs.Trim(threshold) c.DB.BoostOfIDs.Trim(threshold)
c.DB.Client.Trim(threshold)
c.DB.Conversation.Trim(threshold) c.DB.Conversation.Trim(threshold)
c.DB.ConversationLastStatusIDs.Trim(threshold) c.DB.ConversationLastStatusIDs.Trim(threshold)
c.DB.Emoji.Trim(threshold) c.DB.Emoji.Trim(threshold)

29
internal/cache/db.go vendored
View file

@ -52,9 +52,6 @@ type DBCaches struct {
// BoostOfIDs provides access to the boost of IDs list database cache. // BoostOfIDs provides access to the boost of IDs list database cache.
BoostOfIDs SliceCache[string] BoostOfIDs SliceCache[string]
// Client provides access to the gtsmodel Client database cache.
Client StructCache[*gtsmodel.Client]
// Conversation provides access to the gtsmodel Conversation database cache. // Conversation provides access to the gtsmodel Conversation database cache.
Conversation StructCache[*gtsmodel.Conversation] Conversation StructCache[*gtsmodel.Conversation]
@ -489,32 +486,6 @@ func (c *Caches) initBoostOfIDs() {
c.DB.BoostOfIDs.Init(0, cap) c.DB.BoostOfIDs.Init(0, cap)
} }
func (c *Caches) initClient() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
sizeofClient(), // model in-mem size.
config.GetCacheClientMemRatio(),
)
log.Infof(nil, "cache size = %d", cap)
copyF := func(c1 *gtsmodel.Client) *gtsmodel.Client {
c2 := new(gtsmodel.Client)
*c2 = *c1
return c2
}
c.DB.Client.Init(structr.CacheConfig[*gtsmodel.Client]{
Indices: []structr.IndexConfig{
{Fields: "ID"},
},
MaxSize: cap,
IgnoreErr: ignoreErrors,
Copy: copyF,
Invalidate: c.OnInvalidateClient,
})
}
func (c *Caches) initConversation() { func (c *Caches) initConversation() {
cap := calculateResultCacheMax( cap := calculateResultCacheMax(
sizeofConversation(), // model in-mem size. sizeofConversation(), // model in-mem size.

View file

@ -62,8 +62,7 @@ func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) {
} }
func (c *Caches) OnInvalidateApplication(app *gtsmodel.Application) { func (c *Caches) OnInvalidateApplication(app *gtsmodel.Application) {
// Invalidate cached client of this application. // TODO: invalidate tokens?
c.DB.Client.Invalidate("ID", app.ClientID)
} }
func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) {
@ -79,11 +78,6 @@ func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) {
c.DB.BlockIDs.Invalidate(block.AccountID) c.DB.BlockIDs.Invalidate(block.AccountID)
} }
func (c *Caches) OnInvalidateClient(client *gtsmodel.Client) {
// Invalidate any tokens under this client.
c.DB.Token.Invalidate("ClientID", client.ID)
}
func (c *Caches) OnInvalidateConversation(conversation *gtsmodel.Conversation) { func (c *Caches) OnInvalidateConversation(conversation *gtsmodel.Conversation) {
// Invalidate owning account's conversation list. // Invalidate owning account's conversation list.
c.DB.ConversationLastStatusIDs.Invalidate(conversation.AccountID) c.DB.ConversationLastStatusIDs.Invalidate(conversation.AccountID)

View file

@ -302,15 +302,14 @@ func sizeofAccountStats() uintptr {
func sizeofApplication() uintptr { func sizeofApplication() uintptr {
return uintptr(size.Of(&gtsmodel.Application{ return uintptr(size.Of(&gtsmodel.Application{
ID: exampleID, ID: exampleID,
CreatedAt: exampleTime, Name: exampleUsername,
UpdatedAt: exampleTime, Website: exampleURI,
Name: exampleUsername, RedirectURIs: []string{exampleURI},
Website: exampleURI, ClientID: exampleID,
RedirectURI: exampleURI, ClientSecret: exampleID,
ClientID: exampleID, Scopes: exampleTextSmall,
ClientSecret: exampleID, ManagedByUserID: exampleID,
Scopes: exampleTextSmall,
})) }))
} }
@ -325,17 +324,6 @@ func sizeofBlock() uintptr {
})) }))
} }
func sizeofClient() uintptr {
return uintptr(size.Of(&gtsmodel.Client{
ID: exampleID,
CreatedAt: exampleTime,
UpdatedAt: exampleTime,
Secret: exampleID,
Domain: exampleURI,
UserID: exampleID,
}))
}
func sizeofConversation() uintptr { func sizeofConversation() uintptr {
return uintptr(size.Of(&gtsmodel.Conversation{ return uintptr(size.Of(&gtsmodel.Conversation{
ID: exampleID, ID: exampleID,
@ -752,8 +740,7 @@ func sizeofThreadMute() uintptr {
func sizeofToken() uintptr { func sizeofToken() uintptr {
return uintptr(size.Of(&gtsmodel.Token{ return uintptr(size.Of(&gtsmodel.Token{
ID: exampleID, ID: exampleID,
CreatedAt: exampleTime, LastUsed: exampleTime,
UpdatedAt: exampleTime,
ClientID: exampleID, ClientID: exampleID,
UserID: exampleID, UserID: exampleID,
RedirectURI: exampleURI, RedirectURI: exampleURI,

View file

@ -36,15 +36,6 @@ type Application interface {
// DeleteApplicationByClientID deletes the application with corresponding client_id value from the database. // DeleteApplicationByClientID deletes the application with corresponding client_id value from the database.
DeleteApplicationByClientID(ctx context.Context, clientID string) error DeleteApplicationByClientID(ctx context.Context, clientID string) error
// GetClientByID fetches the application client from database with ID.
GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error)
// PutClient puts the given application client in the database.
PutClient(ctx context.Context, client *gtsmodel.Client) error
// DeleteClientByID deletes the application client from database with ID.
DeleteClientByID(ctx context.Context, id string) error
// GetAllTokens fetches all client oauth tokens from database. // GetAllTokens fetches all client oauth tokens from database.
GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error)
@ -63,6 +54,9 @@ type Application interface {
// PutToken puts given client oauth token in the database. // PutToken puts given client oauth token in the database.
PutToken(ctx context.Context, token *gtsmodel.Token) error PutToken(ctx context.Context, token *gtsmodel.Token) error
// UpdateToken updates the given token. Update all columns if no specific columns given.
UpdateToken(ctx context.Context, token *gtsmodel.Token, columns ...string) error
// DeleteTokenByID deletes client oauth token from database with ID. // DeleteTokenByID deletes client oauth token from database with ID.
DeleteTokenByID(ctx context.Context, id string) error DeleteTokenByID(ctx context.Context, id string) error

View file

@ -341,6 +341,7 @@ func (a *adminDB) CreateInstanceApplication(ctx context.Context) error {
// instance account's ID so this is an easy check. // instance account's ID so this is an easy check.
instanceAcct, err := a.state.DB.GetInstanceAccount(ctx, "") instanceAcct, err := a.state.DB.GetInstanceAccount(ctx, "")
if err != nil { if err != nil {
err := gtserror.Newf("db error getting instance account: %w", err)
return err return err
} }
@ -369,18 +370,14 @@ func (a *adminDB) CreateInstanceApplication(ctx context.Context) error {
clientID := instanceAcct.ID clientID := instanceAcct.ID
clientSecret := uuid.NewString() clientSecret := uuid.NewString()
appID, err := id.NewRandomULID()
if err != nil {
return err
}
// Generate the application // Generate the application
// to put in the database. // to put in the database.
app := &gtsmodel.Application{ app := &gtsmodel.Application{
ID: appID, ID: id.NewULID(),
Name: host + " instance application", Name: host + " instance application",
Website: url, Website: url,
RedirectURI: url, RedirectURIs: []string{url},
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
Scopes: "write:accounts", Scopes: "write:accounts",
@ -388,19 +385,11 @@ func (a *adminDB) CreateInstanceApplication(ctx context.Context) error {
// Store it. // Store it.
if err := a.state.DB.PutApplication(ctx, app); err != nil { if err := a.state.DB.PutApplication(ctx, app); err != nil {
err := gtserror.Newf("db error storing instance application: %w", err)
return err return err
} }
// Model an oauth client return nil
// from the application.
oc := &gtsmodel.Client{
ID: clientID,
Secret: clientSecret,
Domain: url,
}
// Store it.
return a.state.DB.PutClient(ctx, oc)
} }
func (a *adminDB) GetInstanceApplication(ctx context.Context) (*gtsmodel.Application, error) { func (a *adminDB) GetInstanceApplication(ctx context.Context) (*gtsmodel.Application, error) {

View file

@ -97,41 +97,6 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI
return nil return nil
} }
func (a *applicationDB) GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) {
return a.state.Caches.DB.Client.LoadOne("ID", func() (*gtsmodel.Client, error) {
var client gtsmodel.Client
if err := a.db.NewSelect().
Model(&client).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx); err != nil {
return nil, err
}
return &client, nil
}, id)
}
func (a *applicationDB) PutClient(ctx context.Context, client *gtsmodel.Client) error {
return a.state.Caches.DB.Client.Store(client, func() error {
_, err := a.db.NewInsert().Model(client).Exec(ctx)
return err
})
}
func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error {
_, err := a.db.NewDelete().
Table("clients").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
if err != nil {
return err
}
a.state.Caches.DB.Client.Invalidate("ID", id)
return nil
}
func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) { func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) {
var tokenIDs []string var tokenIDs []string
@ -233,6 +198,21 @@ func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) err
}) })
} }
func (a *applicationDB) UpdateToken(ctx context.Context, token *gtsmodel.Token, columns ...string) error {
_, err := a.db.
NewUpdate().
Model(token).
Column(columns...).
Where("? = ?", bun.Ident("id"), token.ID).
Exec(ctx)
if err != nil {
return err
}
a.state.Caches.DB.Token.Invalidate("ID", token.ID)
return nil
}
func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
_, err := a.db.NewDelete(). _, err := a.db.NewDelete().
Table("tokens"). Table("tokens").

View file

@ -22,7 +22,6 @@
"errors" "errors"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -45,12 +44,6 @@ func (suite *ApplicationTestSuite) TestGetApplicationBy() {
// isEqual checks if 2 application models are equal. // isEqual checks if 2 application models are equal.
isEqual := func(a1, a2 gtsmodel.Application) bool { isEqual := func(a1, a2 gtsmodel.Application) bool {
// Clear database-set fields.
a1.CreatedAt = time.Time{}
a2.CreatedAt = time.Time{}
a1.UpdatedAt = time.Time{}
a2.UpdatedAt = time.Time{}
return reflect.DeepEqual(a1, a2) return reflect.DeepEqual(a1, a2)
} }

View file

@ -35,7 +35,6 @@ type BunDBStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -62,7 +61,6 @@ type BunDBStandardTestSuite struct {
func (suite *BunDBStandardTestSuite) SetupSuite() { func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -0,0 +1,200 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
oldmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init"
newmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20250224105654_token_app_client_refactor"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Drop unused clients table.
if _, err := tx.
NewDropTable().
Table("clients").
IfExists().
Exec(ctx); err != nil {
return err
}
// Select all old model
// applications into memory.
oldApps := []*oldmodel.Application{}
if err := tx.
NewSelect().
Model(&oldApps).
Scan(ctx); err != nil {
return err
}
// Drop the old applications table.
if _, err := tx.
NewDropTable().
Table("applications").
IfExists().
Exec(ctx); err != nil {
return err
}
// Create the new applications table.
if _, err := tx.
NewCreateTable().
Model((*newmodel.Application)(nil)).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Add indexes to new applications table.
if _, err := tx.
NewCreateIndex().
Table("applications").
Index("applications_client_id_idx").
Column("client_id").
IfNotExists().
Exec(ctx); err != nil {
return err
}
if _, err := tx.
NewCreateIndex().
Table("applications").
Index("applications_managed_by_user_id_idx").
Column("managed_by_user_id").
IfNotExists().
Exec(ctx); err != nil {
return err
}
if len(oldApps) != 0 {
// Convert all the old model applications into new ones.
newApps := make([]*newmodel.Application, 0, len(oldApps))
for _, oldApp := range oldApps {
newApps = append(newApps, &newmodel.Application{
ID: id.NewULIDFromTime(oldApp.CreatedAt),
Name: oldApp.Name,
Website: oldApp.Website,
RedirectURIs: []string{oldApp.RedirectURI},
ClientID: oldApp.ClientID,
ClientSecret: oldApp.ClientSecret,
Scopes: oldApp.Scopes,
})
}
// Whack all the new apps in
// there. Lads lads lads lads!
if _, err := tx.
NewInsert().
Model(&newApps).
Exec(ctx); err != nil {
return err
}
}
// Select all the old model
// tokens into memory.
oldTokens := []*oldmodel.Token{}
if err := tx.
NewSelect().
Model(&oldTokens).
Scan(ctx); err != nil {
return err
}
// Drop the old token table.
if _, err := tx.
NewDropTable().
Table("tokens").
IfExists().
Exec(ctx); err != nil {
return err
}
// Create the new token table.
if _, err := tx.
NewCreateTable().
Model((*newmodel.Token)(nil)).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Add access index to new token table.
if _, err := tx.
NewCreateIndex().
Table("tokens").
Index("tokens_access_idx").
Column("access").
IfNotExists().
Exec(ctx); err != nil {
return err
}
if len(oldTokens) != 0 {
// Convert all the old model tokens into new ones.
newTokens := make([]*newmodel.Token, 0, len(oldTokens))
for _, oldToken := range oldTokens {
newTokens = append(newTokens, &newmodel.Token{
ID: id.NewULIDFromTime(oldToken.CreatedAt),
ClientID: oldToken.ClientID,
UserID: oldToken.UserID,
RedirectURI: oldToken.RedirectURI,
Scope: oldToken.Scope,
Code: oldToken.Code,
CodeChallenge: oldToken.CodeChallenge,
CodeChallengeMethod: oldToken.CodeChallengeMethod,
CodeCreateAt: oldToken.CodeCreateAt,
CodeExpiresAt: oldToken.CodeExpiresAt,
Access: oldToken.Access,
AccessCreateAt: oldToken.AccessCreateAt,
AccessExpiresAt: oldToken.AccessExpiresAt,
Refresh: oldToken.Refresh,
RefreshCreateAt: oldToken.RefreshCreateAt,
RefreshExpiresAt: oldToken.RefreshExpiresAt,
})
}
// Whack all the new tokens in
// there. Lads lads lads lads!
if _, err := tx.
NewInsert().
Model(&newTokens).
Exec(ctx); err != nil {
return err
}
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return nil
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -15,6 +15,15 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>. // along with this program. If not, see <http://www.gnu.org/licenses/>.
package oauth_test package gtsmodel
// TODO: write tests type Application struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"`
Name string `bun:",notnull"`
Website string `bun:",nullzero"`
RedirectURIs []string `bun:"redirect_uris,array"`
ClientID string `bun:"type:CHAR(26),nullzero,notnull"`
ClientSecret string `bun:",nullzero,notnull"`
Scopes string `bun:",notnull"`
ManagedByUserID string `bun:"type:CHAR(26),nullzero"`
}

View file

@ -0,0 +1,42 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtsmodel
import "time"
// Token is a translation of the gotosocial token
// with the ExpiresIn fields replaced with ExpiresAt.
type Token struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
LastUsed time.Time `bun:"type:timestamptz,nullzero"` // approximate time when this token was last used
ClientID string `bun:"type:CHAR(26),nullzero,notnull"` // ID of the client who owns this token
UserID string `bun:"type:CHAR(26),nullzero"` // ID of the user who owns this token
RedirectURI string `bun:",nullzero,notnull"` // Oauth redirect URI for this token
Scope string `bun:",nullzero,notnull,default:'read'"` // Oauth scope
Code string `bun:",pk,nullzero,notnull,default:''"` // Code, if present
CodeChallenge string `bun:",nullzero"` // Code challenge, if code present
CodeChallengeMethod string `bun:",nullzero"` // Code challenge method, if code present
CodeCreateAt time.Time `bun:"type:timestamptz,nullzero"` // Code created time, if code present
CodeExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Code expires at -- null means the code never expires
Access string `bun:",pk,nullzero,notnull,default:''"` // User level access token, if present
AccessCreateAt time.Time `bun:"type:timestamptz,nullzero"` // User level access token created time, if access present
AccessExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // User level access token expires at -- null means the token never expires
Refresh string `bun:",pk,nullzero,notnull,default:''"` // Refresh token, if present
RefreshCreateAt time.Time `bun:"type:timestamptz,nullzero"` // Refresh created at, if refresh present
RefreshExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Refresh expires at -- null means the refresh token never expires
}

View file

@ -42,7 +42,6 @@ type FederatingDBTestSuite struct {
state state.State state state.State
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -61,7 +60,6 @@ func (suite *FederatingDBTestSuite) getFederatorMsg(timeout time.Duration) (*mes
func (suite *FederatingDBTestSuite) SetupSuite() { func (suite *FederatingDBTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -34,7 +34,6 @@ type FilterStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -49,7 +48,6 @@ type FilterStandardTestSuite struct {
func (suite *FilterStandardTestSuite) SetupSuite() { func (suite *FilterStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -17,18 +17,39 @@
package gtsmodel package gtsmodel
import "time" import "strings"
// Application represents an application that can perform actions on behalf of a user. // Application represents an application that
// It is used to authorize tokens etc, and is associated with an oauth client id in the database. // can perform actions on behalf of a user.
//
// It is equivalent to an OAuth client.
type Application struct { type Application struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created Name string `bun:",notnull"` // name of the application given when it was created (eg., 'tusky')
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated Website string `bun:",nullzero"` // website for the application given when it was created (eg., 'https://tusky.app')
Name string `bun:",notnull"` // name of the application given when it was created (eg., 'tusky') RedirectURIs []string `bun:"redirect_uris,array"` // redirect uris requested by the application for oauth2 flow
Website string `bun:",nullzero"` // website for the application given when it was created (eg., 'https://tusky.app') ClientID string `bun:"type:CHAR(26),nullzero,notnull"` // id of the associated oauth client entity in the db
RedirectURI string `bun:",nullzero,notnull"` // redirect uri requested by the application for oauth2 flow ClientSecret string `bun:",nullzero,notnull"` // secret of the associated oauth client entity in the db
ClientID string `bun:"type:CHAR(26),nullzero,notnull"` // id of the associated oauth client entity in the db Scopes string `bun:",notnull"` // scopes requested when this app was created
ClientSecret string `bun:",nullzero,notnull"` // secret of the associated oauth client entity in the db ManagedByUserID string `bun:"type:CHAR(26),nullzero"` // id of the user that manages this application, if it was created through the settings panel
Scopes string `bun:",notnull"` // scopes requested when this app was created }
// Implements oauth2.ClientInfo.
func (a *Application) GetID() string {
return a.ClientID
}
// Implements oauth2.ClientInfo.
func (a *Application) GetSecret() string {
return a.ClientSecret
}
// Implements oauth2.ClientInfo.
func (a *Application) GetDomain() string {
return strings.Join(a.RedirectURIs, "\n")
}
// Implements oauth2.ClientInfo.
func (a *Application) GetUserID() string {
return a.ManagedByUserID
} }

View file

@ -1,30 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtsmodel
import "time"
// Client is a wrapper for OAuth client details.
type Client struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
Secret string `bun:",nullzero,notnull"` // secret generated when client was created
Domain string `bun:",nullzero,notnull"` // domain requested for client
UserID string `bun:"type:CHAR(26),nullzero"` // id of the user that this client acts on behalf of
}

View file

@ -22,22 +22,21 @@
// Token is a translation of the gotosocial token // Token is a translation of the gotosocial token
// with the ExpiresIn fields replaced with ExpiresAt. // with the ExpiresIn fields replaced with ExpiresAt.
type Token struct { type Token struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created LastUsed time.Time `bun:"type:timestamptz,nullzero"` // approximate time when this token was last used
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated ClientID string `bun:"type:CHAR(26),nullzero,notnull"` // ID of the client who owns this token
ClientID string `bun:"type:CHAR(26),nullzero,notnull"` // ID of the client who owns this token UserID string `bun:"type:CHAR(26),nullzero"` // ID of the user who owns this token
UserID string `bun:"type:CHAR(26),nullzero"` // ID of the user who owns this token RedirectURI string `bun:",nullzero,notnull"` // Oauth redirect URI for this token
RedirectURI string `bun:",nullzero,notnull"` // Oauth redirect URI for this token Scope string `bun:",nullzero,notnull,default:'read'"` // Oauth scope // Oauth scope
Scope string `bun:",notnull"` // Oauth scope Code string `bun:",pk,nullzero,notnull,default:''"` // Code, if present
Code string `bun:",pk,nullzero,notnull,default:''"` // Code, if present CodeChallenge string `bun:",nullzero"` // Code challenge, if code present
CodeChallenge string `bun:",nullzero"` // Code challenge, if code present CodeChallengeMethod string `bun:",nullzero"` // Code challenge method, if code present
CodeChallengeMethod string `bun:",nullzero"` // Code challenge method, if code present CodeCreateAt time.Time `bun:"type:timestamptz,nullzero"` // Code created time, if code present
CodeCreateAt time.Time `bun:"type:timestamptz,nullzero"` // Code created time, if code present CodeExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Code expires at -- null means the code never expires
CodeExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Code expires at -- null means the code never expires Access string `bun:",pk,nullzero,notnull,default:''"` // User level access token, if present
Access string `bun:",pk,nullzero,notnull,default:''"` // User level access token, if present AccessCreateAt time.Time `bun:"type:timestamptz,nullzero"` // User level access token created time, if access present
AccessCreateAt time.Time `bun:"type:timestamptz,nullzero"` // User level access token created time, if access present AccessExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // User level access token expires at -- null means the token never expires
AccessExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // User level access token expires at -- null means the token never expires Refresh string `bun:",pk,nullzero,notnull,default:''"` // Refresh token, if present
Refresh string `bun:",pk,nullzero,notnull,default:''"` // Refresh token, if present RefreshCreateAt time.Time `bun:"type:timestamptz,nullzero"` // Refresh created at, if refresh present
RefreshCreateAt time.Time `bun:"type:timestamptz,nullzero"` // Refresh created at, if refresh present RefreshExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Refresh expires at -- null means the refresh token never expires
RefreshExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Refresh expires at -- null means the refresh token never expires
} }

View file

@ -21,45 +21,29 @@
"context" "context"
"codeberg.org/superseriousbusiness/oauth2/v4" "codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/models" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
type clientStore struct { type clientStore struct {
db db.DB state *state.State
} }
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. // NewClientStore returns a minimal implementation of
func NewClientStore(db db.DB) oauth2.ClientStore { // oauth2.ClientStore interface, using state as storage.
pts := &clientStore{ //
db: db, // Only GetByID is implemented, Set and Delete are stubs.
} func NewClientStore(state *state.State) oauth2.ClientStore {
return pts return &clientStore{state: state}
} }
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
client, err := cs.db.GetClientByID(ctx, clientID) return cs.state.DB.GetApplicationByClientID(ctx, clientID)
if err != nil {
return nil, err
}
return models.New(
client.ID,
client.Secret,
client.Domain,
client.UserID,
), nil
} }
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { func (cs *clientStore) Set(_ context.Context, _ string, _ oauth2.ClientInfo) error {
return cs.db.PutClient(ctx, &gtsmodel.Client{ return nil
ID: cli.GetID(),
Secret: cli.GetSecret(),
Domain: cli.GetDomain(),
UserID: cli.GetUserID(),
})
} }
func (cs *clientStore) Delete(ctx context.Context, id string) error { func (cs *clientStore) Delete(_ context.Context, _ string) error {
return cs.db.DeleteClientByID(ctx, id) return nil
} }

View file

@ -25,89 +25,55 @@
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/admin" "github.com/superseriousbusiness/gotosocial/internal/admin"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
type PgClientStoreTestSuite struct { type ClientStoreTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
state state.State state state.State
testClientID string testApplications map[string]*gtsmodel.Application
testClientSecret string
testClientDomain string
testClientUserID string
} }
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout func (suite *ClientStoreTestSuite) SetupSuite() {
func (suite *PgClientStoreTestSuite) SetupSuite() { suite.testApplications = testrig.NewTestApplications()
suite.testClientID = "01FCVB74EW6YBYAEY7QG9CQQF6"
suite.testClientSecret = "4cc87402-259b-4a35-9485-2c8bf54f3763"
suite.testClientDomain = "https://example.org"
suite.testClientUserID = "01FEGYXKVCDB731QF9MVFXA4F5"
} }
// SetupTest creates a postgres connection and creates the oauth_clients table before each test func (suite *ClientStoreTestSuite) SetupTest() {
func (suite *PgClientStoreTestSuite) SetupTest() {
suite.state.Caches.Init() suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog()
suite.db = testrig.NewTestDB(&suite.state) suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db suite.state.DB = suite.db
suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers) suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
} }
// TearDownTest drops the oauth_clients table and closes the pg connection after each test func (suite *ClientStoreTestSuite) TearDownTest() {
func (suite *PgClientStoreTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
} }
func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { func (suite *ClientStoreTestSuite) TestClientStoreGet() {
// set a new client in the store testApp := suite.testApplications["application_1"]
cs := oauth.NewClientStore(suite.db) cs := oauth.NewClientStore(&suite.state)
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
suite.FailNow(err.Error())
}
// fetch that client from the store // Fetch clientInfo from the store.
client, err := cs.GetByID(context.Background(), suite.testClientID) clientInfo, err := cs.GetByID(context.Background(), testApp.ClientID)
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
// check that the values are the same // Check expected values.
suite.NotNil(client) suite.NotNil(clientInfo)
suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client) suite.Equal(testApp.ClientID, clientInfo.GetID())
suite.Equal(testApp.ClientSecret, clientInfo.GetSecret())
suite.Equal(testApp.RedirectURIs[0], clientInfo.GetDomain())
suite.Equal(testApp.ManagedByUserID, clientInfo.GetUserID())
} }
func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { func TestClientStoreTestSuite(t *testing.T) {
// set a new client in the store suite.Run(t, new(ClientStoreTestSuite))
cs := oauth.NewClientStore(suite.db)
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
suite.FailNow(err.Error())
}
// fetch the client from the store
client, err := cs.GetByID(context.Background(), suite.testClientID)
if err != nil {
suite.FailNow(err.Error())
}
// check that the values are the same
suite.NotNil(client)
suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
if err := cs.Delete(context.Background(), suite.testClientID); err != nil {
suite.FailNow(err.Error())
}
// try to get the deleted client; we should get an error
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
suite.Assert().Nil(deletedClient)
suite.Assert().EqualValues(db.ErrNoEntries, err)
}
func TestPgClientStoreTestSuite(t *testing.T) {
suite.Run(t, new(PgClientStoreTestSuite))
} }

View file

@ -22,6 +22,8 @@
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"slices"
"strings" "strings"
"codeberg.org/superseriousbusiness/oauth2/v4" "codeberg.org/superseriousbusiness/oauth2/v4"
@ -30,7 +32,10 @@
"codeberg.org/superseriousbusiness/oauth2/v4/server" "codeberg.org/superseriousbusiness/oauth2/v4/server"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
) )
const ( const (
@ -75,17 +80,58 @@ type s struct {
} }
// New returns a new oauth server that implements the Server interface // New returns a new oauth server that implements the Server interface
func New(ctx context.Context, database db.DB) Server { func New(
ts := newTokenStore(ctx, database) ctx context.Context,
cs := NewClientStore(database) state *state.State,
clientScopeHandler server.ClientScopeHandler,
) Server {
ts := newTokenStore(ctx, state)
cs := NewClientStore(state)
manager := manage.NewDefaultManager() manager := manage.NewDefaultManager()
manager.MapTokenStorage(ts) manager.MapTokenStorage(ts)
manager.MapClientStorage(cs) manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(&manage.Config{ manager.SetAuthorizeCodeTokenCfg(&manage.Config{
AccessTokenExp: 0, // access tokens don't expire -- they must be revoked // Following the Mastodon API,
IsGenerateRefresh: false, // don't use refresh tokens // access tokens don't expire.
AccessTokenExp: 0,
// Don't use refresh tokens.
IsGenerateRefresh: false,
}) })
manager.SetValidateURIHandler(func(hasRedirectList, wantsRedirect string) error {
wantsRedirectURI, err := url.Parse(wantsRedirect)
if err != nil {
return err
}
// Redirect URIs are given to us as
// a list of URIs, newline-separated.
//
// Ensure that one of them matches
// requested redirectURI.
hasRedirects := strings.Split(hasRedirectList, "\n")
if slices.ContainsFunc(
hasRedirects,
func(hasRedirect string) bool {
hasRedirectURI, err := url.Parse(hasRedirect)
if err != nil {
log.Errorf(nil, "error parsing hasRedirect: %v", err)
return false
}
// Want an exact match.
// See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/
return wantsRedirectURI.String() == hasRedirectURI.String()
},
) {
return nil
}
return oautherr.ErrInvalidRedirectURI
})
sc := &server.Config{ sc := &server.Config{
TokenType: "Bearer", TokenType: "Bearer",
// Must follow the spec. // Must follow the spec.
@ -106,6 +152,19 @@ func New(ctx context.Context, database db.DB) Server {
} }
srv := server.NewServer(sc, manager) srv := server.NewServer(sc, manager)
srv.SetAuthorizeScopeHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
// Use provided scope or
// fall back to default "read".
scope := r.FormValue("scope")
if strings.TrimSpace(scope) == "" {
scope = "read"
}
return scope, nil
})
srv.SetClientScopeHandler(clientScopeHandler)
srv.SetInternalErrorHandler(func(err error) *oautherr.Response { srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
log.Errorf(nil, "internal oauth error: %s", err) log.Errorf(nil, "internal oauth error: %s", err)
return nil return nil
@ -122,10 +181,10 @@ func New(ctx context.Context, database db.DB) Server {
} }
return userID, nil return userID, nil
}) })
srv.SetClientInfoHandler(server.ClientFormHandler) srv.SetClientInfoHandler(server.ClientFormHandler)
return &s{
server: srv, return &s{srv}
}
} }
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function // HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
@ -143,31 +202,42 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
} }
ti, err := s.server.GetAccessToken(ctx, gt, tgr) ti, err := s.server.GetAccessToken(ctx, gt, tgr)
if err != nil { switch {
case err == nil:
// No problem.
break
case errors.Is(err, oautherr.ErrInvalidScope):
help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope)
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
case errors.Is(err, oautherr.ErrInvalidRedirectURI):
help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI)
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
default:
help := fmt.Sprintf("could not get access token: %s", err) help := fmt.Sprintf("could not get access token: %s", err)
return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice) return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
} }
// Wrangle data a bit.
data := s.server.GetTokenData(ti) data := s.server.GetTokenData(ti)
// Add created_at for Mastodon API compatibility.
data["created_at"] = ti.GetAccessCreateAt().Unix()
// If expires_in is 0 or less, omit it
// from serialization so that clients don't
// interpret the token as already expired.
if expiresInI, ok := data["expires_in"]; ok { if expiresInI, ok := data["expires_in"]; ok {
switch expiresIn := expiresInI.(type) { expiresIn, ok := expiresInI.(int64)
case int64: if !ok {
// remove this key from the returned map log.Panicf(ctx, "could not cast expires_in %T as int64", expiresInI)
// if the value is 0 or less, so that clients return nil, nil
// don't interpret the token as already expired }
if expiresIn <= 0 {
delete(data, "expires_in") if expiresIn <= 0 {
} delete(data, "expires_in")
default:
err := errors.New("expires_in was set on token response, but was not an int64")
return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
} }
} }
// add this for mastodon api compatibility
data["created_at"] = ti.GetAccessCreateAt().Unix()
return data, nil return data, nil
} }
@ -207,7 +277,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
} }
req.UserID = userID req.UserID = userID
// specify the scope of authorization // Specify the scope of authorization.
if fn := s.server.AuthorizeScopeHandler; fn != nil { if fn := s.server.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r) scope, err := fn(w, r)
if err != nil { if err != nil {
@ -217,7 +287,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
} }
} }
// specify the expiration time of access token // Specify the expiration time of access token.
if fn := s.server.AccessTokenExpHandler; fn != nil { if fn := s.server.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r) exp, err := fn(w, r)
if err != nil { if err != nil {
@ -231,13 +301,28 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
return s.errorOrRedirect(err, w, req) return s.errorOrRedirect(err, w, req)
} }
// If the redirect URI is empty, the default domain provided by the client is used. // If the redirect URI is empty, use the
// first of the client's redirect URIs.
if req.RedirectURI == "" { if req.RedirectURI == "" {
client, err := s.server.Manager.GetClient(ctx, req.ClientID) client, err := s.server.Manager.GetClient(ctx, req.ClientID)
if err != nil { if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Real error.
err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err)
return gtserror.NewErrorInternalError(err)
}
if util.IsNil(client) {
// Application just not found.
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice) return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
} }
req.RedirectURI = client.GetDomain()
app, ok := client.(*gtsmodel.Application)
if !ok {
log.Panicf(ctx, "could not cast %T to *gtsmodel.Application", client)
return nil
}
req.RedirectURI = app.RedirectURIs[0]
} }
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti)) uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))

View file

@ -22,30 +22,32 @@
"errors" "errors"
"time" "time"
"codeberg.org/gruf/go-mutexes"
"codeberg.org/superseriousbusiness/oauth2/v4" "codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/models" "codeberg.org/superseriousbusiness/oauth2/v4/models"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
) )
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct { type tokenStore struct {
oauth2.TokenStore oauth2.TokenStore
db db.DB state *state.State
lastUsedLocks mutexes.MutexMap
} }
// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
// //
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired. // the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore { func newTokenStore(ctx context.Context, state *state.State) oauth2.TokenStore {
ts := &tokenStore{ ts := &tokenStore{state: state}
db: db,
}
// set the token store to clean out expired tokens once per minute, or return if we're done // Set the token store to clean out expired tokens
// once per minute, or return if we're done.
go func(ctx context.Context, ts *tokenStore) { go func(ctx context.Context, ts *tokenStore) {
cleanloop: cleanloop:
for { for {
@ -64,25 +66,48 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
return ts return ts
} }
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. // sweep clears out old tokens that have expired;
// it should be run on a loop about once per minute or so.
func (ts *tokenStore) sweep(ctx context.Context) error { func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db // Select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. //
tokens, err := ts.db.GetAllTokens(ctx) // TODO: if this becomes expensive
// (ie., there are fucking LOADS of
// tokens) then figure out a better way.
tokens, err := ts.state.DB.GetAllTokens(ctx)
if err != nil { if err != nil {
return err return err
} }
// iterate through and remove expired tokens // Remove any expired tokens, bearing
// in mind that zero time = no expiry.
now := time.Now() now := time.Now()
for _, dbt := range tokens { for _, token := range tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: var expired bool
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set. switch {
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { case !token.CodeExpiresAt.IsZero() && token.CodeExpiresAt.Before(now):
if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil { log.Tracef(ctx, "code token %s is expired", token.ID)
return err expired = true
}
case !token.RefreshExpiresAt.IsZero() && token.RefreshExpiresAt.Before(now):
log.Tracef(ctx, "refresh token %s is expired", token.ID)
expired = true
case !token.AccessExpiresAt.IsZero() && token.AccessExpiresAt.Before(now):
log.Tracef(ctx, "access token %s is expired", token.ID)
expired = true
}
if !expired {
// Token's
// still good.
continue
}
if err := ts.state.DB.DeleteTokenByID(ctx, token.ID); err != nil {
err := gtserror.Newf("db error expiring token %s: %w", token.ID, err)
return err
} }
} }
@ -90,7 +115,6 @@ func (ts *tokenStore) sweep(ctx context.Context) error {
} }
// Create creates and store the new token information. // Create creates and store the new token information.
// For the original implementation, see https://codeberg.org/superseriousbusiness/oauth2/blob/master/store/token.go#L34
func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token) t, ok := info.(*models.Token)
if !ok { if !ok {
@ -99,55 +123,97 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
dbt := TokenToDBToken(t) dbt := TokenToDBToken(t)
if dbt.ID == "" { if dbt.ID == "" {
dbtID, err := id.NewRandomULID() dbt.ID = id.NewULID()
if err != nil {
return err
}
dbt.ID = dbtID
} }
return ts.db.PutToken(ctx, dbt) return ts.state.DB.PutToken(ctx, dbt)
} }
// RemoveByCode deletes a token from the DB based on the Code field // RemoveByCode deletes a token from the DB based on the Code field
func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return ts.db.DeleteTokenByCode(ctx, code) return ts.state.DB.DeleteTokenByCode(ctx, code)
} }
// RemoveByAccess deletes a token from the DB based on the Access field // RemoveByAccess deletes a token from the DB based on the Access field
func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return ts.db.DeleteTokenByAccess(ctx, access) return ts.state.DB.DeleteTokenByAccess(ctx, access)
} }
// RemoveByRefresh deletes a token from the DB based on the Refresh field // RemoveByRefresh deletes a token from the DB based on the Refresh field
func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return ts.db.DeleteTokenByRefresh(ctx, refresh) return ts.state.DB.DeleteTokenByRefresh(ctx, refresh)
} }
// GetByCode selects a token from the DB based on the Code field // GetByCode selects a token from
func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { // the DB based on the Code field
token, err := ts.db.GetTokenByCode(ctx, code) func (ts *tokenStore) GetByCode(
if err != nil { ctx context.Context,
return nil, err code string,
} ) (oauth2.TokenInfo, error) {
return DBTokenToToken(token), nil return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByCode,
code,
)
} }
// GetByAccess selects a token from the DB based on the Access field // GetByAccess selects a token from
func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { // the DB based on the Access field.
token, err := ts.db.GetTokenByAccess(ctx, access) func (ts *tokenStore) GetByAccess(
if err != nil { ctx context.Context,
return nil, err access string,
} ) (oauth2.TokenInfo, error) {
return DBTokenToToken(token), nil return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByAccess,
access,
)
} }
// GetByRefresh selects a token from the DB based on the Refresh field // GetByRefresh selects a token from
func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { // the DB based on the Refresh field
token, err := ts.db.GetTokenByRefresh(ctx, refresh) func (ts *tokenStore) GetByRefresh(
ctx context.Context,
refresh string,
) (oauth2.TokenInfo, error) {
return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByRefresh,
refresh,
)
}
// package-internal function for getting a token
// and potentially updating its last_used value.
func (ts *tokenStore) getUpdateToken(
ctx context.Context,
getBy func(context.Context, string) (*gtsmodel.Token, error),
key string,
) (oauth2.TokenInfo, error) {
// Hold a lock to get the token based on
// whatever func + key we've been given.
unlock := ts.lastUsedLocks.Lock(key)
token, err := getBy(ctx, key)
if err != nil { if err != nil {
// Unlock on error.
unlock()
return nil, err return nil, err
} }
// If token was last used more than
// an hour ago, update this in the db.
wasLastUsed := token.LastUsed
if time.Since(wasLastUsed) > 1*time.Hour {
token.LastUsed = time.Now()
if err := ts.state.DB.UpdateToken(ctx, token, "last_used"); err != nil {
err := gtserror.Newf("error updating last_used on token: %w", err)
return nil, err
}
}
// We're done, unlock.
unlock()
return DBTokenToToken(token), nil return DBTokenToToken(token), nil
} }

View file

@ -55,7 +55,6 @@ type AccountStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -76,7 +75,6 @@ func (suite *AccountStandardTestSuite) getClientMsg(timeout time.Duration) (*mes
func (suite *AccountStandardTestSuite) SetupSuite() { func (suite *AccountStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -113,11 +113,6 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *
} }
for _, t := range tokens { for _, t := range tokens {
// Delete any OAuth clients associated with this token.
if err := p.state.DB.DeleteByID(ctx, t.ClientID, &[]*gtsmodel.Client{}); err != nil {
return gtserror.Newf("db error deleting client: %w", err)
}
// Delete any OAuth applications associated with this token. // Delete any OAuth applications associated with this token.
if err := p.state.DB.DeleteApplicationByClientID(ctx, t.ClientID); err != nil { if err := p.state.DB.DeleteApplicationByClientID(ctx, t.ClientID); err != nil {
return gtserror.Newf("db error deleting application: %w", err) return gtserror.Newf("db error deleting application: %w", err)

View file

@ -58,7 +58,6 @@ type AdminStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -73,7 +72,6 @@ type AdminStandardTestSuite struct {
func (suite *AdminStandardTestSuite) SetupSuite() { func (suite *AdminStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()
@ -103,7 +101,7 @@ func (suite *AdminStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")) suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager) suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)

View file

@ -19,6 +19,9 @@
import ( import (
"context" "context"
"fmt"
"net/url"
"strings"
"github.com/google/uuid" "github.com/google/uuid"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@ -26,10 +29,12 @@
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
func (p *Processor) AppCreate(ctx context.Context, authed *apiutil.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, gtserror.WithCode) { func (p *Processor) AppCreate(ctx context.Context, authed *apiutil.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, gtserror.WithCode) {
// set default 'read' for scopes if it's not set // Set default 'read' for
// scopes if it's not set.
var scopes string var scopes string
if form.Scopes == "" { if form.Scopes == "" {
scopes = "read" scopes = "read"
@ -37,48 +42,47 @@ func (p *Processor) AppCreate(ctx context.Context, authed *apiutil.Auth, form *a
scopes = form.Scopes scopes = form.Scopes
} }
// generate new IDs for this application and its associated client // Normalize + parse requested redirect URIs.
form.RedirectURIs = strings.TrimSpace(form.RedirectURIs)
var redirectURIs []string
if form.RedirectURIs != "" {
// Redirect URIs can be just one value, or can be passed
// as a newline-separated list of strings. Ensure each URI
// is parseable + normalize it by reconstructing from *url.URL.
for _, redirectStr := range strings.Split(form.RedirectURIs, "\n") {
redirectURI, err := url.Parse(redirectStr)
if err != nil {
errText := fmt.Sprintf("error parsing redirect URI: %v", err)
return nil, gtserror.NewErrorBadRequest(err, errText)
}
redirectURIs = append(redirectURIs, redirectURI.String())
}
} else {
// No redirect URI(s) provided, just set default oob.
redirectURIs = append(redirectURIs, oauth.OOBURI)
}
// Generate random client ID.
clientID, err := id.NewRandomULID() clientID, err := id.NewRandomULID()
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
clientSecret := uuid.NewString()
appID, err := id.NewRandomULID() // Generate + store app
if err != nil { // to put in the database.
return nil, gtserror.NewErrorInternalError(err)
}
// generate the application to put in the database
app := &gtsmodel.Application{ app := &gtsmodel.Application{
ID: appID, ID: id.NewULID(),
Name: form.ClientName, Name: form.ClientName,
Website: form.Website, Website: form.Website,
RedirectURI: form.RedirectURIs, RedirectURIs: redirectURIs,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: uuid.NewString(),
Scopes: scopes, Scopes: scopes,
} }
// chuck it in the db
if err := p.state.DB.PutApplication(ctx, app); err != nil { if err := p.state.DB.PutApplication(ctx, app); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// now we need to model an oauth client from the application that the oauth library can use
oc := &gtsmodel.Client{
ID: clientID,
Secret: clientSecret,
Domain: form.RedirectURIs,
// This client isn't yet associated with a specific user, it's just an app client right now
UserID: "",
}
// chuck it in the db
if err := p.state.DB.PutClient(ctx, oc); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
apiApp, err := p.converter.AppToAPIAppSensitive(ctx, app) apiApp, err := p.converter.AppToAPIAppSensitive(ctx, app)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)

View file

@ -57,7 +57,6 @@ type ConversationsTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -84,7 +83,6 @@ func (suite *ConversationsTestSuite) getClientMsg(timeout time.Duration) (*messa
func (suite *ConversationsTestSuite) SetupSuite() { func (suite *ConversationsTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -45,7 +45,6 @@ type MediaStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -59,7 +58,6 @@ type MediaStandardTestSuite struct {
func (suite *MediaStandardTestSuite) SetupSuite() { func (suite *MediaStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -58,7 +58,6 @@ type ProcessingStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -77,7 +76,6 @@ type ProcessingStandardTestSuite struct {
func (suite *ProcessingStandardTestSuite) SetupSuite() { func (suite *ProcessingStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()
@ -124,7 +122,7 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient) suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager) suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
suite.emailSender = testrig.NewEmailSender("../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../web/template/", nil)
suite.processor = processing.NewProcessor( suite.processor = processing.NewProcessor(

View file

@ -50,7 +50,6 @@ type StatusStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -65,7 +64,6 @@ type StatusStandardTestSuite struct {
func (suite *StatusStandardTestSuite) SetupSuite() { func (suite *StatusStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -52,7 +52,7 @@ func (suite *StreamTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state) suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db suite.state.DB = suite.db
suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers) suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
suite.streamProcessor = stream.New(&suite.state, suite.oauthServer) suite.streamProcessor = stream.New(&suite.state, suite.oauthServer)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)

View file

@ -54,7 +54,7 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state) suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db suite.state.DB = suite.db
suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers) suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers)
suite.oauthServer = testrig.NewTestOauthServer(suite.state.DB) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
@ -62,7 +62,7 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.user = user.New(&suite.state, typeutils.NewConverter(&suite.state), testrig.NewTestOauthServer(suite.db), suite.emailSender) suite.user = user.New(&suite.state, typeutils.NewConverter(&suite.state), testrig.NewTestOauthServer(&suite.state), suite.emailSender)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
} }

View file

@ -39,7 +39,6 @@ type WorkersTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -57,7 +56,6 @@ type WorkersTestSuite struct {
func (suite *WorkersTestSuite) SetupSuite() { func (suite *WorkersTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -37,7 +37,6 @@ type TextStandardTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -53,7 +52,6 @@ type TextStandardTestSuite struct {
func (suite *TextStandardTestSuite) SetupSuite() { func (suite *TextStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -50,7 +50,6 @@ type TransportTestSuite struct {
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -60,7 +59,6 @@ type TransportTestSuite struct {
func (suite *TransportTestSuite) SetupSuite() { func (suite *TransportTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()

View file

@ -626,10 +626,12 @@ func (c *Converter) AppToAPIAppSensitive(ctx context.Context, a *gtsmodel.Applic
ID: a.ID, ID: a.ID,
Name: a.Name, Name: a.Name,
Website: a.Website, Website: a.Website,
RedirectURI: a.RedirectURI, RedirectURI: strings.Join(a.RedirectURIs, "\n"),
RedirectURIs: a.RedirectURIs,
ClientID: a.ClientID, ClientID: a.ClientID,
ClientSecret: a.ClientSecret, ClientSecret: a.ClientSecret,
VapidKey: vapidKeyPair.Public, VapidKey: vapidKeyPair.Public,
Scopes: strings.Split(a.Scopes, " "),
}, nil }, nil
} }

View file

@ -23,6 +23,7 @@
"net/http" "net/http"
"testing" "testing"
"time" "time"
// for go:linkname // for go:linkname
_ "unsafe" _ "unsafe"
@ -102,7 +103,7 @@ func (suite *RealSenderStandardTestSuite) SetupTest() {
suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient) suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager) suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(&suite.state)
suite.emailSender = testrig.NewEmailSender("../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../web/template/", nil)
suite.webPushSender = newSenderWith( suite.webPushSender = newSenderWith(

View file

@ -68,7 +68,6 @@
&gtsmodel.Notification{}, &gtsmodel.Notification{},
&gtsmodel.RouterSession{}, &gtsmodel.RouterSession{},
&gtsmodel.Token{}, &gtsmodel.Token{},
&gtsmodel.Client{},
&gtsmodel.EmojiCategory{}, &gtsmodel.EmojiCategory{},
&gtsmodel.Tombstone{}, &gtsmodel.Tombstone{},
&gtsmodel.Report{}, &gtsmodel.Report{},
@ -132,12 +131,6 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
} }
} }
for _, v := range NewTestClients() {
if err := db.Put(ctx, v); err != nil {
log.Panic(ctx, err)
}
}
for _, v := range NewTestApplications() { for _, v := range NewTestApplications() {
if err := db.Put(ctx, v); err != nil { if err := db.Put(ctx, v); err != nil {
log.Panic(ctx, err) log.Panic(ctx, err)

View file

@ -20,11 +20,17 @@
import ( import (
"context" "context"
"github.com/superseriousbusiness/gotosocial/internal/db" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
) )
// NewTestOauthServer returns an oauth server with the given db // NewTestOauthServer returns an oauth server with the given db
func NewTestOauthServer(db db.DB) oauth.Server { func NewTestOauthServer(state *state.State) oauth.Server {
return oauth.New(context.Background(), db) ctx := context.Background()
return oauth.New(
ctx,
state,
apiutil.GetClientScopeHandler(ctx, state),
)
} }

View file

@ -51,7 +51,7 @@ func NewTestProcessor(
), ),
typeutils.NewConverter(state), typeutils.NewConverter(state),
federator, federator,
NewTestOauthServer(state.DB), NewTestOauthServer(state),
mediaManager, mediaManager,
state, state,
emailSender, emailSender,

View file

@ -70,6 +70,7 @@ func NewTestTokens() map[string]*gtsmodel.Token {
ID: "01P9SVWS9J3SPHZQ3KCMBEN70N", ID: "01P9SVWS9J3SPHZQ3KCMBEN70N",
ClientID: "01F8MGV8AC3NGSJW0FE8W1BV70", ClientID: "01F8MGV8AC3NGSJW0FE8W1BV70",
RedirectURI: "http://localhost:8080", RedirectURI: "http://localhost:8080",
Scope: "read write push",
Access: "ZTK1MWMWZDGTMGMXOS0ZY2UXLWI5ZWETMWEZYZZIYTLHMZI4", Access: "ZTK1MWMWZDGTMGMXOS0ZY2UXLWI5ZWETMWEZYZZIYTLHMZI4",
AccessCreateAt: TimeMustParse("2022-06-10T15:22:08Z"), AccessCreateAt: TimeMustParse("2022-06-10T15:22:08Z"),
AccessExpiresAt: TimeMustParse("2050-01-01T15:22:08Z"), AccessExpiresAt: TimeMustParse("2050-01-01T15:22:08Z"),
@ -79,6 +80,7 @@ func NewTestTokens() map[string]*gtsmodel.Token {
ClientID: "01F8MGV8AC3NGSJW0FE8W1BV70", ClientID: "01F8MGV8AC3NGSJW0FE8W1BV70",
UserID: "01F8MGVGPHQ2D3P3X0454H54Z5", UserID: "01F8MGVGPHQ2D3P3X0454H54Z5",
RedirectURI: "http://localhost:8080", RedirectURI: "http://localhost:8080",
Scope: "read write push",
Code: "ZJYYMZQ0MTQTZTU1NC0ZNJK4LWE2ZWITYTM1MDHHOTAXNJHL", Code: "ZJYYMZQ0MTQTZTU1NC0ZNJK4LWE2ZWITYTM1MDHHOTAXNJHL",
CodeCreateAt: TimeMustParse("2022-06-10T15:22:08Z"), CodeCreateAt: TimeMustParse("2022-06-10T15:22:08Z"),
CodeExpiresAt: TimeMustParse("2050-01-01T15:22:08Z"), CodeExpiresAt: TimeMustParse("2050-01-01T15:22:08Z"),
@ -107,37 +109,6 @@ func NewTestTokens() map[string]*gtsmodel.Token {
return tokens return tokens
} }
// NewTestClients returns a map of Clients keyed according to which account they are used by.
func NewTestClients() map[string]*gtsmodel.Client {
clients := map[string]*gtsmodel.Client{
"instance_application": {
ID: "01AY6P665V14JJR0AFVRT7311Y",
Secret: "baedee87-6d00-4cf5-87b9-4d78ee58ef01",
Domain: "http://localhost:8080",
UserID: "",
},
"admin_account": {
ID: "01F8MGWSJCND9BWBD4WGJXBM93",
Secret: "dda8e835-2c9c-4bd2-9b8b-77c2e26d7a7a",
Domain: "http://localhost:8080",
UserID: "01F8MGWYWKVKS3VS8DV1AMYPGE", // admin_account
},
"local_account_1": {
ID: "01F8MGV8AC3NGSJW0FE8W1BV70",
Secret: "c3724c74-dc3b-41b2-a108-0ea3d8399830",
Domain: "http://localhost:8080",
UserID: "01F8MGVGPHQ2D3P3X0454H54Z5", // local_account_1
},
"local_account_2": {
ID: "01F8MGW47HN8ZXNHNZ7E47CDMQ",
Secret: "8f5603a5-c721-46cd-8f1b-2e368f51379f",
Domain: "http://localhost:8080",
UserID: "01F8MH1VYJAE00TVVGMM5JNJ8X", // local_account_2
},
}
return clients
}
// NewTestApplications returns a map of applications keyed to which number application they are. // NewTestApplications returns a map of applications keyed to which number application they are.
func NewTestApplications() map[string]*gtsmodel.Application { func NewTestApplications() map[string]*gtsmodel.Application {
apps := map[string]*gtsmodel.Application{ apps := map[string]*gtsmodel.Application{
@ -145,7 +116,7 @@ func NewTestApplications() map[string]*gtsmodel.Application {
ID: "01HT5P2YHDMPAAD500NDAY8JW1", ID: "01HT5P2YHDMPAAD500NDAY8JW1",
Name: "localhost:8080 instance application", Name: "localhost:8080 instance application",
Website: "http://localhost:8080", Website: "http://localhost:8080",
RedirectURI: "http://localhost:8080", RedirectURIs: []string{"http://localhost:8080"},
ClientID: "01AY6P665V14JJR0AFVRT7311Y", // instance account ID ClientID: "01AY6P665V14JJR0AFVRT7311Y", // instance account ID
ClientSecret: "baedee87-6d00-4cf5-87b9-4d78ee58ef01", ClientSecret: "baedee87-6d00-4cf5-87b9-4d78ee58ef01",
Scopes: "write:accounts", Scopes: "write:accounts",
@ -154,28 +125,28 @@ func NewTestApplications() map[string]*gtsmodel.Application {
ID: "01F8MGXQRHYF5QPMTMXP78QC2F", ID: "01F8MGXQRHYF5QPMTMXP78QC2F",
Name: "superseriousbusiness", Name: "superseriousbusiness",
Website: "https://superserious.business", Website: "https://superserious.business",
RedirectURI: "http://localhost:8080", RedirectURIs: []string{"http://localhost:8080"},
ClientID: "01F8MGWSJCND9BWBD4WGJXBM93", // admin client ClientID: "01F8MGWSJCND9BWBD4WGJXBM93", // admin client
ClientSecret: "dda8e835-2c9c-4bd2-9b8b-77c2e26d7a7a", // admin client ClientSecret: "dda8e835-2c9c-4bd2-9b8b-77c2e26d7a7a", // admin client
Scopes: "read write follow push", Scopes: "read write push",
}, },
"application_1": { "application_1": {
ID: "01F8MGY43H3N2C8EWPR2FPYEXG", ID: "01F8MGY43H3N2C8EWPR2FPYEXG",
Name: "really cool gts application", Name: "really cool gts application",
Website: "https://reallycool.app", Website: "https://reallycool.app",
RedirectURI: "http://localhost:8080", RedirectURIs: []string{"http://localhost:8080"},
ClientID: "01F8MGV8AC3NGSJW0FE8W1BV70", // client_1 ClientID: "01F8MGV8AC3NGSJW0FE8W1BV70", // client_1
ClientSecret: "c3724c74-dc3b-41b2-a108-0ea3d8399830", // client_1 ClientSecret: "c3724c74-dc3b-41b2-a108-0ea3d8399830", // client_1
Scopes: "read write follow push", Scopes: "read write push",
}, },
"application_2": { "application_2": {
ID: "01F8MGYG9E893WRHW0TAEXR8GJ", ID: "01F8MGYG9E893WRHW0TAEXR8GJ",
Name: "kindaweird", Name: "kindaweird",
Website: "https://kindaweird.app", Website: "https://kindaweird.app",
RedirectURI: "http://localhost:8080", RedirectURIs: []string{"http://localhost:8080"},
ClientID: "01F8MGW47HN8ZXNHNZ7E47CDMQ", // client_2 ClientID: "01F8MGW47HN8ZXNHNZ7E47CDMQ", // client_2
ClientSecret: "8f5603a5-c721-46cd-8f1b-2e368f51379f", // client_2 ClientSecret: "8f5603a5-c721-46cd-8f1b-2e368f51379f", // client_2
Scopes: "read write follow push", Scopes: "read write push",
}, },
} }
return apps return apps

View file

@ -82,7 +82,7 @@ func SetupTestStructs(
transportController := NewTestTransportController(&state, httpClient) transportController := NewTestTransportController(&state, httpClient)
mediaManager := NewTestMediaManager(&state) mediaManager := NewTestMediaManager(&state)
federator := NewTestFederator(&state, transportController, mediaManager) federator := NewTestFederator(&state, transportController, mediaManager)
oauthServer := NewTestOauthServer(db) oauthServer := NewTestOauthServer(&state)
emailSender := NewEmailSender(rTemplatePath, nil) emailSender := NewEmailSender(rTemplatePath, nil)
webPushSender := NewWebPushMockSender() webPushSender := NewWebPushMockSender()