diff --git a/tests/config_test.go b/tests/config_test.go new file mode 100644 index 0000000..93c3dfc --- /dev/null +++ b/tests/config_test.go @@ -0,0 +1,77 @@ +package tests + +import ( + "testing" + + "grain/relay/utils" +) + +func TestConfigValidity(t *testing.T) { + config, err := utils.LoadConfig("../config.yml") + if err != nil { + t.Fatalf("Error loading config: %v", err) + } + + // Check MongoDB settings + if config.MongoDB.URI == "" { + t.Error("MongoDB URI is required") + } + if config.MongoDB.Database == "" { + t.Error("MongoDB database name is required") + } + + // Check Server settings + if config.Server.Port == "" { + t.Error("Server port is required") + } + + // Check Rate Limit settings + if config.RateLimit.WsLimit == 0 { + t.Error("WebSocket limit is required") + } + if config.RateLimit.WsBurst == 0 { + t.Error("WebSocket burst is required") + } + if config.RateLimit.EventLimit == 0 { + t.Error("Event limit is required") + } + if config.RateLimit.EventBurst == 0 { + t.Error("Event burst is required") + } + if config.RateLimit.ReqLimit == 0 { + t.Error("REQ limit is required") + } + if config.RateLimit.ReqBurst == 0 { + t.Error("REQ burst is required") + } + + // Check Category Limits + if len(config.RateLimit.CategoryLimits) == 0 { + t.Log("Warning: No category limits set") + } + + // Check Kind Limits + if len(config.RateLimit.KindLimits) == 0 { + t.Log("Warning: No kind limits set") + } + + // Validate individual category limits + for category, limits := range config.RateLimit.CategoryLimits { + if limits.Limit == 0 { + t.Errorf("Limit is required for category: %s", category) + } + if limits.Burst == 0 { + t.Errorf("Burst is required for category: %s", category) + } + } + + // Validate individual kind limits + for _, kindLimit := range config.RateLimit.KindLimits { + if kindLimit.Limit == 0 { + t.Errorf("Limit is required for kind: %d", kindLimit.Kind) + } + if kindLimit.Burst == 0 { + t.Errorf("Burst is required for kind: %d", kindLimit.Kind) + } + } +} \ No newline at end of file diff --git a/tests/rateLimits_test.go b/tests/rateLimits_test.go new file mode 100644 index 0000000..0bd55b3 --- /dev/null +++ b/tests/rateLimits_test.go @@ -0,0 +1,108 @@ +package tests + +import ( + "testing" + + "grain/relay/utils" + + "golang.org/x/time/rate" +) + +func TestWebSocketRateLimit(t *testing.T) { + rateLimiter := utils.NewRateLimiter(rate.Limit(1), 1, rate.Limit(100), 200, rate.Limit(100), 200) + + // First message should be allowed + if allowed, _ := rateLimiter.AllowWs(); !allowed { + t.Error("First WebSocket message should be allowed") + } + + // Second message should be rate-limited + if allowed, msg := rateLimiter.AllowWs(); allowed { + t.Error("Second WebSocket message should be rate-limited") + } else { + expectedMsg := "WebSocket message rate limit exceeded" + if msg != expectedMsg { + t.Errorf("Expected message: %s, got: %s", expectedMsg, msg) + } + } +} + +func TestEventRateLimit(t *testing.T) { + rateLimiter := utils.NewRateLimiter(rate.Limit(100), 200, rate.Limit(1), 1, rate.Limit(100), 200) + rateLimiter.AddKindLimit(1, rate.Limit(1), 1) + rateLimiter.AddCategoryLimit("regular", rate.Limit(1), 1) + + // First event should be allowed + if allowed, _ := rateLimiter.AllowEvent(1, "regular"); !allowed { + t.Error("First event should be allowed") + } + + // Second event should be rate-limited + if allowed, msg := rateLimiter.AllowEvent(1, "regular"); allowed { + t.Error("Second event should be rate-limited") + } else { + expectedMsg := "Global event rate limit exceeded" + if msg != expectedMsg { + t.Errorf("Expected message: %s, got: %s", expectedMsg, msg) + } + } +} + +func TestReqRateLimit(t *testing.T) { + rateLimiter := utils.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(1), 1) + + // First REQ should be allowed + if allowed, _ := rateLimiter.AllowReq(); !allowed { + t.Error("First REQ message should be allowed") + } + + // Second REQ should be rate-limited + if allowed, msg := rateLimiter.AllowReq(); allowed { + t.Error("Second REQ message should be rate-limited") + } else { + expectedMsg := "REQ rate limit exceeded" + if msg != expectedMsg { + t.Errorf("Expected message: %s, got: %s", expectedMsg, msg) + } + } +} + +func TestKindRateLimit(t *testing.T) { + rateLimiter := utils.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(100), 200) + rateLimiter.AddKindLimit(1, rate.Limit(1), 1) + + // First event of kind 1 should be allowed + if allowed, _ := rateLimiter.AllowEvent(1, "regular"); !allowed { + t.Error("First event of kind 1 should be allowed") + } + + // Second event of kind 1 should be rate-limited + if allowed, msg := rateLimiter.AllowEvent(1, "regular"); allowed { + t.Error("Second event of kind 1 should be rate-limited") + } else { + expectedMsg := "Rate limit exceeded for kind: 1" + if msg != expectedMsg { + t.Errorf("Expected message: %s, got: %s", expectedMsg, msg) + } + } +} + +func TestCategoryRateLimit(t *testing.T) { + rateLimiter := utils.NewRateLimiter(rate.Limit(100), 200, rate.Limit(100), 200, rate.Limit(100), 200) + rateLimiter.AddCategoryLimit("regular", rate.Limit(1), 1) + + // First event in category "regular" should be allowed + if allowed, _ := rateLimiter.AllowEvent(1, "regular"); !allowed { + t.Error("First event in category 'regular' should be allowed") + } + + // Second event in category "regular" should be rate-limited + if allowed, msg := rateLimiter.AllowEvent(1, "regular"); allowed { + t.Error("Second event in category 'regular' should be rate-limited") + } else { + expectedMsg := "Rate limit exceeded for category: regular" + if msg != expectedMsg { + t.Errorf("Expected message: %s, got: %s", expectedMsg, msg) + } + } +}