refactoring

This commit is contained in:
Chris kerr 2024-09-01 20:51:02 -04:00
parent 0a52bebe14
commit c8ed954a9d
7 changed files with 271 additions and 265 deletions

204
config/Blacklist.go Normal file
View File

@ -0,0 +1,204 @@
package config
import (
"fmt"
types "grain/config/types"
"grain/server/utils"
"log"
"os"
"strings"
"sync"
"time"
"gopkg.in/yaml.v2"
)
// CheckBlacklist checks if a pubkey is in the blacklist based on event content
func CheckBlacklist(pubkey, eventContent string) (bool, string) {
blacklistConfig := GetConfig().Blacklist
if !blacklistConfig.Enabled {
return false, ""
}
log.Printf("Checking blacklist for pubkey: %s", pubkey)
// Check for permanent blacklist by pubkey or npub
if isPubKeyPermanentlyBlacklisted(pubkey, blacklistConfig) {
log.Printf("Pubkey %s is permanently blacklisted", pubkey)
return true, fmt.Sprintf("pubkey %s is permanently blacklisted", pubkey)
}
// Check for temporary ban
if isPubKeyTemporarilyBlacklisted(pubkey) {
log.Printf("Pubkey %s is temporarily blacklisted", pubkey)
return true, fmt.Sprintf("pubkey %s is temporarily blacklisted", pubkey)
}
// Check for permanent ban based on wordlist
for _, word := range blacklistConfig.PermanentBanWords {
if strings.Contains(eventContent, word) {
err := AddToPermanentBlacklist(pubkey)
if err != nil {
return true, fmt.Sprintf("pubkey %s is permanently banned and failed to save: %v", pubkey, err)
}
return true, "blocked: pubkey is permanently banned"
}
}
// Check for temporary ban based on wordlist
for _, word := range blacklistConfig.TempBanWords {
if strings.Contains(eventContent, word) {
err := AddToTemporaryBlacklist(pubkey, blacklistConfig)
if err != nil {
return true, fmt.Sprintf("pubkey %s is temporarily banned and failed to save: %v", pubkey, err)
}
return true, "blocked: pubkey is temporarily banned"
}
}
return false, ""
}
// Checks if a pubkey is temporarily blacklisted
func isPubKeyTemporarilyBlacklisted(pubkey string) bool {
mu.Lock()
defer mu.Unlock()
entry, exists := tempBannedPubkeys[pubkey]
if !exists {
log.Printf("Pubkey %s not found in temporary blacklist", pubkey)
return false
}
now := time.Now()
if now.After(entry.unbanTime) {
log.Printf("Temporary ban for pubkey %s has expired. Count: %d", pubkey, entry.count)
return false
}
log.Printf("Pubkey %s is currently temporarily blacklisted. Count: %d, Unban time: %s", pubkey, entry.count, entry.unbanTime)
return true
}
func ClearTemporaryBans() {
mu.Lock()
defer mu.Unlock()
tempBannedPubkeys = make(map[string]*tempBanEntry)
}
var (
tempBannedPubkeys = make(map[string]*tempBanEntry)
mu sync.Mutex
)
type tempBanEntry struct {
count int
unbanTime time.Time
}
// Adds a pubkey to the temporary blacklist
func AddToTemporaryBlacklist(pubkey string, blacklistConfig types.BlacklistConfig) error {
mu.Lock()
defer mu.Unlock()
entry, exists := tempBannedPubkeys[pubkey]
if !exists {
log.Printf("Creating new temporary ban entry for pubkey %s", pubkey)
entry = &tempBanEntry{
count: 0,
unbanTime: time.Now(),
}
tempBannedPubkeys[pubkey] = entry
} else {
log.Printf("Updating existing temporary ban entry for pubkey %s. Current count: %d", pubkey, entry.count)
if time.Now().After(entry.unbanTime) {
log.Printf("Previous ban for pubkey %s has expired. Keeping count at %d", pubkey, entry.count)
}
}
// Increment the count
entry.count++
entry.unbanTime = time.Now().Add(time.Duration(blacklistConfig.TempBanDuration) * time.Second)
log.Printf("Pubkey %s temporary ban count updated to: %d, MaxTempBans: %d, New unban time: %s", pubkey, entry.count, blacklistConfig.MaxTempBans, entry.unbanTime)
if entry.count > blacklistConfig.MaxTempBans {
log.Printf("Attempting to move pubkey %s to permanent blacklist", pubkey)
delete(tempBannedPubkeys, pubkey)
// Release the lock before calling AddToPermanentBlacklist
mu.Unlock()
err := AddToPermanentBlacklist(pubkey)
mu.Lock() // Re-acquire the lock
if err != nil {
log.Printf("Error adding pubkey %s to permanent blacklist: %v", pubkey, err)
return err
}
log.Printf("Successfully added pubkey %s to permanent blacklist", pubkey)
}
return nil
}
// Checks if a pubkey is permanently blacklisted (only using config.yml)
func isPubKeyPermanentlyBlacklisted(pubKey string, blacklistConfig types.BlacklistConfig) bool {
if !blacklistConfig.Enabled {
return false
}
// Check pubkeys
for _, blacklistedKey := range blacklistConfig.PermanentBlacklistPubkeys {
if pubKey == blacklistedKey {
return true
}
}
// Check npubs
for _, npub := range blacklistConfig.PermanentBlacklistNpubs {
decodedPubKey, err := utils.DecodeNpub(npub)
if err != nil {
fmt.Println("Error decoding npub:", err)
continue
}
if pubKey == decodedPubKey {
return true
}
}
return false
}
func AddToPermanentBlacklist(pubkey string) error {
// Remove the mutex lock from here
blacklistConfig := GetConfig().Blacklist
// Check if already blacklisted
if isPubKeyPermanentlyBlacklisted(pubkey, blacklistConfig) {
return fmt.Errorf("pubkey %s is already in the permanent blacklist", pubkey)
}
// Add pubkey to the blacklist
blacklistConfig.PermanentBlacklistPubkeys = append(blacklistConfig.PermanentBlacklistPubkeys, pubkey)
// Persist changes to config.yml
return saveBlacklistConfig(blacklistConfig)
}
func saveBlacklistConfig(blacklistConfig types.BlacklistConfig) error {
configData := GetConfig()
configData.Blacklist = blacklistConfig
data, err := yaml.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %v", err)
}
err = os.WriteFile("config.yml", data, 0644)
if err != nil {
return fmt.Errorf("failed to write config to file: %v", err)
}
return nil
}

