Compare commits

..

No commits in common. "33706b4200668084ac993c2305d09b49a0c3b865" and "7009533c8dc6df0e2b80974e9811e65fa4673edf" have entirely different histories.

18 changed files with 289 additions and 273 deletions

1
.gitignore vendored
View File

@ -3,4 +3,3 @@ config.yml
relay_metadata.json relay_metadata.json
grain.exe grain.exe
/build /build
/logs

View File

@ -7,7 +7,7 @@ import (
"grain/config" "grain/config"
configTypes "grain/config/types" configTypes "grain/config/types"
relay "grain/server" relay "grain/server"
"grain/server/db/mongo" "grain/server/db"
"grain/server/utils" "grain/server/utils"
"log" "log"
"net/http" "net/http"
@ -41,7 +41,7 @@ func main() {
config.SetResourceLimit(&cfg.ResourceLimits) // Apply limits once before starting the server config.SetResourceLimit(&cfg.ResourceLimits) // Apply limits once before starting the server
client, err := mongo.InitDB(cfg) client, err := db.InitDB(cfg)
if err != nil { if err != nil {
log.Fatal("Error initializing database: ", err) log.Fatal("Error initializing database: ", err)
} }
@ -71,7 +71,7 @@ func main() {
case <-signalChan: case <-signalChan:
log.Println("Shutting down server...") log.Println("Shutting down server...")
server.Close() // Stop the server server.Close() // Stop the server
mongo.DisconnectDB(client) // Disconnect from MongoDB db.DisconnectDB(client) // Disconnect from MongoDB
wg.Wait() // Wait for all goroutines to finish wg.Wait() // Wait for all goroutines to finish
return return
} }

View File

@ -1,4 +1,4 @@
package mongo package db
import ( import (
"context" "context"

View File

@ -1,112 +0,0 @@
package mongo
import (
"context"
"fmt"
relay "grain/server/types"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// QueryEvents queries events from the MongoDB collection(s) based on filters
func QueryEvents(filters []relay.Filter, client *mongo.Client, databaseName string) ([]relay.Event, error) {
var results []relay.Event
var combinedFilters []bson.M
// Build MongoDB filters for each relay.Filter
for _, filter := range filters {
filterBson := bson.M{}
if len(filter.IDs) > 0 {
filterBson["id"] = bson.M{"$in": filter.IDs}
}
if len(filter.Authors) > 0 {
filterBson["pubkey"] = bson.M{"$in": filter.Authors}
}
if len(filter.Kinds) > 0 {
filterBson["kind"] = bson.M{"$in": filter.Kinds}
}
if filter.Tags != nil {
for key, values := range filter.Tags {
if len(values) > 0 {
filterBson["tags."+key] = bson.M{"$in": values}
}
}
}
if filter.Since != nil {
filterBson["created_at"] = bson.M{"$gte": *filter.Since}
}
if filter.Until != nil {
if filterBson["created_at"] == nil {
filterBson["created_at"] = bson.M{"$lte": *filter.Until}
} else {
filterBson["created_at"].(bson.M)["$lte"] = *filter.Until
}
}
combinedFilters = append(combinedFilters, filterBson)
}
// Combine all filter conditions using the $or operator
query := bson.M{}
if len(combinedFilters) > 0 {
query["$or"] = combinedFilters
}
// Apply sorting by creation date (descending)
opts := options.Find().SetSort(bson.D{{Key: "created_at", Value: -1}})
// Apply limit if set in any filter
for _, filter := range filters {
if filter.Limit != nil {
opts.SetLimit(int64(*filter.Limit))
}
}
// If no kinds are specified in any filter, query all collections
var collections []string
if len(filters) > 0 && len(filters[0].Kinds) == 0 {
collections, _ = client.Database(databaseName).ListCollectionNames(context.TODO(), bson.D{})
} else {
// Collect all kinds from filters and query those collections
kindsMap := make(map[int]bool)
for _, filter := range filters {
for _, kind := range filter.Kinds {
kindsMap[kind] = true
}
}
// Construct collection names based on kinds
for kind := range kindsMap {
collectionName := fmt.Sprintf("event-kind%d", kind)
collections = append(collections, collectionName)
}
}
// Query each collection
for _, collectionName := range collections {
collection := client.Database(databaseName).Collection(collectionName)
cursor, err := collection.Find(context.TODO(), query, opts)
if err != nil {
return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err)
}
defer cursor.Close(context.TODO())
for cursor.Next(context.TODO()) {
var event relay.Event
if err := cursor.Decode(&event); err != nil {
return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err)
}
results = append(results, event)
}
// Handle cursor errors
if err := cursor.Err(); err != nil {
return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err)
}
}
return results, nil
}

109
server/db/queryMongo.go Normal file
View File

@ -0,0 +1,109 @@
package db
import (
"context"
"fmt"
relay "grain/server/types"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// QueryEvents queries events from the MongoDB collection(s) based on filters
func QueryEvents(filters []relay.Filter, client *mongo.Client, databaseName string) ([]relay.Event, error) {
var results []relay.Event
for _, filter := range filters {
filterBson := bson.M{}
// Construct the BSON query based on the filters
if len(filter.IDs) > 0 {
filterBson["id"] = bson.M{"$in": filter.IDs}
}
if len(filter.Authors) > 0 {
filterBson["pubkey"] = bson.M{"$in": filter.Authors}
}
if len(filter.Kinds) > 0 {
filterBson["kind"] = bson.M{"$in": filter.Kinds}
}
if filter.Tags != nil {
for key, values := range filter.Tags {
if len(values) > 0 {
filterBson["tags."+key] = bson.M{"$in": values}
}
}
}
if filter.Since != nil {
filterBson["created_at"] = bson.M{"$gte": *filter.Since}
}
if filter.Until != nil {
if filterBson["created_at"] == nil {
filterBson["created_at"] = bson.M{"$lte": *filter.Until}
} else {
filterBson["created_at"].(bson.M)["$lte"] = *filter.Until
}
}
opts := options.Find().SetSort(bson.D{{Key: "created_at", Value: -1}})
if filter.Limit != nil {
opts.SetLimit(int64(*filter.Limit))
}
// If no specific kinds are specified, query all collections in the database
if len(filter.Kinds) == 0 {
collections, err := client.Database(databaseName).ListCollectionNames(context.TODO(), bson.D{})
if err != nil {
return nil, fmt.Errorf("error listing collections: %v", err)
}
for _, collectionName := range collections {
fmt.Printf("Querying collection: %s with query: %v\n", collectionName, filterBson)
collection := client.Database(databaseName).Collection(collectionName)
cursor, err := collection.Find(context.TODO(), filterBson, opts)
if err != nil {
return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err)
}
defer cursor.Close(context.TODO())
for cursor.Next(context.TODO()) {
var event relay.Event
if err := cursor.Decode(&event); err != nil {
return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err)
}
results = append(results, event)
}
if err := cursor.Err(); err != nil {
return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err)
}
}
} else {
// Query specific collections based on kinds
for _, kind := range filter.Kinds {
collectionName := fmt.Sprintf("event-kind%d", kind)
fmt.Printf("Querying collection: %s with query: %v\n", collectionName, filterBson)
collection := client.Database(databaseName).Collection(collectionName)
cursor, err := collection.Find(context.TODO(), filterBson, opts)
if err != nil {
return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err)
}
defer cursor.Close(context.TODO())
for cursor.Next(context.TODO()) {
var event relay.Event
if err := cursor.Decode(&event); err != nil {
return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err)
}
results = append(results, event)
}
if err := cursor.Err(); err != nil {
return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err)
}
}
}
}
return results, nil
}

