Compare commits

..

4 Commits

14 changed files with 187 additions and 94 deletions

View File

@ -1,76 +0,0 @@
package app
import (
"grain/server/db"
relay "grain/server/types"
"html/template"
"net/http"
)
type PageData struct {
Title string
Theme string
Events []relay.Event
}
func RootHandler(w http.ResponseWriter, r *http.Request) {
// Fetch the top ten most recent events
client := db.GetClient()
events, err := FetchTopTenRecentEvents(client)
if err != nil {
http.Error(w, "Unable to fetch events", http.StatusInternalServerError)
return
}
data := PageData{
Title: "GRAIN Dashboard",
Events: events,
}
RenderTemplate(w, data, "index.html")
}
// Define the base directories for views and templates
const (
viewsDir = "app/views/"
templatesDir = "app/views/templates/"
)
// Define the common layout templates filenames
var templateFiles = []string{
"layout.html",
"header.html",
"footer.html",
}
// Initialize the common templates with full paths
var layout = PrependDir(templatesDir, templateFiles)
func RenderTemplate(w http.ResponseWriter, data PageData, view string) {
// Append the specific template for the route
templates := append(layout, viewsDir+view)
// Parse all templates
tmpl, err := template.ParseFiles(templates...)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Execute the "layout" template
err = tmpl.ExecuteTemplate(w, "layout", data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
// Helper function to prepend a directory path to a list of filenames
func PrependDir(dir string, files []string) []string {
var fullPaths []string
for _, file := range files {
fullPaths = append(fullPaths, dir+file)
}
return fullPaths
}

View File

@ -1,7 +1,9 @@
package routes package routes
import ( import (
app "grain/app/src" app "grain/app/src/types"
"grain/app/src/utils"
"net/http" "net/http"
) )
@ -11,5 +13,5 @@ func ImportEvents(w http.ResponseWriter, r *http.Request) {
} }
// Call RenderTemplate with the specific template for this route // Call RenderTemplate with the specific template for this route
app.RenderTemplate(w, data, "importEvents.html") utils.RenderTemplate(w, data, "importEvents.html")
} }

17
app/src/routes/index.go Normal file
View File

@ -0,0 +1,17 @@
package routes
import (
app "grain/app/src/types"
"grain/app/src/utils"
"net/http"
)
func IndexHandler(w http.ResponseWriter, r *http.Request) {
data := app.PageData{
Title: "GRAIN Dashboard",
}
utils.RenderTemplate(w, data, "index.html")
}

View File

@ -0,0 +1,5 @@
package types
type PageData struct {
Title string
}

View File

