diff --git a/app/static/examples/config.example.yml b/app/static/examples/config.example.yml index 9a3ddbf..3351225 100644 --- a/app/static/examples/config.example.yml +++ b/app/static/examples/config.example.yml @@ -10,6 +10,12 @@ server: max_connections: 100 max_subscriptions_per_client: 10 +resource_limits: + cpu_cores: 2 # Limit the number of CPU cores the application can use + memory_mb: 1024 # Cap the maximum amount of RAM in MB the application can use + heap_size_mb: 512 # Set a limit on the Go garbage collector's heap size in MB + max_goroutines: 100 # Limit the maximum number of concurrently running Go routines + rate_limit: ws_limit: 100 # WebSocket messages per second ws_burst: 200 # Allowed burst of WebSocket messages diff --git a/config/types/resourceLimits.go b/config/types/resourceLimits.go new file mode 100644 index 0000000..d008df0 --- /dev/null +++ b/config/types/resourceLimits.go @@ -0,0 +1,8 @@ +package config + +type ResourceLimits struct { + CPUCores int `yaml:"cpu_cores"` + MemoryMB int `yaml:"memory_mb"` + HeapSizeMB int `yaml:"heap_size_mb"` + MaxGoroutines int `yaml:"max_goroutines"` +} diff --git a/config/types/serverConfig.go b/config/types/serverConfig.go index 355df70..08b7a43 100644 --- a/config/types/serverConfig.go +++ b/config/types/serverConfig.go @@ -18,4 +18,5 @@ type ServerConfig struct { KindWhitelist KindWhitelistConfig `yaml:"kind_whitelist"` DomainWhitelist DomainWhitelistConfig `yaml:"domain_whitelist"` Blacklist BlacklistConfig `yaml:"blacklist"` + ResourceLimits ResourceLimits `yaml:"resource_limits"` } diff --git a/main.go b/main.go index 3f59517..6143279 100644 --- a/main.go +++ b/main.go @@ -25,19 +25,22 @@ func main() { utils.EnsureFileExists("relay_metadata.json", "app/static/examples/relay_metadata.example.json") restartChan := make(chan struct{}) - go utils.WatchConfigFile("config.yml", restartChan) + go utils.WatchConfigFile("config.yml", restartChan) // Critical goroutine signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) var wg sync.WaitGroup for { - wg.Add(1) + wg.Add(1) // Add to WaitGroup for the server goroutine + cfg, err := config.LoadConfig("config.yml") if err != nil { log.Fatal("Error loading config: ", err) } + utils.ApplyResourceLimits(&cfg.ResourceLimits) // Apply limits once before starting the server + client, err := db.InitDB(cfg) if err != nil { log.Fatal("Error initializing database: ", err) @@ -54,22 +57,22 @@ func main() { } mux := setupRoutes() + + // Start the server server := startServer(cfg, mux, &wg) select { case <-restartChan: log.Println("Restarting server...") - - // Close server before restart - server.Close() - wg.Wait() - + server.Close() // Stop the current server instance + wg.Wait() // Wait for the server goroutine to finish time.Sleep(3 * time.Second) + case <-signalChan: log.Println("Shutting down server...") - server.Close() - db.DisconnectDB(client) - wg.Wait() + server.Close() // Stop the server + db.DisconnectDB(client) // Disconnect from MongoDB + wg.Wait() // Wait for all goroutines to finish return } } @@ -97,13 +100,14 @@ func startServer(config *configTypes.ServerConfig, mux *http.ServeMux, wg *sync. } go func() { - defer wg.Done() // Notify that the server is done shutting down + defer wg.Done() // Notify that the server goroutine is done fmt.Printf("Server is running on http://localhost%s\n", config.Server.Port) err := server.ListenAndServe() if err != nil && err != http.ErrServerClosed { fmt.Println("Error starting server:", err) } }() + return server } diff --git a/server/handlers/event.go b/server/handlers/event.go index 21c2583..dcc6e6a 100644 --- a/server/handlers/event.go +++ b/server/handlers/event.go @@ -16,37 +16,39 @@ import ( ) func HandleEvent(ws *websocket.Conn, message []interface{}) { - if len(message) != 2 { - fmt.Println("Invalid EVENT message format") - response.SendNotice(ws, "", "Invalid EVENT message format") - return - } + utils.LimitedGoRoutine(func() { + if len(message) != 2 { + fmt.Println("Invalid EVENT message format") + response.SendNotice(ws, "", "Invalid EVENT message format") + return + } - eventData, ok := message[1].(map[string]interface{}) - if !ok { - fmt.Println("Invalid event data format") - response.SendNotice(ws, "", "Invalid event data format") - return - } - eventBytes, err := json.Marshal(eventData) - if err != nil { - fmt.Println("Error marshaling event data:", err) - response.SendNotice(ws, "", "Error marshaling event data") - return - } + eventData, ok := message[1].(map[string]interface{}) + if !ok { + fmt.Println("Invalid event data format") + response.SendNotice(ws, "", "Invalid event data format") + return + } + eventBytes, err := json.Marshal(eventData) + if err != nil { + fmt.Println("Error marshaling event data:", err) + response.SendNotice(ws, "", "Error marshaling event data") + return + } - var evt relay.Event - err = json.Unmarshal(eventBytes, &evt) - if err != nil { - fmt.Println("Error unmarshaling event data:", err) - response.SendNotice(ws, "", "Error unmarshaling event data") - return - } + var evt relay.Event + err = json.Unmarshal(eventBytes, &evt) + if err != nil { + fmt.Println("Error unmarshaling event data:", err) + response.SendNotice(ws, "", "Error unmarshaling event data") + return + } - eventSize := len(eventBytes) // Calculate event size - HandleKind(context.TODO(), evt, ws, eventSize) + eventSize := len(eventBytes) // Calculate event size + HandleKind(context.TODO(), evt, ws, eventSize) - fmt.Println("Event processed:", evt.ID) + fmt.Println("Event processed:", evt.ID) + }) } func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn, eventSize int) { @@ -152,4 +154,3 @@ func determineCategory(kind int) string { return "unknown" } } - diff --git a/server/handlers/req.go b/server/handlers/req.go index 526f028..e2f4188 100644 --- a/server/handlers/req.go +++ b/server/handlers/req.go @@ -21,87 +21,89 @@ var subscriptions = make(map[string]relay.Subscription) var mu sync.Mutex // Protect concurrent access to subscriptions map func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[string][]relay.Filter) { - if len(message) < 3 { - fmt.Println("Invalid REQ message format") - response.SendClosed(ws, "", "invalid: invalid REQ message format") - return - } - - subID, ok := message[1].(string) - if !ok { - fmt.Println("Invalid subscription ID format") - response.SendClosed(ws, "", "invalid: invalid subscription ID format") - return - } - - mu.Lock() - defer mu.Unlock() - - // Check the current number of subscriptions for the client - if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient { - // Find and remove the oldest subscription (FIFO) - var oldestSubID string - for id := range subscriptions { - oldestSubID = id - break + utils.LimitedGoRoutine(func() { + if len(message) < 3 { + fmt.Println("Invalid REQ message format") + response.SendClosed(ws, "", "invalid: invalid REQ message format") + return } - delete(subscriptions, oldestSubID) - fmt.Println("Dropped oldest subscription:", oldestSubID) - } - filters := make([]relay.Filter, len(message)-2) - for i, filter := range message[2:] { - filterData, ok := filter.(map[string]interface{}) + subID, ok := message[1].(string) if !ok { - fmt.Println("Invalid filter format") - response.SendClosed(ws, subID, "invalid: invalid filter format") + fmt.Println("Invalid subscription ID format") + response.SendClosed(ws, "", "invalid: invalid subscription ID format") return } - var f relay.Filter - f.IDs = utils.ToStringArray(filterData["ids"]) - f.Authors = utils.ToStringArray(filterData["authors"]) - f.Kinds = utils.ToIntArray(filterData["kinds"]) - f.Tags = utils.ToTagsMap(filterData["tags"]) - f.Since = utils.ToTime(filterData["since"]) - f.Until = utils.ToTime(filterData["until"]) - f.Limit = utils.ToInt(filterData["limit"]) + mu.Lock() + defer mu.Unlock() - filters[i] = f - } + // Check the current number of subscriptions for the client + if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient { + // Find and remove the oldest subscription (FIFO) + var oldestSubID string + for id := range subscriptions { + oldestSubID = id + break + } + delete(subscriptions, oldestSubID) + fmt.Println("Dropped oldest subscription:", oldestSubID) + } - // Add the new subscription or update the existing one - subscriptions[subID] = filters - fmt.Println("Subscription updated:", subID) + filters := make([]relay.Filter, len(message)-2) + for i, filter := range message[2:] { + filterData, ok := filter.(map[string]interface{}) + if !ok { + fmt.Println("Invalid filter format") + response.SendClosed(ws, subID, "invalid: invalid filter format") + return + } - // Query the database with filters and send back the results - queriedEvents, err := QueryEvents(filters, db.GetClient(), "grain") - if err != nil { - fmt.Println("Error querying events:", err) - response.SendClosed(ws, subID, "error: could not query events") - return - } + var f relay.Filter + f.IDs = utils.ToStringArray(filterData["ids"]) + f.Authors = utils.ToStringArray(filterData["authors"]) + f.Kinds = utils.ToIntArray(filterData["kinds"]) + f.Tags = utils.ToTagsMap(filterData["tags"]) + f.Since = utils.ToTime(filterData["since"]) + f.Until = utils.ToTime(filterData["until"]) + f.Limit = utils.ToInt(filterData["limit"]) - for _, evt := range queriedEvents { - msg := []interface{}{"EVENT", subID, evt} - msgBytes, _ := json.Marshal(msg) - err = websocket.Message.Send(ws, string(msgBytes)) + filters[i] = f + } + + // Add the new subscription or update the existing one + subscriptions[subID] = filters + fmt.Println("Subscription updated:", subID) + + // Query the database with filters and send back the results + queriedEvents, err := QueryEvents(filters, db.GetClient(), "grain") if err != nil { - fmt.Println("Error sending event:", err) - response.SendClosed(ws, subID, "error: could not send event") + fmt.Println("Error querying events:", err) + response.SendClosed(ws, subID, "error: could not query events") return } - } - // Indicate end of stored events - eoseMsg := []interface{}{"EOSE", subID} - eoseBytes, _ := json.Marshal(eoseMsg) - err = websocket.Message.Send(ws, string(eoseBytes)) - if err != nil { - fmt.Println("Error sending EOSE:", err) - response.SendClosed(ws, subID, "error: could not send EOSE") - return - } + for _, evt := range queriedEvents { + msg := []interface{}{"EVENT", subID, evt} + msgBytes, _ := json.Marshal(msg) + err = websocket.Message.Send(ws, string(msgBytes)) + if err != nil { + fmt.Println("Error sending event:", err) + response.SendClosed(ws, subID, "error: could not send event") + return + } + } + + // Indicate end of stored events + eoseMsg := []interface{}{"EOSE", subID} + eoseBytes, _ := json.Marshal(eoseMsg) + err = websocket.Message.Send(ws, string(eoseBytes)) + if err != nil { + fmt.Println("Error sending EOSE:", err) + response.SendClosed(ws, subID, "error: could not send EOSE") + return + } + }) } // QueryEvents queries events from the MongoDB collection based on filters diff --git a/server/utils/applyResourceLimits.go b/server/utils/applyResourceLimits.go new file mode 100644 index 0000000..022db6d --- /dev/null +++ b/server/utils/applyResourceLimits.go @@ -0,0 +1,113 @@ +package utils + +import ( + "log" + "runtime" + "runtime/debug" + "sync" + "time" + + configTypes "grain/config/types" +) + +var ( + maxGoroutinesChan chan struct{} + wg sync.WaitGroup + goroutineQueue []func() + goroutineQueueMutex sync.Mutex +) + +func ApplyResourceLimits(cfg *configTypes.ResourceLimits) { + // Set CPU cores + runtime.GOMAXPROCS(cfg.CPUCores) + + // Set maximum heap size + if cfg.HeapSizeMB > 0 { + heapSize := int64(uint64(cfg.HeapSizeMB) * 1024 * 1024) + debug.SetMemoryLimit(heapSize) + log.Printf("Heap size limited to %d MB\n", cfg.HeapSizeMB) + } + + // Start monitoring memory usage + if cfg.MemoryMB > 0 { + go monitorMemoryUsage(cfg.MemoryMB) + log.Printf("Max memory usage limited to %d MB\n", cfg.MemoryMB) + } + + // Set maximum number of Go routines + if cfg.MaxGoroutines > 0 { + maxGoroutinesChan = make(chan struct{}, cfg.MaxGoroutines) + log.Printf("Max goroutines limited to %d\n", cfg.MaxGoroutines) + } +} + +// LimitedGoRoutine starts a goroutine with limit enforcement +func LimitedGoRoutine(f func()) { + // By default, all routines are considered critical + goroutineQueueMutex.Lock() + goroutineQueue = append(goroutineQueue, f) + goroutineQueueMutex.Unlock() + attemptToStartGoroutine() +} + +func attemptToStartGoroutine() { + goroutineQueueMutex.Lock() + defer goroutineQueueMutex.Unlock() + + if len(goroutineQueue) > 0 { + select { + case maxGoroutinesChan <- struct{}{}: + wg.Add(1) + go func(f func()) { + defer func() { + wg.Done() + <-maxGoroutinesChan + attemptToStartGoroutine() + }() + f() + }(goroutineQueue[0]) + + // Remove the started goroutine from the queue + goroutineQueue = goroutineQueue[1:] + + default: + // If the channel is full, consider dropping the oldest non-critical goroutine + dropOldestNonCriticalGoroutine() + } + } +} + +func dropOldestNonCriticalGoroutine() { + goroutineQueueMutex.Lock() + defer goroutineQueueMutex.Unlock() + + if len(goroutineQueue) > 0 { + log.Println("Dropping the oldest non-critical goroutine to free resources.") + goroutineQueue = goroutineQueue[1:] + attemptToStartGoroutine() + } +} + +func WaitForGoroutines() { + wg.Wait() +} + +func monitorMemoryUsage(maxMemoryMB int) { + for { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + usedMemoryMB := int(memStats.Alloc / 1024 / 1024) + if usedMemoryMB > maxMemoryMB { + log.Printf("Memory usage exceeded limit: %d MB used, limit is %d MB\n", usedMemoryMB, maxMemoryMB) + debug.FreeOSMemory() // Attempt to free memory + + // If memory usage is still high, attempt to drop non-critical goroutines + if usedMemoryMB > maxMemoryMB { + dropOldestNonCriticalGoroutine() + } + } + + time.Sleep(1 * time.Second) + } +}