configuration refactor

This commit is contained in:
0ceanSlim 2024-07-30 11:27:38 -04:00
parent 567d9010a4
commit ba95bc849b
9 changed files with 46 additions and 44 deletions

View File

@ -1,4 +1,4 @@
package utils package config
import ( import (
"os" "os"

View File

@ -1,4 +1,4 @@
package utils package config
import ( import (
"fmt" "fmt"

View File

@ -1,4 +1,4 @@
package utils package config
import ( import (
"sync" "sync"

50
main.go
View File

@ -5,9 +5,9 @@ import (
"log" "log"
"net/http" "net/http"
"grain/config"
"grain/relay" "grain/relay"
"grain/relay/db" "grain/relay/db"
"grain/relay/utils"
"grain/web" "grain/web"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
@ -15,19 +15,19 @@ import (
) )
func main() { func main() {
config, err := loadConfiguration() cfg, err := loadConfiguration()
if err != nil { if err != nil {
log.Fatal("Error loading config: ", err) log.Fatal("Error loading config: ", err)
} }
err = initializeDatabase(config) err = initializeDatabase(cfg)
if err != nil { if err != nil {
log.Fatal("Error initializing database: ", err) log.Fatal("Error initializing database: ", err)
} }
defer db.DisconnectDB() defer db.DisconnectDB()
setupRateLimiter(config) setupRateLimiter(cfg)
setupSizeLimiter(config) setupSizeLimiter(cfg)
err = loadRelayMetadata() err = loadRelayMetadata()
if err != nil { if err != nil {
@ -36,46 +36,46 @@ func main() {
mux := setupRoutes() mux := setupRoutes()
startServer(config, mux) startServer(cfg, mux)
} }
func loadConfiguration() (*utils.Config, error) { func loadConfiguration() (*config.Config, error) {
return utils.LoadConfig("config.yml") return config.LoadConfig("config.yml")
} }
func initializeDatabase(config *utils.Config) error { func initializeDatabase(config *config.Config) error {
_, err := db.InitDB(config.MongoDB.URI, config.MongoDB.Database) _, err := db.InitDB(config.MongoDB.URI, config.MongoDB.Database)
return err return err
} }
func setupRateLimiter(config *utils.Config) { func setupRateLimiter(cfg *config.Config) {
rateLimiter := utils.NewRateLimiter( rateLimiter := config.NewRateLimiter(
rate.Limit(config.RateLimit.WsLimit), rate.Limit(cfg.RateLimit.WsLimit),
config.RateLimit.WsBurst, cfg.RateLimit.WsBurst,
rate.Limit(config.RateLimit.EventLimit), rate.Limit(cfg.RateLimit.EventLimit),
config.RateLimit.EventBurst, cfg.RateLimit.EventBurst,
rate.Limit(config.RateLimit.ReqLimit), rate.Limit(cfg.RateLimit.ReqLimit),
config.RateLimit.ReqBurst, cfg.RateLimit.ReqBurst,
) )
for _, kindLimit := range config.RateLimit.KindLimits { for _, kindLimit := range cfg.RateLimit.KindLimits {
rateLimiter.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst) rateLimiter.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst)
} }
for category, categoryLimit := range config.RateLimit.CategoryLimits { for category, categoryLimit := range cfg.RateLimit.CategoryLimits {
rateLimiter.AddCategoryLimit(category, rate.Limit(categoryLimit.Limit), categoryLimit.Burst) rateLimiter.AddCategoryLimit(category, rate.Limit(categoryLimit.Limit), categoryLimit.Burst)
} }
utils.SetRateLimiter(rateLimiter) config.SetRateLimiter(rateLimiter)
} }
func setupSizeLimiter(config *utils.Config) { func setupSizeLimiter(cfg *config.Config) {
sizeLimiter := utils.NewSizeLimiter(config.RateLimit.MaxEventSize) sizeLimiter := config.NewSizeLimiter(cfg.RateLimit.MaxEventSize)
for _, kindSizeLimit := range config.RateLimit.KindSizeLimits { for _, kindSizeLimit := range cfg.RateLimit.KindSizeLimits {
sizeLimiter.AddKindSizeLimit(kindSizeLimit.Kind, kindSizeLimit.MaxSize) sizeLimiter.AddKindSizeLimit(kindSizeLimit.Kind, kindSizeLimit.MaxSize)
} }
utils.SetSizeLimiter(sizeLimiter) config.SetSizeLimiter(sizeLimiter)
} }
func loadRelayMetadata() error { func loadRelayMetadata() error {
@ -92,7 +92,7 @@ func setupRoutes() *http.ServeMux {
return mux return mux
} }
func startServer(config *utils.Config, mux *http.ServeMux) { func startServer(config *config.Config, mux *http.ServeMux) {
fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port) fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port)
err := http.ListenAndServe(config.Server.Port, mux) err := http.ListenAndServe(config.Server.Port, mux)
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/config"
"grain/relay/db" "grain/relay/db"
"grain/relay/handlers/kinds" "grain/relay/handlers/kinds"
"grain/relay/handlers/response" "grain/relay/handlers/response"
@ -55,8 +56,8 @@ func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn, eventS
} }
collection := db.GetCollection(evt.Kind) collection := db.GetCollection(evt.Kind)
rateLimiter := utils.GetRateLimiter() rateLimiter := config.GetRateLimiter()
sizeLimiter := utils.GetSizeLimiter() sizeLimiter := config.GetSizeLimiter()
category := determineCategory(evt.Kind) category := determineCategory(evt.Kind)

View File

@ -4,7 +4,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/relay/handlers" "grain/relay/handlers"
"grain/relay/utils"
"grain/config"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
@ -13,7 +14,7 @@ func WebSocketHandler(ws *websocket.Conn) {
defer ws.Close() defer ws.Close()
var msg string var msg string
rateLimiter := utils.GetRateLimiter() rateLimiter := config.GetRateLimiter()
for { for {
err := websocket.Message.Receive(ws, &msg) err := websocket.Message.Receive(ws, &msg)

View File

@ -3,11 +3,11 @@ package tests
import ( import (
"testing" "testing"
"grain/relay/utils" "grain/config"
) )
func TestConfigValidity(t *testing.T) { func TestConfigValidity(t *testing.T) {
config, err := utils.LoadConfig("../config.yml") config, err := config.LoadConfig("../config.yml")
if err != nil { if err != nil {
t.Fatalf("Error loading config: %v", err) t.Fatalf("Error loading config: %v", err)
} }

View File

@ -3,13 +3,13 @@ package tests
import ( import (
"testing" "testing"
"grain/relay/utils" "grain/config"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
func TestWebSocketRateLimit(t *testing.T) { func TestWebSocketRateLimit(t *testing.T) {
rateLimiter := utils.NewRateLimiter(rate.Limit(1), 1, rate.Limit(100), 200, rate.Limit(100), 200) rateLimiter := config.NewRateLimiter(rate.Limit(1), 1, rate.Limit(100), 200, rate.Limit(100), 200)
// First message should be allowed // First message should be allowed
if allowed, _ := rateLimiter.AllowWs(); !allowed { if allowed, _ := rateLimiter.AllowWs(); !allowed {
@ -28,7 +28,7 @@ func TestWebSocketRateLimit(t *testing.T) {
} }
func TestEventRateLimit(t *testing.T) { func TestEventRateLimit(t *testing.T) {
rateLimiter := utils.NewRateLimiter(rate.Limit(100), 200, rate.Limit(1), 1, rate.Limit(100), 200) rateLimiter := config.NewRateLimiter(rate.Limit(100), 200, rate.Limit(1), 1, rate.Limit(100), 200)
rateLimiter.AddKindLimit(1, rate.Limit(1), 1) rateLimiter.AddKindLimit(1, rate.Limit(1), 1)
rateLimiter.AddCategoryLimit("regular", rate.Limit(1), 1) rateLimiter.AddCategoryLimit("regular", rate.Limit(1), 1)
@ -49,7 +49,7 @@ func TestEventRateLimit(t *testing.T) {
} }
func TestReqRateLimit(t *testing.T) { func TestReqRateLimit(t *testing.T) {
rateLimiter := utils.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(1), 1) rateLimiter := config.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(1), 1)
// First REQ should be allowed // First REQ should be allowed
if allowed, _ := rateLimiter.AllowReq(); !allowed { if allowed, _ := rateLimiter.AllowReq(); !allowed {
@ -68,7 +68,7 @@ func TestReqRateLimit(t *testing.T) {
} }
func TestKindRateLimit(t *testing.T) { func TestKindRateLimit(t *testing.T) {
rateLimiter := utils.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(100), 200) rateLimiter := config.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(100), 200)
rateLimiter.AddKindLimit(1, rate.Limit(1), 1) rateLimiter.AddKindLimit(1, rate.Limit(1), 1)
// First event of kind 1 should be allowed // First event of kind 1 should be allowed
@ -88,7 +88,7 @@ func TestKindRateLimit(t *testing.T) {
} }
func TestCategoryRateLimit(t *testing.T) { func TestCategoryRateLimit(t *testing.T) {
rateLimiter := utils.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(100), 200) rateLimiter := config.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(100), 200)
rateLimiter.AddCategoryLimit("regular", rate.Limit(1), 1) rateLimiter.AddCategoryLimit("regular", rate.Limit(1), 1)
// First event in category "regular" should be allowed // First event in category "regular" should be allowed

View File

@ -1,12 +1,12 @@
package tests package tests
import ( import (
"grain/relay/utils" "grain/config"
"testing" "testing"
) )
func TestSizeLimiterGlobalMaxSize(t *testing.T) { func TestSizeLimiterGlobalMaxSize(t *testing.T) {
sizeLimiter := utils.NewSizeLimiter(1024) // Set global max size to 1024 bytes sizeLimiter := config.NewSizeLimiter(1024) // Set global max size to 1024 bytes
// Test that an event within the global max size is allowed // Test that an event within the global max size is allowed
if allowed, _ := sizeLimiter.AllowSize(0, 512); !allowed { if allowed, _ := sizeLimiter.AllowSize(0, 512); !allowed {
@ -25,7 +25,7 @@ func TestSizeLimiterGlobalMaxSize(t *testing.T) {
} }
func TestSizeLimiterKindSpecificSize(t *testing.T) { func TestSizeLimiterKindSpecificSize(t *testing.T) {
sizeLimiter := utils.NewSizeLimiter(1024) // Set global max size to 1024 bytes sizeLimiter := config.NewSizeLimiter(1024) // Set global max size to 1024 bytes
sizeLimiter.AddKindSizeLimit(1, 512) // Set max size for kind 1 to 512 bytes sizeLimiter.AddKindSizeLimit(1, 512) // Set max size for kind 1 to 512 bytes
// Test that an event within the kind-specific max size is allowed // Test that an event within the kind-specific max size is allowed
@ -55,7 +55,7 @@ func TestSizeLimiterKindSpecificSize(t *testing.T) {
} }
func TestSizeLimiterNoKindSpecificLimit(t *testing.T) { func TestSizeLimiterNoKindSpecificLimit(t *testing.T) {
sizeLimiter := utils.NewSizeLimiter(1024) // Set global max size to 1024 bytes sizeLimiter := config.NewSizeLimiter(1024) // Set global max size to 1024 bytes
// Test that an event for a kind without a specific limit is governed by the global limit // Test that an event for a kind without a specific limit is governed by the global limit
if allowed, _ := sizeLimiter.AllowSize(2, 512); !allowed { if allowed, _ := sizeLimiter.AllowSize(2, 512); !allowed {