Track access token and refresh token expiry
Also make a bunch more progress on actually updating the tokens when we need them updated.
This commit is contained in:
parent
cf01c8c5c6
commit
109495b702
11 changed files with 348 additions and 104 deletions
|
|
@ -1 +1 @@
|
|||
Subproject commit a99b4a72b2bb3dcff642209f029eb4e7d746fa8d
|
||||
Subproject commit e47b350f9231a16c815b927947f9d718cec2d3fe
|
||||
249
arcgis.go
249
arcgis.go
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
|
|
@ -12,8 +13,10 @@ import (
|
|||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Gleipnir-Technology/arcgis-go"
|
||||
|
|
@ -24,13 +27,28 @@ import (
|
|||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
var NewOAuthTokenChannel chan struct{}
|
||||
var CodeVerifier string = "random_secure_string_min_43_chars_long_should_be_stored_in_session"
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error ErrorResponseContent `json:"error"`
|
||||
}
|
||||
|
||||
type ErrorResponseContent struct {
|
||||
Code int `json:"code"`
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
Message string `json:"message"`
|
||||
Details []string `json:"details"`
|
||||
}
|
||||
|
||||
type OAuthTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Username string `json:"username"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RefreshTokenExpiresIn int `json:"refresh_token_expires_in"`
|
||||
SSL bool `json:"ssl"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// Build the ArcGIS authorization URL with PKCE
|
||||
|
|
@ -66,12 +84,13 @@ func generateCodeVerifier() string {
|
|||
}
|
||||
|
||||
// Find out what we can about this user
|
||||
func updateArcgisUserData(ctx context.Context, user *models.User, access_token string, expires time.Time, refresh_token string) {
|
||||
func updateArcgisUserData(ctx context.Context, user *models.User, access_token string, access_token_expires time.Time, refresh_token string, refresh_token_expires time.Time) {
|
||||
client := arcgis.NewArcGIS(
|
||||
arcgis.AuthenticatorOAuth{
|
||||
AccessToken: access_token,
|
||||
Expires: expires,
|
||||
RefreshToken: refresh_token,
|
||||
AccessToken: access_token,
|
||||
AccessTokenExpires: access_token_expires,
|
||||
RefreshToken: refresh_token,
|
||||
RefreshTokenExpires: refresh_token_expires,
|
||||
},
|
||||
)
|
||||
portal, err := client.PortalsSelf()
|
||||
|
|
@ -158,49 +177,25 @@ func handleOauthAccessCode(ctx context.Context, user *models.User, code string)
|
|||
return fmt.Errorf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
client := http.Client{}
|
||||
slog.Info("POST", slog.String("url", baseURL))
|
||||
resp, err := client.Do(req)
|
||||
token, err := handleTokenRequest(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to do request: %v", err)
|
||||
return fmt.Errorf("Failed to exchange authorization code for token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
slog.Info("Response", slog.Int("status", resp.StatusCode))
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
if err != nil {
|
||||
return fmt.Errorf("Got status code %d and failed to read response body: %v", resp.StatusCode, err)
|
||||
}
|
||||
bodyString := string(bodyBytes)
|
||||
var errorResp map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil {
|
||||
return fmt.Errorf("API response JSON error: %d: %v", resp.StatusCode, errorResp)
|
||||
}
|
||||
return fmt.Errorf("API returned error status %d: %s", resp.StatusCode, bodyString)
|
||||
}
|
||||
var tokenResponse OAuthTokenResponse
|
||||
err = json.Unmarshal(bodyBytes, &tokenResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
slog.Info("Oauth token acquired",
|
||||
slog.String("refresh token", tokenResponse.RefreshToken),
|
||||
slog.String("access token", tokenResponse.AccessToken),
|
||||
slog.Int("expires", tokenResponse.ExpiresIn),
|
||||
)
|
||||
|
||||
expires := futureUTCTimestamp(tokenResponse.ExpiresIn)
|
||||
accessExpires := futureUTCTimestamp(token.ExpiresIn)
|
||||
refreshExpires := futureUTCTimestamp(token.RefreshTokenExpiresIn)
|
||||
setter := models.OauthTokenSetter{
|
||||
AccessToken: omit.From(tokenResponse.AccessToken),
|
||||
Expires: omit.From(expires),
|
||||
RefreshToken: omit.From(tokenResponse.RefreshToken),
|
||||
Username: omit.From(tokenResponse.Username),
|
||||
AccessToken: omit.From(token.AccessToken),
|
||||
AccessTokenExpires: omit.From(accessExpires),
|
||||
RefreshToken: omit.From(token.RefreshToken),
|
||||
RefreshTokenExpires: omit.From(refreshExpires),
|
||||
Username: omit.From(token.Username),
|
||||
}
|
||||
err = user.InsertUserOauthTokens(ctx, PGInstance.BobDB, &setter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to save token to database: %v", err)
|
||||
}
|
||||
go updateArcgisUserData(context.Background(), user, tokenResponse.AccessToken, expires, tokenResponse.RefreshToken)
|
||||
go updateArcgisUserData(context.Background(), user, token.AccessToken, accessExpires, token.RefreshToken, refreshExpires)
|
||||
NewOAuthTokenChannel <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -216,14 +211,174 @@ func redirectURL() string {
|
|||
}
|
||||
|
||||
// This is a goroutine that is in charge of getting Fieldseeker data and keeping it fresh.
|
||||
func refreshFieldseekerData(ctx context.Context, newOauthCh <-chan int) {
|
||||
func refreshFieldseekerData(ctx context.Context, newOauthCh <-chan struct{}) {
|
||||
for {
|
||||
workerCtx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
oauths, err := models.OauthTokens.Query().All(ctx, PGInstance.BobDB)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get oauths", slog.String("err", err.Error()))
|
||||
return
|
||||
}
|
||||
for _, oauth := range oauths {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
maintainOAuth(workerCtx, oauth)
|
||||
}()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("Exiting refresh worker")
|
||||
slog.Info("Exiting refresh worker...")
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return
|
||||
case id := <-newOauthCh:
|
||||
slog.Info("Adding oauth to background work", slog.Int("oauth id", id))
|
||||
case <-newOauthCh:
|
||||
slog.Info("Updating oauth background work")
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func maintainOAuth(ctx context.Context, oauth *models.OauthToken) {
|
||||
refreshDelay := time.Until(oauth.AccessTokenExpires)
|
||||
slog.Info("Need to refresh oauth", slog.Int("id", int(oauth.ID)), slog.Float64("seconds", refreshDelay.Seconds()))
|
||||
if oauth.AccessTokenExpires.Before(time.Now()) {
|
||||
err := refreshOAuth(ctx, oauth)
|
||||
if err != nil {
|
||||
slog.Error("Failed to refresh token", slog.String("err", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
ticker := time.NewTicker(refreshDelay)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func refreshOAuth(ctx context.Context, oauth *models.OauthToken) error {
|
||||
baseURL := "https://www.arcgis.com/sharing/rest/oauth2/token/"
|
||||
|
||||
form := url.Values{
|
||||
"grant_type": []string{"refresh_token"},
|
||||
"client_id": []string{ClientID},
|
||||
"refresh_token": []string{oauth.RefreshToken},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", baseURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
token, err := handleTokenRequest(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to handle request: %v", err)
|
||||
}
|
||||
accessExpires := futureUTCTimestamp(token.ExpiresIn)
|
||||
refreshExpires := futureUTCTimestamp(token.RefreshTokenExpiresIn)
|
||||
setter := models.OauthTokenSetter{
|
||||
AccessToken: omit.From(token.AccessToken),
|
||||
AccessTokenExpires: omit.From(accessExpires),
|
||||
RefreshToken: omit.From(token.RefreshToken),
|
||||
RefreshTokenExpires: omit.From(refreshExpires),
|
||||
Username: omit.From(token.Username),
|
||||
}
|
||||
err = oauth.Update(ctx, PGInstance.BobDB, &setter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to update oauth in database: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleTokenRequest(ctx context.Context, req *http.Request) (*OAuthTokenResponse, error) {
|
||||
client := http.Client{}
|
||||
slog.Info("POST", slog.String("url", req.URL.String()))
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to do request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
slog.Info("Token request", slog.Int("status", resp.StatusCode))
|
||||
saveResponse(bodyBytes, "token.json")
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Got status code %d and failed to read response body: %v", resp.StatusCode, err)
|
||||
}
|
||||
bodyString := string(bodyBytes)
|
||||
var errorResp map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil {
|
||||
return nil, fmt.Errorf("API response JSON error: %d: %v", resp.StatusCode, errorResp)
|
||||
}
|
||||
return nil, fmt.Errorf("API returned error status %d: %s", resp.StatusCode, bodyString)
|
||||
}
|
||||
//logResponseHeaders(resp)
|
||||
var tokenResponse OAuthTokenResponse
|
||||
err = json.Unmarshal(bodyBytes, &tokenResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
// Just because we got a 200-level status code doesn't mean it worked. Experience has taught us that
|
||||
// we can get errors without anything indicated in the headers or the status code
|
||||
if tokenResponse == (OAuthTokenResponse{}) {
|
||||
var errorResponse ErrorResponse
|
||||
err = json.Unmarshal(bodyBytes, &errorResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to unmarshal error JSON: %v", err)
|
||||
}
|
||||
if errorResponse.Error.Code > 0 {
|
||||
return nil, errors.New(fmt.Sprintf("API error %d: %s: %s (%s)",
|
||||
errorResponse.Error.Code,
|
||||
errorResponse.Error.Error,
|
||||
errorResponse.Error.ErrorDescription,
|
||||
errorResponse.Error.Message,
|
||||
))
|
||||
}
|
||||
}
|
||||
slog.Info("Oauth token acquired",
|
||||
slog.String("refresh token", tokenResponse.RefreshToken),
|
||||
slog.String("access token", tokenResponse.AccessToken),
|
||||
slog.Int("access expires", tokenResponse.ExpiresIn),
|
||||
slog.Int("refresh expires", tokenResponse.RefreshTokenExpiresIn),
|
||||
)
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
func logResponseHeaders(resp *http.Response) {
|
||||
if resp == nil {
|
||||
slog.Info("Response is nil")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("HTTP Response headers",
|
||||
"status", resp.Status,
|
||||
"statusCode", resp.StatusCode)
|
||||
|
||||
for name, values := range resp.Header {
|
||||
slog.Info("Header",
|
||||
"name", name,
|
||||
"values", values)
|
||||
}
|
||||
}
|
||||
|
||||
func saveResponse(data []byte, filename string) {
|
||||
dest, err := os.Create(filename)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create file", slog.String("filename", filename), slog.String("err", err.Error()))
|
||||
return
|
||||
}
|
||||
_, err = io.Copy(dest, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
slog.Error("Failed to write", slog.String("filename", filename), slog.String("err", err.Error()))
|
||||
return
|
||||
}
|
||||
slog.Info("Wrote response", slog.String("filename", filename))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ var OauthTokens = Table[
|
|||
Generated: false,
|
||||
AutoIncr: false,
|
||||
},
|
||||
Expires: column{
|
||||
Name: "expires",
|
||||
AccessTokenExpires: column{
|
||||
Name: "access_token_expires",
|
||||
DBType: "timestamp without time zone",
|
||||
Default: "",
|
||||
Comment: "",
|
||||
|
|
@ -87,6 +87,15 @@ var OauthTokens = Table[
|
|||
Generated: false,
|
||||
AutoIncr: false,
|
||||
},
|
||||
RefreshTokenExpires: column{
|
||||
Name: "refresh_token_expires",
|
||||
DBType: "timestamp without time zone",
|
||||
Default: "CURRENT_TIMESTAMP",
|
||||
Comment: "",
|
||||
Nullable: false,
|
||||
Generated: false,
|
||||
AutoIncr: false,
|
||||
},
|
||||
},
|
||||
Indexes: oauthTokenIndexes{
|
||||
OauthTokenPkey: index{
|
||||
|
|
@ -130,17 +139,18 @@ var OauthTokens = Table[
|
|||
type oauthTokenColumns struct {
|
||||
ID column
|
||||
AccessToken column
|
||||
Expires column
|
||||
AccessTokenExpires column
|
||||
RefreshToken column
|
||||
Username column
|
||||
UserID column
|
||||
ArcgisID column
|
||||
ArcgisLicenseTypeID column
|
||||
RefreshTokenExpires column
|
||||
}
|
||||
|
||||
func (c oauthTokenColumns) AsSlice() []column {
|
||||
return []column{
|
||||
c.ID, c.AccessToken, c.Expires, c.RefreshToken, c.Username, c.UserID, c.ArcgisID, c.ArcgisLicenseTypeID,
|
||||
c.ID, c.AccessToken, c.AccessTokenExpires, c.RefreshToken, c.Username, c.UserID, c.ArcgisID, c.ArcgisLicenseTypeID, c.RefreshTokenExpires,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -72,12 +72,13 @@ func (f *Factory) FromExistingOauthToken(m *models.OauthToken) *OauthTokenTempla
|
|||
|
||||
o.ID = func() int32 { return m.ID }
|
||||
o.AccessToken = func() string { return m.AccessToken }
|
||||
o.Expires = func() time.Time { return m.Expires }
|
||||
o.AccessTokenExpires = func() time.Time { return m.AccessTokenExpires }
|
||||
o.RefreshToken = func() string { return m.RefreshToken }
|
||||
o.Username = func() string { return m.Username }
|
||||
o.UserID = func() int32 { return m.UserID }
|
||||
o.ArcgisID = func() null.Val[string] { return m.ArcgisID }
|
||||
o.ArcgisLicenseTypeID = func() null.Val[string] { return m.ArcgisLicenseTypeID }
|
||||
o.RefreshTokenExpires = func() time.Time { return m.RefreshTokenExpires }
|
||||
|
||||
ctx := context.Background()
|
||||
if m.R.UserUser != nil {
|
||||
|
|
|
|||
|
|
@ -39,12 +39,13 @@ func (mods OauthTokenModSlice) Apply(ctx context.Context, n *OauthTokenTemplate)
|
|||
type OauthTokenTemplate struct {
|
||||
ID func() int32
|
||||
AccessToken func() string
|
||||
Expires func() time.Time
|
||||
AccessTokenExpires func() time.Time
|
||||
RefreshToken func() string
|
||||
Username func() string
|
||||
UserID func() int32
|
||||
ArcgisID func() null.Val[string]
|
||||
ArcgisLicenseTypeID func() null.Val[string]
|
||||
RefreshTokenExpires func() time.Time
|
||||
|
||||
r oauthTokenR
|
||||
f *Factory
|
||||
|
|
@ -91,9 +92,9 @@ func (o OauthTokenTemplate) BuildSetter() *models.OauthTokenSetter {
|
|||
val := o.AccessToken()
|
||||
m.AccessToken = omit.From(val)
|
||||
}
|
||||
if o.Expires != nil {
|
||||
val := o.Expires()
|
||||
m.Expires = omit.From(val)
|
||||
if o.AccessTokenExpires != nil {
|
||||
val := o.AccessTokenExpires()
|
||||
m.AccessTokenExpires = omit.From(val)
|
||||
}
|
||||
if o.RefreshToken != nil {
|
||||
val := o.RefreshToken()
|
||||
|
|
@ -115,6 +116,10 @@ func (o OauthTokenTemplate) BuildSetter() *models.OauthTokenSetter {
|
|||
val := o.ArcgisLicenseTypeID()
|
||||
m.ArcgisLicenseTypeID = omitnull.FromNull(val)
|
||||
}
|
||||
if o.RefreshTokenExpires != nil {
|
||||
val := o.RefreshTokenExpires()
|
||||
m.RefreshTokenExpires = omit.From(val)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
|
@ -143,8 +148,8 @@ func (o OauthTokenTemplate) Build() *models.OauthToken {
|
|||
if o.AccessToken != nil {
|
||||
m.AccessToken = o.AccessToken()
|
||||
}
|
||||
if o.Expires != nil {
|
||||
m.Expires = o.Expires()
|
||||
if o.AccessTokenExpires != nil {
|
||||
m.AccessTokenExpires = o.AccessTokenExpires()
|
||||
}
|
||||
if o.RefreshToken != nil {
|
||||
m.RefreshToken = o.RefreshToken()
|
||||
|
|
@ -161,6 +166,9 @@ func (o OauthTokenTemplate) Build() *models.OauthToken {
|
|||
if o.ArcgisLicenseTypeID != nil {
|
||||
m.ArcgisLicenseTypeID = o.ArcgisLicenseTypeID()
|
||||
}
|
||||
if o.RefreshTokenExpires != nil {
|
||||
m.RefreshTokenExpires = o.RefreshTokenExpires()
|
||||
}
|
||||
|
||||
o.setModelRels(m)
|
||||
|
||||
|
|
@ -185,9 +193,9 @@ func ensureCreatableOauthToken(m *models.OauthTokenSetter) {
|
|||
val := random_string(nil)
|
||||
m.AccessToken = omit.From(val)
|
||||
}
|
||||
if !(m.Expires.IsValue()) {
|
||||
if !(m.AccessTokenExpires.IsValue()) {
|
||||
val := random_time_Time(nil)
|
||||
m.Expires = omit.From(val)
|
||||
m.AccessTokenExpires = omit.From(val)
|
||||
}
|
||||
if !(m.RefreshToken.IsValue()) {
|
||||
val := random_string(nil)
|
||||
|
|
@ -322,12 +330,13 @@ func (m oauthTokenMods) RandomizeAllColumns(f *faker.Faker) OauthTokenMod {
|
|||
return OauthTokenModSlice{
|
||||
OauthTokenMods.RandomID(f),
|
||||
OauthTokenMods.RandomAccessToken(f),
|
||||
OauthTokenMods.RandomExpires(f),
|
||||
OauthTokenMods.RandomAccessTokenExpires(f),
|
||||
OauthTokenMods.RandomRefreshToken(f),
|
||||
OauthTokenMods.RandomUsername(f),
|
||||
OauthTokenMods.RandomUserID(f),
|
||||
OauthTokenMods.RandomArcgisID(f),
|
||||
OauthTokenMods.RandomArcgisLicenseTypeID(f),
|
||||
OauthTokenMods.RandomRefreshTokenExpires(f),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -394,31 +403,31 @@ func (m oauthTokenMods) RandomAccessToken(f *faker.Faker) OauthTokenMod {
|
|||
}
|
||||
|
||||
// Set the model columns to this value
|
||||
func (m oauthTokenMods) Expires(val time.Time) OauthTokenMod {
|
||||
func (m oauthTokenMods) AccessTokenExpires(val time.Time) OauthTokenMod {
|
||||
return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) {
|
||||
o.Expires = func() time.Time { return val }
|
||||
o.AccessTokenExpires = func() time.Time { return val }
|
||||
})
|
||||
}
|
||||
|
||||
// Set the Column from the function
|
||||
func (m oauthTokenMods) ExpiresFunc(f func() time.Time) OauthTokenMod {
|
||||
func (m oauthTokenMods) AccessTokenExpiresFunc(f func() time.Time) OauthTokenMod {
|
||||
return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) {
|
||||
o.Expires = f
|
||||
o.AccessTokenExpires = f
|
||||
})
|
||||
}
|
||||
|
||||
// Clear any values for the column
|
||||
func (m oauthTokenMods) UnsetExpires() OauthTokenMod {
|
||||
func (m oauthTokenMods) UnsetAccessTokenExpires() OauthTokenMod {
|
||||
return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) {
|
||||
o.Expires = nil
|
||||
o.AccessTokenExpires = nil
|
||||
})
|
||||
}
|
||||
|
||||
// Generates a random value for the column using the given faker
|
||||
// if faker is nil, a default faker is used
|
||||
func (m oauthTokenMods) RandomExpires(f *faker.Faker) OauthTokenMod {
|
||||
func (m oauthTokenMods) RandomAccessTokenExpires(f *faker.Faker) OauthTokenMod {
|
||||
return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) {
|
||||
o.Expires = func() time.Time {
|
||||
o.AccessTokenExpires = func() time.Time {
|
||||
return random_time_Time(f)
|
||||
}
|
||||
})
|
||||
|
|
@ -623,6 +632,37 @@ func (m oauthTokenMods) RandomArcgisLicenseTypeIDNotNull(f *faker.Faker) OauthTo
|
|||
})
|
||||
}
|
||||
|
||||
// Set the model columns to this value
|
||||
func (m oauthTokenMods) RefreshTokenExpires(val time.Time) OauthTokenMod {
|
||||
return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) {
|
||||
o.RefreshTokenExpires = func() time.Time { return val }
|
||||
})
|
||||
}
|
||||
|
||||
// Set the Column from the function
|
||||
func (m oauthTokenMods) RefreshTokenExpiresFunc(f func() time.Time) OauthTokenMod {
|
||||
return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) {
|
||||
o.RefreshTokenExpires = f
|
||||
})
|
||||
}
|
||||
|
||||
// Clear any values for the column
|
||||
func (m oauthTokenMods) UnsetRefreshTokenExpires() OauthTokenMod {
|
||||
return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) {
|
||||
o.RefreshTokenExpires = nil
|
||||
})
|
||||
}
|
||||
|
||||
// Generates a random value for the column using the given faker
|
||||
// if faker is nil, a default faker is used
|
||||
func (m oauthTokenMods) RandomRefreshTokenExpires(f *faker.Faker) OauthTokenMod {
|
||||
return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) {
|
||||
o.RefreshTokenExpires = func() time.Time {
|
||||
return random_time_Time(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (m oauthTokenMods) WithParentsCascading() OauthTokenMod {
|
||||
return OauthTokenModFunc(func(ctx context.Context, o *OauthTokenTemplate) {
|
||||
if isDone, _ := oauthTokenWithParentsCascadingCtx.Value(ctx); isDone {
|
||||
|
|
|
|||
4
main.go
4
main.go
|
|
@ -82,14 +82,14 @@ func main() {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
newTokenChannel := make(chan int, 10)
|
||||
NewOAuthTokenChannel = make(chan struct{}, 10)
|
||||
|
||||
var waitGroup sync.WaitGroup
|
||||
|
||||
waitGroup.Add(1)
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
refreshFieldseekerData(ctx, newTokenChannel)
|
||||
refreshFieldseekerData(ctx, NewOAuthTokenChannel)
|
||||
}()
|
||||
|
||||
server := &http.Server{
|
||||
|
|
|
|||
7
migrations/00006_add_oauth_refresh_expires.sql
Normal file
7
migrations/00006_add_oauth_refresh_expires.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- +goose Up
|
||||
ALTER TABLE oauth_token RENAME COLUMN expires TO access_token_expires;
|
||||
ALTER TABLE oauth_token ADD COLUMN refresh_token_expires TIMESTAMP NOT NULL DEFAULT current_timestamp;
|
||||
|
||||
-- +goose Down
|
||||
ALTER TABLE oauth_token DROP COLUMN refresh_token_expires;
|
||||
ALTER TABLE oauth_token RENAME COLUMN access_token_expires TO expires;
|
||||
|
|
@ -28,12 +28,13 @@ import (
|
|||
type OauthToken struct {
|
||||
ID int32 `db:"id,pk" `
|
||||
AccessToken string `db:"access_token" `
|
||||
Expires time.Time `db:"expires" `
|
||||
AccessTokenExpires time.Time `db:"access_token_expires" `
|
||||
RefreshToken string `db:"refresh_token" `
|
||||
Username string `db:"username" `
|
||||
UserID int32 `db:"user_id" `
|
||||
ArcgisID null.Val[string] `db:"arcgis_id" `
|
||||
ArcgisLicenseTypeID null.Val[string] `db:"arcgis_license_type_id" `
|
||||
RefreshTokenExpires time.Time `db:"refresh_token_expires" `
|
||||
|
||||
R oauthTokenR `db:"-" `
|
||||
}
|
||||
|
|
@ -56,17 +57,18 @@ type oauthTokenR struct {
|
|||
func buildOauthTokenColumns(alias string) oauthTokenColumns {
|
||||
return oauthTokenColumns{
|
||||
ColumnsExpr: expr.NewColumnsExpr(
|
||||
"id", "access_token", "expires", "refresh_token", "username", "user_id", "arcgis_id", "arcgis_license_type_id",
|
||||
"id", "access_token", "access_token_expires", "refresh_token", "username", "user_id", "arcgis_id", "arcgis_license_type_id", "refresh_token_expires",
|
||||
).WithParent("oauth_token"),
|
||||
tableAlias: alias,
|
||||
ID: psql.Quote(alias, "id"),
|
||||
AccessToken: psql.Quote(alias, "access_token"),
|
||||
Expires: psql.Quote(alias, "expires"),
|
||||
AccessTokenExpires: psql.Quote(alias, "access_token_expires"),
|
||||
RefreshToken: psql.Quote(alias, "refresh_token"),
|
||||
Username: psql.Quote(alias, "username"),
|
||||
UserID: psql.Quote(alias, "user_id"),
|
||||
ArcgisID: psql.Quote(alias, "arcgis_id"),
|
||||
ArcgisLicenseTypeID: psql.Quote(alias, "arcgis_license_type_id"),
|
||||
RefreshTokenExpires: psql.Quote(alias, "refresh_token_expires"),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -75,12 +77,13 @@ type oauthTokenColumns struct {
|
|||
tableAlias string
|
||||
ID psql.Expression
|
||||
AccessToken psql.Expression
|
||||
Expires psql.Expression
|
||||
AccessTokenExpires psql.Expression
|
||||
RefreshToken psql.Expression
|
||||
Username psql.Expression
|
||||
UserID psql.Expression
|
||||
ArcgisID psql.Expression
|
||||
ArcgisLicenseTypeID psql.Expression
|
||||
RefreshTokenExpires psql.Expression
|
||||
}
|
||||
|
||||
func (c oauthTokenColumns) Alias() string {
|
||||
|
|
@ -97,24 +100,25 @@ func (oauthTokenColumns) AliasedAs(alias string) oauthTokenColumns {
|
|||
type OauthTokenSetter struct {
|
||||
ID omit.Val[int32] `db:"id,pk" `
|
||||
AccessToken omit.Val[string] `db:"access_token" `
|
||||
Expires omit.Val[time.Time] `db:"expires" `
|
||||
AccessTokenExpires omit.Val[time.Time] `db:"access_token_expires" `
|
||||
RefreshToken omit.Val[string] `db:"refresh_token" `
|
||||
Username omit.Val[string] `db:"username" `
|
||||
UserID omit.Val[int32] `db:"user_id" `
|
||||
ArcgisID omitnull.Val[string] `db:"arcgis_id" `
|
||||
ArcgisLicenseTypeID omitnull.Val[string] `db:"arcgis_license_type_id" `
|
||||
RefreshTokenExpires omit.Val[time.Time] `db:"refresh_token_expires" `
|
||||
}
|
||||
|
||||
func (s OauthTokenSetter) SetColumns() []string {
|
||||
vals := make([]string, 0, 8)
|
||||
vals := make([]string, 0, 9)
|
||||
if s.ID.IsValue() {
|
||||
vals = append(vals, "id")
|
||||
}
|
||||
if s.AccessToken.IsValue() {
|
||||
vals = append(vals, "access_token")
|
||||
}
|
||||
if s.Expires.IsValue() {
|
||||
vals = append(vals, "expires")
|
||||
if s.AccessTokenExpires.IsValue() {
|
||||
vals = append(vals, "access_token_expires")
|
||||
}
|
||||
if s.RefreshToken.IsValue() {
|
||||
vals = append(vals, "refresh_token")
|
||||
|
|
@ -131,6 +135,9 @@ func (s OauthTokenSetter) SetColumns() []string {
|
|||
if !s.ArcgisLicenseTypeID.IsUnset() {
|
||||
vals = append(vals, "arcgis_license_type_id")
|
||||
}
|
||||
if s.RefreshTokenExpires.IsValue() {
|
||||
vals = append(vals, "refresh_token_expires")
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
|
|
@ -141,8 +148,8 @@ func (s OauthTokenSetter) Overwrite(t *OauthToken) {
|
|||
if s.AccessToken.IsValue() {
|
||||
t.AccessToken = s.AccessToken.MustGet()
|
||||
}
|
||||
if s.Expires.IsValue() {
|
||||
t.Expires = s.Expires.MustGet()
|
||||
if s.AccessTokenExpires.IsValue() {
|
||||
t.AccessTokenExpires = s.AccessTokenExpires.MustGet()
|
||||
}
|
||||
if s.RefreshToken.IsValue() {
|
||||
t.RefreshToken = s.RefreshToken.MustGet()
|
||||
|
|
@ -159,6 +166,9 @@ func (s OauthTokenSetter) Overwrite(t *OauthToken) {
|
|||
if !s.ArcgisLicenseTypeID.IsUnset() {
|
||||
t.ArcgisLicenseTypeID = s.ArcgisLicenseTypeID.MustGetNull()
|
||||
}
|
||||
if s.RefreshTokenExpires.IsValue() {
|
||||
t.RefreshTokenExpires = s.RefreshTokenExpires.MustGet()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OauthTokenSetter) Apply(q *dialect.InsertQuery) {
|
||||
|
|
@ -167,7 +177,7 @@ func (s *OauthTokenSetter) Apply(q *dialect.InsertQuery) {
|
|||
})
|
||||
|
||||
q.AppendValues(bob.ExpressionFunc(func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
|
||||
vals := make([]bob.Expression, 8)
|
||||
vals := make([]bob.Expression, 9)
|
||||
if s.ID.IsValue() {
|
||||
vals[0] = psql.Arg(s.ID.MustGet())
|
||||
} else {
|
||||
|
|
@ -180,8 +190,8 @@ func (s *OauthTokenSetter) Apply(q *dialect.InsertQuery) {
|
|||
vals[1] = psql.Raw("DEFAULT")
|
||||
}
|
||||
|
||||
if s.Expires.IsValue() {
|
||||
vals[2] = psql.Arg(s.Expires.MustGet())
|
||||
if s.AccessTokenExpires.IsValue() {
|
||||
vals[2] = psql.Arg(s.AccessTokenExpires.MustGet())
|
||||
} else {
|
||||
vals[2] = psql.Raw("DEFAULT")
|
||||
}
|
||||
|
|
@ -216,6 +226,12 @@ func (s *OauthTokenSetter) Apply(q *dialect.InsertQuery) {
|
|||
vals[7] = psql.Raw("DEFAULT")
|
||||
}
|
||||
|
||||
if s.RefreshTokenExpires.IsValue() {
|
||||
vals[8] = psql.Arg(s.RefreshTokenExpires.MustGet())
|
||||
} else {
|
||||
vals[8] = psql.Raw("DEFAULT")
|
||||
}
|
||||
|
||||
return bob.ExpressSlice(ctx, w, d, start, vals, "", ", ", "")
|
||||
}))
|
||||
}
|
||||
|
|
@ -225,7 +241,7 @@ func (s OauthTokenSetter) UpdateMod() bob.Mod[*dialect.UpdateQuery] {
|
|||
}
|
||||
|
||||
func (s OauthTokenSetter) Expressions(prefix ...string) []bob.Expression {
|
||||
exprs := make([]bob.Expression, 0, 8)
|
||||
exprs := make([]bob.Expression, 0, 9)
|
||||
|
||||
if s.ID.IsValue() {
|
||||
exprs = append(exprs, expr.Join{Sep: " = ", Exprs: []bob.Expression{
|
||||
|
|
@ -241,10 +257,10 @@ func (s OauthTokenSetter) Expressions(prefix ...string) []bob.Expression {
|
|||
}})
|
||||
}
|
||||
|
||||
if s.Expires.IsValue() {
|
||||
if s.AccessTokenExpires.IsValue() {
|
||||
exprs = append(exprs, expr.Join{Sep: " = ", Exprs: []bob.Expression{
|
||||
psql.Quote(append(prefix, "expires")...),
|
||||
psql.Arg(s.Expires),
|
||||
psql.Quote(append(prefix, "access_token_expires")...),
|
||||
psql.Arg(s.AccessTokenExpires),
|
||||
}})
|
||||
}
|
||||
|
||||
|
|
@ -283,6 +299,13 @@ func (s OauthTokenSetter) Expressions(prefix ...string) []bob.Expression {
|
|||
}})
|
||||
}
|
||||
|
||||
if s.RefreshTokenExpires.IsValue() {
|
||||
exprs = append(exprs, expr.Join{Sep: " = ", Exprs: []bob.Expression{
|
||||
psql.Quote(append(prefix, "refresh_token_expires")...),
|
||||
psql.Arg(s.RefreshTokenExpires),
|
||||
}})
|
||||
}
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
|
|
@ -584,12 +607,13 @@ func (oauthToken0 *OauthToken) AttachUserUser(ctx context.Context, exec bob.Exec
|
|||
type oauthTokenWhere[Q psql.Filterable] struct {
|
||||
ID psql.WhereMod[Q, int32]
|
||||
AccessToken psql.WhereMod[Q, string]
|
||||
Expires psql.WhereMod[Q, time.Time]
|
||||
AccessTokenExpires psql.WhereMod[Q, time.Time]
|
||||
RefreshToken psql.WhereMod[Q, string]
|
||||
Username psql.WhereMod[Q, string]
|
||||
UserID psql.WhereMod[Q, int32]
|
||||
ArcgisID psql.WhereNullMod[Q, string]
|
||||
ArcgisLicenseTypeID psql.WhereNullMod[Q, string]
|
||||
RefreshTokenExpires psql.WhereMod[Q, time.Time]
|
||||
}
|
||||
|
||||
func (oauthTokenWhere[Q]) AliasedAs(alias string) oauthTokenWhere[Q] {
|
||||
|
|
@ -600,12 +624,13 @@ func buildOauthTokenWhere[Q psql.Filterable](cols oauthTokenColumns) oauthTokenW
|
|||
return oauthTokenWhere[Q]{
|
||||
ID: psql.Where[Q, int32](cols.ID),
|
||||
AccessToken: psql.Where[Q, string](cols.AccessToken),
|
||||
Expires: psql.Where[Q, time.Time](cols.Expires),
|
||||
AccessTokenExpires: psql.Where[Q, time.Time](cols.AccessTokenExpires),
|
||||
RefreshToken: psql.Where[Q, string](cols.RefreshToken),
|
||||
Username: psql.Where[Q, string](cols.Username),
|
||||
UserID: psql.Where[Q, int32](cols.UserID),
|
||||
ArcgisID: psql.WhereNull[Q, string](cols.ArcgisID),
|
||||
ArcgisLicenseTypeID: psql.WhereNull[Q, string](cols.ArcgisLicenseTypeID),
|
||||
RefreshTokenExpires: psql.Where[Q, time.Time](cols.RefreshTokenExpires),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import (
|
|||
//go:embed oauth_by_user_id.bob.sql
|
||||
var formattedQueries_oauth_by_user_id string
|
||||
|
||||
var oauthTokenByUserIdSQL = formattedQueries_oauth_by_user_id[156:550]
|
||||
var oauthTokenByUserIdSQL = formattedQueries_oauth_by_user_id[156:642]
|
||||
|
||||
type OauthTokenByUserIdQuery = orm.ModQuery[*dialect.SelectQuery, oauthTokenByUserId, OauthTokenByUserIdRow, []OauthTokenByUserIdRow, oauthTokenByUserIdTransformer]
|
||||
|
||||
|
|
@ -44,12 +44,13 @@ func OauthTokenByUserId(UserID int32) *OauthTokenByUserIdQuery {
|
|||
var t OauthTokenByUserIdRow
|
||||
row.ScheduleScanByIndex(0, &t.ID)
|
||||
row.ScheduleScanByIndex(1, &t.AccessToken)
|
||||
row.ScheduleScanByIndex(2, &t.Expires)
|
||||
row.ScheduleScanByIndex(2, &t.AccessTokenExpires)
|
||||
row.ScheduleScanByIndex(3, &t.RefreshToken)
|
||||
row.ScheduleScanByIndex(4, &t.Username)
|
||||
row.ScheduleScanByIndex(5, &t.UserID)
|
||||
row.ScheduleScanByIndex(6, &t.ArcgisID)
|
||||
row.ScheduleScanByIndex(7, &t.ArcgisLicenseTypeID)
|
||||
row.ScheduleScanByIndex(8, &t.RefreshTokenExpires)
|
||||
return &t, nil
|
||||
}, func(v any) (OauthTokenByUserIdRow, error) {
|
||||
return *(v.(*OauthTokenByUserIdRow)), nil
|
||||
|
|
@ -57,9 +58,9 @@ func OauthTokenByUserId(UserID int32) *OauthTokenByUserIdQuery {
|
|||
},
|
||||
},
|
||||
Mod: bob.ModFunc[*dialect.SelectQuery](func(q *dialect.SelectQuery) {
|
||||
q.AppendSelect(expressionTypArgs.subExpr(7, 357))
|
||||
q.SetTable(expressionTypArgs.subExpr(363, 374))
|
||||
q.AppendWhere(expressionTypArgs.subExpr(382, 394))
|
||||
q.AppendSelect(expressionTypArgs.subExpr(7, 449))
|
||||
q.SetTable(expressionTypArgs.subExpr(455, 466))
|
||||
q.AppendWhere(expressionTypArgs.subExpr(474, 486))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
|
@ -67,12 +68,13 @@ func OauthTokenByUserId(UserID int32) *OauthTokenByUserIdQuery {
|
|||
type OauthTokenByUserIdRow = struct {
|
||||
ID int32 `db:"id"`
|
||||
AccessToken string `db:"access_token"`
|
||||
Expires time.Time `db:"expires"`
|
||||
AccessTokenExpires time.Time `db:"access_token_expires"`
|
||||
RefreshToken string `db:"refresh_token"`
|
||||
Username string `db:"username"`
|
||||
UserID int32 `db:"user_id"`
|
||||
ArcgisID null.Val[string] `db:"arcgis_id"`
|
||||
ArcgisLicenseTypeID null.Val[string] `db:"arcgis_license_type_id"`
|
||||
RefreshTokenExpires time.Time `db:"refresh_token_expires"`
|
||||
}
|
||||
|
||||
type oauthTokenByUserIdTransformer = bob.SliceTransformer[OauthTokenByUserIdRow, []OauthTokenByUserIdRow]
|
||||
|
|
@ -85,8 +87,8 @@ func (o oauthTokenByUserId) args() iter.Seq[orm.ArgWithPosition] {
|
|||
return func(yield func(arg orm.ArgWithPosition) bool) {
|
||||
if !yield(orm.ArgWithPosition{
|
||||
Name: "userID",
|
||||
Start: 392,
|
||||
Stop: 394,
|
||||
Start: 484,
|
||||
Stop: 486,
|
||||
Expression: o.UserID,
|
||||
}) {
|
||||
return
|
||||
|
|
|
|||
|
|
@ -2,5 +2,5 @@
|
|||
-- This file is meant to be re-generated in place and/or deleted at any time.
|
||||
|
||||
-- OauthTokenByUserId
|
||||
SELECT "oauth_token"."id" AS "id", "oauth_token"."access_token" AS "access_token", "oauth_token"."expires" AS "expires", "oauth_token"."refresh_token" AS "refresh_token", "oauth_token"."username" AS "username", "oauth_token"."user_id" AS "user_id", "oauth_token"."arcgis_id" AS "arcgis_id", "oauth_token"."arcgis_license_type_id" AS "arcgis_license_type_id" FROM oauth_token WHERE
|
||||
SELECT "oauth_token"."id" AS "id", "oauth_token"."access_token" AS "access_token", "oauth_token"."access_token_expires" AS "access_token_expires", "oauth_token"."refresh_token" AS "refresh_token", "oauth_token"."username" AS "username", "oauth_token"."user_id" AS "user_id", "oauth_token"."arcgis_id" AS "arcgis_id", "oauth_token"."arcgis_license_type_id" AS "arcgis_license_type_id", "oauth_token"."refresh_token_expires" AS "refresh_token_expires" FROM oauth_token WHERE
|
||||
user_id = $1;
|
||||
|
|
|
|||
|
|
@ -84,8 +84,8 @@ func TestOauthTokenByUserId(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(columns) != 8 {
|
||||
t.Fatalf("expected %d columns, got %d", 8, len(columns))
|
||||
if len(columns) != 9 {
|
||||
t.Fatalf("expected %d columns, got %d", 9, len(columns))
|
||||
}
|
||||
|
||||
if columns[0] != "id" {
|
||||
|
|
@ -96,8 +96,8 @@ func TestOauthTokenByUserId(t *testing.T) {
|
|||
t.Fatalf("expected column %d to be %s, got %s", 1, "access_token", columns[1])
|
||||
}
|
||||
|
||||
if columns[2] != "expires" {
|
||||
t.Fatalf("expected column %d to be %s, got %s", 2, "expires", columns[2])
|
||||
if columns[2] != "access_token_expires" {
|
||||
t.Fatalf("expected column %d to be %s, got %s", 2, "access_token_expires", columns[2])
|
||||
}
|
||||
|
||||
if columns[3] != "refresh_token" {
|
||||
|
|
@ -119,5 +119,9 @@ func TestOauthTokenByUserId(t *testing.T) {
|
|||
if columns[7] != "arcgis_license_type_id" {
|
||||
t.Fatalf("expected column %d to be %s, got %s", 7, "arcgis_license_type_id", columns[7])
|
||||
}
|
||||
|
||||
if columns[8] != "refresh_token_expires" {
|
||||
t.Fatalf("expected column %d to be %s, got %s", 8, "refresh_token_expires", columns[8])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue