configurable rate limit by category of event type

This commit is contained in:
0ceanSlim 2024-07-25 09:03:34 -04:00
parent 80e80c4215
commit 4920f61a99
6 changed files with 99 additions and 50 deletions

View File

@ -6,10 +6,10 @@ server:
# Rate Limits Integers are per second # Rate Limits Integers are per second
# burst is an override for the limit, this is to handle spikes in traffic # burst is an override for the limit, this is to handle spikes in traffic
rate_limit: rate_limit:
event_limit: 25
event_burst: 50
ws_limit: 50 ws_limit: 50
ws_burst: 100 ws_burst: 100
event_limit: 25
event_burst: 50
kind_limits: kind_limits:
- kind: 0 - kind: 0
limit: 1 limit: 1

42
main.go
View File

@ -7,7 +7,6 @@ import (
"grain/relay" "grain/relay"
"grain/relay/db" "grain/relay/db"
"grain/relay/handlers"
"grain/relay/utils" "grain/relay/utils"
"grain/web" "grain/web"
@ -15,8 +14,6 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
var rl *utils.RateLimiter
func main() { func main() {
// Load configuration // Load configuration
config, err := utils.LoadConfig("config.yml") config, err := utils.LoadConfig("config.yml")
@ -31,19 +28,31 @@ func main() {
} }
defer db.DisconnectDB() defer db.DisconnectDB()
// Initialize rate limiter // Initialize RateLimiter
rl = utils.NewRateLimiter(rate.Limit(config.RateLimit.EventLimit), config.RateLimit.EventBurst, rate.Limit(config.RateLimit.WsLimit), config.RateLimit.WsBurst) rateLimiter := utils.NewRateLimiter(rate.Limit(config.RateLimit.EventLimit), config.RateLimit.EventBurst, rate.Limit(config.RateLimit.WsLimit), config.RateLimit.WsBurst)
for _, kindLimit := range config.RateLimit.KindLimits { for _, kindLimit := range config.RateLimit.KindLimits {
rl.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst) rateLimiter.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst)
} }
handlers.SetRateLimiter(rl) rateLimiter.AddCategoryLimit("regular", rate.Limit(config.RateLimit.CategoryLimits.Regular.Limit), config.RateLimit.CategoryLimits.Regular.Burst)
rateLimiter.AddCategoryLimit("replaceable", rate.Limit(config.RateLimit.CategoryLimits.Replaceable.Limit), config.RateLimit.CategoryLimits.Replaceable.Burst)
rateLimiter.AddCategoryLimit("parameterized_replaceable", rate.Limit(config.RateLimit.CategoryLimits.ParameterizedReplaceable.Limit), config.RateLimit.CategoryLimits.ParameterizedReplaceable.Burst)
rateLimiter.AddCategoryLimit("ephemeral", rate.Limit(config.RateLimit.CategoryLimits.Ephemeral.Limit), config.RateLimit.CategoryLimits.Ephemeral.Burst)
// Create a new ServeMux // Create a new ServeMux
mux := http.NewServeMux() mux := http.NewServeMux()
// Handle the root path // Handle the root path
mux.HandleFunc("/", ListenAndServe) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
websocket.Handler(func(ws *websocket.Conn) {
relay.WebSocketHandler(ws, rateLimiter)
}).ServeHTTP(w, r)
} else {
web.RootHandler(w, r)
}
})
// Serve static files // Serve static files
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static")))) mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static"))))
@ -53,25 +62,10 @@ func main() {
http.ServeFile(w, r, "web/static/img/favicon.ico") http.ServeFile(w, r, "web/static/img/favicon.ico")
}) })
// Start the Relay // Start the server
fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port) fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port)
err = http.ListenAndServe(config.Server.Port, mux) err = http.ListenAndServe(config.Server.Port, mux)
if err != nil { if err != nil {
fmt.Println("Error starting server:", err) fmt.Println("Error starting server:", err)
} }
} }
// Listener serves both WebSocket and HTML
func ListenAndServe(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
websocket.Handler(func(ws *websocket.Conn) {
if !rl.AllowWs() {
ws.Close()
return
}
relay.WebSocketHandler(ws)
}).ServeHTTP(w, r)
} else {
web.RootHandler(w, r)
}
}

View File

@ -13,13 +13,7 @@ import (
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
var rl *utils.RateLimiter func HandleEvent(ws *websocket.Conn, message []interface{}, rateLimiter *utils.RateLimiter) {
func SetRateLimiter(rateLimiter *utils.RateLimiter) {
rl = rateLimiter
}
func HandleEvent(ws *websocket.Conn, message []interface{}) {
if len(message) != 2 { if len(message) != 2 {
fmt.Println("Invalid EVENT message format") fmt.Println("Invalid EVENT message format")
return return
@ -43,9 +37,11 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) {
return return
} }
// Check rate limits // Determine the category based on the kind
if !rl.AllowEvent(evt.Kind) { category := getCategory(evt.Kind)
kinds.SendNotice(ws, evt.ID, "rate limit exceeded")
if !rateLimiter.AllowEvent(evt.Kind, category) {
fmt.Printf("Event rate limit exceeded for kind: %d, category: %s\n", evt.Kind, category)
return return
} }
@ -55,6 +51,21 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) {
fmt.Println("Event processed:", evt.ID) fmt.Println("Event processed:", evt.ID)
} }
func getCategory(kind int) string {
switch {
case kind == 0 || kind == 3 || (kind >= 10000 && kind < 20000):
return "replaceable"
case kind >= 20000 && kind < 30000:
return "ephemeral"
case kind >= 30000 && kind < 40000:
return "parameterized_replaceable"
case (kind >= 4 && kind < 45) || (kind >= 1000 && kind < 10000) || kind == 1:
return "regular"
default:
return "unknown"
}
}
func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) { func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) {
if !utils.CheckSignature(evt) { if !utils.CheckSignature(evt) {
sendOK(ws, evt.ID, false, "invalid: signature verification failed") sendOK(ws, evt.ID, false, "invalid: signature verification failed")
@ -94,4 +105,4 @@ func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) {
} }
sendOK(ws, evt.ID, true, "") sendOK(ws, evt.ID, true, "")
} }

