diff --git a/app/static/examples/config.example.yml b/app/static/examples/config.example.yml index 9974d5e..416e8a2 100644 --- a/app/static/examples/config.example.yml +++ b/app/static/examples/config.example.yml @@ -7,6 +7,9 @@ server: read_timeout: 10 # Read timeout in seconds write_timeout: 10 # Write timeout in seconds idle_timeout: 120 # Idle timeout in seconds + max_connections: 100 # Maximum number of concurrent connections + max_subscriptions_per_client: 10 # Maximum number of concurrent subscriptions per client + pubkey_whitelist: enabled: false pubkeys: #["3fe0ab6cbdb7ee27148202249e3fb3b89423c6f6cda6ef43ea5057c3d93088e4", diff --git a/config/types/serverConfig.go b/config/types/serverConfig.go index b4f0836..6ea9aea 100644 --- a/config/types/serverConfig.go +++ b/config/types/serverConfig.go @@ -6,10 +6,12 @@ type ServerConfig struct { Database string `yaml:"database"` } `yaml:"mongodb"` Server struct { - Port string `yaml:"port"` - ReadTimeout int `yaml:"read_timeout"` // Timeout in seconds - WriteTimeout int `yaml:"write_timeout"` // Timeout in seconds - IdleTimeout int `yaml:"idle_timeout"` // Timeout in seconds + Port string `yaml:"port"` + ReadTimeout int `yaml:"read_timeout"` // Timeout in seconds + WriteTimeout int `yaml:"write_timeout"` // Timeout in seconds + IdleTimeout int `yaml:"idle_timeout"` // Timeout in seconds + MaxConnections int `yaml:"max_connections"` // Maximum number of concurrent connections + MaxSubscriptionsPerClient int `yaml:"max_subscriptions_per_client"` // Maximum number of subscriptions per client } `yaml:"server"` RateLimit RateLimitConfig `yaml:"rate_limit"` PubkeyWhitelist PubkeyWhitelistConfig `yaml:"pubkey_whitelist"` diff --git a/server/handlers/req.go b/server/handlers/req.go index 21dc925..526f028 100644 --- a/server/handlers/req.go +++ b/server/handlers/req.go @@ -4,10 +4,12 @@ import ( "context" "encoding/json" "fmt" + "grain/config" "grain/server/db" "grain/server/handlers/response" relay "grain/server/types" "grain/server/utils" + "sync" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -16,6 +18,7 @@ import ( ) var subscriptions = make(map[string]relay.Subscription) +var mu sync.Mutex // Protect concurrent access to subscriptions map func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[string][]relay.Filter) { if len(message) < 3 { @@ -31,6 +34,21 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri return } + mu.Lock() + defer mu.Unlock() + + // Check the current number of subscriptions for the client + if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient { + // Find and remove the oldest subscription (FIFO) + var oldestSubID string + for id := range subscriptions { + oldestSubID = id + break + } + delete(subscriptions, oldestSubID) + fmt.Println("Dropped oldest subscription:", oldestSubID) + } + filters := make([]relay.Filter, len(message)-2) for i, filter := range message[2:] { filterData, ok := filter.(map[string]interface{}) @@ -52,7 +70,7 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri filters[i] = f } - // Update or add the subscription for the given subID + // Add the new subscription or update the existing one subscriptions[subID] = filters fmt.Println("Subscription updated:", subID) diff --git a/server/relay.go b/server/relay.go index d092ba9..d45b641 100644 --- a/server/relay.go +++ b/server/relay.go @@ -3,22 +3,56 @@ package relay import ( "encoding/json" "fmt" - "grain/server/handlers" - "grain/config" - + "grain/server/handlers" relay "grain/server/types" + "grain/server/utils" + "log" + "sync" "golang.org/x/net/websocket" ) +// Global connection count +var ( + currentConnections int + mu sync.Mutex +) + +// Client subscription count +var clientSubscriptions = make(map[*websocket.Conn]int) + func WebSocketHandler(ws *websocket.Conn) { - defer ws.Close() + defer func() { + mu.Lock() + currentConnections-- + delete(clientSubscriptions, ws) + mu.Unlock() + ws.Close() + }() + + mu.Lock() + if currentConnections >= config.GetConfig().Server.MaxConnections { + websocket.Message.Send(ws, `{"error": "too many connections"}`) + mu.Unlock() + return + } + currentConnections++ + mu.Unlock() + + clientInfo := relay.ClientInfo{ + IP: utils.GetClientIP(ws.Request()), + UserAgent: ws.Request().Header.Get("User-Agent"), + Origin: ws.Request().Header.Get("Origin"), + } + + log.Printf("New connection from IP: %s, User-Agent: %s, Origin: %s", clientInfo.IP, clientInfo.UserAgent, clientInfo.Origin) var msg string rateLimiter := config.GetRateLimiter() subscriptions := make(map[string][]relay.Filter) // Subscription map scoped to the connection + clientSubscriptions[ws] = 0 for { err := websocket.Message.Receive(ws, &msg) @@ -57,6 +91,14 @@ func WebSocketHandler(ws *websocket.Conn) { case "EVENT": handlers.HandleEvent(ws, message) case "REQ": + mu.Lock() + if clientSubscriptions[ws] >= config.GetConfig().Server.MaxSubscriptionsPerClient { + websocket.Message.Send(ws, `{"error": "too many subscriptions"}`) + mu.Unlock() + continue + } + clientSubscriptions[ws]++ + mu.Unlock() if allowed, msg := rateLimiter.AllowReq(); !allowed { websocket.Message.Send(ws, fmt.Sprintf(`{"error": "%s"}`, msg)) ws.Close() diff --git a/server/types/clientInfo.go b/server/types/clientInfo.go new file mode 100644 index 0000000..8f08429 --- /dev/null +++ b/server/types/clientInfo.go @@ -0,0 +1,7 @@ +package relay + +type ClientInfo struct { + IP string + UserAgent string + Origin string +} \ No newline at end of file diff --git a/server/utils/getClientInfo.go b/server/utils/getClientInfo.go new file mode 100644 index 0000000..99613e4 --- /dev/null +++ b/server/utils/getClientInfo.go @@ -0,0 +1,22 @@ +package utils + +import ( + "net/http" + "strings" +) + +func GetClientIP(r *http.Request) string { + xff := r.Header.Get("X-Forwarded-For") + if xff != "" { + ips := strings.Split(xff, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + remoteAddr := r.RemoteAddr + if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 { + return remoteAddr[:idx] + } + return remoteAddr +} \ No newline at end of file