我们开发基于chatGPT的本地知识库功能,必须要知道如何操作向量数据库
原因是,GPT的本地知识库,需要先通过向量数据库搜索出相关的数据,然后再发送给GPT的chat接口,让GPT润色后回答。
下面是使用golang实现的向量数据库qdrant操作封装函数,包括:
创建集合,删除集合,查询集合信息
创建向量,搜索向量
代码放在了自己的utils包下,可以根据自己情况自行修改
package utils
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
)
var (
QdrantBase = "127.0.0.1"
QdrantPort = "6333"
)
//创建集合
func PutCollection(collectionName string) error {
url := fmt.Sprintf("http://%s:%s/collections/%s", QdrantBase, QdrantPort, collectionName)
requestBody, err := json.Marshal(map[string]interface{}{
"name": collectionName,
"vectors": map[string]interface{}{
"size": 1536,
"distance": "Cosine",
},
})
if err != nil {
return err
}
request, err := http.NewRequest("PUT", url, bytes.NewBuffer(requestBody))
if err != nil {
return err
}
request.Header.Set("Content-Type", "application/json")
client := http.Client{}
response, err := client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
return nil
}
//删除集合
func DeleteCollection(collectionName string) error {
url := fmt.Sprintf("http://%s:%s/collections/%s", QdrantBase, QdrantPort, collectionName)
request, err := http.NewRequest("DELETE", url, nil)
if err != nil {
return err
}
client := http.Client{}
response, err := client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
return nil
}
//查询集合信息
func GetCollection(collectionName string) ([]byte, error) {
url := fmt.Sprintf("http://%s:%s/collections/%s", QdrantBase, QdrantPort, collectionName)
resp, err := http.Get(url)
if err != nil {
return []byte(""), err
}
defer resp.Body.Close()
result, _ := ioutil.ReadAll(resp.Body)
return result, nil
}
//增加向量数据
func PutPoints(collectionName string, points []map[string]interface{}) (string, error) {
url := fmt.Sprintf("http://%s:%s/collections/%s/points", QdrantBase, QdrantPort, collectionName)
// 构造请求体
requestBody := map[string]interface{}{
"points": points,
}
requestBodyBytes, err := json.Marshal(requestBody)
if err != nil {
return "", err
}
// 发送请求
req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(requestBodyBytes))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
res, _ := ioutil.ReadAll(resp.Body)
// 检查响应状态码
if resp.StatusCode != http.StatusOK {
return string(res), fmt.Errorf("failed to PUT points to collection %s, status code: %d", collectionName, resp.StatusCode)
}
return string(res), nil
}
//搜索向量数据
func SearchPoints(collectionName string, params map[string]interface{}, vector []float64, limit int) ([]byte, error) {
// 构造请求体
requestBody := map[string]interface{}{
"params": params,
"vector": vector,
"limit": limit,
"with_payload": true,
}
requestBodyBytes, err := json.Marshal(requestBody)
if err != nil {
return nil, err
}
// 构造请求
url := fmt.Sprintf("http://%s:%s/collections/%s/points/search", QdrantBase, QdrantPort, collectionName)
request, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(requestBodyBytes))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
// 发送请求
client := http.DefaultClient
response, err := client.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
// 处理响应
responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, err
}
return responseBody, nil
}如何使用上面的函数,请参考下面的测试用例:
func TestPutCollection(t *testing.T) {
collectionName := "data_collection"
err := PutCollection(collectionName)
if err != nil {
t.Errorf("Error putting collection: %v", err)
}
log.Println(err)
}
func TestDeleteCollection(t *testing.T) {
collectionName := "data_collection"
err := DeleteCollection(collectionName)
if err != nil {
t.Errorf("Error putting collection: %v", err)
}
log.Println(err)
}
func TestPutPoints(t *testing.T) {
collectionName := "data_collection"
points := []map[string]interface{}{
{
"id": 1,
"payload": map[string]interface{}{"title": "测试标题", "text": "测试内容"},
"vector": []float64{0, 9, 0.9, 0.9},
},
}
res, err := PutPoints(collectionName, points)
if err != nil {
t.Errorf("Error putting points: %v", err)
}
log.Println(res, err)
}
func TestSearchPoints(t *testing.T) {
collectionName := "data_collection"
params := map[string]interface{}{"exact": false, "hnsw_ef": 128}
vector := []float64{0, 9, 0.9, 0.9}
limit := 10
points, err := SearchPoints(collectionName, params, vector, limit)
if err != nil {
t.Errorf("Error searching points: %v", err)
}
log.Println(string(points))
}