client info, config max connections/subscriptions, drop oldest sub for newest at limit

This commit is contained in:
0ceanSlim 2024-08-09 08:38:18 -04:00
parent bda932c06d
commit 25ff327a7f
6 changed files with 103 additions and 9 deletions

View File

@ -7,6 +7,9 @@ server:
read_timeout: 10 # Read timeout in seconds read_timeout: 10 # Read timeout in seconds
write_timeout: 10 # Write timeout in seconds write_timeout: 10 # Write timeout in seconds
idle_timeout: 120 # Idle timeout in seconds idle_timeout: 120 # Idle timeout in seconds
max_connections: 100 # Maximum number of concurrent connections
max_subscriptions_per_client: 10 # Maximum number of concurrent subscriptions per client
pubkey_whitelist: pubkey_whitelist:
enabled: false enabled: false
pubkeys: #["3fe0ab6cbdb7ee27148202249e3fb3b89423c6f6cda6ef43ea5057c3d93088e4", pubkeys: #["3fe0ab6cbdb7ee27148202249e3fb3b89423c6f6cda6ef43ea5057c3d93088e4",

View File

@ -6,10 +6,12 @@ type ServerConfig struct {
Database string `yaml:"database"` Database string `yaml:"database"`
} `yaml:"mongodb"` } `yaml:"mongodb"`
Server struct { Server struct {
Port string `yaml:"port"` Port string `yaml:"port"`
ReadTimeout int `yaml:"read_timeout"` // Timeout in seconds ReadTimeout int `yaml:"read_timeout"` // Timeout in seconds
WriteTimeout int `yaml:"write_timeout"` // Timeout in seconds WriteTimeout int `yaml:"write_timeout"` // Timeout in seconds
IdleTimeout int `yaml:"idle_timeout"` // Timeout in seconds IdleTimeout int `yaml:"idle_timeout"` // Timeout in seconds
MaxConnections int `yaml:"max_connections"` // Maximum number of concurrent connections
MaxSubscriptionsPerClient int `yaml:"max_subscriptions_per_client"` // Maximum number of subscriptions per client
} `yaml:"server"` } `yaml:"server"`
RateLimit RateLimitConfig `yaml:"rate_limit"` RateLimit RateLimitConfig `yaml:"rate_limit"`
PubkeyWhitelist PubkeyWhitelistConfig `yaml:"pubkey_whitelist"` PubkeyWhitelist PubkeyWhitelistConfig `yaml:"pubkey_whitelist"`

View File

@ -4,10 +4,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/config"
"grain/server/db" "grain/server/db"
"grain/server/handlers/response" "grain/server/handlers/response"
relay "grain/server/types" relay "grain/server/types"
"grain/server/utils" "grain/server/utils"
"sync"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
@ -16,6 +18,7 @@ import (
) )
var subscriptions = make(map[string]relay.Subscription) var subscriptions = make(map[string]relay.Subscription)
var mu sync.Mutex // Protect concurrent access to subscriptions map
func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[string][]relay.Filter) { func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[string][]relay.Filter) {
if len(message) < 3 { if len(message) < 3 {
@ -31,6 +34,21 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri
return return
} }
mu.Lock()
defer mu.Unlock()
// Check the current number of subscriptions for the client
if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient {
// Find and remove the oldest subscription (FIFO)
var oldestSubID string
for id := range subscriptions {
oldestSubID = id
break
}
delete(subscriptions, oldestSubID)
fmt.Println("Dropped oldest subscription:", oldestSubID)
}
filters := make([]relay.Filter, len(message)-2) filters := make([]relay.Filter, len(message)-2)
for i, filter := range message[2:] { for i, filter := range message[2:] {
filterData, ok := filter.(map[string]interface{}) filterData, ok := filter.(map[string]interface{})
@ -52,7 +70,7 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri
filters[i] = f filters[i] = f
} }
// Update or add the subscription for the given subID // Add the new subscription or update the existing one
subscriptions[subID] = filters subscriptions[subID] = filters
fmt.Println("Subscription updated:", subID) fmt.Println("Subscription updated:", subID)

View File

@ -3,22 +3,56 @@ package relay
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/server/handlers"
"grain/config" "grain/config"
"grain/server/handlers"
relay "grain/server/types" relay "grain/server/types"
"grain/server/utils"
"log"
"sync"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
// Global connection count
var (
currentConnections int
mu sync.Mutex
)
// Client subscription count
var clientSubscriptions = make(map[*websocket.Conn]int)
func WebSocketHandler(ws *websocket.Conn) { func WebSocketHandler(ws *websocket.Conn) {
defer ws.Close() defer func() {
mu.Lock()
currentConnections--
delete(clientSubscriptions, ws)
mu.Unlock()
ws.Close()
}()
mu.Lock()
if currentConnections >= config.GetConfig().Server.MaxConnections {
websocket.Message.Send(ws, `{"error": "too many connections"}`)
mu.Unlock()
return
}
currentConnections++
mu.Unlock()
clientInfo := relay.ClientInfo{
IP: utils.GetClientIP(ws.Request()),
UserAgent: ws.Request().Header.Get("User-Agent"),
Origin: ws.Request().Header.Get("Origin"),
}
log.Printf("New connection from IP: %s, User-Agent: %s, Origin: %s", clientInfo.IP, clientInfo.UserAgent, clientInfo.Origin)
var msg string var msg string
rateLimiter := config.GetRateLimiter() rateLimiter := config.GetRateLimiter()
subscriptions := make(map[string][]relay.Filter) // Subscription map scoped to the connection subscriptions := make(map[string][]relay.Filter) // Subscription map scoped to the connection
clientSubscriptions[ws] = 0
for { for {
err := websocket.Message.Receive(ws, &msg) err := websocket.Message.Receive(ws, &msg)
@ -57,6 +91,14 @@ func WebSocketHandler(ws *websocket.Conn) {
case "EVENT": case "EVENT":
handlers.HandleEvent(ws, message) handlers.HandleEvent(ws, message)
case "REQ": case "REQ":
mu.Lock()
if clientSubscriptions[ws] >= config.GetConfig().Server.MaxSubscriptionsPerClient {
websocket.Message.Send(ws, `{"error": "too many subscriptions"}`)
mu.Unlock()
continue
}
clientSubscriptions[ws]++
mu.Unlock()
if allowed, msg := rateLimiter.AllowReq(); !allowed { if allowed, msg := rateLimiter.AllowReq(); !allowed {
websocket.Message.Send(ws, fmt.Sprintf(`{"error": "%s"}`, msg)) websocket.Message.Send(ws, fmt.Sprintf(`{"error": "%s"}`, msg))
ws.Close() ws.Close()

View File

@ -0,0 +1,7 @@
package relay
type ClientInfo struct {
IP string
UserAgent string
Origin string
}

View File

@ -0,0 +1,22 @@
package utils
import (
"net/http"
"strings"
)
func GetClientIP(r *http.Request) string {
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
remoteAddr := r.RemoteAddr
if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 {
return remoteAddr[:idx]
}
return remoteAddr
}