View File

@ -1,9 +1,9 @@
package mongo package db
import ( import (
"context" "context"
"fmt" "fmt"
"grain/server/db/mongo/kinds" "grain/server/handlers/kinds"
"grain/server/handlers/response" "grain/server/handlers/response"
nostr "grain/server/types" nostr "grain/server/types"
@ -16,15 +16,15 @@ func StoreMongoEvent(ctx context.Context, evt nostr.Event, ws *websocket.Conn) {
var err error var err error
switch { switch {
case evt.Kind == 0: case evt.Kind == 0:
err = kinds.HandleReplaceableKind(ctx, evt, collection, ws) err = kinds.HandleKind0(ctx, evt, collection, ws)
case evt.Kind == 1: case evt.Kind == 1:
err = kinds.HandleRegularKind(ctx, evt, collection, ws) err = kinds.HandleKind1(ctx, evt, collection, ws)
case evt.Kind == 2: case evt.Kind == 2:
err = kinds.HandleDeprecatedKind(ctx, evt, ws) err = kinds.HandleKind2(ctx, evt, ws)
case evt.Kind == 3: case evt.Kind == 3:
err = kinds.HandleReplaceableKind(ctx, evt, collection, ws) err = kinds.HandleReplaceableKind(ctx, evt, collection, ws)
case evt.Kind == 5: case evt.Kind == 5:
err = kinds.HandleDeleteKind(ctx, evt, GetClient(), ws) err = kinds.HandleKind5(ctx, evt, GetClient(), ws)
case evt.Kind >= 4 && evt.Kind < 45: case evt.Kind >= 4 && evt.Kind < 45:
err = kinds.HandleRegularKind(ctx, evt, collection, ws) err = kinds.HandleRegularKind(ctx, evt, collection, ws)
case evt.Kind >= 1000 && evt.Kind < 10000: case evt.Kind >= 1000 && evt.Kind < 10000:

View File

@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/config" "grain/config"
"grain/server/db/mongo" "grain/server/db"
"grain/server/handlers/response" "grain/server/handlers/response"
"grain/server/utils" "grain/server/utils"
@ -61,7 +61,7 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) {
} }
// This is where I'll handle storage for multiple database types in the future // This is where I'll handle storage for multiple database types in the future
mongo.StoreMongoEvent(context.TODO(), evt, ws) db.StoreMongoEvent(context.TODO(), evt, ws)
fmt.Println("Event processed:", evt.ID) fmt.Println("Event processed:", evt.ID)

View File

@ -0,0 +1,50 @@
package kinds
import (
"context"
"fmt"
"grain/server/handlers/response"
relay "grain/server/types"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"golang.org/x/net/websocket"
)
func HandleKind0(ctx context.Context, evt relay.Event, collection *mongo.Collection, ws *websocket.Conn) error {
filter := bson.M{"pubkey": evt.PubKey}
var existingEvent relay.Event
err := collection.FindOne(ctx, filter).Decode(&existingEvent)
if err != nil && err != mongo.ErrNoDocuments {
return fmt.Errorf("error finding existing event: %v", err)
}
if err != mongo.ErrNoDocuments {
if existingEvent.CreatedAt >= evt.CreatedAt {
response.SendOK(ws, evt.ID, false, "blocked: a newer kind 0 event already exists for this pubkey")
return nil
}
}
update := bson.M{
"$set": bson.M{
"id": evt.ID,
"created_at": evt.CreatedAt,
"kind": evt.Kind,
"tags": evt.Tags,
"content": evt.Content,
"sig": evt.Sig,
},
}
opts := options.Update().SetUpsert(true)
_, err = collection.UpdateOne(ctx, filter, update, opts)
if err != nil {
response.SendOK(ws, evt.ID, false, "error: could not connect to the database")
return fmt.Errorf("error updating/inserting event kind 0 into MongoDB: %v", err)
}
response.SendOK(ws, evt.ID, true, "")
fmt.Println("Upserted event kind 0 into MongoDB:", evt.ID)
return nil
}

View File

@ -0,0 +1,24 @@
// kinds/kind1.go
package kinds
import (
"context"
"fmt"
"grain/server/handlers/response"
relay "grain/server/types"
"go.mongodb.org/mongo-driver/mongo"
"golang.org/x/net/websocket"
)
func HandleKind1(ctx context.Context, evt relay.Event, collection *mongo.Collection, ws *websocket.Conn) error {
_, err := collection.InsertOne(ctx, evt)
if err != nil {
response.SendOK(ws, evt.ID, false, "error: could not connect to the database")
return fmt.Errorf("error inserting event into MongoDB: %v", err)
}
fmt.Println("Inserted event kind 1 into MongoDB:", evt.ID)
response.SendOK(ws, evt.ID, true, "")
return nil
}

View File

@ -8,7 +8,7 @@ import (
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
func HandleDeprecatedKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) error { func HandleKind2(ctx context.Context, evt relay.Event, ws *websocket.Conn) error {
// Send an OK message to indicate the event was not accepted // Send an OK message to indicate the event was not accepted
response.SendOK(ws, evt.ID, false, "invalid: kind 2 is deprecated, use kind 10002 (NIP65)") response.SendOK(ws, evt.ID, false, "invalid: kind 2 is deprecated, use kind 10002 (NIP65)")

View File

@ -12,7 +12,7 @@ import (
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
func HandleDeleteKind(ctx context.Context, evt relay.Event, dbClient *mongo.Client, ws *websocket.Conn) error { func HandleKind5(ctx context.Context, evt relay.Event, dbClient *mongo.Client, ws *websocket.Conn) error {
for _, tag := range evt.Tags { for _, tag := range evt.Tags {
if len(tag) < 2 { if len(tag) < 2 {
continue continue

View File

@ -4,7 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/config" "grain/config"
"grain/server/db/mongo" "grain/server/db"
"grain/server/handlers/response" "grain/server/handlers/response"
relay "grain/server/types" relay "grain/server/types"
"grain/server/utils" "grain/server/utils"
@ -49,7 +49,6 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri
} }
} }
// processRequest handles the actual processing of each request // processRequest handles the actual processing of each request
func processRequest(ws *websocket.Conn, message []interface{}) { func processRequest(ws *websocket.Conn, message []interface{}) {
if len(message) < 3 { if len(message) < 3 {
@ -59,17 +58,18 @@ func processRequest(ws *websocket.Conn, message []interface{}) {
} }
subID, ok := message[1].(string) subID, ok := message[1].(string)
if !ok || len(subID) == 0 || len(subID) > 64 { if !ok {
fmt.Println("Invalid subscription ID format or length") fmt.Println("Invalid subscription ID format")
response.SendClosed(ws, "", "invalid: subscription ID must be between 1 and 64 characters long") response.SendClosed(ws, "", "invalid: invalid subscription ID format")
return return
} }
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
// Remove oldest subscription if needed // Check the current number of subscriptions for the client
if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient { if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient {
// Find and remove the oldest subscription (FIFO)
var oldestSubID string var oldestSubID string
for id := range subscriptions { for id := range subscriptions {
oldestSubID = id oldestSubID = id
@ -79,7 +79,7 @@ func processRequest(ws *websocket.Conn, message []interface{}) {
fmt.Println("Dropped oldest subscription:", oldestSubID) fmt.Println("Dropped oldest subscription:", oldestSubID)
} }
// Parse and validate filters // Prepare filters based on the incoming message
filters := make([]relay.Filter, len(message)-2) filters := make([]relay.Filter, len(message)-2)
for i, filter := range message[2:] { for i, filter := range message[2:] {
filterData, ok := filter.(map[string]interface{}) filterData, ok := filter.(map[string]interface{})
@ -101,30 +101,25 @@ func processRequest(ws *websocket.Conn, message []interface{}) {
filters[i] = f filters[i] = f
} }
// Validate filters // Add the new subscription or update the existing one
if !utils.ValidateFilters(filters) {
fmt.Println("Invalid filters: hex values not valid")
response.SendClosed(ws, subID, "invalid: filters contain invalid hex values")
return
}
// Add subscription
subscriptions[subID] = relay.Subscription{Filters: filters} subscriptions[subID] = relay.Subscription{Filters: filters}
fmt.Printf("Subscription updated: %s with %d filters\n", subID, len(filters)) fmt.Printf("Subscription updated: %s with %d filters\n", subID, len(filters))
// Query the database with filters and send back the results // Query the database with filters and send back the results
queriedEvents, err := mongo.QueryEvents(filters, mongo.GetClient(), "grain") queriedEvents, err := db.QueryEvents(filters, db.GetClient(), "grain")
if err != nil { if err != nil {
fmt.Println("Error querying events:", err) fmt.Println("Error querying events:", err)
response.SendClosed(ws, subID, "error: could not query events") response.SendClosed(ws, subID, "error: could not query events")
return return
} }
// Log the number of events retrieved
fmt.Printf("Retrieved %d events for subscription %s\n", len(queriedEvents), subID) fmt.Printf("Retrieved %d events for subscription %s\n", len(queriedEvents), subID)
if len(queriedEvents) == 0 { if len(queriedEvents) == 0 {
fmt.Printf("No matching events found for subscription %s\n", subID) fmt.Printf("No matching events found for subscription %s\n", subID)
} }
// Send each event back to the client
for _, evt := range queriedEvents { for _, evt := range queriedEvents {
msg := []interface{}{"EVENT", subID, evt} msg := []interface{}{"EVENT", subID, evt}
msgBytes, _ := json.Marshal(msg) msgBytes, _ := json.Marshal(msg)
@ -136,7 +131,7 @@ func processRequest(ws *websocket.Conn, message []interface{}) {
} }
} }
// Send EOSE message // Indicate end of stored events
eoseMsg := []interface{}{"EOSE", subID} eoseMsg := []interface{}{"EOSE", subID}
eoseBytes, _ := json.Marshal(eoseMsg) eoseBytes, _ := json.Marshal(eoseMsg)
err = websocket.Message.Send(ws, string(eoseBytes)) err = websocket.Message.Send(ws, string(eoseBytes))

View File

@ -13,31 +13,17 @@ import (
"github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr"
) )
// EscapeSpecialChars escapes special characters in the content according to NIP-01 // SerializeEvent manually constructs the JSON string for event serialization
func EscapeSpecialChars(content string) string {
content = strings.ReplaceAll(content, "\\", "\\\\")
content = strings.ReplaceAll(content, "\"", "\\\"")
content = strings.ReplaceAll(content, "\n", "\\n")
content = strings.ReplaceAll(content, "\r", "\\r")
content = strings.ReplaceAll(content, "\t", "\\t")
content = strings.ReplaceAll(content, "\b", "\\b")
content = strings.ReplaceAll(content, "\f", "\\f")
return content
}
// SerializeEvent manually constructs the JSON string for event serialization according to NIP-01
func SerializeEvent(evt relay.Event) string { func SerializeEvent(evt relay.Event) string {
// Escape special characters in the content
escapedContent := EscapeSpecialChars(evt.Content)
// Manually construct the event data as a JSON array string // Manually construct the event data as a JSON array string
// Avoid escaping special characters like "&"
return fmt.Sprintf( return fmt.Sprintf(
`[0,"%s",%d,%d,%s,"%s"]`, `[0,"%s",%d,%d,%s,"%s"]`,
evt.PubKey, evt.PubKey,
evt.CreatedAt, evt.CreatedAt,
evt.Kind, evt.Kind,
serializeTags(evt.Tags), serializeTags(evt.Tags),
escapedContent, // Special characters are escaped evt.Content, // Special characters like "&" are not escaped here
) )
} }
@ -89,16 +75,19 @@ func CheckSignature(evt relay.Event) bool {
return false return false
} }
// Since the public key is 32 bytes, prepend 0x02 (assuming y-coordinate is even) // Parse the public key based on its length
var pubKey *btcec.PublicKey
if len(pubKeyBytes) == 32 { if len(pubKeyBytes) == 32 {
pubKeyBytes = append([]byte{0x02}, pubKeyBytes...) // Handle 32-byte compressed public key (x-coordinate only)
pubKey, err = btcec.ParsePubKey(append([]byte{0x02}, pubKeyBytes...))
} else if len(pubKeyBytes) == 33 || len(pubKeyBytes) == 65 {
// Handle standard compressed (33-byte) or uncompressed (65-byte) public key
pubKey, err = btcec.ParsePubKey(pubKeyBytes)
} else { } else {
log.Printf("Malformed public key: invalid length: %d", len(pubKeyBytes)) log.Printf("Malformed public key: invalid length: %d", len(pubKeyBytes))
return false return false
} }
// Parse the public key
pubKey, err := btcec.ParsePubKey(pubKeyBytes)
if err != nil { if err != nil {
log.Printf("Error parsing public key: %v", err) log.Printf("Error parsing public key: %v", err)
return false return false

View File

@ -1,38 +0,0 @@
package utils
import (
relay "grain/server/types"
"regexp"
)
// isValidHex validates if the given string is a 64-character lowercase hex string
func isValidHex(str string) bool {
return len(str) == 64 && regexp.MustCompile(`^[a-f0-9]{64}$`).MatchString(str)
}
// ValidateFilters ensures the IDs, Authors, and Tags follow the correct hex format
func ValidateFilters(filters []relay.Filter) bool {
for _, f := range filters {
// Validate IDs
for _, id := range f.IDs {
if !isValidHex(id) {
return false
}
}
// Validate Authors
for _, author := range f.Authors {
if !isValidHex(author) {
return false
}
}
// Validate Tags
for _, tags := range f.Tags {
for _, tag := range tags {
if !isValidHex(tag) {
return false
}
}
}
}
return true
}