mirror of
https://github.com/0ceanSlim/grain.git
synced 2024-10-30 01:26:32 +00:00
fix category rate limits
This commit is contained in:
parent
8050e2e74f
commit
d5f2366ff8
41
main.go
41
main.go
@ -28,31 +28,29 @@ func main() {
|
||||
}
|
||||
defer db.DisconnectDB()
|
||||
|
||||
// Initialize RateLimiter
|
||||
rateLimiter := utils.NewRateLimiter(rate.Limit(config.RateLimit.EventLimit), config.RateLimit.EventBurst, rate.Limit(config.RateLimit.WsLimit), config.RateLimit.WsBurst)
|
||||
// Initialize Rate Limiter
|
||||
rateLimiter := utils.NewRateLimiter(
|
||||
rate.Limit(config.RateLimit.EventLimit),
|
||||
config.RateLimit.EventBurst,
|
||||
rate.Limit(config.RateLimit.WsLimit),
|
||||
config.RateLimit.WsBurst,
|
||||
)
|
||||
|
||||
for _, kindLimit := range config.RateLimit.KindLimits {
|
||||
rateLimiter.AddKindLimit(kindLimit.Kind, rate.Limit(kindLimit.Limit), kindLimit.Burst)
|
||||
}
|
||||
|
||||
rateLimiter.AddCategoryLimit("regular", rate.Limit(config.RateLimit.CategoryLimits.Regular.Limit), config.RateLimit.CategoryLimits.Regular.Burst)
|
||||
rateLimiter.AddCategoryLimit("replaceable", rate.Limit(config.RateLimit.CategoryLimits.Replaceable.Limit), config.RateLimit.CategoryLimits.Replaceable.Burst)
|
||||
rateLimiter.AddCategoryLimit("parameterized_replaceable", rate.Limit(config.RateLimit.CategoryLimits.ParameterizedReplaceable.Limit), config.RateLimit.CategoryLimits.ParameterizedReplaceable.Burst)
|
||||
rateLimiter.AddCategoryLimit("ephemeral", rate.Limit(config.RateLimit.CategoryLimits.Ephemeral.Limit), config.RateLimit.CategoryLimits.Ephemeral.Burst)
|
||||
for category, categoryLimit := range config.RateLimit.CategoryLimits {
|
||||
rateLimiter.AddCategoryLimit(category, rate.Limit(categoryLimit.Limit), categoryLimit.Burst)
|
||||
}
|
||||
|
||||
utils.SetRateLimiter(rateLimiter)
|
||||
|
||||
// Create a new ServeMux
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Handle the root path
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") == "websocket" {
|
||||
websocket.Handler(func(ws *websocket.Conn) {
|
||||
relay.WebSocketHandler(ws, rateLimiter)
|
||||
}).ServeHTTP(w, r)
|
||||
} else {
|
||||
web.RootHandler(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/", ListenAndServe)
|
||||
|
||||
// Serve static files
|
||||
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static"))))
|
||||
@ -62,10 +60,21 @@ func main() {
|
||||
http.ServeFile(w, r, "web/static/img/favicon.ico")
|
||||
})
|
||||
|
||||
// Start the server
|
||||
// Start the Relay
|
||||
fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port)
|
||||
err = http.ListenAndServe(config.Server.Port, mux)
|
||||
if err != nil {
|
||||
fmt.Println("Error starting server:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Listener serves both WebSocket and HTML
|
||||
func ListenAndServe(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") == "websocket" {
|
||||
websocket.Handler(func(ws *websocket.Conn) {
|
||||
relay.WebSocketHandler(ws)
|
||||
}).ServeHTTP(w, r)
|
||||
} else {
|
||||
web.RootHandler(w, r)
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ import (
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func HandleEvent(ws *websocket.Conn, message []interface{}, rateLimiter *utils.RateLimiter) {
|
||||
func HandleEvent(ws *websocket.Conn, message []interface{}) {
|
||||
if len(message) != 2 {
|
||||
fmt.Println("Invalid EVENT message format")
|
||||
return
|
||||
@ -37,35 +37,12 @@ func HandleEvent(ws *websocket.Conn, message []interface{}, rateLimiter *utils.R
|
||||
return
|
||||
}
|
||||
|
||||
// Determine the category based on the kind
|
||||
category := getCategory(evt.Kind)
|
||||
|
||||
if !rateLimiter.AllowEvent(evt.Kind, category) {
|
||||
fmt.Printf("Event rate limit exceeded for kind: %d, category: %s\n", evt.Kind, category)
|
||||
return
|
||||
}
|
||||
|
||||
// Call the HandleKind function
|
||||
HandleKind(context.TODO(), evt, ws)
|
||||
|
||||
fmt.Println("Event processed:", evt.ID)
|
||||
}
|
||||
|
||||
func getCategory(kind int) string {
|
||||
switch {
|
||||
case kind == 0 || kind == 3 || (kind >= 10000 && kind < 20000):
|
||||
return "replaceable"
|
||||
case kind >= 20000 && kind < 30000:
|
||||
return "ephemeral"
|
||||
case kind >= 30000 && kind < 40000:
|
||||
return "parameterized_replaceable"
|
||||
case (kind >= 4 && kind < 45) || (kind >= 1000 && kind < 10000) || kind == 1:
|
||||
return "regular"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) {
|
||||
if !utils.CheckSignature(evt) {
|
||||
sendOK(ws, evt.ID, false, "invalid: signature verification failed")
|
||||
@ -74,6 +51,36 @@ func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) {
|
||||
|
||||
collection := db.GetCollection(evt.Kind)
|
||||
|
||||
rateLimiter := utils.GetRateLimiter()
|
||||
var category string
|
||||
switch {
|
||||
case evt.Kind == 0:
|
||||
category = "replaceable"
|
||||
case evt.Kind == 1:
|
||||
category = "regular"
|
||||
case evt.Kind == 2:
|
||||
category = "deprecated"
|
||||
case evt.Kind == 3:
|
||||
category = "replaceable"
|
||||
case evt.Kind >= 4 && evt.Kind < 45:
|
||||
category = "regular"
|
||||
case evt.Kind >= 1000 && evt.Kind < 10000:
|
||||
category = "regular"
|
||||
case evt.Kind >= 10000 && evt.Kind < 20000:
|
||||
category = "replaceable"
|
||||
case evt.Kind >= 20000 && evt.Kind < 30000:
|
||||
category = "ephemeral"
|
||||
case evt.Kind >= 30000 && evt.Kind < 40000:
|
||||
category = "parameterized_replaceable"
|
||||
default:
|
||||
category = "unknown"
|
||||
}
|
||||
|
||||
if !rateLimiter.AllowEvent(evt.Kind, category) {
|
||||
sendOK(ws, evt.ID, false, fmt.Sprintf("rate limit exceeded for category: %s", category))
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
switch {
|
||||
case evt.Kind == 0:
|
||||
@ -106,3 +113,4 @@ func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) {
|
||||
|
||||
sendOK(ws, evt.ID, true, "")
|
||||
}
|
||||
|
||||
|
@ -5,13 +5,10 @@ import (
|
||||
"fmt"
|
||||
"grain/relay/handlers"
|
||||
|
||||
"grain/relay/utils"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
// WebSocketHandler handles incoming WebSocket connections
|
||||
func WebSocketHandler(ws *websocket.Conn, rateLimiter *utils.RateLimiter) {
|
||||
func WebSocketHandler(ws *websocket.Conn) {
|
||||
defer ws.Close()
|
||||
|
||||
var msg string
|
||||
@ -41,14 +38,9 @@ func WebSocketHandler(ws *websocket.Conn, rateLimiter *utils.RateLimiter) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !rateLimiter.AllowWs() {
|
||||
fmt.Println("WebSocket message rate limit exceeded")
|
||||
continue
|
||||
}
|
||||
|
||||
switch messageType {
|
||||
case "EVENT":
|
||||
handlers.HandleEvent(ws, message, rateLimiter)
|
||||
handlers.HandleEvent(ws, message)
|
||||
case "REQ":
|
||||
handlers.HandleReq(ws, message)
|
||||
case "CLOSE":
|
||||
|
@ -12,7 +12,7 @@ type RateLimitConfig struct {
|
||||
EventLimit float64 `yaml:"event_limit"`
|
||||
EventBurst int `yaml:"event_burst"`
|
||||
KindLimits []KindLimitConfig `yaml:"kind_limits"`
|
||||
CategoryLimits CategoryLimitConfig `yaml:"category_limits"`
|
||||
CategoryLimits map[string]KindLimitConfig `yaml:"category_limits"`
|
||||
}
|
||||
|
||||
type KindLimitConfig struct {
|
||||
|
@ -19,18 +19,21 @@ type CategoryLimiter struct {
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
eventLimiter *rate.Limiter
|
||||
wsLimiter *rate.Limiter
|
||||
kindLimiters map[int]*KindLimiter
|
||||
eventLimiter *rate.Limiter
|
||||
wsLimiter *rate.Limiter
|
||||
kindLimiters map[int]*KindLimiter
|
||||
categoryLimiters map[string]*CategoryLimiter
|
||||
mu sync.RWMutex
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var rateLimiterInstance *RateLimiter
|
||||
var once sync.Once
|
||||
|
||||
func NewRateLimiter(eventLimit rate.Limit, eventBurst int, wsLimit rate.Limit, wsBurst int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
eventLimiter: rate.NewLimiter(eventLimit, eventBurst),
|
||||
wsLimiter: rate.NewLimiter(wsLimit, wsBurst),
|
||||
kindLimiters: make(map[int]*KindLimiter),
|
||||
eventLimiter: rate.NewLimiter(eventLimit, eventBurst),
|
||||
wsLimiter: rate.NewLimiter(wsLimit, wsBurst),
|
||||
kindLimiters: make(map[int]*KindLimiter),
|
||||
categoryLimiters: make(map[string]*CategoryLimiter),
|
||||
}
|
||||
}
|
||||
@ -81,3 +84,13 @@ func (rl *RateLimiter) AllowEvent(kind int, category string) bool {
|
||||
func (rl *RateLimiter) AllowWs() bool {
|
||||
return rl.wsLimiter.Allow()
|
||||
}
|
||||
|
||||
func SetRateLimiter(rl *RateLimiter) {
|
||||
once.Do(func() {
|
||||
rateLimiterInstance = rl
|
||||
})
|
||||
}
|
||||
|
||||
func GetRateLimiter() *RateLimiter {
|
||||
return rateLimiterInstance
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user