package service import ( "encoding/base64" "encoding/json" "fmt" "io" "log" "net/http" "strings" "x-agents/server/internal/model" "x-agents/server/internal/repository" "github.com/google/uuid" ) // Neo4jService Neo4j 服务 type Neo4jService struct { client *http.Client databaseRepo *repository.DatabaseRepository } func NewNeo4jService(dbRepo *repository.DatabaseRepository) *Neo4jService { return &Neo4jService{ client: &http.Client{}, databaseRepo: dbRepo, } } // GetGraphs 获取图谱概览数据(新增接口) func (s *Neo4jService) GetGraphs(req model.Neo4jGraphRequest) (*model.Neo4jGraphResponse, error) { host := "localhost" port := 7687 if req.URI != "" { // 解析 URI uri := strings.TrimPrefix(req.URI, "bolt://") uri = strings.TrimPrefix(uri, "neo4j://") if idx := strings.Index(uri, ":"); idx > 0 { host = uri[:idx] fmt.Sscanf(uri[idx+1:], "%d", &port) } } db := req.Database if db == "" { db = "neo4j" } auth := fmt.Sprintf("%s:%s", req.Username, req.Password) encodedAuth := base64.StdEncoding.EncodeToString([]byte(auth)) // 尝试多个端口 ports := []int{port - 1000, 7474, port} checkedPorts := make(map[int]bool) for _, p := range ports { if checkedPorts[p] { continue } checkedPorts[p] = true url := fmt.Sprintf("http://%s:%d/db/%s/tx/commit", host, p, db) graph, err := s.getGraphOverviewWithURL(url, encodedAuth, db) if err == nil && graph != nil { return &model.Neo4jGraphResponse{ Success: true, Message: "success", Graphs: graph, }, nil } } return &model.Neo4jGraphResponse{ Success: false, Message: "failed to connect to Neo4j", }, nil } // Check 测试 Neo4j 连接 func (s *Neo4jService) Check(req model.Neo4jCheckRequest) (*model.Neo4jCheckResponse, error) { db := req.Database if db == "" { db = "neo4j" } auth := fmt.Sprintf("%s:%s", req.Username, req.Password) encodedAuth := base64.StdEncoding.EncodeToString([]byte(auth)) httpPort := req.Port - 1000 if httpPort <= 0 { httpPort = 7474 } ports := []int{httpPort, 7474, req.Port} checkedPorts := make(map[int]bool) var version string for _, port := range ports { if checkedPorts[port] { continue } checkedPorts[port] = true url := fmt.Sprintf("http://%s:%d/db/%s/tx/commit", req.Host, port, db) resp, err := s.checkWithURL(url, encodedAuth, db) if err == nil && resp.Success { version = resp.Version // 连接成功,检查或创建数据库记录 log.Printf("[Check] Neo4j 连接成功,准备获取/创建数据库记录, host=%s, port=%d, db=%s", req.Host, req.Port, db) dbInfo, err := s.ensureNeo4jDatabase(req, db) if err != nil { log.Printf("[Check] 确保数据库记录失败: %v", err) } log.Printf("[Check] 数据库记录ID: %s, Name: %s", dbInfo.ID, dbInfo.Name) return &model.Neo4jCheckResponse{ Success: true, Message: "connection successful", Version: version, Databases: []string{db}, DatabaseID: dbInfo.ID, Name: dbInfo.Name, Description: dbInfo.Description, }, nil } } return &model.Neo4jCheckResponse{ Success: false, Message: fmt.Sprintf("connection failed on all ports (7474, %d)", req.Port), }, nil } // Neo4jDatabaseInfo Neo4j 数据库记录信息 type Neo4jDatabaseInfo struct { ID string Name string Description string } // ensureNeo4jDatabase 确保 Neo4j 数据库记录存在 func (s *Neo4jService) ensureNeo4jDatabase(req model.Neo4jCheckRequest, dbName string) (*Neo4jDatabaseInfo, error) { log.Printf("[ensureNeo4jDatabase] 开始处理, host=%s, port=%d, username=%s, dbName=%s, uri=%s", req.Host, req.Port, req.Username, dbName, req.URI) // 根据 host, port, username, database 查找是否已存在 databases, err := s.databaseRepo.FindAll() if err != nil { log.Printf("[ensureNeo4jDatabase] FindAll 失败: %v", err) return nil, err } log.Printf("[ensureNeo4jDatabase] 找到 %d 条数据库记录", len(databases)) // 构建 URI uri := req.URI if uri == "" { uri = fmt.Sprintf("bolt://%s:%d", req.Host, req.Port) } log.Printf("[ensureNeo4jDatabase] 使用 URI: %s", uri) // 查找已存在的记录 for _, d := range databases { log.Printf("[ensureNeo4jDatabase] 对比: URI=%s, Username=%s, Database=%s", d.URI, d.Username, d.Database) if d.URI == uri && d.Username == req.Username && d.Database == dbName { log.Printf("[ensureNeo4jDatabase] 找到已存在的记录, id=%s, name=%s", d.ID, d.Name) return &Neo4jDatabaseInfo{ ID: d.ID, Name: d.Name, Description: d.Description, }, nil } } // 不存在,创建新记录 log.Printf("[ensureNeo4jDatabase] 未找到匹配记录,创建新记录") dbType := "neo4j" name := req.Name if name == "" { name = fmt.Sprintf("Neo4j-%s", dbName) } description := req.Description if description == "" { description = fmt.Sprintf("Neo4j %s@%s:%d", dbName, req.Host, req.Port) } newDB := &model.DatabaseInfo{ ID: uuid.New().String(), Name: name, DBType: dbType, Host: req.Host, Port: req.Port, Username: req.Username, Password: req.Password, Database: dbName, URI: uri, Description: description, } log.Printf("[ensureNeo4jDatabase] 创建新数据库: ID=%s, Name=%s, Host=%s, Port=%d, URI=%s", newDB.ID, newDB.Name, newDB.Host, newDB.Port, newDB.URI) if err := s.databaseRepo.Create(newDB); err != nil { log.Printf("[ensureNeo4jDatabase] Create 失败: %v", err) return nil, err } log.Printf("[ensureNeo4jDatabase] 创建成功, 返回 ID=%s", newDB.ID) return &Neo4jDatabaseInfo{ ID: newDB.ID, Name: newDB.Name, Description: newDB.Description, }, nil } // GetGraphOverview 获取图谱概览数据 func (s *Neo4jService) GetGraphOverview(req model.CheckRequest) (*model.GraphOverview, error) { // 从 CheckRequest 获取连接信息 // URI 可能是 bolt://host:7687 格式 host := req.Host port := req.Port if req.URI != "" { // 解析 URI uri := strings.TrimPrefix(req.URI, "bolt://") uri = strings.TrimPrefix(uri, "neo4j://") if idx := strings.Index(uri, ":"); idx > 0 { host = uri[:idx] fmt.Sscanf(uri[idx+1:], "%d", &port) } } db := req.Database if db == "" { db = "neo4j" } auth := fmt.Sprintf("%s:%s", req.Username, req.Password) encodedAuth := base64.StdEncoding.EncodeToString([]byte(auth)) // 尝试多个端口 ports := []int{port - 1000, 7474, port} checkedPorts := make(map[int]bool) var graph *model.GraphOverview for _, p := range ports { if checkedPorts[p] { continue } checkedPorts[p] = true url := fmt.Sprintf("http://%s:%d/db/%s/tx/commit", host, p, db) var err error graph, err = s.getGraphOverviewWithURL(url, encodedAuth, db) if err == nil && graph != nil { return graph, nil } } return nil, fmt.Errorf("failed to connect to Neo4j") } func (s *Neo4jService) getGraphOverviewWithURL(url, encodedAuth, db string) (*model.GraphOverview, error) { // 查询所有标签及其数量 labelsQuery := `CALL db.labels() YIELD label RETURN label` relTypesQuery := `CALL db.relationshipTypes() YIELD relationshipType RETURN relationshipType` // 构建复合查询 query := fmt.Sprintf(`{"statements": [ {"statement": "%s", "resultDataContents": ["row"]}, {"statement": "%s", "resultDataContents": ["row"]}, {"statement": "MATCH (n) RETURN labels(n) as nodeLabels, count(*) as count", "resultDataContents": ["row"]}, {"statement": "MATCH ()-[r]->() RETURN type(r) as relType, count(*) as count", "resultDataContents": ["row"]} ]}`, labelsQuery, relTypesQuery) reqBody := strings.NewReader(query) httpReq, err := http.NewRequest("POST", url, reqBody) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Basic "+encodedAuth) resp, err := s.client.Do(httpReq) if err != nil { return nil, err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } var result map[string]interface{} if err := json.Unmarshal(body, &result); err != nil { return nil, err } // 检查错误 if errors, ok := result["errors"].([]interface{}); ok && len(errors) > 0 { if errMap, ok := errors[0].(map[string]interface{}); ok { if msg, ok := errMap["message"].(string); ok { return nil, fmt.Errorf("%s", msg) } } return nil, fmt.Errorf("query error") } // 解析结果 results, ok := result["results"].([]interface{}) if !ok || len(results) < 4 { return nil, fmt.Errorf("invalid response format") } graph := &model.GraphOverview{ Labels: []model.LabelCount{}, RelationshipTypes: []model.RelTypeCount{}, Nodes: []model.NodeProperty{}, Relationships: []model.RelProperty{}, } // 解析标签 if len(results) > 0 { if res0, ok := results[0].(map[string]interface{}); ok { if data, ok := res0["data"].([]interface{}); ok { for _, item := range data { if row, ok := item.(map[string]interface{}); ok { if label, ok := row["row"].([]interface{}); ok && len(label) > 0 { graph.Labels = append(graph.Labels, model.LabelCount{ Name: fmt.Sprintf("%v", label[0]), Count: 0, }) } } } } } } // 解析关系类型 if len(results) > 1 { if res1, ok := results[1].(map[string]interface{}); ok { if data, ok := res1["data"].([]interface{}); ok { for _, item := range data { if row, ok := item.(map[string]interface{}); ok { if relType, ok := row["row"].([]interface{}); ok && len(relType) > 0 { graph.RelationshipTypes = append(graph.RelationshipTypes, model.RelTypeCount{ Name: fmt.Sprintf("%v", relType[0]), Count: 0, }) } } } } } } // 解析节点统计 if len(results) > 2 { if res2, ok := results[2].(map[string]interface{}); ok { if data, ok := res2["data"].([]interface{}); ok { for _, item := range data { if row, ok := item.(map[string]interface{}); ok { if rowData, ok := row["row"].([]interface{}); ok && len(rowData) >= 2 { if labels, ok := rowData[0].([]interface{}); ok && len(labels) > 0 { labelName := fmt.Sprintf("%v", labels[0]) count := int(rowData[1].(float64)) // 更新标签数量 for i := range graph.Labels { if graph.Labels[i].Name == labelName { graph.Labels[i].Count = count break } } } } } } } } } // 解析关系统计 if len(results) > 3 { if res3, ok := results[3].(map[string]interface{}); ok { if data, ok := res3["data"].([]interface{}); ok { for _, item := range data { if row, ok := item.(map[string]interface{}); ok { if rowData, ok := row["row"].([]interface{}); ok && len(rowData) >= 2 { relTypeName := fmt.Sprintf("%v", rowData[0]) count := int(rowData[1].(float64)) // 更新关系类型数量 for i := range graph.RelationshipTypes { if graph.RelationshipTypes[i].Name == relTypeName { graph.RelationshipTypes[i].Count = count break } } } } } } } } // 获取节点属性定义 s.fillNodeProperties(url, encodedAuth, db, graph) // 获取关系属性定义 s.fillRelProperties(url, encodedAuth, db, graph) return graph, nil } func (s *Neo4jService) fillNodeProperties(url, encodedAuth, db string, graph *model.GraphOverview) { // 获取每个标签的属性 for i, label := range graph.Labels { query := fmt.Sprintf(`MATCH (n:%s) RETURN properties(n) as props LIMIT 1`, label.Name) body := fmt.Sprintf(`{"statements": [{"statement": "%s", "resultDataContents": ["row"]}]}`, query) reqBody := strings.NewReader(body) httpReq, err := http.NewRequest("POST", url, reqBody) if err != nil { continue } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Basic "+encodedAuth) resp, err := s.client.Do(httpReq) if err != nil { continue } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { continue } var result map[string]interface{} if err := json.Unmarshal(respBody, &result); err != nil { continue } results, ok := result["results"].([]interface{}) if !ok || len(results) == 0 { continue } if res0, ok := results[0].(map[string]interface{}); ok { if data, ok := res0["data"].([]interface{}); ok && len(data) > 0 { if item, ok := data[0].(map[string]interface{}); ok { if row, ok := item["row"].([]interface{}); ok && len(row) > 0 { if props, ok := row[0].(map[string]interface{}); ok { nodeProp := model.NodeProperty{ Label: label.Name, Properties: []model.PropertyInfo{}, } for name, value := range props { nodeProp.Properties = append(nodeProp.Properties, model.PropertyInfo{ Name: name, Type: fmt.Sprintf("%T", value), }) } graph.Nodes = append(graph.Nodes, nodeProp) graph.Labels[i].Count = label.Count } } } } } } } func (s *Neo4jService) fillRelProperties(url, encodedAuth, db string, graph *model.GraphOverview) { // 获取每个关系类型的属性 for _, relType := range graph.RelationshipTypes { query := fmt.Sprintf(`MATCH ()-[r:%s]->() RETURN properties(r) as props LIMIT 1`, relType.Name) body := fmt.Sprintf(`{"statements": [{"statement": "%s", "resultDataContents": ["row"]}]}`, query) reqBody := strings.NewReader(body) httpReq, err := http.NewRequest("POST", url, reqBody) if err != nil { continue } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Basic "+encodedAuth) resp, err := s.client.Do(httpReq) if err != nil { continue } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { continue } var result map[string]interface{} if err := json.Unmarshal(respBody, &result); err != nil { continue } results, ok := result["results"].([]interface{}) if !ok || len(results) == 0 { continue } if res0, ok := results[0].(map[string]interface{}); ok { if data, ok := res0["data"].([]interface{}); ok && len(data) > 0 { if item, ok := data[0].(map[string]interface{}); ok { if row, ok := item["row"].([]interface{}); ok && len(row) > 0 { if props, ok := row[0].(map[string]interface{}); ok { relProp := model.RelProperty{ Type: relType.Name, Properties: []model.PropertyInfo{}, } for name, value := range props { relProp.Properties = append(relProp.Properties, model.PropertyInfo{ Name: name, Type: fmt.Sprintf("%T", value), }) } graph.Relationships = append(graph.Relationships, relProp) } } } } } } } // GetNodes 获取节点详情 func (s *Neo4jService) GetNodes(req model.Neo4jNodeRequest) (*model.Neo4jNodeResponse, error) { host := "localhost" port := 7687 if req.URI != "" { uri := strings.TrimPrefix(req.URI, "bolt://") uri = strings.TrimPrefix(uri, "neo4j://") if idx := strings.Index(uri, ":"); idx > 0 { host = uri[:idx] fmt.Sscanf(uri[idx+1:], "%d", &port) } } db := req.Database if db == "" { db = "neo4j" } limit := req.Limit if limit <= 0 { limit = 10 } auth := fmt.Sprintf("%s:%s", req.Username, req.Password) encodedAuth := base64.StdEncoding.EncodeToString([]byte(auth)) ports := []int{port - 1000, 7474, port} checkedPorts := make(map[int]bool) for _, p := range ports { if checkedPorts[p] { continue } checkedPorts[p] = true url := fmt.Sprintf("http://%s:%d/db/%s/tx/commit", host, p, db) nodes, props, err := s.getNodesWithURL(url, encodedAuth, req.Label, limit) if err == nil { return &model.Neo4jNodeResponse{ Success: true, Message: "success", Nodes: nodes, Properties: props, }, nil } } return &model.Neo4jNodeResponse{ Success: false, Message: "failed to connect to Neo4j", }, nil } func (s *Neo4jService) getNodesWithURL(url, encodedAuth, label string, limit int) ([]map[string]interface{}, []model.PropertyInfo, error) { query := fmt.Sprintf(`MATCH (n:%s) RETURN properties(n) as props, elementId(n) as id LIMIT %d`, label, limit) body := fmt.Sprintf(`{"statements": [{"statement": "%s", "resultDataContents": ["row"]}]}`, query) reqBody := strings.NewReader(body) httpReq, err := http.NewRequest("POST", url, reqBody) if err != nil { return nil, nil, err } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Basic "+encodedAuth) resp, err := s.client.Do(httpReq) if err != nil { return nil, nil, err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, nil, err } var result map[string]interface{} if err := json.Unmarshal(respBody, &result); err != nil { return nil, nil, err } if errors, ok := result["errors"].([]interface{}); ok && len(errors) > 0 { if errMap, ok := errors[0].(map[string]interface{}); ok { if msg, ok := errMap["message"].(string); ok { return nil, nil, fmt.Errorf("%s", msg) } } return nil, nil, fmt.Errorf("query error") } results, ok := result["results"].([]interface{}) if !ok || len(results) == 0 { return []map[string]interface{}{}, []model.PropertyInfo{}, nil } var nodes []map[string]interface{} propsMap := make(map[string]string) // 用于去重属性 if res0, ok := results[0].(map[string]interface{}); ok { if data, ok := res0["data"].([]interface{}); ok { for _, item := range data { if row, ok := item.(map[string]interface{}); ok { if rowData, ok := row["row"].([]interface{}); ok && len(rowData) >= 2 { node := map[string]interface{}{ "id": rowData[1], } if props, ok := rowData[0].(map[string]interface{}); ok { for k, v := range props { node[k] = v // 收集属性类型 if _, exists := propsMap[k]; !exists { propsMap[k] = fmt.Sprintf("%T", v) } } } nodes = append(nodes, node) } } } } } // 构建属性列表 var properties []model.PropertyInfo for name, propType := range propsMap { properties = append(properties, model.PropertyInfo{ Name: name, Type: propType, }) } return nodes, properties, nil } // GetRelationships 获取关系详情 func (s *Neo4jService) GetRelationships(req model.Neo4jRelRequest) (*model.Neo4jRelResponse, error) { host := "localhost" port := 7687 if req.URI != "" { uri := strings.TrimPrefix(req.URI, "bolt://") uri = strings.TrimPrefix(uri, "neo4j://") if idx := strings.Index(uri, ":"); idx > 0 { host = uri[:idx] fmt.Sscanf(uri[idx+1:], "%d", &port) } } db := req.Database if db == "" { db = "neo4j" } limit := req.Limit if limit <= 0 { limit = 10 } auth := fmt.Sprintf("%s:%s", req.Username, req.Password) encodedAuth := base64.StdEncoding.EncodeToString([]byte(auth)) ports := []int{port - 1000, 7474, port} checkedPorts := make(map[int]bool) for _, p := range ports { if checkedPorts[p] { continue } checkedPorts[p] = true url := fmt.Sprintf("http://%s:%d/db/%s/tx/commit", host, p, db) rels, err := s.getRelationshipsWithURL(url, encodedAuth, req.RelationshipType, limit) if err == nil { return &model.Neo4jRelResponse{ Success: true, Message: "success", Relationships: rels, }, nil } } return &model.Neo4jRelResponse{ Success: false, Message: "failed to connect to Neo4j", }, nil } func (s *Neo4jService) getRelationshipsWithURL(url, encodedAuth, relType string, limit int) ([]model.Neo4jRelationship, error) { query := fmt.Sprintf(`MATCH (a)-[r:%s]->(b) RETURN properties(r) as props, elementId(a) as startId, elementId(b) as endId, labels(a) as startLabels, labels(b) as endLabels, elementId(r) as relId LIMIT %d`, relType, limit) body := fmt.Sprintf(`{"statements": [{"statement": "%s", "resultDataContents": ["row"]}]}`, query) reqBody := strings.NewReader(body) httpReq, err := http.NewRequest("POST", url, reqBody) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Basic "+encodedAuth) resp, err := s.client.Do(httpReq) if err != nil { return nil, err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, err } var result map[string]interface{} if err := json.Unmarshal(respBody, &result); err != nil { return nil, err } if errors, ok := result["errors"].([]interface{}); ok && len(errors) > 0 { if errMap, ok := errors[0].(map[string]interface{}); ok { if msg, ok := errMap["message"].(string); ok { return nil, fmt.Errorf("%s", msg) } } return nil, fmt.Errorf("query error") } results, ok := result["results"].([]interface{}) if !ok || len(results) == 0 { return []model.Neo4jRelationship{}, nil } var rels []model.Neo4jRelationship if res0, ok := results[0].(map[string]interface{}); ok { if data, ok := res0["data"].([]interface{}); ok { for _, item := range data { if row, ok := item.(map[string]interface{}); ok { if rowData, ok := row["row"].([]interface{}); ok && len(rowData) >= 6 { rel := model.Neo4jRelationship{ ID: fmt.Sprintf("%v", rowData[5]), Source: fmt.Sprintf("%v", rowData[1]), Target: fmt.Sprintf("%v", rowData[2]), } if props, ok := rowData[0].(map[string]interface{}); ok { rel.Properties = props } rels = append(rels, rel) } } } } } return rels, nil } func (s *Neo4jService) checkWithURL(url, encodedAuth, db string) (*model.Neo4jCheckResponse, error) { query := `{"statements": [{"statement": "RETURN 1 AS num"}]}` reqBody := strings.NewReader(query) httpReq, err := http.NewRequest("POST", url, reqBody) if err != nil { return &model.Neo4jCheckResponse{ Success: false, Message: fmt.Sprintf("failed to create request: %v", err), }, nil } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Basic "+encodedAuth) resp, err := s.client.Do(httpReq) if err != nil { return nil, err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return &model.Neo4jCheckResponse{ Success: false, Message: fmt.Sprintf("failed to read response: %v", err), }, nil } var result map[string]interface{} if err := json.Unmarshal(body, &result); err != nil { return &model.Neo4jCheckResponse{ Success: false, Message: fmt.Sprintf("failed to parse response: %v", err), }, nil } if errors, ok := result["errors"].([]interface{}); ok && len(errors) > 0 { msg := "connection failed" if errMap, ok := errors[0].(map[string]interface{}); ok { if m, ok := errMap["message"].(string); ok { msg = m } } return &model.Neo4jCheckResponse{ Success: false, Message: msg, }, nil } version := "unknown" if resp.Header.Get("X-Neo4j-Version") != "" { version = resp.Header.Get("X-Neo4j-Version") } return &model.Neo4jCheckResponse{ Success: true, Message: "connection successful", Version: version, Databases: []string{db}, }, nil }