View File

@ -5,11 +5,13 @@ import (
"fmt" "fmt"
"grain/relay/handlers" "grain/relay/handlers"
"grain/relay/utils"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
// WebSocketHandler handles incoming WebSocket connections // WebSocketHandler handles incoming WebSocket connections
func WebSocketHandler(ws *websocket.Conn) { func WebSocketHandler(ws *websocket.Conn, rateLimiter *utils.RateLimiter) {
defer ws.Close() defer ws.Close()
var msg string var msg string
@ -39,9 +41,14 @@ func WebSocketHandler(ws *websocket.Conn) {
continue continue
} }
if !rateLimiter.AllowWs() {
fmt.Println("WebSocket message rate limit exceeded")
continue
}
switch messageType { switch messageType {
case "EVENT": case "EVENT":
handlers.HandleEvent(ws, message) handlers.HandleEvent(ws, message, rateLimiter)
case "REQ": case "REQ":
handlers.HandleReq(ws, message) handlers.HandleReq(ws, message)
case "CLOSE": case "CLOSE":
@ -50,4 +57,4 @@ func WebSocketHandler(ws *websocket.Conn) {
fmt.Println("Unknown message type:", messageType) fmt.Println("Unknown message type:", messageType)
} }
} }
} }

View File

@ -7,11 +7,12 @@ import (
) )
type RateLimitConfig struct { type RateLimitConfig struct {
EventLimit float64 `yaml:"event_limit"`
EventBurst int `yaml:"event_burst"`
WsLimit float64 `yaml:"ws_limit"` WsLimit float64 `yaml:"ws_limit"`
WsBurst int `yaml:"ws_burst"` WsBurst int `yaml:"ws_burst"`
EventLimit float64 `yaml:"event_limit"`
EventBurst int `yaml:"event_burst"`
KindLimits []KindLimitConfig `yaml:"kind_limits"` KindLimits []KindLimitConfig `yaml:"kind_limits"`
CategoryLimits CategoryLimitConfig `yaml:"category_limits"`
} }
type KindLimitConfig struct { type KindLimitConfig struct {
@ -20,6 +21,18 @@ type KindLimitConfig struct {
Burst int `yaml:"burst"` Burst int `yaml:"burst"`
} }
type CategoryLimitConfig struct {
Regular LimitBurst `yaml:"regular"`
Replaceable LimitBurst `yaml:"replaceable"`
ParameterizedReplaceable LimitBurst `yaml:"parameterized_replaceable"`
Ephemeral LimitBurst `yaml:"ephemeral"`
}
type LimitBurst struct {
Limit float64 `yaml:"limit"`
Burst int `yaml:"burst"`
}
type Config struct { type Config struct {
MongoDB struct { MongoDB struct {
URI string `yaml:"uri"` URI string `yaml:"uri"`

View File

@ -12,18 +12,26 @@ type KindLimiter struct {
Burst int Burst int
} }
type CategoryLimiter struct {
Limiter *rate.Limiter
Limit rate.Limit
Burst int
}
type RateLimiter struct { type RateLimiter struct {
eventLimiter *rate.Limiter eventLimiter *rate.Limiter
wsLimiter *rate.Limiter wsLimiter *rate.Limiter
kindLimiters map[int]*KindLimiter kindLimiters map[int]*KindLimiter
mu sync.RWMutex categoryLimiters map[string]*CategoryLimiter
mu sync.RWMutex
} }
func NewRateLimiter(eventLimit rate.Limit, eventBurst int, wsLimit rate.Limit, wsBurst int) *RateLimiter { func NewRateLimiter(eventLimit rate.Limit, eventBurst int, wsLimit rate.Limit, wsBurst int) *RateLimiter {
return &RateLimiter{ return &RateLimiter{
eventLimiter: rate.NewLimiter(eventLimit, eventBurst), eventLimiter: rate.NewLimiter(eventLimit, eventBurst),
wsLimiter: rate.NewLimiter(wsLimit, wsBurst), wsLimiter: rate.NewLimiter(wsLimit, wsBurst),
kindLimiters: make(map[int]*KindLimiter), kindLimiters: make(map[int]*KindLimiter),
categoryLimiters: make(map[string]*CategoryLimiter),
} }
} }
@ -37,7 +45,17 @@ func (rl *RateLimiter) AddKindLimit(kind int, limit rate.Limit, burst int) {
} }
} }
func (rl *RateLimiter) AllowEvent(kind int) bool { func (rl *RateLimiter) AddCategoryLimit(category string, limit rate.Limit, burst int) {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.categoryLimiters[category] = &CategoryLimiter{
Limiter: rate.NewLimiter(limit, burst),
Limit: limit,
Burst: burst,
}
}
func (rl *RateLimiter) AllowEvent(kind int, category string) bool {
rl.mu.RLock() rl.mu.RLock()
defer rl.mu.RUnlock() defer rl.mu.RUnlock()
@ -51,6 +69,12 @@ func (rl *RateLimiter) AllowEvent(kind int) bool {
} }
} }
if categoryLimiter, exists := rl.categoryLimiters[category]; exists {
if !categoryLimiter.Limiter.Allow() {
return false
}
}
return true return true
} }