mirror of
https://github.com/0ceanSlim/grain.git
synced 2024-11-21 16:17:13 +00:00
client info, config max connections/subscriptions, drop oldest sub for newest at limit
This commit is contained in:
parent
bda932c06d
commit
25ff327a7f
@ -7,6 +7,9 @@ server:
|
||||
read_timeout: 10 # Read timeout in seconds
|
||||
write_timeout: 10 # Write timeout in seconds
|
||||
idle_timeout: 120 # Idle timeout in seconds
|
||||
max_connections: 100 # Maximum number of concurrent connections
|
||||
max_subscriptions_per_client: 10 # Maximum number of concurrent subscriptions per client
|
||||
|
||||
pubkey_whitelist:
|
||||
enabled: false
|
||||
pubkeys: #["3fe0ab6cbdb7ee27148202249e3fb3b89423c6f6cda6ef43ea5057c3d93088e4",
|
||||
|
@ -6,10 +6,12 @@ type ServerConfig struct {
|
||||
Database string `yaml:"database"`
|
||||
} `yaml:"mongodb"`
|
||||
Server struct {
|
||||
Port string `yaml:"port"`
|
||||
ReadTimeout int `yaml:"read_timeout"` // Timeout in seconds
|
||||
WriteTimeout int `yaml:"write_timeout"` // Timeout in seconds
|
||||
IdleTimeout int `yaml:"idle_timeout"` // Timeout in seconds
|
||||
Port string `yaml:"port"`
|
||||
ReadTimeout int `yaml:"read_timeout"` // Timeout in seconds
|
||||
WriteTimeout int `yaml:"write_timeout"` // Timeout in seconds
|
||||
IdleTimeout int `yaml:"idle_timeout"` // Timeout in seconds
|
||||
MaxConnections int `yaml:"max_connections"` // Maximum number of concurrent connections
|
||||
MaxSubscriptionsPerClient int `yaml:"max_subscriptions_per_client"` // Maximum number of subscriptions per client
|
||||
} `yaml:"server"`
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
PubkeyWhitelist PubkeyWhitelistConfig `yaml:"pubkey_whitelist"`
|
||||
|
@ -4,10 +4,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"grain/config"
|
||||
"grain/server/db"
|
||||
"grain/server/handlers/response"
|
||||
relay "grain/server/types"
|
||||
"grain/server/utils"
|
||||
"sync"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
@ -16,6 +18,7 @@ import (
|
||||
)
|
||||
|
||||
var subscriptions = make(map[string]relay.Subscription)
|
||||
var mu sync.Mutex // Protect concurrent access to subscriptions map
|
||||
|
||||
func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[string][]relay.Filter) {
|
||||
if len(message) < 3 {
|
||||
@ -31,6 +34,21 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Check the current number of subscriptions for the client
|
||||
if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient {
|
||||
// Find and remove the oldest subscription (FIFO)
|
||||
var oldestSubID string
|
||||
for id := range subscriptions {
|
||||
oldestSubID = id
|
||||
break
|
||||
}
|
||||
delete(subscriptions, oldestSubID)
|
||||
fmt.Println("Dropped oldest subscription:", oldestSubID)
|
||||
}
|
||||
|
||||
filters := make([]relay.Filter, len(message)-2)
|
||||
for i, filter := range message[2:] {
|
||||
filterData, ok := filter.(map[string]interface{})
|
||||
@ -52,7 +70,7 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri
|
||||
filters[i] = f
|
||||
}
|
||||
|
||||
// Update or add the subscription for the given subID
|
||||
// Add the new subscription or update the existing one
|
||||
subscriptions[subID] = filters
|
||||
fmt.Println("Subscription updated:", subID)
|
||||
|
||||
|
@ -3,22 +3,56 @@ package relay
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"grain/server/handlers"
|
||||
|
||||
"grain/config"
|
||||
|
||||
"grain/server/handlers"
|
||||
relay "grain/server/types"
|
||||
"grain/server/utils"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
// Global connection count
|
||||
var (
|
||||
currentConnections int
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
// Client subscription count
|
||||
var clientSubscriptions = make(map[*websocket.Conn]int)
|
||||
|
||||
func WebSocketHandler(ws *websocket.Conn) {
|
||||
defer ws.Close()
|
||||
defer func() {
|
||||
mu.Lock()
|
||||
currentConnections--
|
||||
delete(clientSubscriptions, ws)
|
||||
mu.Unlock()
|
||||
ws.Close()
|
||||
}()
|
||||
|
||||
mu.Lock()
|
||||
if currentConnections >= config.GetConfig().Server.MaxConnections {
|
||||
websocket.Message.Send(ws, `{"error": "too many connections"}`)
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
currentConnections++
|
||||
mu.Unlock()
|
||||
|
||||
clientInfo := relay.ClientInfo{
|
||||
IP: utils.GetClientIP(ws.Request()),
|
||||
UserAgent: ws.Request().Header.Get("User-Agent"),
|
||||
Origin: ws.Request().Header.Get("Origin"),
|
||||
}
|
||||
|
||||
log.Printf("New connection from IP: %s, User-Agent: %s, Origin: %s", clientInfo.IP, clientInfo.UserAgent, clientInfo.Origin)
|
||||
|
||||
var msg string
|
||||
rateLimiter := config.GetRateLimiter()
|
||||
|
||||
subscriptions := make(map[string][]relay.Filter) // Subscription map scoped to the connection
|
||||
clientSubscriptions[ws] = 0
|
||||
|
||||
for {
|
||||
err := websocket.Message.Receive(ws, &msg)
|
||||
@ -57,6 +91,14 @@ func WebSocketHandler(ws *websocket.Conn) {
|
||||
case "EVENT":
|
||||
handlers.HandleEvent(ws, message)
|
||||
case "REQ":
|
||||
mu.Lock()
|
||||
if clientSubscriptions[ws] >= config.GetConfig().Server.MaxSubscriptionsPerClient {
|
||||
websocket.Message.Send(ws, `{"error": "too many subscriptions"}`)
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
clientSubscriptions[ws]++
|
||||
mu.Unlock()
|
||||
if allowed, msg := rateLimiter.AllowReq(); !allowed {
|
||||
websocket.Message.Send(ws, fmt.Sprintf(`{"error": "%s"}`, msg))
|
||||
ws.Close()
|
||||
|
7
server/types/clientInfo.go
Normal file
7
server/types/clientInfo.go
Normal file
@ -0,0 +1,7 @@
|
||||
package relay
|
||||
|
||||
type ClientInfo struct {
|
||||
IP string
|
||||
UserAgent string
|
||||
Origin string
|
||||
}
|
22
server/utils/getClientInfo.go
Normal file
22
server/utils/getClientInfo.go
Normal file
@ -0,0 +1,22 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetClientIP(r *http.Request) string {
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
remoteAddr := r.RemoteAddr
|
||||
if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 {
|
||||
return remoteAddr[:idx]
|
||||
}
|
||||
return remoteAddr
|
||||
}
|
Loading…
Reference in New Issue
Block a user