diff --git a/api/routes.go b/api/routes.go index cfd04814..58769dfe 100644 --- a/api/routes.go +++ b/api/routes.go @@ -1,66 +1,64 @@ package api import ( - "github.com/go-chi/chi/v5" - "github.com/go-chi/render" - "github.com/Gleipnir-Technology/nidus-sync/auth" "github.com/Gleipnir-Technology/nidus-sync/platform/file" + "github.com/gorilla/mux" ) -func AddRoutes(r chi.Router) { - r.Use(render.SetContentType(render.ContentTypeJSON)) +func AddRoutes(r *mux.Router) { + //r.Use(render.SetContentType(render.ContentTypeJSON)) // Unauthenticated endpoints - r.Post("/signin", handlerJSONPost(postSignin)) - r.Post("/signup", handlerJSONPost(postSignup)) + r.HandleFunc("/signin", handlerJSONPost(postSignin)) + r.HandleFunc("/signup", handlerJSONPost(postSignup)) // Authenticated endpoints - r.Method("POST", "/audio/{uuid}", auth.NewEnsureAuth(apiAudioPost)) - r.Method("POST", "/audio/{uuid}/content", auth.NewEnsureAuth(apiAudioContentPost)) - r.Method("POST", "/avatar", authenticatedHandlerPostMultipart(avatarPost, file.CollectionAvatar)) - r.Method("GET", "/client/ios", auth.NewEnsureAuth(handleClientIos)) - r.Method("GET", "/communication", authenticatedHandlerJSON(listCommunication)) - r.Method("POST", "/configuration/integration/arcgis", authenticatedHandlerJSONPost(postConfigurationIntegrationArcgis)) - r.Method("GET", "/events", auth.NewEnsureAuth(streamEvents)) - r.Method("POST", "/image/{uuid}", auth.NewEnsureAuth(apiImagePost)) - r.Method("GET", "/image/{uuid}/content", auth.NewEnsureAuth(apiImageContentGet)) - r.Method("POST", "/image/{uuid}/content", auth.NewEnsureAuth(apiImageContentPost)) - r.Method("GET", "/leads", authenticatedHandlerJSON(listLead)) - r.Method("POST", "/leads", authenticatedHandlerJSONPost(postLeads)) - r.Method("GET", "/mosquito-source", auth.NewEnsureAuth(apiMosquitoSource)) - r.Method("POST", "/publicreport/invalid", authenticatedHandlerJSONPost(postPublicreportInvalid)) - r.Method("POST", "/publicreport/signal", authenticatedHandlerJSONPost(postPublicreportSignal)) - r.Method("POST", "/publicreport/message", authenticatedHandlerJSONPost(postPublicreportMessage)) - r.Method("POST", "/review/pool", authenticatedHandlerJSONPost(postReviewPool)) - r.Method("GET", "/review-task", authenticatedHandlerJSON(listReviewTask)) - r.Method("GET", "/service-request", auth.NewEnsureAuth(apiServiceRequest)) - r.Method("GET", "/signal", authenticatedHandlerJSON(listSignal)) - r.Method("POST", "/sudo/email", authenticatedHandlerJSONPost(postSudoEmail)) - r.Method("POST", "/sudo/sms", authenticatedHandlerJSONPost(postSudoSMS)) - r.Method("POST", "/sudo/sse", authenticatedHandlerJSONPost(postSudoSSE)) - r.Method("GET", "/trap-data", auth.NewEnsureAuth(apiTrapData)) - r.Method("GET", "/tile/{z}/{y}/{x}", auth.NewEnsureAuth(getTile)) - r.Method("POST", "/upload/pool/flyover", authenticatedHandlerPostMultipart(postUploadPoolFlyoverCreate, file.CollectionCSV)) - r.Method("POST", "/upload/pool/custom", authenticatedHandlerPostMultipart(postUploadPoolCustomCreate, file.CollectionCSV)) - r.Method("GET", "/upload", authenticatedHandlerJSON(getUploadList)) - r.Method("GET", "/upload/{id}", authenticatedHandlerJSON(getUploadByID)) - r.Method("POST", "/upload/{id}/commit", authenticatedHandlerJSONPost(postUploadCommit)) - r.Method("POST", "/upload/{id}/discard", authenticatedHandlerJSONPost(postUploadDiscard)) - r.Method("GET", "/user/self", authenticatedHandlerJSON(getUserSelf)) - r.Method("GET", "/user/suggestion", authenticatedHandlerJSON(listUserSuggestion)) - r.Method("GET", "/user", authenticatedHandlerJSON(listUser)) - r.Method("PUT", "/user/{id}", authenticatedHandlerJSONPut(userPut)) + r.Handle("/audio/{uuid}", auth.NewEnsureAuth(apiAudioPost)).Methods("POST") + r.Handle("/audio/{uuid}/content", auth.NewEnsureAuth(apiAudioContentPost)).Methods("POST") + r.Handle("/avatar", authenticatedHandlerPostMultipart(avatarPost, file.CollectionAvatar)).Methods("POST") + r.Handle("/client/ios", auth.NewEnsureAuth(handleClientIos)).Methods("GET") + r.Handle("/communication", authenticatedHandlerJSON(listCommunication)).Methods("GET") + r.Handle("/configuration/integration/arcgis", authenticatedHandlerJSONPost(postConfigurationIntegrationArcgis)).Methods("POST") + r.Handle("/events", auth.NewEnsureAuth(streamEvents)).Methods("GET") + r.Handle("/image/{uuid}", auth.NewEnsureAuth(apiImagePost)).Methods("POST") + r.Handle("/image/{uuid}/content", auth.NewEnsureAuth(apiImageContentGet)).Methods("GET") + r.Handle("/image/{uuid}/content", auth.NewEnsureAuth(apiImageContentPost)).Methods("POST") + r.Handle("/leads", authenticatedHandlerJSON(listLead)).Methods("GET") + r.Handle("/leads", authenticatedHandlerJSONPost(postLeads)).Methods("POST") + r.Handle("/mosquito-source", auth.NewEnsureAuth(apiMosquitoSource)).Methods("GET") + r.Handle("/publicreport/invalid", authenticatedHandlerJSONPost(postPublicreportInvalid)).Methods("POST") + r.Handle("/publicreport/signal", authenticatedHandlerJSONPost(postPublicreportSignal)).Methods("POST") + r.Handle("/publicreport/message", authenticatedHandlerJSONPost(postPublicreportMessage)).Methods("POST") + r.Handle("/review/pool", authenticatedHandlerJSONPost(postReviewPool)).Methods("POST") + r.Handle("/review-task", authenticatedHandlerJSON(listReviewTask)).Methods("GET") + r.Handle("/service-request", auth.NewEnsureAuth(apiServiceRequest)).Methods("GET") + r.Handle("/signal", authenticatedHandlerJSON(listSignal)).Methods("GET") + r.Handle("/sudo/email", authenticatedHandlerJSONPost(postSudoEmail)).Methods("POST") + r.Handle("/sudo/sms", authenticatedHandlerJSONPost(postSudoSMS)).Methods("POST") + r.Handle("/sudo/sse", authenticatedHandlerJSONPost(postSudoSSE)).Methods("POST") + r.Handle("/trap-data", auth.NewEnsureAuth(apiTrapData)).Methods("GET") + r.Handle("/tile/{z}/{y}/{x}", auth.NewEnsureAuth(getTile)).Methods("GET") + r.Handle("/upload/pool/flyover", authenticatedHandlerPostMultipart(postUploadPoolFlyoverCreate, file.CollectionCSV)).Methods("POST") + r.Handle("/upload/pool/custom", authenticatedHandlerPostMultipart(postUploadPoolCustomCreate, file.CollectionCSV)).Methods("POST") + r.Handle("/upload", authenticatedHandlerJSON(getUploadList)).Methods("GET") + r.Handle("/upload/{id}", authenticatedHandlerJSON(getUploadByID)).Methods("GET") + r.Handle("/upload/{id}/commit", authenticatedHandlerJSONPost(postUploadCommit)).Methods("POST") + r.Handle("/upload/{id}/discard", authenticatedHandlerJSONPost(postUploadDiscard)).Methods("POST") + r.Handle("/user/self", authenticatedHandlerJSON(getUserSelf)).Methods("GET") + r.Handle("/user/suggestion", authenticatedHandlerJSON(listUserSuggestion)).Methods("GET") + r.Handle("/user", authenticatedHandlerJSON(listUser)).Methods("GET") + r.Handle("/user/{id}", authenticatedHandlerJSONPut(userPut)).Methods("PUT") // Unauthenticated endpoints - r.Get("/district", apiGetDistrict) - r.Get("/district/{slug}/logo", apiGetDistrictLogo) - r.Get("/compliance-request/image/pool/{public_id}", getComplianceRequestImagePool) - r.Post("/twilio/call", twilioCallPost) - r.Post("/twilio/call/status", twilioCallStatusPost) - r.Post("/twilio/message", twilioMessagePost) - r.Post("/twilio/text", twilioTextPost) - r.Post("/twilio/text/status", twilioTextStatusPost) - r.Get("/voipms/text", voipmsTextGet) - r.Post("/voipms/text", voipmsTextPost) - r.Get("/webhook/fieldseeker", webhookFieldseeker) - r.Post("/webhook/fieldseeker", webhookFieldseeker) + r.HandleFunc("/district", apiGetDistrict).Methods("GET") + r.HandleFunc("/district/{slug}/logo", apiGetDistrictLogo).Methods("GET") + r.HandleFunc("/compliance-request/image/pool/{public_id}", getComplianceRequestImagePool).Methods("GET") + r.HandleFunc("/twilio/call", twilioCallPost).Methods("POST") + r.HandleFunc("/twilio/call/status", twilioCallStatusPost).Methods("POST") + r.HandleFunc("/twilio/message", twilioMessagePost).Methods("POST") + r.HandleFunc("/twilio/text", twilioTextPost).Methods("POST") + r.HandleFunc("/twilio/text/status", twilioTextStatusPost).Methods("POST") + r.HandleFunc("/voipms/text", voipmsTextGet).Methods("GET") + r.HandleFunc("/voipms/text", voipmsTextPost).Methods("POST") + r.HandleFunc("/webhook/fieldseeker", webhookFieldseeker).Methods("GET") + r.HandleFunc("/webhook/fieldseeker", webhookFieldseeker).Methods("POST") } diff --git a/go.mod b/go.mod index 794f6ffb..6e9eb46d 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/gobwas/ws v1.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/golang/mock v1.6.0 // indirect + github.com/gorilla/mux v1.8.1 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect diff --git a/go.sum b/go.sum index 2d18d6f4..a0fc122e 100644 --- a/go.sum +++ b/go.sum @@ -118,6 +118,8 @@ github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfh github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E= github.com/gorilla/schema v1.4.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= diff --git a/main.go b/main.go index 87de8a02..f383faa9 100644 --- a/main.go +++ b/main.go @@ -17,15 +17,14 @@ import ( "github.com/Gleipnir-Technology/nidus-sync/db" "github.com/Gleipnir-Technology/nidus-sync/html" "github.com/Gleipnir-Technology/nidus-sync/llm" + "github.com/Gleipnir-Technology/nidus-sync/middleware" "github.com/Gleipnir-Technology/nidus-sync/platform" "github.com/Gleipnir-Technology/nidus-sync/rmo" nidussync "github.com/Gleipnir-Technology/nidus-sync/sync" "github.com/getsentry/sentry-go" sentryhttp "github.com/getsentry/sentry-go/http" "github.com/getsentry/sentry-go/zerolog" - "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" - "github.com/go-chi/hostrouter" + "github.com/gorilla/mux" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -105,7 +104,8 @@ func main() { sentryMiddleware := sentryhttp.New(sentryhttp.Options{ Repanic: true, }) - r := chi.NewRouter() + //r := chi.NewRouter() + r := mux.NewRouter() r.Use(LoggerMiddleware(&router_logger)) r.Use(middleware.RequestID) @@ -115,15 +115,14 @@ func main() { r.Use(sentryMiddleware.Handle) r.Use(auth.NewSessionManager().LoadAndSave) - hr := hostrouter.New() + sync_router := r.Host(config.DomainNidus).Subrouter() + rmo_router := r.Host(config.DomainRMO).Subrouter() // Set up routing by hostname - sr := nidussync.Router() - hr.Map("", sr) // default - hr.Map("*", sr) // default - hr.Map(config.DomainRMO, rmo.Router()) // report.mosquitoes.online - hr.Map(config.DomainNidus, sr) - r.Mount("/", hr) + nidussync.Router(sync_router) + rmo.Router(rmo_router) + //hr.Map("", sr) // default + //hr.Map("*", sr) // default log.Debug().Str("report url", config.DomainRMO).Str("sync url", config.DomainNidus).Msg("Serving at URLs") diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 00000000..cff9bd20 --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,172 @@ +package middleware + +import ( + "bytes" + "context" + "log" + "net/http" + "os" + "runtime" + "time" +) + +var ( + // LogEntryCtxKey is the context.Context key to store the request log entry. + LogEntryCtxKey = &contextKey{"LogEntry"} + + // DefaultLogger is called by the Logger middleware handler to log each request. + // Its made a package-level variable so that it can be reconfigured for custom + // logging configurations. + DefaultLogger func(next http.Handler) http.Handler +) + +// Logger is a middleware that logs the start and end of each request, along +// with some useful data about what was requested, what the response status was, +// and how long it took to return. When standard output is a TTY, Logger will +// print in color, otherwise it will print in black and white. Logger prints a +// request ID if one is provided. +// +// Alternatively, look at https://github.com/goware/httplog for a more in-depth +// http logger with structured logging support. +// +// IMPORTANT NOTE: Logger should go before any other middleware that may change +// the response, such as middleware.Recoverer. Example: +// +// r := chi.NewRouter() +// r.Use(middleware.Logger) // <--<< Logger should come before Recoverer +// r.Use(middleware.Recoverer) +// r.Get("/", handler) +func Logger(next http.Handler) http.Handler { + return DefaultLogger(next) +} + +// RequestLogger returns a logger handler using a custom LogFormatter. +func RequestLogger(f LogFormatter) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + entry := f.NewLogEntry(r) + ww := NewWrapResponseWriter(w, r.ProtoMajor) + + t1 := time.Now() + defer func() { + entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil) + }() + + next.ServeHTTP(ww, WithLogEntry(r, entry)) + } + return http.HandlerFunc(fn) + } +} + +// LogFormatter initiates the beginning of a new LogEntry per request. +// See DefaultLogFormatter for an example implementation. +type LogFormatter interface { + NewLogEntry(r *http.Request) LogEntry +} + +// LogEntry records the final log when a request completes. +// See defaultLogEntry for an example implementation. +type LogEntry interface { + Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) + Panic(v interface{}, stack []byte) +} + +// GetLogEntry returns the in-context LogEntry for a request. +func GetLogEntry(r *http.Request) LogEntry { + entry, _ := r.Context().Value(LogEntryCtxKey).(LogEntry) + return entry +} + +// WithLogEntry sets the in-context LogEntry for a request. +func WithLogEntry(r *http.Request, entry LogEntry) *http.Request { + r = r.WithContext(context.WithValue(r.Context(), LogEntryCtxKey, entry)) + return r +} + +// LoggerInterface accepts printing to stdlib logger or compatible logger. +type LoggerInterface interface { + Print(v ...interface{}) +} + +// DefaultLogFormatter is a simple logger that implements a LogFormatter. +type DefaultLogFormatter struct { + Logger LoggerInterface + NoColor bool +} + +// NewLogEntry creates a new LogEntry for the request. +func (l *DefaultLogFormatter) NewLogEntry(r *http.Request) LogEntry { + useColor := !l.NoColor + entry := &defaultLogEntry{ + DefaultLogFormatter: l, + request: r, + buf: &bytes.Buffer{}, + useColor: useColor, + } + + reqID := GetReqID(r.Context()) + if reqID != "" { + cW(entry.buf, useColor, nYellow, "[%s] ", reqID) + } + cW(entry.buf, useColor, nCyan, "\"") + cW(entry.buf, useColor, bMagenta, "%s ", r.Method) + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + cW(entry.buf, useColor, nCyan, "%s://%s%s %s\" ", scheme, r.Host, r.RequestURI, r.Proto) + + entry.buf.WriteString("from ") + entry.buf.WriteString(r.RemoteAddr) + entry.buf.WriteString(" - ") + + return entry +} + +type defaultLogEntry struct { + *DefaultLogFormatter + request *http.Request + buf *bytes.Buffer + useColor bool +} + +func (l *defaultLogEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { + switch { + case status < 200: + cW(l.buf, l.useColor, bBlue, "%03d", status) + case status < 300: + cW(l.buf, l.useColor, bGreen, "%03d", status) + case status < 400: + cW(l.buf, l.useColor, bCyan, "%03d", status) + case status < 500: + cW(l.buf, l.useColor, bYellow, "%03d", status) + default: + cW(l.buf, l.useColor, bRed, "%03d", status) + } + + cW(l.buf, l.useColor, bBlue, " %dB", bytes) + + l.buf.WriteString(" in ") + if elapsed < 500*time.Millisecond { + cW(l.buf, l.useColor, nGreen, "%s", elapsed) + } else if elapsed < 5*time.Second { + cW(l.buf, l.useColor, nYellow, "%s", elapsed) + } else { + cW(l.buf, l.useColor, nRed, "%s", elapsed) + } + + l.Logger.Print(l.buf.String()) +} + +func (l *defaultLogEntry) Panic(v interface{}, stack []byte) { + PrintPrettyStack(v) +} + +func init() { + color := true + if runtime.GOOS == "windows" { + color = false + } + DefaultLogger = RequestLogger(&DefaultLogFormatter{Logger: log.New(os.Stdout, "", log.LstdFlags), NoColor: !color}) +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 00000000..cc371e00 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,23 @@ +package middleware + +import "net/http" + +// New will create a new middleware handler from a http.Handler. +func New(h http.Handler) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + }) + } +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "chi/middleware context value " + k.name +} diff --git a/middleware/realip.go b/middleware/realip.go new file mode 100644 index 00000000..3d521ae6 --- /dev/null +++ b/middleware/realip.go @@ -0,0 +1,56 @@ +package middleware + +// Ported from Chi's middleware, source: +// https://github.com/go-chi/chi/blob/master/middleware/realip.go + +import ( + "net" + "net/http" + "strings" +) + +var trueClientIP = http.CanonicalHeaderKey("True-Client-IP") +var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") +var xRealIP = http.CanonicalHeaderKey("X-Real-IP") + +// RealIP is a middleware that sets a http.Request's RemoteAddr to the results +// of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers +// (in that order). +// +// This middleware should be inserted fairly early in the middleware stack to +// ensure that subsequent layers (e.g., request loggers) which examine the +// RemoteAddr will see the intended value. +// +// You should only use this middleware if you can trust the headers passed to +// you (in particular, the three headers this middleware uses), for example +// because you have placed a reverse proxy like HAProxy or nginx in front of +// chi. If your reverse proxies are configured to pass along arbitrary header +// values from the client, or if you use this middleware without a reverse +// proxy, malicious clients will be able to make you very sad (or, depending on +// how you're using RemoteAddr, vulnerable to an attack of some sort). +func RealIP(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if rip := realIP(r); rip != "" { + r.RemoteAddr = rip + } + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +func realIP(r *http.Request) string { + var ip string + + if tcip := r.Header.Get(trueClientIP); tcip != "" { + ip = tcip + } else if xrip := r.Header.Get(xRealIP); xrip != "" { + ip = xrip + } else if xff := r.Header.Get(xForwardedFor); xff != "" { + ip, _, _ = strings.Cut(xff, ",") + } + if ip == "" || net.ParseIP(ip) == nil { + return "" + } + return ip +} diff --git a/middleware/recoverer.go b/middleware/recoverer.go new file mode 100644 index 00000000..0e75b55c --- /dev/null +++ b/middleware/recoverer.go @@ -0,0 +1,203 @@ +package middleware + +// The original work was derived from Chi's middleware, source: +// https://github.com/go-chi/chi/blob/master/middleware/recoverer.go + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "os" + "runtime/debug" + "strings" +) + +// Recoverer is a middleware that recovers from panics, logs the panic (and a +// backtrace), and returns a HTTP 500 (Internal Server Error) status if +// possible. Recoverer prints a request ID if one is provided. +// +// Alternatively, look at https://github.com/go-chi/httplog middleware pkgs. +func Recoverer(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rvr := recover(); rvr != nil { + if rvr == http.ErrAbortHandler { + // we don't recover http.ErrAbortHandler so the response + // to the client is aborted, this should not be logged + panic(rvr) + } + + logEntry := GetLogEntry(r) + if logEntry != nil { + logEntry.Panic(rvr, debug.Stack()) + } else { + PrintPrettyStack(rvr) + } + + if r.Header.Get("Connection") != "Upgrade" { + w.WriteHeader(http.StatusInternalServerError) + } + } + }() + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +// for ability to test the PrintPrettyStack function +var recovererErrorWriter io.Writer = os.Stderr + +func PrintPrettyStack(rvr interface{}) { + debugStack := debug.Stack() + s := prettyStack{} + out, err := s.parse(debugStack, rvr) + if err == nil { + recovererErrorWriter.Write(out) + } else { + // print stdlib output as a fallback + os.Stderr.Write(debugStack) + } +} + +type prettyStack struct { +} + +func (s prettyStack) parse(debugStack []byte, rvr interface{}) ([]byte, error) { + var err error + useColor := true + buf := &bytes.Buffer{} + + cW(buf, false, bRed, "\n") + cW(buf, useColor, bCyan, " panic: ") + cW(buf, useColor, bBlue, "%v", rvr) + cW(buf, false, bWhite, "\n \n") + + // process debug stack info + stack := strings.Split(string(debugStack), "\n") + lines := []string{} + + // locate panic line, as we may have nested panics + for i := len(stack) - 1; i > 0; i-- { + lines = append(lines, stack[i]) + if strings.HasPrefix(stack[i], "panic(") { + lines = lines[0 : len(lines)-2] // remove boilerplate + break + } + } + + // reverse + for i := len(lines)/2 - 1; i >= 0; i-- { + opp := len(lines) - 1 - i + lines[i], lines[opp] = lines[opp], lines[i] + } + + // decorate + for i, line := range lines { + lines[i], err = s.decorateLine(line, useColor, i) + if err != nil { + return nil, err + } + } + + for _, l := range lines { + fmt.Fprintf(buf, "%s", l) + } + return buf.Bytes(), nil +} + +func (s prettyStack) decorateLine(line string, useColor bool, num int) (string, error) { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "\t") || strings.Contains(line, ".go:") { + return s.decorateSourceLine(line, useColor, num) + } + if strings.HasSuffix(line, ")") { + return s.decorateFuncCallLine(line, useColor, num) + } + if strings.HasPrefix(line, "\t") { + return strings.Replace(line, "\t", " ", 1), nil + } + return fmt.Sprintf(" %s\n", line), nil +} + +func (s prettyStack) decorateFuncCallLine(line string, useColor bool, num int) (string, error) { + idx := strings.LastIndex(line, "(") + if idx < 0 { + return "", errors.New("not a func call line") + } + + buf := &bytes.Buffer{} + pkg := line[0:idx] + // addr := line[idx:] + method := "" + + if idx := strings.LastIndex(pkg, string(os.PathSeparator)); idx < 0 { + if idx := strings.Index(pkg, "."); idx > 0 { + method = pkg[idx:] + pkg = pkg[0:idx] + } + } else { + method = pkg[idx+1:] + pkg = pkg[0 : idx+1] + if idx := strings.Index(method, "."); idx > 0 { + pkg += method[0:idx] + method = method[idx:] + } + } + pkgColor := nYellow + methodColor := bGreen + + if num == 0 { + cW(buf, useColor, bRed, " -> ") + pkgColor = bMagenta + methodColor = bRed + } else { + cW(buf, useColor, bWhite, " ") + } + cW(buf, useColor, pkgColor, "%s", pkg) + cW(buf, useColor, methodColor, "%s\n", method) + // cW(buf, useColor, nBlack, "%s", addr) + return buf.String(), nil +} + +func (s prettyStack) decorateSourceLine(line string, useColor bool, num int) (string, error) { + idx := strings.LastIndex(line, ".go:") + if idx < 0 { + return "", errors.New("not a source line") + } + + buf := &bytes.Buffer{} + path := line[0 : idx+3] + lineno := line[idx+3:] + + idx = strings.LastIndex(path, string(os.PathSeparator)) + dir := path[0 : idx+1] + file := path[idx+1:] + + idx = strings.Index(lineno, " ") + if idx > 0 { + lineno = lineno[0:idx] + } + fileColor := bCyan + lineColor := bGreen + + if num == 1 { + cW(buf, useColor, bRed, " -> ") + fileColor = bRed + lineColor = bMagenta + } else { + cW(buf, false, bWhite, " ") + } + cW(buf, useColor, bWhite, "%s", dir) + cW(buf, useColor, fileColor, "%s", file) + cW(buf, useColor, lineColor, "%s", lineno) + if num == 1 { + cW(buf, false, bWhite, "\n") + } + cW(buf, false, bWhite, "\n") + + return buf.String(), nil +} diff --git a/middleware/request_id.go b/middleware/request_id.go new file mode 100644 index 00000000..c4f8e033 --- /dev/null +++ b/middleware/request_id.go @@ -0,0 +1,96 @@ +package middleware + +// Ported from chi's middleware, source: +// https://github.com/go-chi/chi/blob/master/middleware/request_id.go + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "os" + "strings" + "sync/atomic" +) + +// Key to use when setting the request ID. +type ctxKeyRequestID int + +// RequestIDKey is the key that holds the unique request ID in a request context. +const RequestIDKey ctxKeyRequestID = 0 + +// RequestIDHeader is the name of the HTTP Header which contains the request id. +// Exported so that it can be changed by developers +var RequestIDHeader = "X-Request-Id" + +var prefix string +var reqid atomic.Uint64 + +// A quick note on the statistics here: we're trying to calculate the chance that +// two randomly generated base62 prefixes will collide. We use the formula from +// http://en.wikipedia.org/wiki/Birthday_problem +// +// P[m, n] \approx 1 - e^{-m^2/2n} +// +// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server +// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$ +// +// For a $k$ character base-62 identifier, we have $n(k) = 62^k$ +// +// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for +// our purposes, and is surely more than anyone would ever need in practice -- a +// process that is rebooted a handful of times a day for a hundred years has less +// than a millionth of a percent chance of generating two colliding IDs. + +func init() { + hostname, err := os.Hostname() + if hostname == "" || err != nil { + hostname = "localhost" + } + var buf [12]byte + var b64 string + for len(b64) < 10 { + rand.Read(buf[:]) + b64 = base64.StdEncoding.EncodeToString(buf[:]) + b64 = strings.NewReplacer("+", "", "/", "").Replace(b64) + } + + prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10]) +} + +// RequestID is a middleware that injects a request ID into the context of each +// request. A request ID is a string of the form "host.example.com/random-0001", +// where "random" is a base62 random string that uniquely identifies this go +// process, and where the last number is an atomically incremented request +// counter. +func RequestID(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + requestID := r.Header.Get(RequestIDHeader) + if requestID == "" { + myid := reqid.Add(1) + requestID = fmt.Sprintf("%s-%06d", prefix, myid) + } + ctx = context.WithValue(ctx, RequestIDKey, requestID) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +// GetReqID returns a request ID from the given context if one is present. +// Returns the empty string if a request ID cannot be found. +func GetReqID(ctx context.Context) string { + if ctx == nil { + return "" + } + if reqID, ok := ctx.Value(RequestIDKey).(string); ok { + return reqID + } + return "" +} + +// NextRequestID generates the next request ID in the sequence. +func NextRequestID() uint64 { + return reqid.Add(1) +} diff --git a/middleware/terminal.go b/middleware/terminal.go new file mode 100644 index 00000000..5ead7b92 --- /dev/null +++ b/middleware/terminal.go @@ -0,0 +1,63 @@ +package middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "fmt" + "io" + "os" +) + +var ( + // Normal colors + nBlack = []byte{'\033', '[', '3', '0', 'm'} + nRed = []byte{'\033', '[', '3', '1', 'm'} + nGreen = []byte{'\033', '[', '3', '2', 'm'} + nYellow = []byte{'\033', '[', '3', '3', 'm'} + nBlue = []byte{'\033', '[', '3', '4', 'm'} + nMagenta = []byte{'\033', '[', '3', '5', 'm'} + nCyan = []byte{'\033', '[', '3', '6', 'm'} + nWhite = []byte{'\033', '[', '3', '7', 'm'} + // Bright colors + bBlack = []byte{'\033', '[', '3', '0', ';', '1', 'm'} + bRed = []byte{'\033', '[', '3', '1', ';', '1', 'm'} + bGreen = []byte{'\033', '[', '3', '2', ';', '1', 'm'} + bYellow = []byte{'\033', '[', '3', '3', ';', '1', 'm'} + bBlue = []byte{'\033', '[', '3', '4', ';', '1', 'm'} + bMagenta = []byte{'\033', '[', '3', '5', ';', '1', 'm'} + bCyan = []byte{'\033', '[', '3', '6', ';', '1', 'm'} + bWhite = []byte{'\033', '[', '3', '7', ';', '1', 'm'} + + reset = []byte{'\033', '[', '0', 'm'} +) + +var IsTTY bool + +func init() { + // This is sort of cheating: if stdout is a character device, we assume + // that means it's a TTY. Unfortunately, there are many non-TTY + // character devices, but fortunately stdout is rarely set to any of + // them. + // + // We could solve this properly by pulling in a dependency on + // code.google.com/p/go.crypto/ssh/terminal, for instance, but as a + // heuristic for whether to print in color or in black-and-white, I'd + // really rather not. + fi, err := os.Stdout.Stat() + if err == nil { + m := os.ModeDevice | os.ModeCharDevice + IsTTY = fi.Mode()&m == m + } +} + +// colorWrite +func cW(w io.Writer, useColor bool, color []byte, s string, args ...interface{}) { + if IsTTY && useColor { + w.Write(color) + } + fmt.Fprintf(w, s, args...) + if IsTTY && useColor { + w.Write(reset) + } +} diff --git a/middleware/wrap_writer.go b/middleware/wrap_writer.go new file mode 100644 index 00000000..367e0fcd --- /dev/null +++ b/middleware/wrap_writer.go @@ -0,0 +1,241 @@ +package middleware + +// The original work was derived from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "bufio" + "io" + "net" + "net/http" +) + +// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to +// hook into various parts of the response process. +func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter { + _, fl := w.(http.Flusher) + + bw := basicWriter{ResponseWriter: w} + + if protoMajor == 2 { + _, ps := w.(http.Pusher) + if fl && ps { + return &http2FancyWriter{bw} + } + } else { + _, hj := w.(http.Hijacker) + _, rf := w.(io.ReaderFrom) + if fl && hj && rf { + return &httpFancyWriter{bw} + } + if fl && hj { + return &flushHijackWriter{bw} + } + if hj { + return &hijackWriter{bw} + } + } + + if fl { + return &flushWriter{bw} + } + + return &bw +} + +// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook +// into various parts of the response process. +type WrapResponseWriter interface { + http.ResponseWriter + // Status returns the HTTP status of the request, or 0 if one has not + // yet been sent. + Status() int + // BytesWritten returns the total number of bytes sent to the client. + BytesWritten() int + // Tee causes the response body to be written to the given io.Writer in + // addition to proxying the writes through. Only one io.Writer can be + // tee'd to at once: setting a second one will overwrite the first. + // Writes will be sent to the proxy before being written to this + // io.Writer. It is illegal for the tee'd writer to be modified + // concurrently with writes. + Tee(io.Writer) + // Unwrap returns the original proxied target. + Unwrap() http.ResponseWriter + // Discard causes all writes to the original ResponseWriter be discarded, + // instead writing only to the tee'd writer if it's set. + // The caller is responsible for calling WriteHeader and Write on the + // original ResponseWriter once the processing is done. + Discard() +} + +// basicWriter wraps a http.ResponseWriter that implements the minimal +// http.ResponseWriter interface. +type basicWriter struct { + http.ResponseWriter + tee io.Writer + code int + bytes int + wroteHeader bool + discard bool +} + +func (b *basicWriter) WriteHeader(code int) { + if code >= 100 && code <= 199 && code != http.StatusSwitchingProtocols { + if !b.discard { + b.ResponseWriter.WriteHeader(code) + } + } else if !b.wroteHeader { + b.code = code + b.wroteHeader = true + if !b.discard { + b.ResponseWriter.WriteHeader(code) + } + } +} + +func (b *basicWriter) Write(buf []byte) (n int, err error) { + b.maybeWriteHeader() + if !b.discard { + n, err = b.ResponseWriter.Write(buf) + if b.tee != nil { + _, err2 := b.tee.Write(buf[:n]) + // Prefer errors generated by the proxied writer. + if err == nil { + err = err2 + } + } + } else if b.tee != nil { + n, err = b.tee.Write(buf) + } else { + n, err = io.Discard.Write(buf) + } + b.bytes += n + return n, err +} + +func (b *basicWriter) maybeWriteHeader() { + if !b.wroteHeader { + b.WriteHeader(http.StatusOK) + } +} + +func (b *basicWriter) Status() int { + return b.code +} + +func (b *basicWriter) BytesWritten() int { + return b.bytes +} + +func (b *basicWriter) Tee(w io.Writer) { + b.tee = w +} + +func (b *basicWriter) Unwrap() http.ResponseWriter { + return b.ResponseWriter +} + +func (b *basicWriter) Discard() { + b.discard = true +} + +// flushWriter ... +type flushWriter struct { + basicWriter +} + +func (f *flushWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +var _ http.Flusher = &flushWriter{} + +// hijackWriter ... +type hijackWriter struct { + basicWriter +} + +func (f *hijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +var _ http.Hijacker = &hijackWriter{} + +// flushHijackWriter ... +type flushHijackWriter struct { + basicWriter +} + +func (f *flushHijackWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +func (f *flushHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +var _ http.Flusher = &flushHijackWriter{} +var _ http.Hijacker = &flushHijackWriter{} + +// httpFancyWriter is a HTTP writer that additionally satisfies +// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case +// of wrapping the http.ResponseWriter that package http gives you, in order to +// make the proxied object support the full method set of the proxied object. +type httpFancyWriter struct { + basicWriter +} + +func (f *httpFancyWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + return hj.Hijack() +} + +func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error { + return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts) +} + +func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) { + if f.basicWriter.tee != nil { + n, err := io.Copy(&f.basicWriter, r) + f.basicWriter.bytes += int(n) + return n, err + } + rf := f.basicWriter.ResponseWriter.(io.ReaderFrom) + f.basicWriter.maybeWriteHeader() + n, err := rf.ReadFrom(r) + f.basicWriter.bytes += int(n) + return n, err +} + +var _ http.Flusher = &httpFancyWriter{} +var _ http.Hijacker = &httpFancyWriter{} +var _ http.Pusher = &http2FancyWriter{} +var _ io.ReaderFrom = &httpFancyWriter{} + +// http2FancyWriter is a HTTP2 writer that additionally satisfies +// http.Flusher, and io.ReaderFrom. It exists for the common case +// of wrapping the http.ResponseWriter that package http gives you, in order to +// make the proxied object support the full method set of the proxied object. +type http2FancyWriter struct { + basicWriter +} + +func (f *http2FancyWriter) Flush() { + f.wroteHeader = true + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +var _ http.Flusher = &http2FancyWriter{} diff --git a/rmo/routes.go b/rmo/routes.go index d1035087..40cf7062 100644 --- a/rmo/routes.go +++ b/rmo/routes.go @@ -3,49 +3,47 @@ package rmo import ( "github.com/Gleipnir-Technology/nidus-sync/html" "github.com/Gleipnir-Technology/nidus-sync/static" - "github.com/go-chi/chi/v5" + "github.com/gorilla/mux" ) -func Router() chi.Router { - r := chi.NewRouter() - r.Get("/", getRoot) - r.Get("/nuisance", getNuisance) - r.Post("/nuisance", postNuisance) - r.Get("/submit-complete", getSubmitComplete) - r.Get("/water", getWater) - r.Post("/water", postWater) +func Router(r *mux.Router) { + r.HandleFunc("/", getRoot).Methods("GET") + r.HandleFunc("/nuisance", getNuisance).Methods("GET") + r.HandleFunc("/nuisance", postNuisance).Methods("POST") + r.HandleFunc("/submit-complete", getSubmitComplete).Methods("GET") + r.HandleFunc("/water", getWater).Methods("GET") + r.HandleFunc("/water", postWater).Methods("POST") - r.Get("/district", getDistrictList) - r.Get("/district/{slug}", getRootDistrict) - r.Get("/district/{slug}/nuisance", getNuisanceDistrict) - //r.Get("/district/{slug}/nuisance-submit-complete", renderMock(mockNuisanceSubmitCompleteT)) - //r.Get("/district/{slug}/status", renderMock(mockStatusT)) - r.Get("/district/{slug}/water", getWaterDistrict) - //r.Post("/district/{slug}/water", postWaterDistrict) - r.Get("/error", getError) + r.HandleFunc("/district", getDistrictList).Methods("GET") + r.HandleFunc("/district/{slug}", getRootDistrict).Methods("GET") + r.HandleFunc("/district/{slug}/nuisance", getNuisanceDistrict).Methods("GET") + //r.HandleFunc("/district/{slug}/nuisance-submit-complete", renderMock(mockNuisanceSubmitCompleteT)).Methods("GET") + //r.HandleFunc("/district/{slug}/status", renderMock(mockStatusT)).Methods("GET") + r.HandleFunc("/district/{slug}/water", getWaterDistrict).Methods("GET") + //r.HandleFunc("/district/{slug}/water", postWaterDistrict).Methods("POST") + r.HandleFunc("/error", getError).Methods("GET") - r.Get("/privacy", getPrivacy) - r.Get("/robots.txt", getRobots) - r.Get("/email/render/{code}", getEmailByCode) - r.Get("/email/confirm", getEmailConfirm) - r.Post("/email/confirm", postEmailConfirm) - r.Get("/email/confirm/complete", getEmailConfirmComplete) - r.Get("/email/unsubscribe", getEmailUnsubscribe) - r.Get("/email/unsubscribe/report/{report_id}", getEmailReportUnsubscribe) - r.Get("/image/{uuid}", getImageByUUID) - r.Get("/mailer/{public_id}", html.MakeGet(getMailer)) - r.Post("/mailer/{public_id}/confirm", html.MakePost(postMailerConfirm)) - r.Get("/mailer/{public_id}/contribute", html.MakeGet(getMailerContribute)) - r.Get("/mailer/{public_id}/evidence", html.MakeGet(getMailerEvidence)) - r.Get("/mailer/{public_id}/schedule", html.MakeGet(getMailerSchedule)) - r.Get("/mailer/{public_id}/update", html.MakeGet(getMailerUpdate)) - r.Post("/register-notifications", postRegisterNotifications) - r.Get("/register-notifications-complete", getRegisterNotificationsComplete) - r.Get("/report/suggest", getReportSuggestion) - r.Get("/scss/*", getScssDebug) - r.Get("/status", getStatus) - r.Get("/status/{report_id}", getStatusByID) - r.Get("/terms-of-service", getTerms) + r.HandleFunc("/privacy", getPrivacy).Methods("GET") + r.HandleFunc("/robots.txt", getRobots).Methods("GET") + r.HandleFunc("/email/render/{code}", getEmailByCode).Methods("GET") + r.HandleFunc("/email/confirm", getEmailConfirm).Methods("GET") + r.HandleFunc("/email/confirm", postEmailConfirm).Methods("POST") + r.HandleFunc("/email/confirm/complete", getEmailConfirmComplete).Methods("GET") + r.HandleFunc("/email/unsubscribe", getEmailUnsubscribe).Methods("GET") + r.HandleFunc("/email/unsubscribe/report/{report_id}", getEmailReportUnsubscribe).Methods("GET") + r.HandleFunc("/image/{uuid}", getImageByUUID).Methods("GET") + r.HandleFunc("/mailer/{public_id}", html.MakeGet(getMailer)).Methods("GET") + r.HandleFunc("/mailer/{public_id}/confirm", html.MakePost(postMailerConfirm)).Methods("POST") + r.HandleFunc("/mailer/{public_id}/contribute", html.MakeGet(getMailerContribute)).Methods("GET") + r.HandleFunc("/mailer/{public_id}/evidence", html.MakeGet(getMailerEvidence)).Methods("GET") + r.HandleFunc("/mailer/{public_id}/schedule", html.MakeGet(getMailerSchedule)).Methods("GET") + r.HandleFunc("/mailer/{public_id}/update", html.MakeGet(getMailerUpdate)).Methods("GET") + r.HandleFunc("/register-notifications", postRegisterNotifications).Methods("POST") + r.HandleFunc("/register-notifications-complete", getRegisterNotificationsComplete).Methods("GET") + r.HandleFunc("/report/suggest", getReportSuggestion).Methods("GET") + r.HandleFunc("/scss/*", getScssDebug).Methods("GET") + r.HandleFunc("/status", getStatus).Methods("GET") + r.HandleFunc("/status/{report_id}", getStatusByID).Methods("GET") + r.HandleFunc("/terms-of-service", getTerms).Methods("GET") static.AddStaticRoute(r, "/static") - return r } diff --git a/static/static.go b/static/static.go index 49b653f8..3a724dd2 100644 --- a/static/static.go +++ b/static/static.go @@ -11,7 +11,7 @@ import ( "time" "github.com/Gleipnir-Technology/nidus-sync/config" - "github.com/go-chi/chi/v5" + "github.com/gorilla/mux" "github.com/rs/zerolog/log" ) @@ -24,7 +24,7 @@ var startedTime time.Time = time.Now() var localFS http.Dir -func AddStaticRoute(r chi.Router, path string) { +func AddStaticRoute(r *mux.Router, path string) { if localFS == "" { localFS = http.Dir("./static") // Useful for debugging embedded file issues @@ -38,20 +38,22 @@ func AddStaticRoute(r chi.Router, path string) { fileServer(r, "/static", localFS, embeddedStaticFS) } -func fileServer(r chi.Router, path string, root http.FileSystem, embeddedFS embed.FS) { +func fileServer(r *mux.Router, path string, root http.FileSystem, embeddedFS embed.FS) { if strings.ContainsAny(path, "{}*") { panic("FileServer does not permit any URL parameters.") } if path != "/" && path[len(path)-1] != '/' { - r.Get(path, http.RedirectHandler(path+"/", 301).ServeHTTP) + r.HandleFunc(path, http.RedirectHandler(path+"/", 301).ServeHTTP) path += "/" } path += "*" - r.Get(path, func(w http.ResponseWriter, r *http.Request) { - rctx := chi.RouteContext(r.Context()) - pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") + r.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + //rctx := chi.RouteContext(r.Context()) + + //pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") + pathPrefix := strings.TrimPrefix(r.URL.Path, "/static") // Determine the actual file path requestedPath := strings.TrimPrefix(r.URL.Path, pathPrefix+"/") diff --git a/sync/mock.go b/sync/mock.go index 3c8c87b9..3ab8406e 100644 --- a/sync/mock.go +++ b/sync/mock.go @@ -3,7 +3,7 @@ package sync import ( "fmt" "github.com/Gleipnir-Technology/nidus-sync/html" - "github.com/go-chi/chi/v5" + "github.com/gorilla/mux" "net/http" //"github.com/rs/zerolog/log" ) @@ -45,12 +45,12 @@ type mock struct { var mocks = []mock{} -func addMock(r chi.Router, path string, template string) { +func addMock(r *mux.Router, path string, template string) { mocks = append(mocks, mock{ Path: path, template: template, }) - r.Get(path, renderMock(template)) + r.HandleFunc(path, renderMock(template)) } type contentMock struct { @@ -61,7 +61,8 @@ type contentMock struct { func renderMock(template_name string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - code := chi.URLParam(r, "code") + vars := mux.Vars(r) + code := vars["code"] if code == "" { code = "abc-123" } diff --git a/sync/routes.go b/sync/routes.go index 2c25bf57..7aea71aa 100644 --- a/sync/routes.go +++ b/sync/routes.go @@ -3,26 +3,24 @@ package sync import ( "github.com/Gleipnir-Technology/nidus-sync/api" "github.com/Gleipnir-Technology/nidus-sync/static" - "github.com/go-chi/chi/v5" + "github.com/gorilla/mux" ) -func Router() chi.Router { - r := chi.NewRouter() - +func Router(r *mux.Router) { // Unauthenticated endpoints - r.Get("/arcgis/oauth/begin", getArcgisOauthBegin) - r.Get("/arcgis/oauth/callback", getArcgisOauthCallback) - r.Get("/mailer/pool/random", getMailerPoolRandom) - r.Get("/mailer/mode-1", getMailer1) - r.Get("/mailer/mode-2", getMailer2) - r.Get("/mailer/mode-3/{code}", getMailer3) - r.Get("/mailer/mode-1/preview", getMailer1Preview) - r.Get("/mailer/mode-2/preview", getMailer2Preview) - r.Get("/mailer/mode-3/{code}/preview", getMailer3Preview) - r.Get("/district", getDistrict) + r.HandleFunc("/arcgis/oauth/begin", getArcgisOauthBegin) + r.HandleFunc("/arcgis/oauth/callback", getArcgisOauthCallback) + r.HandleFunc("/mailer/pool/random", getMailerPoolRandom) + r.HandleFunc("/mailer/mode-1", getMailer1) + r.HandleFunc("/mailer/mode-2", getMailer2) + r.HandleFunc("/mailer/mode-3/{code}", getMailer3) + r.HandleFunc("/mailer/mode-1/preview", getMailer1Preview) + r.HandleFunc("/mailer/mode-2/preview", getMailer2Preview) + r.HandleFunc("/mailer/mode-3/{code}/preview", getMailer3Preview) + r.HandleFunc("/district", getDistrict) // Mock endpoints - r.Get("/mock", renderMockList) + r.HandleFunc("/mock", renderMockList) addMock(r, "/mock/report", "sync/mock/report.html") addMock(r, "/mock/report/{code}", "sync/mock/report-detail.html") addMock(r, "/mock/report/{code}/confirm", "sync/mock/report-confirmation.html") @@ -32,18 +30,17 @@ func Router() chi.Router { addMock(r, "/mock/report/{code}/update", "sync/mock/report-update.html") // Utility endpoints - r.Get("/privacy", getPrivacy) - r.Get("/qr-code/marketing", getQRCodeMarketing) - r.Get("/qr-code/report/{code}", getQRCodeReport) - r.Get("/qr-code/mailer/{code}", getQRCodeMailer) - r.Get("/template-test", getTemplateTest) + r.HandleFunc("/privacy", getPrivacy) + r.HandleFunc("/qr-code/marketing", getQRCodeMarketing) + r.HandleFunc("/qr-code/report/{code}", getQRCodeReport) + r.HandleFunc("/qr-code/mailer/{code}", getQRCodeMailer) + r.HandleFunc("/template-test", getTemplateTest) - // Authenticated endpoints - r.Route("/api", api.AddRoutes) + api_router := r.PathPrefix("/api").Subrouter() + api.AddRoutes(api_router) - r.Get("/", getRoot) - r.Get("/_/*", getRoot) + r.HandleFunc("/", getRoot) + r.HandleFunc("/_/*", getRoot) static.AddStaticRoute(r, "/static") - return r }