package manage import ( "context" "time" "github.com/superseriousbusiness/oauth2/v4" "github.com/superseriousbusiness/oauth2/v4/errors" "github.com/superseriousbusiness/oauth2/v4/generates" "github.com/superseriousbusiness/oauth2/v4/models" ) // NewDefaultManager create to default authorization management instance func NewDefaultManager() *Manager { m := NewManager() // default implementation m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) m.MapAccessGenerate(generates.NewAccessGenerate()) return m } // NewManager create to authorization management instance func NewManager() *Manager { return &Manager{ gtcfg: make(map[oauth2.GrantType]*Config), validateURI: DefaultValidateURI, } } // Manager provide authorization management type Manager struct { codeExp time.Duration gtcfg map[oauth2.GrantType]*Config rcfg *RefreshingConfig validateURI ValidateURIHandler authorizeGenerate oauth2.AuthorizeGenerate accessGenerate oauth2.AccessGenerate tokenStore oauth2.TokenStore clientStore oauth2.ClientStore } // get grant type config func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { if c, ok := m.gtcfg[gt]; ok && c != nil { return c } switch gt { case oauth2.AuthorizationCode: return DefaultAuthorizeCodeTokenCfg case oauth2.Implicit: return DefaultImplicitTokenCfg case oauth2.PasswordCredentials: return DefaultPasswordTokenCfg case oauth2.ClientCredentials: return DefaultClientTokenCfg } return &Config{} } // SetAuthorizeCodeExp set the authorization code expiration time func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { m.codeExp = exp } // SetAuthorizeCodeTokenCfg set the authorization code grant token config func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { m.gtcfg[oauth2.AuthorizationCode] = cfg } // SetImplicitTokenCfg set the implicit grant token config func (m *Manager) SetImplicitTokenCfg(cfg *Config) { m.gtcfg[oauth2.Implicit] = cfg } // SetPasswordTokenCfg set the password grant token config func (m *Manager) SetPasswordTokenCfg(cfg *Config) { m.gtcfg[oauth2.PasswordCredentials] = cfg } // SetClientTokenCfg set the client grant token config func (m *Manager) SetClientTokenCfg(cfg *Config) { m.gtcfg[oauth2.ClientCredentials] = cfg } // SetRefreshTokenCfg set the refreshing token config func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { m.rcfg = cfg } // SetValidateURIHandler set the validates that RedirectURI is contained in baseURI func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { m.validateURI = handler } // MapAuthorizeGenerate mapping the authorize code generate interface func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { m.authorizeGenerate = gen } // MapAccessGenerate mapping the access token generate interface func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { m.accessGenerate = gen } // MapClientStorage mapping the client store interface func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { m.clientStore = stor } // MustClientStorage mandatory mapping the client store interface func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { if err != nil { panic(err.Error()) } m.clientStore = stor } // MapTokenStorage mapping the token store interface func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { m.tokenStore = stor } // MustTokenStorage mandatory mapping the token store interface func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { if err != nil { panic(err) } m.tokenStore = stor } // GetClient get the client information func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { cli, err = m.clientStore.GetByID(ctx, clientID) if err != nil { return } else if cli == nil { err = errors.ErrInvalidClient } return } // GenerateAuthToken generate the authorization token(code) func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } else if tgr.RedirectURI != "" { if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { return nil, err } } ti := models.NewToken() ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) createAt := time.Now() td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: createAt, TokenInfo: ti, Request: tgr.Request, } switch rt { case oauth2.Code: codeExp := m.codeExp if codeExp == 0 { codeExp = DefaultCodeExp } ti.SetCodeCreateAt(createAt) ti.SetCodeExpiresIn(codeExp) if exp := tgr.AccessTokenExp; exp > 0 { ti.SetAccessExpiresIn(exp) } if tgr.CodeChallenge != "" { ti.SetCodeChallenge(tgr.CodeChallenge) ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod) } tv, err := m.authorizeGenerate.Token(ctx, td) if err != nil { return nil, err } ti.SetCode(tv) case oauth2.Token: // set access token expires icfg := m.grantConfig(oauth2.Implicit) aexp := icfg.AccessTokenExp if exp := tgr.AccessTokenExp; exp > 0 { aexp = exp } ti.SetAccessCreateAt(createAt) ti.SetAccessExpiresIn(aexp) if icfg.IsGenerateRefresh { ti.SetRefreshCreateAt(createAt) ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) } tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(tv) if rv != "" { ti.SetRefresh(rv) } } err = m.tokenStore.Create(ctx, ti) if err != nil { return nil, err } return ti, nil } // get authorization code data func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { ti, err := m.tokenStore.GetByCode(ctx, code) if err != nil { return nil, err } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { err = errors.ErrInvalidAuthorizeCode return nil, errors.ErrInvalidAuthorizeCode } return ti, nil } // delete authorization code data func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { return m.tokenStore.RemoveByCode(ctx, code) } // get and delete authorization code data func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { code := tgr.Code ti, err := m.getAuthorizationCode(ctx, code) if err != nil { return nil, err } else if ti.GetClientID() != tgr.ClientID { return nil, errors.ErrInvalidAuthorizeCode } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { return nil, errors.ErrInvalidAuthorizeCode } err = m.delAuthorizationCode(ctx, code) if err != nil { return nil, err } return ti, nil } func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error { cc := ti.GetCodeChallenge() // early return if cc == "" && ver == "" { return nil } if cc == "" { return errors.ErrMissingCodeVerifier } if ver == "" { return errors.ErrMissingCodeVerifier } ccm := ti.GetCodeChallengeMethod() if ccm.String() == "" { ccm = oauth2.CodeChallengePlain } if !ccm.Validate(cc, ver) { return errors.ErrInvalidCodeChallenge } return nil } // GenerateAccessToken generate the access token func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { if !cliPass.VerifyPassword(tgr.ClientSecret) { return nil, errors.ErrInvalidClient } } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { return nil, errors.ErrInvalidClient } if tgr.RedirectURI != "" { if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { return nil, err } } if gt == oauth2.AuthorizationCode { ti, err := m.getAndDelAuthorizationCode(ctx, tgr) if err != nil { return nil, err } if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil { return nil, err } tgr.UserID = ti.GetUserID() tgr.Scope = ti.GetScope() if exp := ti.GetAccessExpiresIn(); exp > 0 { tgr.AccessTokenExp = exp } } ti := models.NewToken() ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) createAt := time.Now() ti.SetAccessCreateAt(createAt) // set access token expires gcfg := m.grantConfig(gt) aexp := gcfg.AccessTokenExp if exp := tgr.AccessTokenExp; exp > 0 { aexp = exp } ti.SetAccessExpiresIn(aexp) if gcfg.IsGenerateRefresh { ti.SetRefreshCreateAt(createAt) ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) } td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: createAt, TokenInfo: ti, Request: tgr.Request, } av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(av) if rv != "" { ti.SetRefresh(rv) } err = m.tokenStore.Create(ctx, ti) if err != nil { return nil, err } return ti, nil } // RefreshAccessToken refreshing an access token func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) if err != nil { return nil, err } else if tgr.ClientSecret != cli.GetSecret() { return nil, errors.ErrInvalidClient } ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { return nil, err } else if ti.GetClientID() != tgr.ClientID { return nil, errors.ErrInvalidRefreshToken } oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() td := &oauth2.GenerateBasic{ Client: cli, UserID: ti.GetUserID(), CreateAt: time.Now(), TokenInfo: ti, Request: tgr.Request, } rcfg := DefaultRefreshTokenCfg if v := m.rcfg; v != nil { rcfg = v } ti.SetAccessCreateAt(td.CreateAt) if v := rcfg.AccessTokenExp; v > 0 { ti.SetAccessExpiresIn(v) } if v := rcfg.RefreshTokenExp; v > 0 { ti.SetRefreshExpiresIn(v) } if rcfg.IsResetRefreshTime { ti.SetRefreshCreateAt(td.CreateAt) } if scope := tgr.Scope; scope != "" { ti.SetScope(scope) } tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) if err != nil { return nil, err } ti.SetAccess(tv) if rv != "" { ti.SetRefresh(rv) } if err := m.tokenStore.Create(ctx, ti); err != nil { return nil, err } if rcfg.IsRemoveAccess { // remove the old access token if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { return nil, err } } if rcfg.IsRemoveRefreshing && rv != "" { // remove the old refresh token if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { return nil, err } } if rv == "" { ti.SetRefresh("") ti.SetRefreshCreateAt(time.Now()) ti.SetRefreshExpiresIn(0) } return ti, nil } // RemoveAccessToken use the access token to delete the token information func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { if access == "" { return errors.ErrInvalidAccessToken } return m.tokenStore.RemoveByAccess(ctx, access) } // RemoveRefreshToken use the refresh token to delete the token information func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { if refresh == "" { return errors.ErrInvalidAccessToken } return m.tokenStore.RemoveByRefresh(ctx, refresh) } // LoadAccessToken according to the access token for corresponding token information func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { if access == "" { return nil, errors.ErrInvalidAccessToken } ct := time.Now() ti, err := m.tokenStore.GetByAccess(ctx, access) if err != nil { return nil, err } else if ti == nil || ti.GetAccess() != access { return nil, errors.ErrInvalidAccessToken } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { return nil, errors.ErrExpiredRefreshToken } else if ti.GetAccessExpiresIn() != 0 && ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { return nil, errors.ErrExpiredAccessToken } return ti, nil } // LoadRefreshToken according to the refresh token for corresponding token information func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { if refresh == "" { return nil, errors.ErrInvalidRefreshToken } ti, err := m.tokenStore.GetByRefresh(ctx, refresh) if err != nil { return nil, err } else if ti == nil || ti.GetRefresh() != refresh { return nil, errors.ErrInvalidRefreshToken } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { return nil, errors.ErrExpiredRefreshToken } return ti, nil }