const.go

package main

var (
    ColNames = []string{"feature", "document", "machine", "load_time",
        "search_time", "reduce_and_save"}

    ResColNames = []string{"feature", "document", "machine", "total"}
)

fit_classification.go

package main

import (
    "fmt"
    "log"
    "math"
    "os"

    "github.com/go-gota/gota/dataframe"
    "github.com/go-gota/gota/series"
    "gonum.org/v1/gonum/optimize"
    "gonum.org/v1/plot"
    "gonum.org/v1/plot/plotter"
    "gonum.org/v1/plot/plotutil"
    "gonum.org/v1/plot/vg"
)

// 根据条件修改原先值
func getTotal(s series.Series) series.Series {

    loadTime, _ := s.Val(3).(int)
    searchTime, _ := s.Val(4).(int)
    rAsTime, _ := s.Val(5).(int)

    res := loadTime + searchTime + rAsTime
    resF := float64(res) / float64(60)
    return series.Floats(resF)
}

func getDoc(s series.Series) series.Series {
    document, _ := s.Val(1).(float64)
    resF := float64(2*document) / float64(1000)
    return series.Floats(resF)
}

// dataPrepare 数据预处理函数
func dataPrepare(clsDF *dataframe.DataFrame) {
    // 获取total列
    *clsDF = clsDF.Select(ColNames)
    totalSeries := clsDF.Rapply(getTotal)
    totalSeries.SetNames("total")
    *clsDF = clsDF.CBind(totalSeries)

    // document列 *2/1000
    *clsDF = clsDF.Select(ResColNames)
    newDocSeries := clsDF.Rapply(getDoc)
    newDocSeries.SetNames("new_doc")
    *clsDF = clsDF.CBind(newDocSeries)
    *clsDF = clsDF.Drop([]string{"document"})
    *clsDF = clsDF.Rename("document", "new_doc")
    *clsDF = clsDF.Select(ResColNames)
}

// dataOptimize 数据优化和拟合函数
func dataOptimize(clsDF *dataframe.DataFrame) (actPoints, expPoints plotter.XYs, fa, fb float64) {
    // 开始数据拟合

    // 实际观测点
    actPoints = plotter.XYs{}
    // N行数据产生N个点
    for i := 0; i < clsDF.Nrow(); i++ {
        document := clsDF.Elem(i, 1).Val().(float64)
        machine := clsDF.Elem(i, 2).Val().(int)
        val := clsDF.Elem(i, 3).Val().(float64)

        actPoints = append(actPoints, plotter.XY{
            X: float64(document) / float64(machine),
            Y: val,
        })
    }

    result, err := optimize.Minimize(optimize.Problem{
        Func: func(x []float64) float64 {
            if len(x) != 2 {
                panic("illegal x")
            }
            a := x[0]
            b := x[1]
            var sum float64
            for _, point := range actPoints {
                y := a*point.X + b
                sum += math.Abs(y - point.Y)
            }
            return sum
        },
    }, []float64{1, 1}, &optimize.Settings{}, &optimize.NelderMead{})
    if err != nil {
        panic(err)
    }

    // 最小二乘法拟合出来的k和b值
    fa, fb = result.X[0], result.X[1]
    expPoints = plotter.XYs{}
    for i := 0; i < clsDF.Nrow(); i++ {
        document := clsDF.Elem(i, 1).Val().(float64)
        machine := clsDF.Elem(i, 2).Val().(int)
        x := float64(document) / float64(machine)
        expPoints = append(expPoints, plotter.XY{
            X: x,
            Y: fa*float64(x) + fb,
        })
    }

    return
}

func dataPlot(actPoints, expPoints plotter.XYs) {
    plt, err := plot.New()
    if err != nil {
        panic(err)
    }
    plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10

    if err := plotutil.AddLinePoints(plt,
        "expPoints", expPoints,
        "actPoints", actPoints,
    ); err != nil {
        panic(err)
    }

    if err := plt.Save(5*vg.Inch, 5*vg.Inch, "classification-fit.png"); err != nil {
        panic(err)
    }
}

