mirror of
https://github.com/0ceanSlim/grain.git
synced 2024-11-22 08:37:13 +00:00
configurable rate limit by category of event type
This commit is contained in:
parent
80e80c4215
commit
4920f61a99
@ -6,10 +6,10 @@ server:
|
|||||||
# Rate Limits Integers are per second
|
# Rate Limits Integers are per second
|
||||||
# burst is an override for the limit, this is to handle spikes in traffic
|
# burst is an override for the limit, this is to handle spikes in traffic
|
||||||
rate_limit:
|
rate_limit:
|
||||||
event_limit: 25
|
|
||||||
event_burst: 50
|
|
||||||
ws_limit: 50
|
ws_limit: 50
|
||||||
ws_burst: 100
|
ws_burst: 100
|
||||||
|
event_limit: 25
|
||||||
|
event_burst: 50
|
||||||
kind_limits:
|
kind_limits:
|
||||||
- kind: 0
|
- kind: 0
|
||||||
limit: 1
|
limit: 1
|
||||||
|
42
main.go
42
main.go
@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"grain/relay"
|
"grain/relay"
|
||||||
"grain/relay/db"
|
"grain/relay/db"
|
||||||
"grain/relay/handlers"
|
|
||||||
"grain/relay/utils"
|
"grain/relay/utils"
|
||||||
"grain/web"
|
"grain/web"
|
||||||
|
|
||||||
@ -15,8 +14,6 @@ import (
|
|||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
var rl *utils.RateLimiter
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Load configuration
|
// Load configuration
|
||||||
config, err := utils.LoadConfig("config.yml")
|
config, err := utils.LoadConfig("config.yml")
|
||||||
@ -31,19 +28,31 @@ func main() {
|
|||||||
}
|
}
|
||||||
defer db.DisconnectDB()
|
defer db.DisconnectDB()
|
||||||
|
|
||||||
// Initialize rate limiter
|
// Initialize RateLimiter
|
||||||
rl = utils.NewRateLimiter(rate.Limit(config.RateLimit.EventLimit), config.RateLimit.EventBurst, rate.Limit(config.RateLimit.WsLimit), config.RateLimit.WsBurst)
|
rateLimiter := utils.NewRateLimiter(rate.Limit(config.RateLimit.EventLimit), config.RateLimit.EventBurst, rate.Limit(config.RateLimit.WsLimit), config.RateLimit.WsBurst)
|
||||||
|
|
||||||
for _, kindLimit := range config.RateLimit.KindLimits {
|
for _, kindLimit := range config.RateLimit.KindLimits {
|
||||||
rl.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst)
|
rateLimiter.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst)
|
||||||
}
|
}
|
||||||
|
|
||||||
handlers.SetRateLimiter(rl)
|
rateLimiter.AddCategoryLimit("regular", rate.Limit(config.RateLimit.CategoryLimits.Regular.Limit), config.RateLimit.CategoryLimits.Regular.Burst)
|
||||||
|
rateLimiter.AddCategoryLimit("replaceable", rate.Limit(config.RateLimit.CategoryLimits.Replaceable.Limit), config.RateLimit.CategoryLimits.Replaceable.Burst)
|
||||||
|
rateLimiter.AddCategoryLimit("parameterized_replaceable", rate.Limit(config.RateLimit.CategoryLimits.ParameterizedReplaceable.Limit), config.RateLimit.CategoryLimits.ParameterizedReplaceable.Burst)
|
||||||
|
rateLimiter.AddCategoryLimit("ephemeral", rate.Limit(config.RateLimit.CategoryLimits.Ephemeral.Limit), config.RateLimit.CategoryLimits.Ephemeral.Burst)
|
||||||
|
|
||||||
// Create a new ServeMux
|
// Create a new ServeMux
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
// Handle the root path
|
// Handle the root path
|
||||||
mux.HandleFunc("/", ListenAndServe)
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("Upgrade") == "websocket" {
|
||||||
|
websocket.Handler(func(ws *websocket.Conn) {
|
||||||
|
relay.WebSocketHandler(ws, rateLimiter)
|
||||||
|
}).ServeHTTP(w, r)
|
||||||
|
} else {
|
||||||
|
web.RootHandler(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Serve static files
|
// Serve static files
|
||||||
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static"))))
|
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static"))))
|
||||||
@ -53,25 +62,10 @@ func main() {
|
|||||||
http.ServeFile(w, r, "web/static/img/favicon.ico")
|
http.ServeFile(w, r, "web/static/img/favicon.ico")
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start the Relay
|
// Start the server
|
||||||
fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port)
|
fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port)
|
||||||
err = http.ListenAndServe(config.Server.Port, mux)
|
err = http.ListenAndServe(config.Server.Port, mux)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error starting server:", err)
|
fmt.Println("Error starting server:", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listener serves both WebSocket and HTML
|
|
||||||
func ListenAndServe(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Header.Get("Upgrade") == "websocket" {
|
|
||||||
websocket.Handler(func(ws *websocket.Conn) {
|
|
||||||
if !rl.AllowWs() {
|
|
||||||
ws.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
relay.WebSocketHandler(ws)
|
|
||||||
}).ServeHTTP(w, r)
|
|
||||||
} else {
|
|
||||||
web.RootHandler(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -13,13 +13,7 @@ import (
|
|||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
var rl *utils.RateLimiter
|
func HandleEvent(ws *websocket.Conn, message []interface{}, rateLimiter *utils.RateLimiter) {
|
||||||
|
|
||||||
func SetRateLimiter(rateLimiter *utils.RateLimiter) {
|
|
||||||
rl = rateLimiter
|
|
||||||
}
|
|
||||||
|
|
||||||
func HandleEvent(ws *websocket.Conn, message []interface{}) {
|
|
||||||
if len(message) != 2 {
|
if len(message) != 2 {
|
||||||
fmt.Println("Invalid EVENT message format")
|
fmt.Println("Invalid EVENT message format")
|
||||||
return
|
return
|
||||||
@ -43,9 +37,11 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check rate limits
|
// Determine the category based on the kind
|
||||||
if !rl.AllowEvent(evt.Kind) {
|
category := getCategory(evt.Kind)
|
||||||
kinds.SendNotice(ws, evt.ID, "rate limit exceeded")
|
|
||||||
|
if !rateLimiter.AllowEvent(evt.Kind, category) {
|
||||||
|
fmt.Printf("Event rate limit exceeded for kind: %d, category: %s\n", evt.Kind, category)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,6 +51,21 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) {
|
|||||||
fmt.Println("Event processed:", evt.ID)
|
fmt.Println("Event processed:", evt.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getCategory(kind int) string {
|
||||||
|
switch {
|
||||||
|
case kind == 0 || kind == 3 || (kind >= 10000 && kind < 20000):
|
||||||
|
return "replaceable"
|
||||||
|
case kind >= 20000 && kind < 30000:
|
||||||
|
return "ephemeral"
|
||||||
|
case kind >= 30000 && kind < 40000:
|
||||||
|
return "parameterized_replaceable"
|
||||||
|
case (kind >= 4 && kind < 45) || (kind >= 1000 && kind < 10000) || kind == 1:
|
||||||
|
return "regular"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) {
|
func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) {
|
||||||
if !utils.CheckSignature(evt) {
|
if !utils.CheckSignature(evt) {
|
||||||
sendOK(ws, evt.ID, false, "invalid: signature verification failed")
|
sendOK(ws, evt.ID, false, "invalid: signature verification failed")
|
||||||
|
@ -5,11 +5,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"grain/relay/handlers"
|
"grain/relay/handlers"
|
||||||
|
|
||||||
|
"grain/relay/utils"
|
||||||
|
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WebSocketHandler handles incoming WebSocket connections
|
// WebSocketHandler handles incoming WebSocket connections
|
||||||
func WebSocketHandler(ws *websocket.Conn) {
|
func WebSocketHandler(ws *websocket.Conn, rateLimiter *utils.RateLimiter) {
|
||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
|
|
||||||
var msg string
|
var msg string
|
||||||
@ -39,9 +41,14 @@ func WebSocketHandler(ws *websocket.Conn) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !rateLimiter.AllowWs() {
|
||||||
|
fmt.Println("WebSocket message rate limit exceeded")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
switch messageType {
|
switch messageType {
|
||||||
case "EVENT":
|
case "EVENT":
|
||||||
handlers.HandleEvent(ws, message)
|
handlers.HandleEvent(ws, message, rateLimiter)
|
||||||
case "REQ":
|
case "REQ":
|
||||||
handlers.HandleReq(ws, message)
|
handlers.HandleReq(ws, message)
|
||||||
case "CLOSE":
|
case "CLOSE":
|
||||||
|
@ -7,11 +7,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RateLimitConfig struct {
|
type RateLimitConfig struct {
|
||||||
EventLimit float64 `yaml:"event_limit"`
|
|
||||||
EventBurst int `yaml:"event_burst"`
|
|
||||||
WsLimit float64 `yaml:"ws_limit"`
|
WsLimit float64 `yaml:"ws_limit"`
|
||||||
WsBurst int `yaml:"ws_burst"`
|
WsBurst int `yaml:"ws_burst"`
|
||||||
|
EventLimit float64 `yaml:"event_limit"`
|
||||||
|
EventBurst int `yaml:"event_burst"`
|
||||||
KindLimits []KindLimitConfig `yaml:"kind_limits"`
|
KindLimits []KindLimitConfig `yaml:"kind_limits"`
|
||||||
|
CategoryLimits CategoryLimitConfig `yaml:"category_limits"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type KindLimitConfig struct {
|
type KindLimitConfig struct {
|
||||||
@ -20,6 +21,18 @@ type KindLimitConfig struct {
|
|||||||
Burst int `yaml:"burst"`
|
Burst int `yaml:"burst"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CategoryLimitConfig struct {
|
||||||
|
Regular LimitBurst `yaml:"regular"`
|
||||||
|
Replaceable LimitBurst `yaml:"replaceable"`
|
||||||
|
ParameterizedReplaceable LimitBurst `yaml:"parameterized_replaceable"`
|
||||||
|
Ephemeral LimitBurst `yaml:"ephemeral"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LimitBurst struct {
|
||||||
|
Limit float64 `yaml:"limit"`
|
||||||
|
Burst int `yaml:"burst"`
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
MongoDB struct {
|
MongoDB struct {
|
||||||
URI string `yaml:"uri"`
|
URI string `yaml:"uri"`
|
||||||
|
@ -12,10 +12,17 @@ type KindLimiter struct {
|
|||||||
Burst int
|
Burst int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CategoryLimiter struct {
|
||||||
|
Limiter *rate.Limiter
|
||||||
|
Limit rate.Limit
|
||||||
|
Burst int
|
||||||
|
}
|
||||||
|
|
||||||
type RateLimiter struct {
|
type RateLimiter struct {
|
||||||
eventLimiter *rate.Limiter
|
eventLimiter *rate.Limiter
|
||||||
wsLimiter *rate.Limiter
|
wsLimiter *rate.Limiter
|
||||||
kindLimiters map[int]*KindLimiter
|
kindLimiters map[int]*KindLimiter
|
||||||
|
categoryLimiters map[string]*CategoryLimiter
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,6 +31,7 @@ func NewRateLimiter(eventLimit rate.Limit, eventBurst int, wsLimit rate.Limit, w
|
|||||||
eventLimiter: rate.NewLimiter(eventLimit, eventBurst),
|
eventLimiter: rate.NewLimiter(eventLimit, eventBurst),
|
||||||
wsLimiter: rate.NewLimiter(wsLimit, wsBurst),
|
wsLimiter: rate.NewLimiter(wsLimit, wsBurst),
|
||||||
kindLimiters: make(map[int]*KindLimiter),
|
kindLimiters: make(map[int]*KindLimiter),
|
||||||
|
categoryLimiters: make(map[string]*CategoryLimiter),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,7 +45,17 @@ func (rl *RateLimiter) AddKindLimit(kind int, limit rate.Limit, burst int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *RateLimiter) AllowEvent(kind int) bool {
|
func (rl *RateLimiter) AddCategoryLimit(category string, limit rate.Limit, burst int) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
rl.categoryLimiters[category] = &CategoryLimiter{
|
||||||
|
Limiter: rate.NewLimiter(limit, burst),
|
||||||
|
Limit: limit,
|
||||||
|
Burst: burst,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *RateLimiter) AllowEvent(kind int, category string) bool {
|
||||||
rl.mu.RLock()
|
rl.mu.RLock()
|
||||||
defer rl.mu.RUnlock()
|
defer rl.mu.RUnlock()
|
||||||
|
|
||||||
@ -51,6 +69,12 @@ func (rl *RateLimiter) AllowEvent(kind int) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if categoryLimiter, exists := rl.categoryLimiters[category]; exists {
|
||||||
|
if !categoryLimiter.Limiter.Allow() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user