Track access token and refresh token expiry

Also make a bunch more progress on actually updating the tokens when we
need them updated.
This commit is contained in:
Eli Ribble 2025-11-07 05:46:41 +00:00
parent cf01c8c5c6
commit 109495b702
No known key found for this signature in database
11 changed files with 348 additions and 104 deletions

249
arcgis.go
View file

@ -1,6 +1,7 @@
package main
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
@ -12,8 +13,10 @@ import (
"log/slog"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/Gleipnir-Technology/arcgis-go"
@ -24,13 +27,28 @@ import (
"github.com/jackc/pgx/v5"
)
var NewOAuthTokenChannel chan struct{}
var CodeVerifier string = "random_secure_string_min_43_chars_long_should_be_stored_in_session"
type ErrorResponse struct {
Error ErrorResponseContent `json:"error"`
}
type ErrorResponseContent struct {
Code int `json:"code"`
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
Message string `json:"message"`
Details []string `json:"details"`
}
type OAuthTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
Username string `json:"username"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
RefreshTokenExpiresIn int `json:"refresh_token_expires_in"`
SSL bool `json:"ssl"`
Username string `json:"username"`
}
// Build the ArcGIS authorization URL with PKCE
@ -66,12 +84,13 @@ func generateCodeVerifier() string {
}
// Find out what we can about this user
func updateArcgisUserData(ctx context.Context, user *models.User, access_token string, expires time.Time, refresh_token string) {
func updateArcgisUserData(ctx context.Context, user *models.User, access_token string, access_token_expires time.Time, refresh_token string, refresh_token_expires time.Time) {
client := arcgis.NewArcGIS(
arcgis.AuthenticatorOAuth{
AccessToken: access_token,
Expires: expires,
RefreshToken: refresh_token,
AccessToken: access_token,
AccessTokenExpires: access_token_expires,
RefreshToken: refresh_token,
RefreshTokenExpires: refresh_token_expires,
},
)
portal, err := client.PortalsSelf()
@ -158,49 +177,25 @@ func handleOauthAccessCode(ctx context.Context, user *models.User, code string)
return fmt.Errorf("Failed to create request: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
client := http.Client{}
slog.Info("POST", slog.String("url", baseURL))
resp, err := client.Do(req)
token, err := handleTokenRequest(ctx, req)
if err != nil {
return fmt.Errorf("Failed to do request: %v", err)
return fmt.Errorf("Failed to exchange authorization code for token: %v", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
slog.Info("Response", slog.Int("status", resp.StatusCode))
if resp.StatusCode >= http.StatusBadRequest {
if err != nil {
return fmt.Errorf("Got status code %d and failed to read response body: %v", resp.StatusCode, err)
}
bodyString := string(bodyBytes)
var errorResp map[string]interface{}
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil {
return fmt.Errorf("API response JSON error: %d: %v", resp.StatusCode, errorResp)
}
return fmt.Errorf("API returned error status %d: %s", resp.StatusCode, bodyString)
}
var tokenResponse OAuthTokenResponse
err = json.Unmarshal(bodyBytes, &tokenResponse)
if err != nil {
return fmt.Errorf("Failed to unmarshal JSON: %v", err)
}
slog.Info("Oauth token acquired",
slog.String("refresh token", tokenResponse.RefreshToken),
slog.String("access token", tokenResponse.AccessToken),
slog.Int("expires", tokenResponse.ExpiresIn),
)
expires := futureUTCTimestamp(tokenResponse.ExpiresIn)
accessExpires := futureUTCTimestamp(token.ExpiresIn)
refreshExpires := futureUTCTimestamp(token.RefreshTokenExpiresIn)
setter := models.OauthTokenSetter{
AccessToken: omit.From(tokenResponse.AccessToken),
Expires: omit.From(expires),
RefreshToken: omit.From(tokenResponse.RefreshToken),
Username: omit.From(tokenResponse.Username),
AccessToken: omit.From(token.AccessToken),
AccessTokenExpires: omit.From(accessExpires),
RefreshToken: omit.From(token.RefreshToken),
RefreshTokenExpires: omit.From(refreshExpires),
Username: omit.From(token.Username),
}
err = user.InsertUserOauthTokens(ctx, PGInstance.BobDB, &setter)
if err != nil {
return fmt.Errorf("Failed to save token to database: %v", err)
}
go updateArcgisUserData(context.Background(), user, tokenResponse.AccessToken, expires, tokenResponse.RefreshToken)
go updateArcgisUserData(context.Background(), user, token.AccessToken, accessExpires, token.RefreshToken, refreshExpires)
NewOAuthTokenChannel <- struct{}{}
return nil
}
@ -216,14 +211,174 @@ func redirectURL() string {
}
// This is a goroutine that is in charge of getting Fieldseeker data and keeping it fresh.
func refreshFieldseekerData(ctx context.Context, newOauthCh <-chan int) {
func refreshFieldseekerData(ctx context.Context, newOauthCh <-chan struct{}) {
for {
workerCtx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
oauths, err := models.OauthTokens.Query().All(ctx, PGInstance.BobDB)
if err != nil {
slog.Error("Failed to get oauths", slog.String("err", err.Error()))
return
}
for _, oauth := range oauths {
wg.Add(1)
go func() {
defer wg.Done()
maintainOAuth(workerCtx, oauth)
}()
}
select {
case <-ctx.Done():
slog.Info("Exiting refresh worker")
slog.Info("Exiting refresh worker...")
cancel()
wg.Wait()
return
case id := <-newOauthCh:
slog.Info("Adding oauth to background work", slog.Int("oauth id", id))
case <-newOauthCh:
slog.Info("Updating oauth background work")
cancel()
wg.Wait()
}
}
}
func maintainOAuth(ctx context.Context, oauth *models.OauthToken) {
refreshDelay := time.Until(oauth.AccessTokenExpires)
slog.Info("Need to refresh oauth", slog.Int("id", int(oauth.ID)), slog.Float64("seconds", refreshDelay.Seconds()))
if oauth.AccessTokenExpires.Before(time.Now()) {
err := refreshOAuth(ctx, oauth)
if err != nil {
slog.Error("Failed to refresh token", slog.String("err", err.Error()))
return
}
}
ticker := time.NewTicker(refreshDelay)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
}
}
func refreshOAuth(ctx context.Context, oauth *models.OauthToken) error {
baseURL := "https://www.arcgis.com/sharing/rest/oauth2/token/"
form := url.Values{
"grant_type": []string{"refresh_token"},
"client_id": []string{ClientID},
"refresh_token": []string{oauth.RefreshToken},
}
req, err := http.NewRequest("POST", baseURL, strings.NewReader(form.Encode()))
if err != nil {
return fmt.Errorf("Failed to create request: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
token, err := handleTokenRequest(ctx, req)
if err != nil {
return fmt.Errorf("Failed to handle request: %v", err)
}
accessExpires := futureUTCTimestamp(token.ExpiresIn)
refreshExpires := futureUTCTimestamp(token.RefreshTokenExpiresIn)
setter := models.OauthTokenSetter{
AccessToken: omit.From(token.AccessToken),
AccessTokenExpires: omit.From(accessExpires),
RefreshToken: omit.From(token.RefreshToken),
RefreshTokenExpires: omit.From(refreshExpires),
Username: omit.From(token.Username),
}
err = oauth.Update(ctx, PGInstance.BobDB, &setter)
if err != nil {
return fmt.Errorf("Failed to update oauth in database: %v", err)
}
return nil
}
func handleTokenRequest(ctx context.Context, req *http.Request) (*OAuthTokenResponse, error) {
client := http.Client{}
slog.Info("POST", slog.String("url", req.URL.String()))
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("Failed to do request: %v", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
slog.Info("Token request", slog.Int("status", resp.StatusCode))
saveResponse(bodyBytes, "token.json")
if resp.StatusCode >= http.StatusBadRequest {
if err != nil {
return nil, fmt.Errorf("Got status code %d and failed to read response body: %v", resp.StatusCode, err)
}
bodyString := string(bodyBytes)
var errorResp map[string]interface{}
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil {
return nil, fmt.Errorf("API response JSON error: %d: %v", resp.StatusCode, errorResp)
}
return nil, fmt.Errorf("API returned error status %d: %s", resp.StatusCode, bodyString)
}
//logResponseHeaders(resp)
var tokenResponse OAuthTokenResponse
err = json.Unmarshal(bodyBytes, &tokenResponse)
if err != nil {
return nil, fmt.Errorf("Failed to unmarshal JSON: %v", err)
}
// Just because we got a 200-level status code doesn't mean it worked. Experience has taught us that
// we can get errors without anything indicated in the headers or the status code
if tokenResponse == (OAuthTokenResponse{}) {
var errorResponse ErrorResponse
err = json.Unmarshal(bodyBytes, &errorResponse)
if err != nil {
return nil, fmt.Errorf("Failed to unmarshal error JSON: %v", err)
}
if errorResponse.Error.Code > 0 {
return nil, errors.New(fmt.Sprintf("API error %d: %s: %s (%s)",
errorResponse.Error.Code,
errorResponse.Error.Error,
errorResponse.Error.ErrorDescription,
errorResponse.Error.Message,
))
}
}
slog.Info("Oauth token acquired",
slog.String("refresh token", tokenResponse.RefreshToken),
slog.String("access token", tokenResponse.AccessToken),
slog.Int("access expires", tokenResponse.ExpiresIn),
slog.Int("refresh expires", tokenResponse.RefreshTokenExpiresIn),
)
return &tokenResponse, nil
}
func logResponseHeaders(resp *http.Response) {
if resp == nil {
slog.Info("Response is nil")
return
}
slog.Info("HTTP Response headers",
"status", resp.Status,
"statusCode", resp.StatusCode)
for name, values := range resp.Header {
slog.Info("Header",
"name", name,
"values", values)
}
}
func saveResponse(data []byte, filename string) {
dest, err := os.Create(filename)
if err != nil {
slog.Error("Failed to create file", slog.String("filename", filename), slog.String("err", err.Error()))
return
}
_, err = io.Copy(dest, bytes.NewReader(data))
if err != nil {
slog.Error("Failed to write", slog.String("filename", filename), slog.String("err", err.Error()))
return
}
slog.Info("Wrote response", slog.String("filename", filename))
}