Compare commits

...

4 Commits

18 changed files with 273 additions and 289 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ config.yml
relay_metadata.json
grain.exe
/build
/logs

View File

@ -7,7 +7,7 @@ import (
"grain/config"
configTypes "grain/config/types"
relay "grain/server"
"grain/server/db"
"grain/server/db/mongo"
"grain/server/utils"
"log"
"net/http"
@ -41,7 +41,7 @@ func main() {
config.SetResourceLimit(&cfg.ResourceLimits) // Apply limits once before starting the server
client, err := db.InitDB(cfg)
client, err := mongo.InitDB(cfg)
if err != nil {
log.Fatal("Error initializing database: ", err)
}
@ -71,7 +71,7 @@ func main() {
case <-signalChan:
log.Println("Shutting down server...")
server.Close() // Stop the server
db.DisconnectDB(client) // Disconnect from MongoDB
mongo.DisconnectDB(client) // Disconnect from MongoDB
wg.Wait() // Wait for all goroutines to finish
return
}

View File

@ -1,4 +1,4 @@
package db
package mongo
import (
"context"

View File

@ -12,7 +12,7 @@ import (
"golang.org/x/net/websocket"
)
func HandleKind5(ctx context.Context, evt relay.Event, dbClient *mongo.Client, ws *websocket.Conn) error {
func HandleDeleteKind(ctx context.Context, evt relay.Event, dbClient *mongo.Client, ws *websocket.Conn) error {
for _, tag := range evt.Tags {
if len(tag) < 2 {
continue

View File

@ -8,7 +8,7 @@ import (
"golang.org/x/net/websocket"
)
func HandleKind2(ctx context.Context, evt relay.Event, ws *websocket.Conn) error {
func HandleDeprecatedKind(ctx context.Context, evt relay.Event, ws *websocket.Conn) error {
// Send an OK message to indicate the event was not accepted
response.SendOK(ws, evt.ID, false, "invalid: kind 2 is deprecated, use kind 10002 (NIP65)")

View File

@ -0,0 +1,112 @@
package mongo
import (
"context"
"fmt"
relay "grain/server/types"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// QueryEvents queries events from the MongoDB collection(s) based on filters
func QueryEvents(filters []relay.Filter, client *mongo.Client, databaseName string) ([]relay.Event, error) {
var results []relay.Event
var combinedFilters []bson.M
// Build MongoDB filters for each relay.Filter
for _, filter := range filters {
filterBson := bson.M{}
if len(filter.IDs) > 0 {
filterBson["id"] = bson.M{"$in": filter.IDs}
}
if len(filter.Authors) > 0 {
filterBson["pubkey"] = bson.M{"$in": filter.Authors}
}
if len(filter.Kinds) > 0 {
filterBson["kind"] = bson.M{"$in": filter.Kinds}
}
if filter.Tags != nil {
for key, values := range filter.Tags {
if len(values) > 0 {
filterBson["tags."+key] = bson.M{"$in": values}
}
}
}
if filter.Since != nil {
filterBson["created_at"] = bson.M{"$gte": *filter.Since}
}
if filter.Until != nil {
if filterBson["created_at"] == nil {
filterBson["created_at"] = bson.M{"$lte": *filter.Until}
} else {
filterBson["created_at"].(bson.M)["$lte"] = *filter.Until
}
}
combinedFilters = append(combinedFilters, filterBson)
}
// Combine all filter conditions using the $or operator
query := bson.M{}
if len(combinedFilters) > 0 {
query["$or"] = combinedFilters
}
// Apply sorting by creation date (descending)
opts := options.Find().SetSort(bson.D{{Key: "created_at", Value: -1}})
// Apply limit if set in any filter
for _, filter := range filters {
if filter.Limit != nil {
opts.SetLimit(int64(*filter.Limit))
}
}
// If no kinds are specified in any filter, query all collections
var collections []string
if len(filters) > 0 && len(filters[0].Kinds) == 0 {
collections, _ = client.Database(databaseName).ListCollectionNames(context.TODO(), bson.D{})
} else {
// Collect all kinds from filters and query those collections
kindsMap := make(map[int]bool)
for _, filter := range filters {
for _, kind := range filter.Kinds {
kindsMap[kind] = true
}
}
// Construct collection names based on kinds
for kind := range kindsMap {
collectionName := fmt.Sprintf("event-kind%d", kind)
collections = append(collections, collectionName)
}
}
// Query each collection
for _, collectionName := range collections {
collection := client.Database(databaseName).Collection(collectionName)
cursor, err := collection.Find(context.TODO(), query, opts)
if err != nil {
return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err)
}
defer cursor.Close(context.TODO())
for cursor.Next(context.TODO()) {
var event relay.Event
if err := cursor.Decode(&event); err != nil {
return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err)
}
results = append(results, event)
}
// Handle cursor errors
if err := cursor.Err(); err != nil {
return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err)
}
}
return results, nil
}

View File

@ -1,9 +1,9 @@
package db
package mongo
import (
"context"
"fmt"
"grain/server/handlers/kinds"
"grain/server/db/mongo/kinds"
"grain/server/handlers/response"
nostr "grain/server/types"
@ -16,15 +16,15 @@ func StoreMongoEvent(ctx context.Context, evt nostr.Event, ws *websocket.Conn) {
var err error
switch {
case evt.Kind == 0:
err = kinds.HandleKind0(ctx, evt, collection, ws)
err = kinds.HandleReplaceableKind(ctx, evt, collection, ws)
case evt.Kind == 1:
err = kinds.HandleKind1(ctx, evt, collection, ws)
err = kinds.HandleRegularKind(ctx, evt, collection, ws)
case evt.Kind == 2:
err = kinds.HandleKind2(ctx, evt, ws)
err = kinds.HandleDeprecatedKind(ctx, evt, ws)
case evt.Kind == 3:
err = kinds.HandleReplaceableKind(ctx, evt, collection, ws)
case evt.Kind == 5:
err = kinds.HandleKind5(ctx, evt, GetClient(), ws)
err = kinds.HandleDeleteKind(ctx, evt, GetClient(), ws)
case evt.Kind >= 4 && evt.Kind < 45:
err = kinds.HandleRegularKind(ctx, evt, collection, ws)
case evt.Kind >= 1000 && evt.Kind < 10000:

View File

@ -1,109 +0,0 @@
package db
import (
"context"
"fmt"
relay "grain/server/types"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// QueryEvents queries events from the MongoDB collection(s) based on filters
func QueryEvents(filters []relay.Filter, client *mongo.Client, databaseName string) ([]relay.Event, error) {
var results []relay.Event
for _, filter := range filters {
filterBson := bson.M{}
// Construct the BSON query based on the filters
if len(filter.IDs) > 0 {
filterBson["id"] = bson.M{"$in": filter.IDs}
}
if len(filter.Authors) > 0 {
filterBson["pubkey"] = bson.M{"$in": filter.Authors}
}
if len(filter.Kinds) > 0 {
filterBson["kind"] = bson.M{"$in": filter.Kinds}
}
if filter.Tags != nil {
for key, values := range filter.Tags {
if len(values) > 0 {
filterBson["tags."+key] = bson.M{"$in": values}
}
}
}
if filter.Since != nil {
filterBson["created_at"] = bson.M{"$gte": *filter.Since}
}
if filter.Until != nil {
if filterBson["created_at"] == nil {
filterBson["created_at"] = bson.M{"$lte": *filter.Until}
} else {
filterBson["created_at"].(bson.M)["$lte"] = *filter.Until
}
}
opts := options.Find().SetSort(bson.D{{Key: "created_at", Value: -1}})
if filter.Limit != nil {
opts.SetLimit(int64(*filter.Limit))
}
// If no specific kinds are specified, query all collections in the database
if len(filter.Kinds) == 0 {
collections, err := client.Database(databaseName).ListCollectionNames(context.TODO(), bson.D{})
if err != nil {
return nil, fmt.Errorf("error listing collections: %v", err)
}
for _, collectionName := range collections {
fmt.Printf("Querying collection: %s with query: %v\n", collectionName, filterBson)
collection := client.Database(databaseName).Collection(collectionName)
cursor, err := collection.Find(context.TODO(), filterBson, opts)
if err != nil {
return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err)
}
defer cursor.Close(context.TODO())
for cursor.Next(context.TODO()) {
var event relay.Event
if err := cursor.Decode(&event); err != nil {
return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err)
}
results = append(results, event)
}
if err := cursor.Err(); err != nil {
return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err)
}
}
} else {
// Query specific collections based on kinds
for _, kind := range filter.Kinds {
collectionName := fmt.Sprintf("event-kind%d", kind)
fmt.Printf("Querying collection: %s with query: %v\n", collectionName, filterBson)
collection := client.Database(databaseName).Collection(collectionName)
cursor, err := collection.Find(context.TODO(), filterBson, opts)
if err != nil {
return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err)
}
defer cursor.Close(context.TODO())
for cursor.Next(context.TODO()) {
var event relay.Event
if err := cursor.Decode(&event); err != nil {
return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err)
}
results = append(results, event)
}
if err := cursor.Err(); err != nil {
return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err)
}
}
}
}
return results, nil
}

View File

@ -5,7 +5,7 @@ import (
"encoding/json"
"fmt"
"grain/config"
"grain/server/db"
"grain/server/db/mongo"
"grain/server/handlers/response"
"grain/server/utils"
@ -61,7 +61,7 @@ func HandleEvent(ws *websocket.Conn, message []interface{}) {
}
// This is where I'll handle storage for multiple database types in the future
db.StoreMongoEvent(context.TODO(), evt, ws)
mongo.StoreMongoEvent(context.TODO(), evt, ws)
fmt.Println("Event processed:", evt.ID)

View File

@ -1,50 +0,0 @@
package kinds
import (
"context"
"fmt"
"grain/server/handlers/response"
relay "grain/server/types"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"golang.org/x/net/websocket"
)
func HandleKind0(ctx context.Context, evt relay.Event, collection *mongo.Collection, ws *websocket.Conn) error {
filter := bson.M{"pubkey": evt.PubKey}
var existingEvent relay.Event
err := collection.FindOne(ctx, filter).Decode(&existingEvent)
if err != nil && err != mongo.ErrNoDocuments {
return fmt.Errorf("error finding existing event: %v", err)
}
if err != mongo.ErrNoDocuments {
if existingEvent.CreatedAt >= evt.CreatedAt {
response.SendOK(ws, evt.ID, false, "blocked: a newer kind 0 event already exists for this pubkey")
return nil
}
}
update := bson.M{
"$set": bson.M{
"id": evt.ID,
"created_at": evt.CreatedAt,
"kind": evt.Kind,
"tags": evt.Tags,
"content": evt.Content,
"sig": evt.Sig,
},
}
opts := options.Update().SetUpsert(true)
_, err = collection.UpdateOne(ctx, filter, update, opts)
if err != nil {
response.SendOK(ws, evt.ID, false, "error: could not connect to the database")
return fmt.Errorf("error updating/inserting event kind 0 into MongoDB: %v", err)
}
response.SendOK(ws, evt.ID, true, "")
fmt.Println("Upserted event kind 0 into MongoDB:", evt.ID)
return nil
}

View File

@ -1,24 +0,0 @@
// kinds/kind1.go
package kinds
import (
"context"
"fmt"
"grain/server/handlers/response"
relay "grain/server/types"
"go.mongodb.org/mongo-driver/mongo"
"golang.org/x/net/websocket"
)
func HandleKind1(ctx context.Context, evt relay.Event, collection *mongo.Collection, ws *websocket.Conn) error {
_, err := collection.InsertOne(ctx, evt)
if err != nil {
response.SendOK(ws, evt.ID, false, "error: could not connect to the database")
return fmt.Errorf("error inserting event into MongoDB: %v", err)
}
fmt.Println("Inserted event kind 1 into MongoDB:", evt.ID)
response.SendOK(ws, evt.ID, true, "")
return nil
}

View File

@ -4,7 +4,7 @@ import (
"encoding/json"
"fmt"
"grain/config"
"grain/server/db"
"grain/server/db/mongo"
"grain/server/handlers/response"
relay "grain/server/types"
"grain/server/utils"
@ -49,6 +49,7 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri
}
}
// processRequest handles the actual processing of each request
func processRequest(ws *websocket.Conn, message []interface{}) {
if len(message) < 3 {
@ -58,18 +59,17 @@ func processRequest(ws *websocket.Conn, message []interface{}) {
}
subID, ok := message[1].(string)
if !ok {
fmt.Println("Invalid subscription ID format")
response.SendClosed(ws, "", "invalid: invalid subscription ID format")
if !ok || len(subID) == 0 || len(subID) > 64 {
fmt.Println("Invalid subscription ID format or length")
response.SendClosed(ws, "", "invalid: subscription ID must be between 1 and 64 characters long")
return
}
mu.Lock()
defer mu.Unlock()
// Check the current number of subscriptions for the client
// Remove oldest subscription if needed
if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient {
// Find and remove the oldest subscription (FIFO)
var oldestSubID string
for id := range subscriptions {
oldestSubID = id
@ -79,7 +79,7 @@ func processRequest(ws *websocket.Conn, message []interface{}) {
fmt.Println("Dropped oldest subscription:", oldestSubID)
}
// Prepare filters based on the incoming message
// Parse and validate filters
filters := make([]relay.Filter, len(message)-2)
for i, filter := range message[2:] {
filterData, ok := filter.(map[string]interface{})
@ -101,25 +101,30 @@ func processRequest(ws *websocket.Conn, message []interface{}) {
filters[i] = f
}
// Add the new subscription or update the existing one
// Validate filters
if !utils.ValidateFilters(filters) {
fmt.Println("Invalid filters: hex values not valid")
response.SendClosed(ws, subID, "invalid: filters contain invalid hex values")
return
}
// Add subscription
subscriptions[subID] = relay.Subscription{Filters: filters}
fmt.Printf("Subscription updated: %s with %d filters\n", subID, len(filters))
// Query the database with filters and send back the results
queriedEvents, err := db.QueryEvents(filters, db.GetClient(), "grain")
queriedEvents, err := mongo.QueryEvents(filters, mongo.GetClient(), "grain")
if err != nil {
fmt.Println("Error querying events:", err)
response.SendClosed(ws, subID, "error: could not query events")
return
}
// Log the number of events retrieved
fmt.Printf("Retrieved %d events for subscription %s\n", len(queriedEvents), subID)
if len(queriedEvents) == 0 {
fmt.Printf("No matching events found for subscription %s\n", subID)
}
// Send each event back to the client
for _, evt := range queriedEvents {
msg := []interface{}{"EVENT", subID, evt}
msgBytes, _ := json.Marshal(msg)
@ -131,7 +136,7 @@ func processRequest(ws *websocket.Conn, message []interface{}) {
}
}
// Indicate end of stored events
// Send EOSE message
eoseMsg := []interface{}{"EOSE", subID}
eoseBytes, _ := json.Marshal(eoseMsg)
err = websocket.Message.Send(ws, string(eoseBytes))

View File

@ -13,17 +13,31 @@ import (
"github.com/btcsuite/btcd/btcec/v2/schnorr"
)
// SerializeEvent manually constructs the JSON string for event serialization
// EscapeSpecialChars escapes special characters in the content according to NIP-01
func EscapeSpecialChars(content string) string {
content = strings.ReplaceAll(content, "\\", "\\\\")
content = strings.ReplaceAll(content, "\"", "\\\"")
content = strings.ReplaceAll(content, "\n", "\\n")
content = strings.ReplaceAll(content, "\r", "\\r")
content = strings.ReplaceAll(content, "\t", "\\t")
content = strings.ReplaceAll(content, "\b", "\\b")
content = strings.ReplaceAll(content, "\f", "\\f")
return content
}
// SerializeEvent manually constructs the JSON string for event serialization according to NIP-01
func SerializeEvent(evt relay.Event) string {
// Escape special characters in the content
escapedContent := EscapeSpecialChars(evt.Content)
// Manually construct the event data as a JSON array string
// Avoid escaping special characters like "&"
return fmt.Sprintf(
`[0,"%s",%d,%d,%s,"%s"]`,
evt.PubKey,
evt.CreatedAt,
evt.Kind,
serializeTags(evt.Tags),
evt.Content, // Special characters like "&" are not escaped here
escapedContent, // Special characters are escaped
)
}
@ -75,19 +89,16 @@ func CheckSignature(evt relay.Event) bool {
return false
}
// Parse the public key based on its length
var pubKey *btcec.PublicKey
// Since the public key is 32 bytes, prepend 0x02 (assuming y-coordinate is even)
if len(pubKeyBytes) == 32 {
// Handle 32-byte compressed public key (x-coordinate only)
pubKey, err = btcec.ParsePubKey(append([]byte{0x02}, pubKeyBytes...))
} else if len(pubKeyBytes) == 33 || len(pubKeyBytes) == 65 {
// Handle standard compressed (33-byte) or uncompressed (65-byte) public key
pubKey, err = btcec.ParsePubKey(pubKeyBytes)
pubKeyBytes = append([]byte{0x02}, pubKeyBytes...)
} else {
log.Printf("Malformed public key: invalid length: %d", len(pubKeyBytes))
return false
}
// Parse the public key
pubKey, err := btcec.ParsePubKey(pubKeyBytes)
if err != nil {
log.Printf("Error parsing public key: %v", err)
return false

View File

@ -0,0 +1,38 @@
package utils
import (
relay "grain/server/types"
"regexp"
)
// isValidHex validates if the given string is a 64-character lowercase hex string
func isValidHex(str string) bool {
return len(str) == 64 && regexp.MustCompile(`^[a-f0-9]{64}$`).MatchString(str)
}
// ValidateFilters ensures the IDs, Authors, and Tags follow the correct hex format
func ValidateFilters(filters []relay.Filter) bool {
for _, f := range filters {
// Validate IDs
for _, id := range f.IDs {
if !isValidHex(id) {
return false
}
}
// Validate Authors
for _, author := range f.Authors {
if !isValidHex(author) {
return false
}
}
// Validate Tags
for _, tags := range f.Tags {
for _, tag := range tags {
if !isValidHex(tag) {
return false
}
}
}
}
return true
}