// FitClassification 分类曲线拟合函数
func FitClassification() {
    clsData, err := os.Open("classification_data.csv")
    if err != nil {
        log.Fatal(err)
    }

    defer clsData.Close()
    clsDF := dataframe.ReadCSV(clsData)
    // 数据预处理
    dataPrepare(&clsDF)
    // 数据预处理完成
    fmt.Println("数据预处理完成...")
    fmt.Println(clsDF)

    // 数据拟合
    actPoints, expPoints, fa, fb := dataOptimize(&clsDF)
    // 拟合完成,输出fa,fb
    fmt.Println("Fa", fa, "Fb", fb)

    // 数据绘图
    dataPlot(actPoints, expPoints)
    fmt.Println("绘制完成,图形地址: classification-fit.png")
}

main.go

package main

func main() {

}

main_test.go

package main

import "testing"

// TestFitClassification 测试分类曲线拟合
func TestFitClassification(t *testing.T) {
    FitClassification()
}

运行数据

feature,document,machine,load_time,search_time,reduce_and_save

100,5000,4,19,130,67

100,5000,4,12,130,61

100,5000,4,13,127,61

100,5000,4,13,124,63

100,5000,4,13,129,59

100,5000,4,13,125,60

100,5000,4,13,123,63

100,5000,4,13,129,61

100,5000,4,12,127,61

100,5000,4,12,125,62

100,5000,4,13,128,59

100,5000,4,13,128,61

100,5000,4,12,130,60

100,5000,4,12,125,61

100,5000,4,13,127,60

100,5000,4,13,126,63

100,5000,4,13,127,64

100,5000,3,18,160,67

100,5000,3,13,166,59

100,5000,3,12,167,61

100,5000,3,12,168,60

100,5000,3,12,170,61

100,5000,3,12,154,63

100,5000,3,13,168,60

100,5000,3,12,167,60

100,5000,3,12,148,64

100,5000,3,12,167,65

100,5000,3,12,164,60

100,5000,3,12,150,59

100,5000,2,20,217,65

100,5000,2,13,205,63

100,5000,2,14,204,60

100,5000,2,13,205,55

100,5000,2,14,210,59

100,5000,2,13,201,59

100,5000,2,13,211,59

100,5000,2,13,217,59

100,5000,2,14,207,60

100,5000,2,14,209,59

100,5000,2,14,214,61

100,5000,2,13,210,61

100,5000,1,24,376,60

100,5000,1,21,393,58

100,5000,1,20,386,58

100,5000,1,22,384,59

100,5000,1,21,387,59

100,4000,4,18,112,70

100,4000,4,13,118,62

100,4000,4,12,114,63

100,4000,4,14,112,65

100,4000,4,12,113,62

100,4000,4,14,109,61

100,4000,4,13,118,63

100,4000,4,12,112,61

100,4000,4,12,110,61

100,4000,4,11,111,63

100,4000,4,13,112,67

100,4000,4,12,110,60

100,4000,4,12,113,60

100,3000,4,19,100,66

100,3000,4,13,99,64

100,3000,4,12,100,65

100,3000,4,13,103,61

100,3000,4,14,104,63

100,3000,4,14,99,63

100,2000,4,18,90,67

100,2000,4,13,86,65

100,2000,4,13,85,63

100,2000,4,14,87,62

100,2000,4,13,85,61

100,2000,4,13,86,64

100,2000,4,13,85,61

100,2000,4,13,89,58

100,2000,4,13,85,61

100,2000,4,12,85,60

100,2000,4,13,85,66

100,2000,4,12,86,59

100,2000,4,12,86,61

100,2000,4,12,82,60

100,2000,4,13,87,62

100,2000,4,12,83,65

100,2000,4,12,85,60

100,2000,4,13,87,60

100,2000,4,12,86,59

100,1000,4,19,75,63

100,1000,4,14,71,61

100,1000,4,13,75,59

100,1000,4,14,72,61

100,1000,4,13,72,59

100,1000,4,13,71,59

100,1000,4,13,72,59

100,1000,4,14,70,62

100,1000,4,12,72,58

100,1000,4,13,71,59

100,1000,4,13,70,62

100,1000,4,12,72,59

100,1000,4,20,71,58

100,1000,4,13,69,60

100,1000,4,12,73,60

100,1000,4,13,69,59

100,1000,4,13,71,60

100,1000,4,13,73,62

100,1000,4,12,71,59

100,1000,4,12,70,56

100,1000,4,13,70,58

100,1000,4,12,69,57

运行方法
go test -v -run=FitClass
最终输出的数据和scipy的结果差不多

程序输出


输出的拟合图像如下