diff --git a/config.example.yml b/config.example.yml index 50e264c..979d1df 100644 --- a/config.example.yml +++ b/config.example.yml @@ -6,10 +6,10 @@ server: # Rate Limits Integers are per second # burst is an override for the limit, this is to handle spikes in traffic rate_limit: - event_limit: 25 - event_burst: 50 ws_limit: 50 ws_burst: 100 + event_limit: 25 + event_burst: 50 kind_limits: - kind: 0 limit: 1 diff --git a/main.go b/main.go index 3266ad0..47d0ca0 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "grain/relay" "grain/relay/db" - "grain/relay/handlers" "grain/relay/utils" "grain/web" @@ -15,8 +14,6 @@ import ( "golang.org/x/time/rate" ) -var rl *utils.RateLimiter - func main() { // Load configuration config, err := utils.LoadConfig("config.yml") @@ -31,19 +28,31 @@ func main() { } defer db.DisconnectDB() - // Initialize rate limiter - rl = utils.NewRateLimiter(rate.Limit(config.RateLimit.EventLimit), config.RateLimit.EventBurst, rate.Limit(config.RateLimit.WsLimit), config.RateLimit.WsBurst) + // Initialize RateLimiter + rateLimiter := utils.NewRateLimiter(rate.Limit(config.RateLimit.EventLimit), config.RateLimit.EventBurst, rate.Limit(config.RateLimit.WsLimit), config.RateLimit.WsBurst) + for _, kindLimit := range config.RateLimit.KindLimits { - rl.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst) + rateLimiter.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst) } - handlers.SetRateLimiter(rl) + rateLimiter.AddCategoryLimit("regular", rate.Limit(config.RateLimit.CategoryLimits.Regular.Limit), config.RateLimit.CategoryLimits.Regular.Burst) + rateLimiter.AddCategoryLimit("replaceable", rate.Limit(config.RateLimit.CategoryLimits.Replaceable.Limit), config.RateLimit.CategoryLimits.Replaceable.Burst) + rateLimiter.AddCategoryLimit("parameterized_replaceable", rate.Limit(config.RateLimit.CategoryLimits.ParameterizedReplaceable.Limit), config.RateLimit.CategoryLimits.ParameterizedReplaceable.Burst) + rateLimiter.AddCategoryLimit("ephemeral", rate.Limit(config.RateLimit.CategoryLimits.Ephemeral.Limit), config.RateLimit.CategoryLimits.Ephemeral.Burst) // Create a new ServeMux mux := http.NewServeMux() // Handle the root path - mux.HandleFunc("/", ListenAndServe) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + websocket.Handler(func(ws *websocket.Conn) { + relay.WebSocketHandler(ws, rateLimiter) + }).ServeHTTP(w, r) + } else { + web.RootHandler(w, r) + } + }) // Serve static files mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static")))) @@ -53,25 +62,10 @@ func main() { http.ServeFile(w, r, "web/static/img/favicon.ico") }) - // Start the Relay + // Start the server fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port) err = http.ListenAndServe(config.Server.Port, mux) if err != nil { fmt.Println("Error starting server:", err) } } - -// Listener serves both WebSocket and HTML -func ListenAndServe(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") == "websocket" { - websocket.Handler(func(ws *websocket.Conn) { - if !rl.AllowWs() { - ws.Close() - return - } - relay.WebSocketHandler(ws) - }).ServeHTTP(w, r) - } else { - web.RootHandler(w, r) - } -} diff --git a/relay/handlers/event.go b/relay/handlers/event.go index 009d700..9ebe061 100644 --- a/relay/handlers/event.go +++ b/relay/handlers/event.go @@ -13,13 +13,7 @@ import ( "golang.org/x/net/websocket" ) -var rl *utils.RateLimiter - -func SetRateLimiter(rateLimiter *utils.RateLimiter) { - rl = rateLimiter -} - -func HandleEvent(ws *websocket.Conn, message []interface{}) { +func HandleEvent(ws *websocket.Conn, message []interface{}, rateLimiter *utils.RateLimiter) { if len(message) != 2 { fmt.Println("Invalid EVENT message format") return @@ -43,9 +37,11 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) { return } - // Check rate limits - if !rl.AllowEvent(evt.Kind) { - kinds.SendNotice(ws, evt.ID, "rate limit exceeded") + // Determine the category based on the kind + category := getCategory(evt.Kind) + + if !rateLimiter.AllowEvent(evt.Kind, category) { + fmt.Printf("Event rate limit exceeded for kind: %d, category: %s\n", evt.Kind, category) return } @@ -55,6 +51,21 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) { fmt.Println("Event processed:", evt.ID) } +func getCategory(kind int) string { + switch { + case kind == 0 || kind == 3 || (kind >= 10000 && kind < 20000): + return "replaceable" + case kind >= 20000 && kind < 30000: + return "ephemeral" + case kind >= 30000 && kind < 40000: + return "parameterized_replaceable" + case (kind >= 4 && kind < 45) || (kind >= 1000 && kind < 10000) || kind == 1: + return "regular" + default: + return "unknown" + } +} + func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) { if !utils.CheckSignature(evt) { sendOK(ws, evt.ID, false, "invalid: signature verification failed") @@ -94,4 +105,4 @@ func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) { } sendOK(ws, evt.ID, true, "") -} +} \ No newline at end of file diff --git a/relay/relay.go b/relay/relay.go index 92b0d13..f537b93 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -5,11 +5,13 @@ import ( "fmt" "grain/relay/handlers" + "grain/relay/utils" + "golang.org/x/net/websocket" ) // WebSocketHandler handles incoming WebSocket connections -func WebSocketHandler(ws *websocket.Conn) { +func WebSocketHandler(ws *websocket.Conn, rateLimiter *utils.RateLimiter) { defer ws.Close() var msg string @@ -39,9 +41,14 @@ func WebSocketHandler(ws *websocket.Conn) { continue } + if !rateLimiter.AllowWs() { + fmt.Println("WebSocket message rate limit exceeded") + continue + } + switch messageType { case "EVENT": - handlers.HandleEvent(ws, message) + handlers.HandleEvent(ws, message, rateLimiter) case "REQ": handlers.HandleReq(ws, message) case "CLOSE": @@ -50,4 +57,4 @@ func WebSocketHandler(ws *websocket.Conn) { fmt.Println("Unknown message type:", messageType) } } -} +} \ No newline at end of file diff --git a/relay/utils/loadConfig.go b/relay/utils/loadConfig.go index 79535b7..a138def 100644 --- a/relay/utils/loadConfig.go +++ b/relay/utils/loadConfig.go @@ -7,11 +7,12 @@ import ( ) type RateLimitConfig struct { - EventLimit float64 `yaml:"event_limit"` - EventBurst int `yaml:"event_burst"` WsLimit float64 `yaml:"ws_limit"` WsBurst int `yaml:"ws_burst"` + EventLimit float64 `yaml:"event_limit"` + EventBurst int `yaml:"event_burst"` KindLimits []KindLimitConfig `yaml:"kind_limits"` + CategoryLimits CategoryLimitConfig `yaml:"category_limits"` } type KindLimitConfig struct { @@ -20,6 +21,18 @@ type KindLimitConfig struct { Burst int `yaml:"burst"` } +type CategoryLimitConfig struct { + Regular LimitBurst `yaml:"regular"` + Replaceable LimitBurst `yaml:"replaceable"` + ParameterizedReplaceable LimitBurst `yaml:"parameterized_replaceable"` + Ephemeral LimitBurst `yaml:"ephemeral"` +} + +type LimitBurst struct { + Limit float64 `yaml:"limit"` + Burst int `yaml:"burst"` +} + type Config struct { MongoDB struct { URI string `yaml:"uri"` diff --git a/relay/utils/rateLimiter.go b/relay/utils/rateLimiter.go index 83a847b..97c88fd 100644 --- a/relay/utils/rateLimiter.go +++ b/relay/utils/rateLimiter.go @@ -12,18 +12,26 @@ type KindLimiter struct { Burst int } +type CategoryLimiter struct { + Limiter *rate.Limiter + Limit rate.Limit + Burst int +} + type RateLimiter struct { - eventLimiter *rate.Limiter - wsLimiter *rate.Limiter - kindLimiters map[int]*KindLimiter - mu sync.RWMutex + eventLimiter *rate.Limiter + wsLimiter *rate.Limiter + kindLimiters map[int]*KindLimiter + categoryLimiters map[string]*CategoryLimiter + mu sync.RWMutex } func NewRateLimiter(eventLimit rate.Limit, eventBurst int, wsLimit rate.Limit, wsBurst int) *RateLimiter { return &RateLimiter{ - eventLimiter: rate.NewLimiter(eventLimit, eventBurst), - wsLimiter: rate.NewLimiter(wsLimit, wsBurst), - kindLimiters: make(map[int]*KindLimiter), + eventLimiter: rate.NewLimiter(eventLimit, eventBurst), + wsLimiter: rate.NewLimiter(wsLimit, wsBurst), + kindLimiters: make(map[int]*KindLimiter), + categoryLimiters: make(map[string]*CategoryLimiter), } } @@ -37,7 +45,17 @@ func (rl *RateLimiter) AddKindLimit(kind int, limit rate.Limit, burst int) { } } -func (rl *RateLimiter) AllowEvent(kind int) bool { +func (rl *RateLimiter) AddCategoryLimit(category string, limit rate.Limit, burst int) { + rl.mu.Lock() + defer rl.mu.Unlock() + rl.categoryLimiters[category] = &CategoryLimiter{ + Limiter: rate.NewLimiter(limit, burst), + Limit: limit, + Burst: burst, + } +} + +func (rl *RateLimiter) AllowEvent(kind int, category string) bool { rl.mu.RLock() defer rl.mu.RUnlock() @@ -51,6 +69,12 @@ func (rl *RateLimiter) AllowEvent(kind int) bool { } } + if categoryLimiter, exists := rl.categoryLimiters[category]; exists { + if !categoryLimiter.Limiter.Allow() { + return false + } + } + return true }