Track access token and refresh token expiry
Also make a bunch more progress on actually updating the tokens when we need them updated.
This commit is contained in:
parent
cf01c8c5c6
commit
109495b702
11 changed files with 348 additions and 104 deletions
249
arcgis.go
249
arcgis.go
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
|
|
@ -12,8 +13,10 @@ import (
|
|||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Gleipnir-Technology/arcgis-go"
|
||||
|
|
@ -24,13 +27,28 @@ import (
|
|||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
var NewOAuthTokenChannel chan struct{}
|
||||
var CodeVerifier string = "random_secure_string_min_43_chars_long_should_be_stored_in_session"
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error ErrorResponseContent `json:"error"`
|
||||
}
|
||||
|
||||
type ErrorResponseContent struct {
|
||||
Code int `json:"code"`
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
Message string `json:"message"`
|
||||
Details []string `json:"details"`
|
||||
}
|
||||
|
||||
type OAuthTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Username string `json:"username"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RefreshTokenExpiresIn int `json:"refresh_token_expires_in"`
|
||||
SSL bool `json:"ssl"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// Build the ArcGIS authorization URL with PKCE
|
||||
|
|
@ -66,12 +84,13 @@ func generateCodeVerifier() string {
|
|||
}
|
||||
|
||||
// Find out what we can about this user
|
||||
func updateArcgisUserData(ctx context.Context, user *models.User, access_token string, expires time.Time, refresh_token string) {
|
||||
func updateArcgisUserData(ctx context.Context, user *models.User, access_token string, access_token_expires time.Time, refresh_token string, refresh_token_expires time.Time) {
|
||||
client := arcgis.NewArcGIS(
|
||||
arcgis.AuthenticatorOAuth{
|
||||
AccessToken: access_token,
|
||||
Expires: expires,
|
||||
RefreshToken: refresh_token,
|
||||
AccessToken: access_token,
|
||||
AccessTokenExpires: access_token_expires,
|
||||
RefreshToken: refresh_token,
|
||||
RefreshTokenExpires: refresh_token_expires,
|
||||
},
|
||||
)
|
||||
portal, err := client.PortalsSelf()
|
||||
|
|
@ -158,49 +177,25 @@ func handleOauthAccessCode(ctx context.Context, user *models.User, code string)
|
|||
return fmt.Errorf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
client := http.Client{}
|
||||
slog.Info("POST", slog.String("url", baseURL))
|
||||
resp, err := client.Do(req)
|
||||
token, err := handleTokenRequest(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to do request: %v", err)
|
||||
return fmt.Errorf("Failed to exchange authorization code for token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
slog.Info("Response", slog.Int("status", resp.StatusCode))
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
if err != nil {
|
||||
return fmt.Errorf("Got status code %d and failed to read response body: %v", resp.StatusCode, err)
|
||||
}
|
||||
bodyString := string(bodyBytes)
|
||||
var errorResp map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil {
|
||||
return fmt.Errorf("API response JSON error: %d: %v", resp.StatusCode, errorResp)
|
||||
}
|
||||
return fmt.Errorf("API returned error status %d: %s", resp.StatusCode, bodyString)
|
||||
}
|
||||
var tokenResponse OAuthTokenResponse
|
||||
err = json.Unmarshal(bodyBytes, &tokenResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
slog.Info("Oauth token acquired",
|
||||
slog.String("refresh token", tokenResponse.RefreshToken),
|
||||
slog.String("access token", tokenResponse.AccessToken),
|
||||
slog.Int("expires", tokenResponse.ExpiresIn),
|
||||
)
|
||||
|
||||
expires := futureUTCTimestamp(tokenResponse.ExpiresIn)
|
||||
accessExpires := futureUTCTimestamp(token.ExpiresIn)
|
||||
refreshExpires := futureUTCTimestamp(token.RefreshTokenExpiresIn)
|
||||
setter := models.OauthTokenSetter{
|
||||
AccessToken: omit.From(tokenResponse.AccessToken),
|
||||
Expires: omit.From(expires),
|
||||
RefreshToken: omit.From(tokenResponse.RefreshToken),
|
||||
Username: omit.From(tokenResponse.Username),
|
||||
AccessToken: omit.From(token.AccessToken),
|
||||
AccessTokenExpires: omit.From(accessExpires),
|
||||
RefreshToken: omit.From(token.RefreshToken),
|
||||
RefreshTokenExpires: omit.From(refreshExpires),
|
||||
Username: omit.From(token.Username),
|
||||
}
|
||||
err = user.InsertUserOauthTokens(ctx, PGInstance.BobDB, &setter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to save token to database: %v", err)
|
||||
}
|
||||
go updateArcgisUserData(context.Background(), user, tokenResponse.AccessToken, expires, tokenResponse.RefreshToken)
|
||||
go updateArcgisUserData(context.Background(), user, token.AccessToken, accessExpires, token.RefreshToken, refreshExpires)
|
||||
NewOAuthTokenChannel <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -216,14 +211,174 @@ func redirectURL() string {
|
|||
}
|
||||
|
||||
// This is a goroutine that is in charge of getting Fieldseeker data and keeping it fresh.
|
||||
func refreshFieldseekerData(ctx context.Context, newOauthCh <-chan int) {
|
||||
func refreshFieldseekerData(ctx context.Context, newOauthCh <-chan struct{}) {
|
||||
for {
|
||||
workerCtx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
oauths, err := models.OauthTokens.Query().All(ctx, PGInstance.BobDB)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get oauths", slog.String("err", err.Error()))
|
||||
return
|
||||
}
|
||||
for _, oauth := range oauths {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
maintainOAuth(workerCtx, oauth)
|
||||
}()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("Exiting refresh worker")
|
||||
slog.Info("Exiting refresh worker...")
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return
|
||||
case id := <-newOauthCh:
|
||||
slog.Info("Adding oauth to background work", slog.Int("oauth id", id))
|
||||
case <-newOauthCh:
|
||||
slog.Info("Updating oauth background work")
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func maintainOAuth(ctx context.Context, oauth *models.OauthToken) {
|
||||
refreshDelay := time.Until(oauth.AccessTokenExpires)
|
||||
slog.Info("Need to refresh oauth", slog.Int("id", int(oauth.ID)), slog.Float64("seconds", refreshDelay.Seconds()))
|
||||
if oauth.AccessTokenExpires.Before(time.Now()) {
|
||||
err := refreshOAuth(ctx, oauth)
|
||||
if err != nil {
|
||||
slog.Error("Failed to refresh token", slog.String("err", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
ticker := time.NewTicker(refreshDelay)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func refreshOAuth(ctx context.Context, oauth *models.OauthToken) error {
|
||||
baseURL := "https://www.arcgis.com/sharing/rest/oauth2/token/"
|
||||
|
||||
form := url.Values{
|
||||
"grant_type": []string{"refresh_token"},
|
||||
"client_id": []string{ClientID},
|
||||
"refresh_token": []string{oauth.RefreshToken},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", baseURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
token, err := handleTokenRequest(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to handle request: %v", err)
|
||||
}
|
||||
accessExpires := futureUTCTimestamp(token.ExpiresIn)
|
||||
refreshExpires := futureUTCTimestamp(token.RefreshTokenExpiresIn)
|
||||
setter := models.OauthTokenSetter{
|
||||
AccessToken: omit.From(token.AccessToken),
|
||||
AccessTokenExpires: omit.From(accessExpires),
|
||||
RefreshToken: omit.From(token.RefreshToken),
|
||||
RefreshTokenExpires: omit.From(refreshExpires),
|
||||
Username: omit.From(token.Username),
|
||||
}
|
||||
err = oauth.Update(ctx, PGInstance.BobDB, &setter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to update oauth in database: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleTokenRequest(ctx context.Context, req *http.Request) (*OAuthTokenResponse, error) {
|
||||
client := http.Client{}
|
||||
slog.Info("POST", slog.String("url", req.URL.String()))
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to do request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
slog.Info("Token request", slog.Int("status", resp.StatusCode))
|
||||
saveResponse(bodyBytes, "token.json")
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Got status code %d and failed to read response body: %v", resp.StatusCode, err)
|
||||
}
|
||||
bodyString := string(bodyBytes)
|
||||
var errorResp map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &errorResp); err == nil {
|
||||
return nil, fmt.Errorf("API response JSON error: %d: %v", resp.StatusCode, errorResp)
|
||||
}
|
||||
return nil, fmt.Errorf("API returned error status %d: %s", resp.StatusCode, bodyString)
|
||||
}
|
||||
//logResponseHeaders(resp)
|
||||
var tokenResponse OAuthTokenResponse
|
||||
err = json.Unmarshal(bodyBytes, &tokenResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
// Just because we got a 200-level status code doesn't mean it worked. Experience has taught us that
|
||||
// we can get errors without anything indicated in the headers or the status code
|
||||
if tokenResponse == (OAuthTokenResponse{}) {
|
||||
var errorResponse ErrorResponse
|
||||
err = json.Unmarshal(bodyBytes, &errorResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to unmarshal error JSON: %v", err)
|
||||
}
|
||||
if errorResponse.Error.Code > 0 {
|
||||
return nil, errors.New(fmt.Sprintf("API error %d: %s: %s (%s)",
|
||||
errorResponse.Error.Code,
|
||||
errorResponse.Error.Error,
|
||||
errorResponse.Error.ErrorDescription,
|
||||
errorResponse.Error.Message,
|
||||
))
|
||||
}
|
||||
}
|
||||
slog.Info("Oauth token acquired",
|
||||
slog.String("refresh token", tokenResponse.RefreshToken),
|
||||
slog.String("access token", tokenResponse.AccessToken),
|
||||
slog.Int("access expires", tokenResponse.ExpiresIn),
|
||||
slog.Int("refresh expires", tokenResponse.RefreshTokenExpiresIn),
|
||||
)
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
func logResponseHeaders(resp *http.Response) {
|
||||
if resp == nil {
|
||||
slog.Info("Response is nil")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("HTTP Response headers",
|
||||
"status", resp.Status,
|
||||
"statusCode", resp.StatusCode)
|
||||
|
||||
for name, values := range resp.Header {
|
||||
slog.Info("Header",
|
||||
"name", name,
|
||||
"values", values)
|
||||
}
|
||||
}
|
||||
|
||||
func saveResponse(data []byte, filename string) {
|
||||
dest, err := os.Create(filename)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create file", slog.String("filename", filename), slog.String("err", err.Error()))
|
||||
return
|
||||
}
|
||||
_, err = io.Copy(dest, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
slog.Error("Failed to write", slog.String("filename", filename), slog.String("err", err.Error()))
|
||||
return
|
||||
}
|
||||
slog.Info("Wrote response", slog.String("filename", filename))
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue