diff --git a/main.go b/main.go index e6723db..87a1b41 100644 --- a/main.go +++ b/main.go @@ -21,8 +21,7 @@ import ( ) func main() { - utils.EnsureFileExists("config.yml", "app/static/examples/config.example.yml") - utils.EnsureFileExists("relay_metadata.json", "app/static/examples/relay_metadata.example.json") + utils.ClearTemporaryBans() restartChan := make(chan struct{}) go utils.WatchConfigFile("config.yml", restartChan) @@ -46,6 +45,8 @@ func main() { config.SetupRateLimiter(cfg) config.SetupSizeLimiter(cfg) + utils.ClearTemporaryBans() + err = utils.LoadRelayMetadataJSON() if err != nil { log.Fatal("Failed to load relay metadata: ", err) @@ -57,15 +58,17 @@ func main() { select { case <-restartChan: log.Println("Restarting server...") + + // Close server before restart server.Close() - db.DisconnectDB(client) - wg.Wait() // Wait for the server to fully shut down before restarting - time.Sleep(3 * time.Second) // Add a delay before restarting + wg.Wait() + + time.Sleep(3 * time.Second) case <-signalChan: log.Println("Shutting down server...") server.Close() db.DisconnectDB(client) - wg.Wait() // Wait for the server to fully shut down before exiting + wg.Wait() return } } diff --git a/server/handlers/event.go b/server/handlers/event.go index 1c338fe..21c2583 100644 --- a/server/handlers/event.go +++ b/server/handlers/event.go @@ -72,6 +72,12 @@ func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn, eventS } } + // Check against manual blacklist + if blacklisted, msg := utils.CheckBlacklist(evt.PubKey, evt.Content); blacklisted { + response.SendOK(ws, evt.ID, false, msg) + return + } + // Check if the kind is whitelisted if config.GetConfig().KindWhitelist.Enabled && !utils.IsKindWhitelisted(evt.Kind) { response.SendOK(ws, evt.ID, false, "not allowed: event kind is not whitelisted") @@ -84,13 +90,6 @@ func HandleKind(ctx context.Context, evt relay.Event, ws *websocket.Conn, eventS return } - // Check against manual blacklist - blacklisted, msg := utils.CheckBlacklist(evt.PubKey, evt.Content) - if blacklisted { - response.SendOK(ws, evt.ID, false, msg) - return - } - category := determineCategory(evt.Kind) if allowed, msg := rateLimiter.AllowEvent(evt.Kind, category); !allowed { diff --git a/server/utils/checkBlacklist.go b/server/utils/checkBlacklist.go index 80ccd12..63fa695 100644 --- a/server/utils/checkBlacklist.go +++ b/server/utils/checkBlacklist.go @@ -6,10 +6,30 @@ import ( cfg "grain/config/types" "os" "strings" + "sync" + "time" "gopkg.in/yaml.v2" ) +// Structure to manage temporary bans with timestamps +type tempBanEntry struct { + count int // Number of temporary bans + unbanTime time.Time // Time when the pubkey should be unbanned +} + +var ( + tempBannedPubkeys = make(map[string]*tempBanEntry) + mu sync.Mutex +) + +func ClearTemporaryBans() { + mu.Lock() + defer mu.Unlock() + tempBannedPubkeys = make(map[string]*tempBanEntry) +} + + // CheckBlacklist checks if a pubkey is in the blacklist based on event content func CheckBlacklist(pubkey, eventContent string) (bool, string) { cfg := config.GetConfig().Blacklist @@ -23,10 +43,14 @@ func CheckBlacklist(pubkey, eventContent string) (bool, string) { return true, fmt.Sprintf("pubkey %s is permanently blacklisted", pubkey) } + // Check for temporary ban + if isPubKeyTemporarilyBlacklisted(pubkey) { + return true, fmt.Sprintf("pubkey %s is temporarily blacklisted", pubkey) + } + // Check for permanent ban based on wordlist for _, word := range cfg.PermanentBanWords { if strings.Contains(eventContent, word) { - // Permanently ban the pubkey err := AddToPermanentBlacklist(pubkey) if err != nil { return true, fmt.Sprintf("pubkey %s is permanently banned and failed to save: %v", pubkey, err) @@ -35,11 +59,74 @@ func CheckBlacklist(pubkey, eventContent string) (bool, string) { } } + // Check for temporary ban based on wordlist + for _, word := range cfg.TempBanWords { + if strings.Contains(eventContent, word) { + err := AddToTemporaryBlacklist(pubkey) + if err != nil { + return true, fmt.Sprintf("pubkey %s is temporarily banned and failed to save: %v", pubkey, err) + } + return true, fmt.Sprintf("pubkey %s is temporarily banned for containing forbidden words", pubkey) + } + } + return false, "" } -func isPubKeyPermanentlyBlacklisted(pubKey string) bool { + +// Checks if a pubkey is temporarily blacklisted +func isPubKeyTemporarilyBlacklisted(pubkey string) bool { + mu.Lock() + defer mu.Unlock() + + entry, exists := tempBannedPubkeys[pubkey] + if !exists { + return false + } + + // If the ban has expired, remove it from the temporary ban list + if time.Now().After(entry.unbanTime) { + delete(tempBannedPubkeys, pubkey) + return false + } + + return true +} + +// Adds a pubkey to the temporary blacklist +func AddToTemporaryBlacklist(pubkey string) error { + mu.Lock() + defer mu.Unlock() + cfg := config.GetConfig().Blacklist + + // Check if the pubkey is already temporarily banned + entry, exists := tempBannedPubkeys[pubkey] + if !exists { + entry = &tempBanEntry{ + count: 1, + unbanTime: time.Now().Add(time.Duration(cfg.TempBanDuration) * time.Second), + } + tempBannedPubkeys[pubkey] = entry + } + + // Increment the temporary ban count and set the unban time + entry.count++ + entry.unbanTime = time.Now().Add(time.Duration(cfg.TempBanDuration) * time.Second) + + // If the count exceeds max_temp_bans, move to permanent blacklist + if entry.count >= cfg.MaxTempBans { + delete(tempBannedPubkeys, pubkey) + return AddToPermanentBlacklist(pubkey) + } + + return nil +} + +// Checks if a pubkey is permanently blacklisted (only using config.yml) +func isPubKeyPermanentlyBlacklisted(pubKey string) bool { + cfg := config.GetConfig().Blacklist // Get the latest configuration + if !cfg.Enabled { return false } @@ -67,6 +154,9 @@ func isPubKeyPermanentlyBlacklisted(pubKey string) bool { } func AddToPermanentBlacklist(pubkey string) error { + mu.Lock() + defer mu.Unlock() + cfg := config.GetConfig().Blacklist // Check if already blacklisted @@ -82,10 +172,10 @@ func AddToPermanentBlacklist(pubkey string) error { } func saveBlacklistConfig(blacklistConfig cfg.BlacklistConfig) error { - cfg := config.GetConfig() - cfg.Blacklist = blacklistConfig + configData := config.GetConfig() + configData.Blacklist = blacklistConfig - data, err := yaml.Marshal(cfg) + data, err := yaml.Marshal(configData) if err != nil { return fmt.Errorf("failed to marshal config: %v", err) } @@ -97,3 +187,4 @@ func saveBlacklistConfig(blacklistConfig cfg.BlacklistConfig) error { return nil } +