@ -1,4 +1,4 @@
package app package utils
import ( import (
"context" "context"

View File

@ -0,0 +1,10 @@
package utils
// Helper function to prepend a directory path to a list of filenames
func PrependDir(dir string, files []string) []string {
var fullPaths []string
for _, file := range files {
fullPaths = append(fullPaths, dir+file)
}
return fullPaths
}

View File

@ -0,0 +1,44 @@
package utils
import (
app "grain/app/src/types"
"html/template"
"net/http"
)
// Define the base directories for views and templates
const (
viewsDir = "app/views/"
templatesDir = "app/views/templates/"
)
// Define the common layout templates filenames
var templateFiles = []string{
"layout.html",
"header.html",
"footer.html",
}
// Initialize the common templates with full paths
var layout = PrependDir(templatesDir, templateFiles)
func RenderTemplate(w http.ResponseWriter, data app.PageData, view string) {
// Append the specific template for the route
templates := append(layout, viewsDir+view)
// Parse all templates
tmpl, err := template.ParseFiles(templates...)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Execute the "layout" template
err = tmpl.ExecuteTemplate(w, "layout", data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View File

@ -7,6 +7,9 @@ server:
read_timeout: 10 # Read timeout in seconds read_timeout: 10 # Read timeout in seconds
write_timeout: 10 # Write timeout in seconds write_timeout: 10 # Write timeout in seconds
idle_timeout: 120 # Idle timeout in seconds idle_timeout: 120 # Idle timeout in seconds
max_connections: 100 # Maximum number of concurrent connections
max_subscriptions_per_client: 10 # Maximum number of concurrent subscriptions per client
pubkey_whitelist: pubkey_whitelist:
enabled: false enabled: false
pubkeys: #["3fe0ab6cbdb7ee27148202249e3fb3b89423c6f6cda6ef43ea5057c3d93088e4", pubkeys: #["3fe0ab6cbdb7ee27148202249e3fb3b89423c6f6cda6ef43ea5057c3d93088e4",

View File

@ -10,6 +10,8 @@ type ServerConfig struct {
ReadTimeout int `yaml:"read_timeout"` // Timeout in seconds ReadTimeout int `yaml:"read_timeout"` // Timeout in seconds
WriteTimeout int `yaml:"write_timeout"` // Timeout in seconds WriteTimeout int `yaml:"write_timeout"` // Timeout in seconds
IdleTimeout int `yaml:"idle_timeout"` // Timeout in seconds IdleTimeout int `yaml:"idle_timeout"` // Timeout in seconds
MaxConnections int `yaml:"max_connections"` // Maximum number of concurrent connections
MaxSubscriptionsPerClient int `yaml:"max_subscriptions_per_client"` // Maximum number of subscriptions per client
} `yaml:"server"` } `yaml:"server"`
RateLimit RateLimitConfig `yaml:"rate_limit"` RateLimit RateLimitConfig `yaml:"rate_limit"`
PubkeyWhitelist PubkeyWhitelistConfig `yaml:"pubkey_whitelist"` PubkeyWhitelist PubkeyWhitelistConfig `yaml:"pubkey_whitelist"`

View File

@ -2,14 +2,12 @@ package main
import ( import (
"fmt" "fmt"
app "grain/app/src"
"grain/app/src/api" "grain/app/src/api"
"grain/app/src/routes" "grain/app/src/routes"
"grain/config" "grain/config"
configTypes "grain/config/types" configTypes "grain/config/types"
relay "grain/server" relay "grain/server"
"grain/server/db" "grain/server/db"
"grain/server/nip"
"grain/server/utils" "grain/server/utils"
"log" "log"
"net/http" "net/http"
@ -36,7 +34,7 @@ func main() {
config.SetupRateLimiter(cfg) config.SetupRateLimiter(cfg)
config.SetupSizeLimiter(cfg) config.SetupSizeLimiter(cfg)
err = nip.LoadRelayMetadataJSON() err = utils.LoadRelayMetadataJSON()
if err != nil { if err != nil {
log.Fatal("Failed to load relay metadata: ", err) log.Fatal("Failed to load relay metadata: ", err)
} }
@ -85,8 +83,8 @@ func ListenAndServe(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" { if r.Header.Get("Upgrade") == "websocket" {
wsServer.ServeHTTP(w, r) wsServer.ServeHTTP(w, r)
} else if r.Header.Get("Accept") == "application/nostr+json" { } else if r.Header.Get("Accept") == "application/nostr+json" {
nip.RelayInfoHandler(w, r) utils.RelayInfoHandler(w, r)
} else { } else {
app.RootHandler(w, r) routes.IndexHandler(w, r)
} }
} }

View File

@ -4,10 +4,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/config"
"grain/server/db" "grain/server/db"
"grain/server/handlers/response" "grain/server/handlers/response"
relay "grain/server/types" relay "grain/server/types"
"grain/server/utils" "grain/server/utils"
"sync"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
@ -16,6 +18,7 @@ import (
) )
var subscriptions = make(map[string]relay.Subscription) var subscriptions = make(map[string]relay.Subscription)
var mu sync.Mutex // Protect concurrent access to subscriptions map
func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[string][]relay.Filter) { func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[string][]relay.Filter) {
if len(message) < 3 { if len(message) < 3 {
@ -31,6 +34,21 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri
return return
} }
mu.Lock()
defer mu.Unlock()
// Check the current number of subscriptions for the client
if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient {
// Find and remove the oldest subscription (FIFO)
var oldestSubID string
for id := range subscriptions {
oldestSubID = id
break
}
delete(subscriptions, oldestSubID)
fmt.Println("Dropped oldest subscription:", oldestSubID)
}
filters := make([]relay.Filter, len(message)-2) filters := make([]relay.Filter, len(message)-2)
for i, filter := range message[2:] { for i, filter := range message[2:] {
filterData, ok := filter.(map[string]interface{}) filterData, ok := filter.(map[string]interface{})
@ -52,7 +70,7 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri
filters[i] = f filters[i] = f
} }
// Update or add the subscription for the given subID // Add the new subscription or update the existing one
subscriptions[subID] = filters subscriptions[subID] = filters
fmt.Println("Subscription updated:", subID) fmt.Println("Subscription updated:", subID)

View File

@ -3,22 +3,56 @@ package relay
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/server/handlers"
"grain/config" "grain/config"
"grain/server/handlers"
relay "grain/server/types" relay "grain/server/types"
"grain/server/utils"
"log"
"sync"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
// Global connection count
var (
currentConnections int
mu sync.Mutex
)
// Client subscription count
var clientSubscriptions = make(map[*websocket.Conn]int)
func WebSocketHandler(ws *websocket.Conn) { func WebSocketHandler(ws *websocket.Conn) {
defer ws.Close() defer func() {
mu.Lock()
currentConnections--
delete(clientSubscriptions, ws)
mu.Unlock()
ws.Close()
}()
mu.Lock()
if currentConnections >= config.GetConfig().Server.MaxConnections {
websocket.Message.Send(ws, `{"error": "too many connections"}`)
mu.Unlock()
return
}
currentConnections++
mu.Unlock()
clientInfo := utils.ClientInfo{
IP: utils.GetClientIP(ws.Request()),
UserAgent: ws.Request().Header.Get("User-Agent"),
Origin: ws.Request().Header.Get("Origin"),
}
log.Printf("New connection from IP: %s, User-Agent: %s, Origin: %s", clientInfo.IP, clientInfo.UserAgent, clientInfo.Origin)
var msg string var msg string
rateLimiter := config.GetRateLimiter() rateLimiter := config.GetRateLimiter()
subscriptions := make(map[string][]relay.Filter) // Subscription map scoped to the connection subscriptions := make(map[string][]relay.Filter) // Subscription map scoped to the connection
clientSubscriptions[ws] = 0
for { for {
err := websocket.Message.Receive(ws, &msg) err := websocket.Message.Receive(ws, &msg)
@ -57,6 +91,14 @@ func WebSocketHandler(ws *websocket.Conn) {
case "EVENT": case "EVENT":
handlers.HandleEvent(ws, message) handlers.HandleEvent(ws, message)
case "REQ": case "REQ":
mu.Lock()
if clientSubscriptions[ws] >= config.GetConfig().Server.MaxSubscriptionsPerClient {
websocket.Message.Send(ws, `{"error": "too many subscriptions"}`)
mu.Unlock()
continue
}
clientSubscriptions[ws]++
mu.Unlock()
if allowed, msg := rateLimiter.AllowReq(); !allowed { if allowed, msg := rateLimiter.AllowReq(); !allowed {
websocket.Message.Send(ws, fmt.Sprintf(`{"error": "%s"}`, msg)) websocket.Message.Send(ws, fmt.Sprintf(`{"error": "%s"}`, msg))
ws.Close() ws.Close()

View File

@ -0,0 +1,28 @@
package utils
import (
"net/http"
"strings"
)
type ClientInfo struct {
IP string
UserAgent string
Origin string
}
func GetClientIP(r *http.Request) string {
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
remoteAddr := r.RemoteAddr
if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 {
return remoteAddr[:idx]
}
return remoteAddr
}

View File

@ -1,4 +1,4 @@
package nip package utils
import ( import (
"encoding/json" "encoding/json"