diff --git a/.gitignore b/.gitignore index 5976ef3..4bc1c0d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ config.yml relay_metadata.json grain.exe -/build \ No newline at end of file +/build +/logs \ No newline at end of file diff --git a/server/db/queryMongo.go b/server/db/queryMongo.go index 82583f6..ed3077a 100644 --- a/server/db/queryMongo.go +++ b/server/db/queryMongo.go @@ -13,11 +13,12 @@ import ( // QueryEvents queries events from the MongoDB collection(s) based on filters func QueryEvents(filters []relay.Filter, client *mongo.Client, databaseName string) ([]relay.Event, error) { var results []relay.Event + var combinedFilters []bson.M + // Build MongoDB filters for each relay.Filter for _, filter := range filters { filterBson := bson.M{} - // Construct the BSON query based on the filters if len(filter.IDs) > 0 { filterBson["id"] = bson.M{"$in": filter.IDs} } @@ -45,62 +46,70 @@ func QueryEvents(filters []relay.Filter, client *mongo.Client, databaseName stri } } - opts := options.Find().SetSort(bson.D{{Key: "created_at", Value: -1}}) + combinedFilters = append(combinedFilters, filterBson) + } + + // Combine all filter conditions using the $or operator + query := bson.M{} + if len(combinedFilters) > 0 { + query["$or"] = combinedFilters + } + + opts := options.Find().SetSort(bson.D{{Key: "created_at", Value: -1}}) + + // Apply limit if set for initial query + for _, filter := range filters { if filter.Limit != nil { opts.SetLimit(int64(*filter.Limit)) } + } - // If no specific kinds are specified, query all collections in the database - if len(filter.Kinds) == 0 { - collections, err := client.Database(databaseName).ListCollectionNames(context.TODO(), bson.D{}) + // If no specific kinds are specified, query all collections + if len(filters[0].Kinds) == 0 { + collections, err := client.Database(databaseName).ListCollectionNames(context.TODO(), bson.D{}) + if err != nil { + return nil, fmt.Errorf("error listing collections: %v", err) + } + + for _, collectionName := range collections { + collection := client.Database(databaseName).Collection(collectionName) + cursor, err := collection.Find(context.TODO(), query, opts) if err != nil { - return nil, fmt.Errorf("error listing collections: %v", err) + return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err) } + defer cursor.Close(context.TODO()) - for _, collectionName := range collections { - fmt.Printf("Querying collection: %s with query: %v\n", collectionName, filterBson) - - collection := client.Database(databaseName).Collection(collectionName) - cursor, err := collection.Find(context.TODO(), filterBson, opts) - if err != nil { - return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err) - } - defer cursor.Close(context.TODO()) - - for cursor.Next(context.TODO()) { - var event relay.Event - if err := cursor.Decode(&event); err != nil { - return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err) - } - results = append(results, event) - } - if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err) + for cursor.Next(context.TODO()) { + var event relay.Event + if err := cursor.Decode(&event); err != nil { + return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err) } + results = append(results, event) } - } else { - // Query specific collections based on kinds - for _, kind := range filter.Kinds { - collectionName := fmt.Sprintf("event-kind%d", kind) - fmt.Printf("Querying collection: %s with query: %v\n", collectionName, filterBson) + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err) + } + } + } else { + // Query specific collections based on kinds + for _, kind := range filters[0].Kinds { + collectionName := fmt.Sprintf("event-kind%d", kind) + collection := client.Database(databaseName).Collection(collectionName) + cursor, err := collection.Find(context.TODO(), query, opts) + if err != nil { + return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err) + } + defer cursor.Close(context.TODO()) - collection := client.Database(databaseName).Collection(collectionName) - cursor, err := collection.Find(context.TODO(), filterBson, opts) - if err != nil { - return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err) - } - defer cursor.Close(context.TODO()) - - for cursor.Next(context.TODO()) { - var event relay.Event - if err := cursor.Decode(&event); err != nil { - return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err) - } - results = append(results, event) - } - if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err) + for cursor.Next(context.TODO()) { + var event relay.Event + if err := cursor.Decode(&event); err != nil { + return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err) } + results = append(results, event) + } + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err) } } } diff --git a/server/handlers/req.go b/server/handlers/req.go index eca958d..9b7b97b 100644 --- a/server/handlers/req.go +++ b/server/handlers/req.go @@ -49,99 +49,104 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri } } + // processRequest handles the actual processing of each request func processRequest(ws *websocket.Conn, message []interface{}) { - if len(message) < 3 { - fmt.Println("Invalid REQ message format") - response.SendClosed(ws, "", "invalid: invalid REQ message format") - return - } + if len(message) < 3 { + fmt.Println("Invalid REQ message format") + response.SendClosed(ws, "", "invalid: invalid REQ message format") + return + } - subID, ok := message[1].(string) - if !ok { - fmt.Println("Invalid subscription ID format") - response.SendClosed(ws, "", "invalid: invalid subscription ID format") - return - } + subID, ok := message[1].(string) + if !ok || len(subID) == 0 || len(subID) > 64 { + fmt.Println("Invalid subscription ID format or length") + response.SendClosed(ws, "", "invalid: subscription ID must be between 1 and 64 characters long") + return + } - mu.Lock() - defer mu.Unlock() + mu.Lock() + defer mu.Unlock() - // Check the current number of subscriptions for the client - if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient { - // Find and remove the oldest subscription (FIFO) - var oldestSubID string - for id := range subscriptions { - oldestSubID = id - break - } - delete(subscriptions, oldestSubID) - fmt.Println("Dropped oldest subscription:", oldestSubID) - } + // Remove oldest subscription if needed + if len(subscriptions) >= config.GetConfig().Server.MaxSubscriptionsPerClient { + var oldestSubID string + for id := range subscriptions { + oldestSubID = id + break + } + delete(subscriptions, oldestSubID) + fmt.Println("Dropped oldest subscription:", oldestSubID) + } - // Prepare filters based on the incoming message - filters := make([]relay.Filter, len(message)-2) - for i, filter := range message[2:] { - filterData, ok := filter.(map[string]interface{}) - if !ok { - fmt.Println("Invalid filter format") - response.SendClosed(ws, subID, "invalid: invalid filter format") - return - } + // Parse and validate filters + filters := make([]relay.Filter, len(message)-2) + for i, filter := range message[2:] { + filterData, ok := filter.(map[string]interface{}) + if !ok { + fmt.Println("Invalid filter format") + response.SendClosed(ws, subID, "invalid: invalid filter format") + return + } - var f relay.Filter - f.IDs = utils.ToStringArray(filterData["ids"]) - f.Authors = utils.ToStringArray(filterData["authors"]) - f.Kinds = utils.ToIntArray(filterData["kinds"]) - f.Tags = utils.ToTagsMap(filterData["tags"]) - f.Since = utils.ToTime(filterData["since"]) - f.Until = utils.ToTime(filterData["until"]) - f.Limit = utils.ToInt(filterData["limit"]) + var f relay.Filter + f.IDs = utils.ToStringArray(filterData["ids"]) + f.Authors = utils.ToStringArray(filterData["authors"]) + f.Kinds = utils.ToIntArray(filterData["kinds"]) + f.Tags = utils.ToTagsMap(filterData["tags"]) + f.Since = utils.ToTime(filterData["since"]) + f.Until = utils.ToTime(filterData["until"]) + f.Limit = utils.ToInt(filterData["limit"]) - filters[i] = f - } + filters[i] = f + } - // Add the new subscription or update the existing one - subscriptions[subID] = relay.Subscription{Filters: filters} - fmt.Printf("Subscription updated: %s with %d filters\n", subID, len(filters)) + // Validate filters + if !utils.ValidateFilters(filters) { + fmt.Println("Invalid filters: hex values not valid") + response.SendClosed(ws, subID, "invalid: filters contain invalid hex values") + return + } - // Query the database with filters and send back the results - queriedEvents, err := db.QueryEvents(filters, db.GetClient(), "grain") - if err != nil { - fmt.Println("Error querying events:", err) - response.SendClosed(ws, subID, "error: could not query events") - return - } + // Add subscription + subscriptions[subID] = relay.Subscription{Filters: filters} + fmt.Printf("Subscription updated: %s with %d filters\n", subID, len(filters)) - // Log the number of events retrieved - fmt.Printf("Retrieved %d events for subscription %s\n", len(queriedEvents), subID) - if len(queriedEvents) == 0 { - fmt.Printf("No matching events found for subscription %s\n", subID) - } + // Query the database with filters and send back the results + queriedEvents, err := db.QueryEvents(filters, db.GetClient(), "grain") + if err != nil { + fmt.Println("Error querying events:", err) + response.SendClosed(ws, subID, "error: could not query events") + return + } - // Send each event back to the client - for _, evt := range queriedEvents { - msg := []interface{}{"EVENT", subID, evt} - msgBytes, _ := json.Marshal(msg) - err = websocket.Message.Send(ws, string(msgBytes)) - if err != nil { - fmt.Println("Error sending event:", err) - response.SendClosed(ws, subID, "error: could not send event") - return - } - } + fmt.Printf("Retrieved %d events for subscription %s\n", len(queriedEvents), subID) + if len(queriedEvents) == 0 { + fmt.Printf("No matching events found for subscription %s\n", subID) + } - // Indicate end of stored events - eoseMsg := []interface{}{"EOSE", subID} - eoseBytes, _ := json.Marshal(eoseMsg) - err = websocket.Message.Send(ws, string(eoseBytes)) - if err != nil { - fmt.Println("Error sending EOSE:", err) - response.SendClosed(ws, subID, "error: could not send EOSE") - return - } + for _, evt := range queriedEvents { + msg := []interface{}{"EVENT", subID, evt} + msgBytes, _ := json.Marshal(msg) + err = websocket.Message.Send(ws, string(msgBytes)) + if err != nil { + fmt.Println("Error sending event:", err) + response.SendClosed(ws, subID, "error: could not send event") + return + } + } - fmt.Println("Subscription handling completed, keeping connection open.") + // Send EOSE message + eoseMsg := []interface{}{"EOSE", subID} + eoseBytes, _ := json.Marshal(eoseMsg) + err = websocket.Message.Send(ws, string(eoseBytes)) + if err != nil { + fmt.Println("Error sending EOSE:", err) + response.SendClosed(ws, subID, "error: could not send EOSE") + return + } + + fmt.Println("Subscription handling completed, keeping connection open.") } // Initialize the worker pool when your server starts diff --git a/server/utils/validateFilter.go b/server/utils/validateFilter.go new file mode 100644 index 0000000..e406807 --- /dev/null +++ b/server/utils/validateFilter.go @@ -0,0 +1,38 @@ +package utils + +import ( + relay "grain/server/types" + "regexp" +) + +// isValidHex validates if the given string is a 64-character lowercase hex string +func isValidHex(str string) bool { + return len(str) == 64 && regexp.MustCompile(`^[a-f0-9]{64}$`).MatchString(str) +} + +// ValidateFilters ensures the IDs, Authors, and Tags follow the correct hex format +func ValidateFilters(filters []relay.Filter) bool { + for _, f := range filters { + // Validate IDs + for _, id := range f.IDs { + if !isValidHex(id) { + return false + } + } + // Validate Authors + for _, author := range f.Authors { + if !isValidHex(author) { + return false + } + } + // Validate Tags + for _, tags := range f.Tags { + for _, tag := range tags { + if !isValidHex(tag) { + return false + } + } + } + } + return true +}