diff --git a/db/connection.go b/db/connection.go index 0e446a33..d0f550fb 100644 --- a/db/connection.go +++ b/db/connection.go @@ -111,6 +111,10 @@ func InitializeDatabase(ctx context.Context, uri string) error { return fmt.Errorf("Failed to get database current: %w", err) } log.Info().Str("database", current).Msg("Connected to database") + err = prepareStatements(ctx) + if err != nil { + return fmt.Errorf("Failed to initialize prepared statements: %w", err) + } return nil } diff --git a/db/prepared.go b/db/prepared.go index 06498e20..c6b38260 100644 --- a/db/prepared.go +++ b/db/prepared.go @@ -10,6 +10,8 @@ import ( //"github.com/stephenafamo/bob" //"github.com/stephenafamo/bob/dialect/psql" "github.com/rs/zerolog/log" + "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/dialect/psql" ) //go:embed prepared_functions/*.sql @@ -20,19 +22,21 @@ var sqlFiles embed.FS // preparing statements that will be used later. func prepareStatements(ctx context.Context) error { // Get a list of all embedded SQL files - entries, err := sqlFiles.ReadDir(".") + entries, err := sqlFiles.ReadDir("prepared_functions") if err != nil { return fmt.Errorf("failed to read SQL directory: %w", err) } + log.Info().Int("len", len(entries)).Msg("Reading prepared functions") // Process each SQL file for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { + log.Info().Str("name", entry.Name()).Msg("Skipping") continue } // Read the SQL file content - content, err := sqlFiles.ReadFile(entry.Name()) + content, err := sqlFiles.ReadFile(filepath.Join("prepared_functions", entry.Name())) if err != nil { return fmt.Errorf("failed to read SQL file %s: %w", entry.Name(), err) } @@ -59,5 +63,23 @@ func prepareStatements(ctx context.Context) error { return nil } func TestPreparedQuery(ctx context.Context) error { + query := psql.RawQuery("EXECUTE test_function") + result, err := bob.Exec(ctx, PGInstance.BobDB, query) + if err != nil { + return fmt.Errorf("Failed to exectue test function: %w", err) + } + /*insert_id, err := result.LastInsertId() + if err != nil { + log.Error().Err(err).Msg("failed insert id") + return fmt.Errorf("Failed to get insert ID: %w", err) + }*/ + rows_affected, err := result.RowsAffected() + if err != nil { + log.Error().Err(err).Msg("failed rows affected") + return fmt.Errorf("Failed to get rows affected: %w", err) + } + //log.Info().Int64("insert id", insert_id).Int64("rows", rows_affected).Msg("bah") + log.Info().Int64("rows", rows_affected).Msg("got rows") + return nil } diff --git a/db/prepared_functions/test_function.sql b/db/prepared_functions/test_function.sql new file mode 100644 index 00000000..ae7b8617 --- /dev/null +++ b/db/prepared_functions/test_function.sql @@ -0,0 +1,3 @@ +PREPARE test_function AS + SELECT version(); +