View File

@ -1,14 +1,14 @@
package utils package config
import ( import (
"fmt" "fmt"
"grain/config" "grain/server/utils"
"strconv" "strconv"
) )
// Helper function to check if a pubkey or npub is whitelisted // Helper function to check if a pubkey or npub is whitelisted
func IsPubKeyWhitelisted(pubKey string) bool { func IsPubKeyWhitelisted(pubKey string) bool {
cfg := config.GetConfig() cfg := GetConfig()
if !cfg.PubkeyWhitelist.Enabled { if !cfg.PubkeyWhitelist.Enabled {
return true return true
} }
@ -22,7 +22,7 @@ func IsPubKeyWhitelisted(pubKey string) bool {
// Check npubs // Check npubs
for _, npub := range cfg.PubkeyWhitelist.Npubs { for _, npub := range cfg.PubkeyWhitelist.Npubs {
decodedPubKey, err := DecodeNpub(npub) decodedPubKey, err := utils.DecodeNpub(npub)
if err != nil { if err != nil {
fmt.Println("Error decoding npub:", err) fmt.Println("Error decoding npub:", err)
continue continue
@ -36,7 +36,7 @@ func IsPubKeyWhitelisted(pubKey string) bool {
} }
func IsKindWhitelisted(kind int) bool { func IsKindWhitelisted(kind int) bool {
cfg := config.GetConfig() cfg := GetConfig()
if !cfg.KindWhitelist.Enabled { if !cfg.KindWhitelist.Enabled {
return true return true
} }

View File

@ -49,7 +49,7 @@ func main() {
config.SetupRateLimiter(cfg) config.SetupRateLimiter(cfg)
config.SetupSizeLimiter(cfg) config.SetupSizeLimiter(cfg)
utils.ClearTemporaryBans() config.ClearTemporaryBans()
err = utils.LoadRelayMetadataJSON() err = utils.LoadRelayMetadataJSON()
if err != nil { if err != nil {
@ -70,9 +70,9 @@ func main() {
case <-signalChan: case <-signalChan:
log.Println("Shutting down server...") log.Println("Shutting down server...")
server.Close() // Stop the server server.Close() // Stop the server
db.DisconnectDB(client) // Disconnect from MongoDB db.DisconnectDB(client) // Disconnect from MongoDB
wg.Wait() // Wait for all goroutines to finish wg.Wait() // Wait for all goroutines to finish
return return
} }
} }

48
server/db/storeMongo.go Normal file
View File

@ -0,0 +1,48 @@
package db
import (
"context"
"fmt"
"grain/server/handlers/kinds"
"grain/server/handlers/response"
nostr "grain/server/types"
"golang.org/x/net/websocket"
)
func StoreMongoEvent(ctx context.Context, evt nostr.Event, ws *websocket.Conn) {
collection := GetCollection(evt.Kind)
var err error
switch {
case evt.Kind == 0:
err = kinds.HandleKind0(ctx, evt, collection, ws)
case evt.Kind == 1:
err = kinds.HandleKind1(ctx, evt, collection, ws)
case evt.Kind == 2:
err = kinds.HandleKind2(ctx, evt, ws)
case evt.Kind == 3:
err = kinds.HandleReplaceableKind(ctx, evt, collection, ws)
case evt.Kind == 5:
err = kinds.HandleKind5(ctx, evt, GetClient(), ws)
case evt.Kind >= 4 && evt.Kind < 45:
err = kinds.HandleRegularKind(ctx, evt, collection, ws)
case evt.Kind >= 1000 && evt.Kind < 10000:
err = kinds.HandleRegularKind(ctx, evt, collection, ws)
case evt.Kind >= 10000 && evt.Kind < 20000:
err = kinds.HandleReplaceableKind(ctx, evt, collection, ws)
case evt.Kind >= 20000 && evt.Kind < 30000:
fmt.Println("Ephemeral event received and ignored:", evt.ID)
case evt.Kind >= 30000 && evt.Kind < 40000:
err = kinds.HandleParameterizedReplaceableKind(ctx, evt, collection, ws)
default:
err = kinds.HandleUnknownKind(ctx, evt, collection, ws)
}
if err != nil {
response.SendOK(ws, evt.ID, false, fmt.Sprintf("error: %v", err))
return
}
response.SendOK(ws, evt.ID, true, "")
}

View File

@ -6,11 +6,11 @@ import (
"fmt" "fmt"
"grain/config" "grain/config"
"grain/server/db" "grain/server/db"
"grain/server/handlers/kinds"
"grain/server/handlers/response" "grain/server/handlers/response"
"grain/server/utils" "grain/server/utils"
relay "grain/server/types" nostr "grain/server/types"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
@ -36,7 +36,7 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) {
return return
} }
var evt relay.Event var evt nostr.Event
err = json.Unmarshal(eventBytes, &evt) err = json.Unmarshal(eventBytes, &evt)
if err != nil { if err != nil {
fmt.Println("Error unmarshaling event data:", err) fmt.Println("Error unmarshaling event data:", err)
@ -60,13 +60,14 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) {
return return
} }
storeEvent(context.TODO(), evt, ws) // This is where I'll handle storage for multiple database types in the future
db.StoreMongoEvent(context.TODO(), evt, ws)
fmt.Println("Event processed:", evt.ID) fmt.Println("Event processed:", evt.ID)
}) })
} }
func handleBlacklistAndWhitelist(ws *websocket.Conn, evt relay.Event) bool { func handleBlacklistAndWhitelist(ws *websocket.Conn, evt nostr.Event) bool {
if config.GetConfig().DomainWhitelist.Enabled { if config.GetConfig().DomainWhitelist.Enabled {
domains := config.GetConfig().DomainWhitelist.Domains domains := config.GetConfig().DomainWhitelist.Domains
pubkeys, err := utils.FetchPubkeysFromDomains(domains) pubkeys, err := utils.FetchPubkeysFromDomains(domains)
@ -80,17 +81,17 @@ func handleBlacklistAndWhitelist(ws *websocket.Conn, evt relay.Event) bool {
} }
} }
if blacklisted, msg := utils.CheckBlacklist(evt.PubKey, evt.Content); blacklisted { if blacklisted, msg := config.CheckBlacklist(evt.PubKey, evt.Content); blacklisted {
response.SendOK(ws, evt.ID, false, msg) response.SendOK(ws, evt.ID, false, msg)
return false return false
} }
if config.GetConfig().KindWhitelist.Enabled && !utils.IsKindWhitelisted(evt.Kind) { if config.GetConfig().KindWhitelist.Enabled && !config.IsKindWhitelisted(evt.Kind) {
response.SendOK(ws, evt.ID, false, "not allowed: event kind is not whitelisted") response.SendOK(ws, evt.ID, false, "not allowed: event kind is not whitelisted")
return false return false
} }
if config.GetConfig().PubkeyWhitelist.Enabled && !utils.IsPubKeyWhitelisted(evt.PubKey) { if config.GetConfig().PubkeyWhitelist.Enabled && !config.IsPubKeyWhitelisted(evt.PubKey) {
response.SendOK(ws, evt.ID, false, "not allowed: pubkey or npub is not whitelisted") response.SendOK(ws, evt.ID, false, "not allowed: pubkey or npub is not whitelisted")
return false return false
} }
@ -98,7 +99,7 @@ func handleBlacklistAndWhitelist(ws *websocket.Conn, evt relay.Event) bool {
return true return true
} }
func handleRateAndSizeLimits(ws *websocket.Conn, evt relay.Event, eventSize int) bool { func handleRateAndSizeLimits(ws *websocket.Conn, evt nostr.Event, eventSize int) bool {
rateLimiter := config.GetRateLimiter() rateLimiter := config.GetRateLimiter()
sizeLimiter := config.GetSizeLimiter() sizeLimiter := config.GetSizeLimiter()
category := determineCategory(evt.Kind) category := determineCategory(evt.Kind)
@ -116,43 +117,6 @@ func handleRateAndSizeLimits(ws *websocket.Conn, evt relay.Event, eventSize int)
return true return true
} }
func storeEvent(ctx context.Context, evt relay.Event, ws *websocket.Conn) {
collection := db.GetCollection(evt.Kind)
var err error
switch {
case evt.Kind == 0:
err = kinds.HandleKind0(ctx, evt, collection, ws)
case evt.Kind == 1:
err = kinds.HandleKind1(ctx, evt, collection, ws)
case evt.Kind == 2:
err = kinds.HandleKind2(ctx, evt, ws)
case evt.Kind == 3:
err = kinds.HandleReplaceableKind(ctx, evt, collection, ws)
case evt.Kind == 5:
err = kinds.HandleKind5(ctx, evt, db.GetClient(), ws)
case evt.Kind >= 4 && evt.Kind < 45:
err = kinds.HandleRegularKind(ctx, evt, collection, ws)
case evt.Kind >= 1000 && evt.Kind < 10000:
err = kinds.HandleRegularKind(ctx, evt, collection, ws)
case evt.Kind >= 10000 && evt.Kind < 20000:
err = kinds.HandleReplaceableKind(ctx, evt, collection, ws)
case evt.Kind >= 20000 && evt.Kind < 30000:
fmt.Println("Ephemeral event received and ignored:", evt.ID)
case evt.Kind >= 30000 && evt.Kind < 40000:
err = kinds.HandleParameterizedReplaceableKind(ctx, evt, collection, ws)
default:
err = kinds.HandleUnknownKind(ctx, evt, collection, ws)
}
if err != nil {
response.SendOK(ws, evt.ID, false, fmt.Sprintf("error: %v", err))
return
}
response.SendOK(ws, evt.ID, true, "")
}
func determineCategory(kind int) string { func determineCategory(kind int) string {
switch { switch {
case kind == 0, kind == 3, kind >= 10000 && kind < 20000: case kind == 0, kind == 3, kind >= 10000 && kind < 20000:

View File

@ -1,210 +0,0 @@
package utils
import (
"fmt"
"grain/config"
cfg "grain/config/types"
"log"
"os"
"strings"
"sync"
"time"
"gopkg.in/yaml.v2"
)
// CheckBlacklist checks if a pubkey is in the blacklist based on event content
func CheckBlacklist(pubkey, eventContent string) (bool, string) {
cfg := config.GetConfig().Blacklist
if !cfg.Enabled {
return false, ""
}
log.Printf("Checking blacklist for pubkey: %s", pubkey)
// Check for permanent blacklist by pubkey or npub
if isPubKeyPermanentlyBlacklisted(pubkey) {
log.Printf("Pubkey %s is permanently blacklisted", pubkey)
return true, fmt.Sprintf("pubkey %s is permanently blacklisted", pubkey)
}
// Check for temporary ban
if isPubKeyTemporarilyBlacklisted(pubkey) {
log.Printf("Pubkey %s is temporarily blacklisted", pubkey)
return true, fmt.Sprintf("pubkey %s is temporarily blacklisted", pubkey)
}
// Check for permanent ban based on wordlist
for _, word := range cfg.PermanentBanWords {
if strings.Contains(eventContent, word) {
err := AddToPermanentBlacklist(pubkey)
if err != nil {
return true, fmt.Sprintf("pubkey %s is permanently banned and failed to save: %v", pubkey, err)
}
return true, "blocked: pubkey is permanently banned"
}
}
// Check for temporary ban based on wordlist
for _, word := range cfg.TempBanWords {
if strings.Contains(eventContent, word) {
err := AddToTemporaryBlacklist(pubkey)
if err != nil {
return true, fmt.Sprintf("pubkey %s is temporarily banned and failed to save: %v", pubkey, err)
}
return true, "blocked: pubkey is temporarily banned"
}
}
return false, ""
}
// Checks if a pubkey is temporarily blacklisted
func isPubKeyTemporarilyBlacklisted(pubkey string) bool {
mu.Lock()
defer mu.Unlock()
entry, exists := tempBannedPubkeys[pubkey]
if !exists {
log.Printf("Pubkey %s not found in temporary blacklist", pubkey)
return false
}
now := time.Now()
if now.After(entry.unbanTime) {
log.Printf("Temporary ban for pubkey %s has expired. Count: %d", pubkey, entry.count)
return false
}
log.Printf("Pubkey %s is currently temporarily blacklisted. Count: %d, Unban time: %s", pubkey, entry.count, entry.unbanTime)
return true
}
func ClearTemporaryBans() {
mu.Lock()
defer mu.Unlock()
tempBannedPubkeys = make(map[string]*tempBanEntry)
}
var (
tempBannedPubkeys = make(map[string]*tempBanEntry)
mu sync.Mutex
)
type tempBanEntry struct {
count int
unbanTime time.Time
}
// Adds a pubkey to the temporary blacklist
func AddToTemporaryBlacklist(pubkey string) error {
mu.Lock()
defer mu.Unlock()
cfg := config.GetConfig().Blacklist
entry, exists := tempBannedPubkeys[pubkey]
if !exists {
log.Printf("Creating new temporary ban entry for pubkey %s", pubkey)
entry = &tempBanEntry{
count: 0,
unbanTime: time.Now(),
}
tempBannedPubkeys[pubkey] = entry
} else {
log.Printf("Updating existing temporary ban entry for pubkey %s. Current count: %d", pubkey, entry.count)
if time.Now().After(entry.unbanTime) {
log.Printf("Previous ban for pubkey %s has expired. Keeping count at %d", pubkey, entry.count)
}
}
// Increment the count
entry.count++
entry.unbanTime = time.Now().Add(time.Duration(cfg.TempBanDuration) * time.Second)
log.Printf("Pubkey %s temporary ban count updated to: %d, MaxTempBans: %d, New unban time: %s", pubkey, entry.count, cfg.MaxTempBans, entry.unbanTime)
if entry.count > cfg.MaxTempBans {
log.Printf("Attempting to move pubkey %s to permanent blacklist", pubkey)
delete(tempBannedPubkeys, pubkey)
// Release the lock before calling AddToPermanentBlacklist
mu.Unlock()
err := AddToPermanentBlacklist(pubkey)
mu.Lock() // Re-acquire the lock
if err != nil {
log.Printf("Error adding pubkey %s to permanent blacklist: %v", pubkey, err)
return err
}
log.Printf("Successfully added pubkey %s to permanent blacklist", pubkey)
}
return nil
}
// Checks if a pubkey is permanently blacklisted (only using config.yml)
func isPubKeyPermanentlyBlacklisted(pubKey string) bool {
cfg := config.GetConfig().Blacklist // Get the latest configuration
if !cfg.Enabled {
return false
}
// Check pubkeys
for _, blacklistedKey := range cfg.PermanentBlacklistPubkeys {
if pubKey == blacklistedKey {
return true
}
}
// Check npubs
for _, npub := range cfg.PermanentBlacklistNpubs {
decodedPubKey, err := DecodeNpub(npub)
if err != nil {
fmt.Println("Error decoding npub:", err)
continue
}
if pubKey == decodedPubKey {
return true
}
}
return false
}
func AddToPermanentBlacklist(pubkey string) error {
// Remove the mutex lock from here
cfg := config.GetConfig().Blacklist
// Check if already blacklisted
if isPubKeyPermanentlyBlacklisted(pubkey) {
return fmt.Errorf("pubkey %s is already in the permanent blacklist", pubkey)
}
// Add pubkey to the blacklist
cfg.PermanentBlacklistPubkeys = append(cfg.PermanentBlacklistPubkeys, pubkey)
// Persist changes to config.yml
return saveBlacklistConfig(cfg)
}
func saveBlacklistConfig(blacklistConfig cfg.BlacklistConfig) error {
configData := config.GetConfig()
configData.Blacklist = blacklistConfig
data, err := yaml.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %v", err)
}
err = os.WriteFile("config.yml", data, 0644)
if err != nil {
return fmt.Errorf("failed to write config to file: %v", err)
}
return nil
}