diff --git a/app/static/examples/config.example.yml b/app/static/examples/config.example.yml index 8888173..cf830c2 100644 --- a/app/static/examples/config.example.yml +++ b/app/static/examples/config.example.yml @@ -4,7 +4,10 @@ mongodb: server: port: ":8080" # Port for the server to listen on - +whitelist: + enabled: false + pubkeys: #["3fe0ab6cbdb7ee27148202249e3fb3b89423c6f6cda6ef43ea5057c3d93088e4", + #"cac0e43235806da094f0787a5b04e29ad04cb1a3c7ea5cf61edc1c338734082b"] rate_limit: ws_limit: 100 # Global rate limit for WebSocket messages (50 messages per second) ws_burst: 200 # Global burst limit for WebSocket messages (allows a burst of 100 messages) diff --git a/config/loadConfig.go b/config/loadConfig.go index 7f9f5b7..c8a8104 100644 --- a/config/loadConfig.go +++ b/config/loadConfig.go @@ -2,24 +2,37 @@ package config import ( "os" + "sync" - config "grain/config/types" + configTypes "grain/config/types" "gopkg.in/yaml.v2" ) -func LoadConfig(filename string) (*config.ServerConfig, error) { +var ( + cfg *configTypes.ServerConfig + once sync.Once +) + +func LoadConfig(filename string) (*configTypes.ServerConfig, error) { data, err := os.ReadFile(filename) if err != nil { return nil, err } - var config config.ServerConfig - + var config configTypes.ServerConfig err = yaml.Unmarshal(data, &config) if err != nil { return nil, err } - return &config, nil -} \ No newline at end of file + once.Do(func() { + cfg = &config + }) + + return cfg, nil +} + +func GetConfig() *configTypes.ServerConfig { + return cfg +} diff --git a/config/rateLimiter.go b/config/rateLimiter.go index 8c39d97..0ff606a 100644 --- a/config/rateLimiter.go +++ b/config/rateLimiter.go @@ -30,7 +30,7 @@ type RateLimiter struct { } var rateLimiterInstance *RateLimiter -var once sync.Once +var rateOnce sync.Once func SetupRateLimiter(cfg *config.ServerConfig) { rateLimiter := NewRateLimiter( @@ -54,7 +54,7 @@ func SetupRateLimiter(cfg *config.ServerConfig) { } func SetRateLimiter(rl *RateLimiter) { - once.Do(func() { + rateOnce.Do(func() { rateLimiterInstance = rl }) } diff --git a/config/types/serverConfig.go b/config/types/serverConfig.go index fbb1185..ac83aa0 100644 --- a/config/types/serverConfig.go +++ b/config/types/serverConfig.go @@ -9,4 +9,5 @@ type ServerConfig struct { Port string `yaml:"port"` } `yaml:"server"` RateLimit RateLimitConfig `yaml:"rate_limit"` -} \ No newline at end of file + Whitelist WhitelistConfig `yaml:"whitelist"` +} diff --git a/config/types/whitelistConfig.go b/config/types/whitelistConfig.go new file mode 100644 index 0000000..38ac601 --- /dev/null +++ b/config/types/whitelistConfig.go @@ -0,0 +1,6 @@ +package config + +type WhitelistConfig struct { + Enabled bool `yaml:"enabled"` + Pubkeys []string `yaml:"pubkeys"` +} diff --git a/server/handlers/event.go b/server/handlers/event.go index cbeabc5..f8dea04 100644 --- a/server/handlers/event.go +++ b/server/handlers/event.go @@ -59,6 +59,12 @@ func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn, eventS rateLimiter := config.GetRateLimiter() sizeLimiter := config.GetSizeLimiter() + // Check whitelist + if !isWhitelisted(evt.PubKey) { + response.SendOK(ws, evt.ID, false, "not allowed: pubkey is not whitelisted") + return + } + category := determineCategory(evt.Kind) if allowed, msg := rateLimiter.AllowEvent(evt.Kind, category); !allowed { @@ -121,3 +127,17 @@ func determineCategory(kind int) string { return "unknown" } } + +// Helper function to check if a pubkey is whitelisted +func isWhitelisted(pubKey string) bool { + cfg := config.GetConfig() + if !cfg.Whitelist.Enabled { + return true + } + for _, whitelistedKey := range cfg.Whitelist.Pubkeys { + if pubKey == whitelistedKey { + return true + } + } + return false +}