diff --git a/arcgis-go b/arcgis-go index a99b4a72..e47b350f 160000 --- a/arcgis-go +++ b/arcgis-go @@ -1 +1 @@ -Subproject commit a99b4a72b2bb3dcff642209f029eb4e7d746fa8d +Subproject commit e47b350f9231a16c815b927947f9d718cec2d3fe diff --git a/arcgis.go b/arcgis.go index a5664592..c045610b 100644 --- a/arcgis.go +++ b/arcgis.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "crypto/rand" "crypto/sha256" @@ -12,8 +13,10 @@ import ( "log/slog" "net/http" "net/url" + "os" "strconv" "strings" + "sync" "time" "github.com/Gleipnir-Technology/arcgis-go" @@ -24,13 +27,28 @@ import ( "github.com/jackc/pgx/v5" ) +var NewOAuthTokenChannel chan struct{} var CodeVerifier string = "random_secure_string_min_43_chars_long_should_be_stored_in_session" +type ErrorResponse struct { + Error ErrorResponseContent `json:"error"` +} + +type ErrorResponseContent struct { + Code int `json:"code"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + Message string `json:"message"` + Details []string `json:"details"` +} + type OAuthTokenResponse struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token"` - Username string `json:"username"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + RefreshTokenExpiresIn int `json:"refresh_token_expires_in"` + SSL bool `json:"ssl"` + Username string `json:"username"` } // Build the ArcGIS authorization URL with PKCE @@ -66,12 +84,13 @@ func generateCodeVerifier() string { } // Find out what we can about this user -func updateArcgisUserData(ctx context.Context, user *models.User, access_token string, expires time.Time, refresh_token string) { +func updateArcgisUserData(ctx context.Context, user *models.User, access_token string, access_token_expires time.Time, refresh_token string, refresh_token_expires time.Time) { client := arcgis.NewArcGIS( arcgis.AuthenticatorOAuth{ - AccessToken: access_token, - Expires: expires, - RefreshToken: refresh_token, + AccessToken: access_token, + AccessTokenExpires: access_token_expires, + RefreshToken: refresh_token, + RefreshTokenExpires: refresh_token_expires, }, ) portal, err := client.PortalsSelf() @@ -158,49 +177,25 @@ func handleOauthAccessCode(ctx context.Context, user *models.User, code string) return fmt.Errorf("Failed to create request: %v", err) } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - client := http.Client{} - slog.Info("POST", slog.String("url", baseURL)) - resp, err := client.Do(req) + token, err := handleTokenRequest(ctx, req) if err != nil { - return fmt.Errorf("Failed to do request: %v", err) + return fmt.Errorf("Failed to exchange authorization code for token: %v", err) } - defer resp.Body.Close() - bodyBytes, err := io.ReadAll(resp.Body) - slog.Info("Response", slog.Int("status", resp.StatusCode)) - if resp.StatusCode >= http.StatusBadRequest { - if err != nil { - return fmt.Errorf("Got status code %d and failed to read response body: %v", resp.StatusCode, err) - } - bodyString := string(bodyBytes) - var errorResp map[string]interface{} - if err := json.Unmarshal(bodyBytes, &errorResp); err == nil { - return fmt.Errorf("API response JSON error: %d: %v", resp.StatusCode, errorResp) - } - return fmt.Errorf("API returned error status %d: %s", resp.StatusCode, bodyString) - } - var tokenResponse OAuthTokenResponse - err = json.Unmarshal(bodyBytes, &tokenResponse) - if err != nil { - return fmt.Errorf("Failed to unmarshal JSON: %v", err) - } - slog.Info("Oauth token acquired", - slog.String("refresh token", tokenResponse.RefreshToken), - slog.String("access token", tokenResponse.AccessToken), - slog.Int("expires", tokenResponse.ExpiresIn), - ) - - expires := futureUTCTimestamp(tokenResponse.ExpiresIn) + accessExpires := futureUTCTimestamp(token.ExpiresIn) + refreshExpires := futureUTCTimestamp(token.RefreshTokenExpiresIn) setter := models.OauthTokenSetter{ - AccessToken: omit.From(tokenResponse.AccessToken), - Expires: omit.From(expires), - RefreshToken: omit.From(tokenResponse.RefreshToken), - Username: omit.From(tokenResponse.Username), + AccessToken: omit.From(token.AccessToken), + AccessTokenExpires: omit.From(accessExpires), + RefreshToken: omit.From(token.RefreshToken), + RefreshTokenExpires: omit.From(refreshExpires), + Username: omit.From(token.Username), } err = user.InsertUserOauthTokens(ctx, PGInstance.BobDB, &setter) if err != nil { return fmt.Errorf("Failed to save token to database: %v", err) } - go updateArcgisUserData(context.Background(), user, tokenResponse.AccessToken, expires, tokenResponse.RefreshToken) + go updateArcgisUserData(context.Background(), user, token.AccessToken, accessExpires, token.RefreshToken, refreshExpires) + NewOAuthTokenChannel <- struct{}{} return nil } @@ -216,14 +211,174 @@ func redirectURL() string { } // This is a goroutine that is in charge of getting Fieldseeker data and keeping it fresh. -func refreshFieldseekerData(ctx context.Context, newOauthCh <-chan int) { +func refreshFieldseekerData(ctx context.Context, newOauthCh <-chan struct{}) { for { + workerCtx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + + oauths, err := models.OauthTokens.Query().All(ctx, PGInstance.BobDB) + if err != nil { + slog.Error("Failed to get oauths", slog.String("err", err.Error())) + return + } + for _, oauth := range oauths { + wg.Add(1) + go func() { + defer wg.Done() + maintainOAuth(workerCtx, oauth) + }() + } select { case <-ctx.Done(): - slog.Info("Exiting refresh worker") + slog.Info("Exiting refresh worker...") + cancel() + wg.Wait() return - case id := <-newOauthCh: - slog.Info("Adding oauth to background work", slog.Int("oauth id", id)) + case <-newOauthCh: + slog.Info("Updating oauth background work") + cancel() + wg.Wait() } } } + +func maintainOAuth(ctx context.Context, oauth *models.OauthToken) { + refreshDelay := time.Until(oauth.AccessTokenExpires) + slog.Info("Need to refresh oauth", slog.Int("id", int(oauth.ID)), slog.Float64("seconds", refreshDelay.Seconds())) + if oauth.AccessTokenExpires.Before(time.Now()) { + err := refreshOAuth(ctx, oauth) + if err != nil { + slog.Error("Failed to refresh token", slog.String("err", err.Error())) + return + } + } + ticker := time.NewTicker(refreshDelay) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + + } + } + +} + +func refreshOAuth(ctx context.Context, oauth *models.OauthToken) error { + baseURL := "https://www.arcgis.com/sharing/rest/oauth2/token/" + + form := url.Values{ + "grant_type": []string{"refresh_token"}, + "client_id": []string{ClientID}, + "refresh_token": []string{oauth.RefreshToken}, + } + + req, err := http.NewRequest("POST", baseURL, strings.NewReader(form.Encode())) + if err != nil { + return fmt.Errorf("Failed to create request: %v", err) + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + token, err := handleTokenRequest(ctx, req) + if err != nil { + return fmt.Errorf("Failed to handle request: %v", err) + } + accessExpires := futureUTCTimestamp(token.ExpiresIn) + refreshExpires := futureUTCTimestamp(token.RefreshTokenExpiresIn) + setter := models.OauthTokenSetter{ + AccessToken: omit.From(token.AccessToken), + AccessTokenExpires: omit.From(accessExpires), + RefreshToken: omit.From(token.RefreshToken), + RefreshTokenExpires: omit.From(refreshExpires), + Username: omit.From(token.Username), + } + err = oauth.Update(ctx, PGInstance.BobDB, &setter) + if err != nil { + return fmt.Errorf("Failed to update oauth in database: %v", err) + } + return nil +} + +func handleTokenRequest(ctx context.Context, req *http.Request) (*OAuthTokenResponse, error) { + client := http.Client{} + slog.Info("POST", slog.String("url", req.URL.String())) + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("Failed to do request: %v", err) + } + defer resp.Body.Close() + bodyBytes, err := io.ReadAll(resp.Body) + slog.Info("Token request", slog.Int("status", resp.StatusCode)) + saveResponse(bodyBytes, "token.json") + if resp.StatusCode >= http.StatusBadRequest { + if err != nil { + return nil, fmt.Errorf("Got status code %d and failed to read response body: %v", resp.StatusCode, err) + } + bodyString := string(bodyBytes) + var errorResp map[string]interface{} + if err := json.Unmarshal(bodyBytes, &errorResp); err == nil { + return nil, fmt.Errorf("API response JSON error: %d: %v", resp.StatusCode, errorResp) + } + return nil, fmt.Errorf("API returned error status %d: %s", resp.StatusCode, bodyString) + } + //logResponseHeaders(resp) + var tokenResponse OAuthTokenResponse + err = json.Unmarshal(bodyBytes, &tokenResponse) + if err != nil { + return nil, fmt.Errorf("Failed to unmarshal JSON: %v", err) + } + // Just because we got a 200-level status code doesn't mean it worked. Experience has taught us that + // we can get errors without anything indicated in the headers or the status code + if tokenResponse == (OAuthTokenResponse{}) { + var errorResponse ErrorResponse + err = json.Unmarshal(bodyBytes, &errorResponse) + if err != nil { + return nil, fmt.Errorf("Failed to unmarshal error JSON: %v", err) + } + if errorResponse.Error.Code > 0 { + return nil, errors.New(fmt.Sprintf("API error %d: %s: %s (%s)", + errorResponse.Error.Code, + errorResponse.Error.Error, + errorResponse.Error.ErrorDescription, + errorResponse.Error.Message, + )) + } + } + slog.Info("Oauth token acquired", + slog.String("refresh token", tokenResponse.RefreshToken), + slog.String("access token", tokenResponse.AccessToken), + slog.Int("access expires", tokenResponse.ExpiresIn), + slog.Int("refresh expires", tokenResponse.RefreshTokenExpiresIn), + ) + return &tokenResponse, nil +} + +func logResponseHeaders(resp *http.Response) { + if resp == nil { + slog.Info("Response is nil") + return + } + + slog.Info("HTTP Response headers", + "status", resp.Status, + "statusCode", resp.StatusCode) + + for name, values := range resp.Header { + slog.Info("Header", + "name", name, + "values", values) + } +} + +func saveResponse(data []byte, filename string) { + dest, err := os.Create(filename) + if err != nil { + slog.Error("Failed to create file", slog.String("filename", filename), slog.String("err", err.Error())) + return + } + _, err = io.Copy(dest, bytes.NewReader(data)) + if err != nil { + slog.Error("Failed to write", slog.String("filename", filename), slog.String("err", err.Error())) + return + } + slog.Info("Wrote response", slog.String("filename", filename)) +} diff --git a/dbinfo/oauth_token.bob.go b/dbinfo/oauth_token.bob.go index 0762c7ac..cfb49837 100644 --- a/dbinfo/oauth_token.bob.go +++ b/dbinfo/oauth_token.bob.go @@ -33,8 +33,8 @@ var OauthTokens = Table[ Generated: false, AutoIncr: false, }, - Expires: column{ - Name: "expires", + AccessTokenExpires: column{ + Name: "access_token_expires", DBType: "timestamp without time zone", Default: "", Comment: "", @@ -87,6 +87,15 @@ var OauthTokens = Table[ Generated: false, AutoIncr: false, }, + RefreshTokenExpires: column{ + Name: "refresh_token_expires", + DBType: "timestamp without time zone", + Default: "CURRENT_TIMESTAMP", + Comment: "", + Nullable: false, + Generated: false, + AutoIncr: false, + }, }, Indexes: oauthTokenIndexes{ OauthTokenPkey: index{ @@ -130,17 +139,18 @@ var OauthTokens = Table[ type oauthTokenColumns struct { ID column AccessToken column - Expires column + AccessTokenExpires column RefreshToken column Username column UserID column ArcgisID column ArcgisLicenseTypeID column + RefreshTokenExpires column } func (c oauthTokenColumns) AsSlice() []column { return []column{ - c.ID, c.AccessToken, c.Expires, c.RefreshToken, c.Username, c.UserID, c.ArcgisID, c.ArcgisLicenseTypeID, + c.ID, c.AccessToken, c.AccessTokenExpires, c.RefreshToken, c.Username, c.UserID, c.ArcgisID, c.ArcgisLicenseTypeID, c.RefreshTokenExpires, } } diff --git a/factory/bobfactory_main.bob.go b/factory/bobfactory_main.bob.go index 56ddd221..5f009b21 100644 --- a/factory/bobfactory_main.bob.go +++ b/factory/bobfactory_main.bob.go @@ -72,12 +72,13 @@ func (f *Factory) FromExistingOauthToken(m *models.OauthToken) *OauthTokenTempla o.ID = func() int32 { return m.ID } o.AccessToken = func() string { return m.AccessToken } - o.Expires = func() time.Time { return m.Expires } + o.AccessTokenExpires = func() time.Time { return m.AccessTokenExpires } o.RefreshToken = func() string { return m.RefreshToken } o.Username = func() string { return m.Username } o.UserID = func() int32 { return m.UserID } o.ArcgisID = func() null.Val[string] { return m.ArcgisID } o.ArcgisLicenseTypeID = func() null.Val[string] { return m.ArcgisLicenseTypeID } + o.RefreshTokenExpires = func() time.Time { return m.RefreshTokenExpires } ctx := context.Background() if m.R.UserUser != nil { diff --git a/factory/oauth_token.bob.go b/factory/oauth_token.bob.go index de841132..44c1aa65 100644 --- a/factory/oauth_token.bob.go +++ b/factory/oauth_token.bob.go @@ -39,12 +39,13 @@ func (mods OauthTokenModSlice) Apply(ctx context.Context, n *OauthTokenTemplate) type OauthTokenTemplate struct { ID func() int32 AccessToken func() string - Expires func() time.Time + AccessTokenExpires func() time.Time RefreshToken func() string Username func() string UserID func() int32 ArcgisID func() null.Val[string] ArcgisLicenseTypeID func() null.Val[string] + RefreshTokenExpires func() time.Time r oauthTokenR f *Factory @@ -91,9 +92,9 @@ func (o OauthTokenTemplate) BuildSetter() *models.OauthTokenSetter { val := o.AccessToken() m.AccessToken = omit.From(val) } - if o.Expires != nil { - val := o.Expires() - m.Expires = omit.From(val) + if o.AccessTokenExpires != nil { + val := o.AccessTokenExpires() + m.AccessTokenExpires = omit.From(val) } if o.RefreshToken != nil { val := o.RefreshToken() @@ -115,6 +116,10 @@ func (o OauthTokenTemplate) BuildSetter() *models.OauthTokenSetter { val := o.ArcgisLicenseTypeID() m.ArcgisLicenseTypeID = omitnull.FromNull(val) } + if o.RefreshTokenExpires != nil { + val := o.RefreshTokenExpires() + m.RefreshTokenExpires = omit.From(val) + } return m } @@ -143,8 +148,8 @@ func (o OauthTokenTemplate) Build() *models.OauthToken { if o.AccessToken != nil { m.AccessToken = o.AccessToken() } - if o.Expires != nil { - m.Expires = o.Expires() + if o.AccessTokenExpires != nil { + m.AccessTokenExpires = o.AccessTokenExpires() } if o.RefreshToken != nil { m.RefreshToken = o.RefreshToken() @@ -161,6 +166,9 @@ func (o OauthTokenTemplate) Build() *models.OauthToken { if o.ArcgisLicenseTypeID != nil { m.ArcgisLicenseTypeID = o.ArcgisLicenseTypeID() } + if o.RefreshTokenExpires != nil { + m.RefreshTokenExpires = o.RefreshTokenExpires() + } o.setModelRels(m) @@ -185,9 +193,9 @@ func ensureCreatableOauthToken(m *models.OauthTokenSetter) { val := random_string(nil) m.AccessToken = omit.From(val) } - if !(m.Expires.IsValue()) { + if !(m.AccessTokenExpires.IsValue()) { val := random_time_Time(nil) - m.Expires = omit.From(val) + m.AccessTokenExpires = omit.From(val) } if !(m.RefreshToken.IsValue()) { val := random_string(nil) @@ -322,12 +330,13 @@ func (m oauthTokenMods) RandomizeAllColumns(f *faker.Faker) OauthTokenMod { return OauthTokenModSlice{ OauthTokenMods.RandomID(f), OauthTokenMods.RandomAccessToken(f), - OauthTokenMods.RandomExpires(f), + OauthTokenMods.RandomAccessTokenExpires(f), OauthTokenMods.RandomRefreshToken(f), OauthTokenMods.RandomUsername(f), OauthTokenMods.RandomUserID(f), OauthTokenMods.RandomArcgisID(f), OauthTokenMods.RandomArcgisLicenseTypeID(f), + OauthTokenMods.RandomRefreshTokenExpires(f), } } @@ -394,31 +403,31 @@ func (m oauthTokenMods) RandomAccessToken(f *faker.Faker) OauthTokenMod { } // Set the model columns to this value -func (m oauthTokenMods) Expires(val time.Time) OauthTokenMod { +func (m oauthTokenMods) AccessTokenExpires(val time.Time) OauthTokenMod { return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) { - o.Expires = func() time.Time { return val } + o.AccessTokenExpires = func() time.Time { return val } }) } // Set the Column from the function -func (m oauthTokenMods) ExpiresFunc(f func() time.Time) OauthTokenMod { +func (m oauthTokenMods) AccessTokenExpiresFunc(f func() time.Time) OauthTokenMod { return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) { - o.Expires = f + o.AccessTokenExpires = f }) } // Clear any values for the column -func (m oauthTokenMods) UnsetExpires() OauthTokenMod { +func (m oauthTokenMods) UnsetAccessTokenExpires() OauthTokenMod { return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) { - o.Expires = nil + o.AccessTokenExpires = nil }) } // Generates a random value for the column using the given faker // if faker is nil, a default faker is used -func (m oauthTokenMods) RandomExpires(f *faker.Faker) OauthTokenMod { +func (m oauthTokenMods) RandomAccessTokenExpires(f *faker.Faker) OauthTokenMod { return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) { - o.Expires = func() time.Time { + o.AccessTokenExpires = func() time.Time { return random_time_Time(f) } }) @@ -623,6 +632,37 @@ func (m oauthTokenMods) RandomArcgisLicenseTypeIDNotNull(f *faker.Faker) OauthTo }) } +// Set the model columns to this value +func (m oauthTokenMods) RefreshTokenExpires(val time.Time) OauthTokenMod { + return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) { + o.RefreshTokenExpires = func() time.Time { return val } + }) +} + +// Set the Column from the function +func (m oauthTokenMods) RefreshTokenExpiresFunc(f func() time.Time) OauthTokenMod { + return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) { + o.RefreshTokenExpires = f + }) +} + +// Clear any values for the column +func (m oauthTokenMods) UnsetRefreshTokenExpires() OauthTokenMod { + return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) { + o.RefreshTokenExpires = nil + }) +} + +// Generates a random value for the column using the given faker +// if faker is nil, a default faker is used +func (m oauthTokenMods) RandomRefreshTokenExpires(f *faker.Faker) OauthTokenMod { + return OauthTokenModFunc(func(_ context.Context, o *OauthTokenTemplate) { + o.RefreshTokenExpires = func() time.Time { + return random_time_Time(f) + } + }) +} + func (m oauthTokenMods) WithParentsCascading() OauthTokenMod { return OauthTokenModFunc(func(ctx context.Context, o *OauthTokenTemplate) { if isDone, _ := oauthTokenWithParentsCascadingCtx.Value(ctx); isDone { diff --git a/main.go b/main.go index ced13d0f..ebe69626 100644 --- a/main.go +++ b/main.go @@ -82,14 +82,14 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - newTokenChannel := make(chan int, 10) + NewOAuthTokenChannel = make(chan struct{}, 10) var waitGroup sync.WaitGroup waitGroup.Add(1) go func() { defer waitGroup.Done() - refreshFieldseekerData(ctx, newTokenChannel) + refreshFieldseekerData(ctx, NewOAuthTokenChannel) }() server := &http.Server{ diff --git a/migrations/00006_add_oauth_refresh_expires.sql b/migrations/00006_add_oauth_refresh_expires.sql new file mode 100644 index 00000000..65c2bffd --- /dev/null +++ b/migrations/00006_add_oauth_refresh_expires.sql @@ -0,0 +1,7 @@ +-- +goose Up +ALTER TABLE oauth_token RENAME COLUMN expires TO access_token_expires; +ALTER TABLE oauth_token ADD COLUMN refresh_token_expires TIMESTAMP NOT NULL DEFAULT current_timestamp; + +-- +goose Down +ALTER TABLE oauth_token DROP COLUMN refresh_token_expires; +ALTER TABLE oauth_token RENAME COLUMN access_token_expires TO expires; diff --git a/models/oauth_token.bob.go b/models/oauth_token.bob.go index 38b582c5..89d20aa4 100644 --- a/models/oauth_token.bob.go +++ b/models/oauth_token.bob.go @@ -28,12 +28,13 @@ import ( type OauthToken struct { ID int32 `db:"id,pk" ` AccessToken string `db:"access_token" ` - Expires time.Time `db:"expires" ` + AccessTokenExpires time.Time `db:"access_token_expires" ` RefreshToken string `db:"refresh_token" ` Username string `db:"username" ` UserID int32 `db:"user_id" ` ArcgisID null.Val[string] `db:"arcgis_id" ` ArcgisLicenseTypeID null.Val[string] `db:"arcgis_license_type_id" ` + RefreshTokenExpires time.Time `db:"refresh_token_expires" ` R oauthTokenR `db:"-" ` } @@ -56,17 +57,18 @@ type oauthTokenR struct { func buildOauthTokenColumns(alias string) oauthTokenColumns { return oauthTokenColumns{ ColumnsExpr: expr.NewColumnsExpr( - "id", "access_token", "expires", "refresh_token", "username", "user_id", "arcgis_id", "arcgis_license_type_id", + "id", "access_token", "access_token_expires", "refresh_token", "username", "user_id", "arcgis_id", "arcgis_license_type_id", "refresh_token_expires", ).WithParent("oauth_token"), tableAlias: alias, ID: psql.Quote(alias, "id"), AccessToken: psql.Quote(alias, "access_token"), - Expires: psql.Quote(alias, "expires"), + AccessTokenExpires: psql.Quote(alias, "access_token_expires"), RefreshToken: psql.Quote(alias, "refresh_token"), Username: psql.Quote(alias, "username"), UserID: psql.Quote(alias, "user_id"), ArcgisID: psql.Quote(alias, "arcgis_id"), ArcgisLicenseTypeID: psql.Quote(alias, "arcgis_license_type_id"), + RefreshTokenExpires: psql.Quote(alias, "refresh_token_expires"), } } @@ -75,12 +77,13 @@ type oauthTokenColumns struct { tableAlias string ID psql.Expression AccessToken psql.Expression - Expires psql.Expression + AccessTokenExpires psql.Expression RefreshToken psql.Expression Username psql.Expression UserID psql.Expression ArcgisID psql.Expression ArcgisLicenseTypeID psql.Expression + RefreshTokenExpires psql.Expression } func (c oauthTokenColumns) Alias() string { @@ -97,24 +100,25 @@ func (oauthTokenColumns) AliasedAs(alias string) oauthTokenColumns { type OauthTokenSetter struct { ID omit.Val[int32] `db:"id,pk" ` AccessToken omit.Val[string] `db:"access_token" ` - Expires omit.Val[time.Time] `db:"expires" ` + AccessTokenExpires omit.Val[time.Time] `db:"access_token_expires" ` RefreshToken omit.Val[string] `db:"refresh_token" ` Username omit.Val[string] `db:"username" ` UserID omit.Val[int32] `db:"user_id" ` ArcgisID omitnull.Val[string] `db:"arcgis_id" ` ArcgisLicenseTypeID omitnull.Val[string] `db:"arcgis_license_type_id" ` + RefreshTokenExpires omit.Val[time.Time] `db:"refresh_token_expires" ` } func (s OauthTokenSetter) SetColumns() []string { - vals := make([]string, 0, 8) + vals := make([]string, 0, 9) if s.ID.IsValue() { vals = append(vals, "id") } if s.AccessToken.IsValue() { vals = append(vals, "access_token") } - if s.Expires.IsValue() { - vals = append(vals, "expires") + if s.AccessTokenExpires.IsValue() { + vals = append(vals, "access_token_expires") } if s.RefreshToken.IsValue() { vals = append(vals, "refresh_token") @@ -131,6 +135,9 @@ func (s OauthTokenSetter) SetColumns() []string { if !s.ArcgisLicenseTypeID.IsUnset() { vals = append(vals, "arcgis_license_type_id") } + if s.RefreshTokenExpires.IsValue() { + vals = append(vals, "refresh_token_expires") + } return vals } @@ -141,8 +148,8 @@ func (s OauthTokenSetter) Overwrite(t *OauthToken) { if s.AccessToken.IsValue() { t.AccessToken = s.AccessToken.MustGet() } - if s.Expires.IsValue() { - t.Expires = s.Expires.MustGet() + if s.AccessTokenExpires.IsValue() { + t.AccessTokenExpires = s.AccessTokenExpires.MustGet() } if s.RefreshToken.IsValue() { t.RefreshToken = s.RefreshToken.MustGet() @@ -159,6 +166,9 @@ func (s OauthTokenSetter) Overwrite(t *OauthToken) { if !s.ArcgisLicenseTypeID.IsUnset() { t.ArcgisLicenseTypeID = s.ArcgisLicenseTypeID.MustGetNull() } + if s.RefreshTokenExpires.IsValue() { + t.RefreshTokenExpires = s.RefreshTokenExpires.MustGet() + } } func (s *OauthTokenSetter) Apply(q *dialect.InsertQuery) { @@ -167,7 +177,7 @@ func (s *OauthTokenSetter) Apply(q *dialect.InsertQuery) { }) q.AppendValues(bob.ExpressionFunc(func(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) { - vals := make([]bob.Expression, 8) + vals := make([]bob.Expression, 9) if s.ID.IsValue() { vals[0] = psql.Arg(s.ID.MustGet()) } else { @@ -180,8 +190,8 @@ func (s *OauthTokenSetter) Apply(q *dialect.InsertQuery) { vals[1] = psql.Raw("DEFAULT") } - if s.Expires.IsValue() { - vals[2] = psql.Arg(s.Expires.MustGet()) + if s.AccessTokenExpires.IsValue() { + vals[2] = psql.Arg(s.AccessTokenExpires.MustGet()) } else { vals[2] = psql.Raw("DEFAULT") } @@ -216,6 +226,12 @@ func (s *OauthTokenSetter) Apply(q *dialect.InsertQuery) { vals[7] = psql.Raw("DEFAULT") } + if s.RefreshTokenExpires.IsValue() { + vals[8] = psql.Arg(s.RefreshTokenExpires.MustGet()) + } else { + vals[8] = psql.Raw("DEFAULT") + } + return bob.ExpressSlice(ctx, w, d, start, vals, "", ", ", "") })) } @@ -225,7 +241,7 @@ func (s OauthTokenSetter) UpdateMod() bob.Mod[*dialect.UpdateQuery] { } func (s OauthTokenSetter) Expressions(prefix ...string) []bob.Expression { - exprs := make([]bob.Expression, 0, 8) + exprs := make([]bob.Expression, 0, 9) if s.ID.IsValue() { exprs = append(exprs, expr.Join{Sep: " = ", Exprs: []bob.Expression{ @@ -241,10 +257,10 @@ func (s OauthTokenSetter) Expressions(prefix ...string) []bob.Expression { }}) } - if s.Expires.IsValue() { + if s.AccessTokenExpires.IsValue() { exprs = append(exprs, expr.Join{Sep: " = ", Exprs: []bob.Expression{ - psql.Quote(append(prefix, "expires")...), - psql.Arg(s.Expires), + psql.Quote(append(prefix, "access_token_expires")...), + psql.Arg(s.AccessTokenExpires), }}) } @@ -283,6 +299,13 @@ func (s OauthTokenSetter) Expressions(prefix ...string) []bob.Expression { }}) } + if s.RefreshTokenExpires.IsValue() { + exprs = append(exprs, expr.Join{Sep: " = ", Exprs: []bob.Expression{ + psql.Quote(append(prefix, "refresh_token_expires")...), + psql.Arg(s.RefreshTokenExpires), + }}) + } + return exprs } @@ -584,12 +607,13 @@ func (oauthToken0 *OauthToken) AttachUserUser(ctx context.Context, exec bob.Exec type oauthTokenWhere[Q psql.Filterable] struct { ID psql.WhereMod[Q, int32] AccessToken psql.WhereMod[Q, string] - Expires psql.WhereMod[Q, time.Time] + AccessTokenExpires psql.WhereMod[Q, time.Time] RefreshToken psql.WhereMod[Q, string] Username psql.WhereMod[Q, string] UserID psql.WhereMod[Q, int32] ArcgisID psql.WhereNullMod[Q, string] ArcgisLicenseTypeID psql.WhereNullMod[Q, string] + RefreshTokenExpires psql.WhereMod[Q, time.Time] } func (oauthTokenWhere[Q]) AliasedAs(alias string) oauthTokenWhere[Q] { @@ -600,12 +624,13 @@ func buildOauthTokenWhere[Q psql.Filterable](cols oauthTokenColumns) oauthTokenW return oauthTokenWhere[Q]{ ID: psql.Where[Q, int32](cols.ID), AccessToken: psql.Where[Q, string](cols.AccessToken), - Expires: psql.Where[Q, time.Time](cols.Expires), + AccessTokenExpires: psql.Where[Q, time.Time](cols.AccessTokenExpires), RefreshToken: psql.Where[Q, string](cols.RefreshToken), Username: psql.Where[Q, string](cols.Username), UserID: psql.Where[Q, int32](cols.UserID), ArcgisID: psql.WhereNull[Q, string](cols.ArcgisID), ArcgisLicenseTypeID: psql.WhereNull[Q, string](cols.ArcgisLicenseTypeID), + RefreshTokenExpires: psql.Where[Q, time.Time](cols.RefreshTokenExpires), } } diff --git a/sql/oauth_by_user_id.bob.go b/sql/oauth_by_user_id.bob.go index 590ac15b..69e7b765 100644 --- a/sql/oauth_by_user_id.bob.go +++ b/sql/oauth_by_user_id.bob.go @@ -21,7 +21,7 @@ import ( //go:embed oauth_by_user_id.bob.sql var formattedQueries_oauth_by_user_id string -var oauthTokenByUserIdSQL = formattedQueries_oauth_by_user_id[156:550] +var oauthTokenByUserIdSQL = formattedQueries_oauth_by_user_id[156:642] type OauthTokenByUserIdQuery = orm.ModQuery[*dialect.SelectQuery, oauthTokenByUserId, OauthTokenByUserIdRow, []OauthTokenByUserIdRow, oauthTokenByUserIdTransformer] @@ -44,12 +44,13 @@ func OauthTokenByUserId(UserID int32) *OauthTokenByUserIdQuery { var t OauthTokenByUserIdRow row.ScheduleScanByIndex(0, &t.ID) row.ScheduleScanByIndex(1, &t.AccessToken) - row.ScheduleScanByIndex(2, &t.Expires) + row.ScheduleScanByIndex(2, &t.AccessTokenExpires) row.ScheduleScanByIndex(3, &t.RefreshToken) row.ScheduleScanByIndex(4, &t.Username) row.ScheduleScanByIndex(5, &t.UserID) row.ScheduleScanByIndex(6, &t.ArcgisID) row.ScheduleScanByIndex(7, &t.ArcgisLicenseTypeID) + row.ScheduleScanByIndex(8, &t.RefreshTokenExpires) return &t, nil }, func(v any) (OauthTokenByUserIdRow, error) { return *(v.(*OauthTokenByUserIdRow)), nil @@ -57,9 +58,9 @@ func OauthTokenByUserId(UserID int32) *OauthTokenByUserIdQuery { }, }, Mod: bob.ModFunc[*dialect.SelectQuery](func(q *dialect.SelectQuery) { - q.AppendSelect(expressionTypArgs.subExpr(7, 357)) - q.SetTable(expressionTypArgs.subExpr(363, 374)) - q.AppendWhere(expressionTypArgs.subExpr(382, 394)) + q.AppendSelect(expressionTypArgs.subExpr(7, 449)) + q.SetTable(expressionTypArgs.subExpr(455, 466)) + q.AppendWhere(expressionTypArgs.subExpr(474, 486)) }), } } @@ -67,12 +68,13 @@ func OauthTokenByUserId(UserID int32) *OauthTokenByUserIdQuery { type OauthTokenByUserIdRow = struct { ID int32 `db:"id"` AccessToken string `db:"access_token"` - Expires time.Time `db:"expires"` + AccessTokenExpires time.Time `db:"access_token_expires"` RefreshToken string `db:"refresh_token"` Username string `db:"username"` UserID int32 `db:"user_id"` ArcgisID null.Val[string] `db:"arcgis_id"` ArcgisLicenseTypeID null.Val[string] `db:"arcgis_license_type_id"` + RefreshTokenExpires time.Time `db:"refresh_token_expires"` } type oauthTokenByUserIdTransformer = bob.SliceTransformer[OauthTokenByUserIdRow, []OauthTokenByUserIdRow] @@ -85,8 +87,8 @@ func (o oauthTokenByUserId) args() iter.Seq[orm.ArgWithPosition] { return func(yield func(arg orm.ArgWithPosition) bool) { if !yield(orm.ArgWithPosition{ Name: "userID", - Start: 392, - Stop: 394, + Start: 484, + Stop: 486, Expression: o.UserID, }) { return diff --git a/sql/oauth_by_user_id.bob.sql b/sql/oauth_by_user_id.bob.sql index 4d372c2b..a2e32b78 100644 --- a/sql/oauth_by_user_id.bob.sql +++ b/sql/oauth_by_user_id.bob.sql @@ -2,5 +2,5 @@ -- This file is meant to be re-generated in place and/or deleted at any time. -- OauthTokenByUserId -SELECT "oauth_token"."id" AS "id", "oauth_token"."access_token" AS "access_token", "oauth_token"."expires" AS "expires", "oauth_token"."refresh_token" AS "refresh_token", "oauth_token"."username" AS "username", "oauth_token"."user_id" AS "user_id", "oauth_token"."arcgis_id" AS "arcgis_id", "oauth_token"."arcgis_license_type_id" AS "arcgis_license_type_id" FROM oauth_token WHERE +SELECT "oauth_token"."id" AS "id", "oauth_token"."access_token" AS "access_token", "oauth_token"."access_token_expires" AS "access_token_expires", "oauth_token"."refresh_token" AS "refresh_token", "oauth_token"."username" AS "username", "oauth_token"."user_id" AS "user_id", "oauth_token"."arcgis_id" AS "arcgis_id", "oauth_token"."arcgis_license_type_id" AS "arcgis_license_type_id", "oauth_token"."refresh_token_expires" AS "refresh_token_expires" FROM oauth_token WHERE user_id = $1; diff --git a/sql/oauth_by_user_id.bob_test.go b/sql/oauth_by_user_id.bob_test.go index cf3425ce..09a040c0 100644 --- a/sql/oauth_by_user_id.bob_test.go +++ b/sql/oauth_by_user_id.bob_test.go @@ -84,8 +84,8 @@ func TestOauthTokenByUserId(t *testing.T) { t.Fatal(err) } - if len(columns) != 8 { - t.Fatalf("expected %d columns, got %d", 8, len(columns)) + if len(columns) != 9 { + t.Fatalf("expected %d columns, got %d", 9, len(columns)) } if columns[0] != "id" { @@ -96,8 +96,8 @@ func TestOauthTokenByUserId(t *testing.T) { t.Fatalf("expected column %d to be %s, got %s", 1, "access_token", columns[1]) } - if columns[2] != "expires" { - t.Fatalf("expected column %d to be %s, got %s", 2, "expires", columns[2]) + if columns[2] != "access_token_expires" { + t.Fatalf("expected column %d to be %s, got %s", 2, "access_token_expires", columns[2]) } if columns[3] != "refresh_token" { @@ -119,5 +119,9 @@ func TestOauthTokenByUserId(t *testing.T) { if columns[7] != "arcgis_license_type_id" { t.Fatalf("expected column %d to be %s, got %s", 7, "arcgis_license_type_id", columns[7]) } + + if columns[8] != "refresh_token_expires" { + t.Fatalf("expected column %d to be %s, got %s", 8, "refresh_token_expires", columns[8]) + } }) }