golang初学数据库连接(mysql,sqlserver,oracle)

萧业
2023-12-01
package util

import (
	"database/sql"
	"encoding/json"
	log "erpgo/util/logUtil"
	"fmt"
	"reflect"

	"time"

	_ "github.com/denisenkom/go-mssqldb" //这个驱动可以连接sqlserver2019,但是好像不能够连接sql server 2008一下

	//编码转换
	_ "github.com/go-sql-driver/mysql" //mysql 数据库连接
	"github.com/goinggo/mapstructure"  //map转struct
	_ "github.com/mattn/go-adodb"      //连接sqlserver 2008等使用这个
	//_ "github.com/mattn/go-oci8"       // oracle数据库连接
)

//Config 数据库初始化
func Config(dbType string, server string, port int, user string, password string, database string, version int) (connString, qd string, versions int) {
	//var connString string
	//var qd = "mysql" //默认使用adodb驱动
	switch dbType {
	case "sqlserver":
		qd = "adodb"
		connString = fmt.Sprintf("Provider=SQLOLEDB;Data Source=%s,%d;Initial Catalog=%s;user id=%s;password=%s;Connection Timeout=3600;Connect Timeout=3600;", server, port, database, user, password)
		if port == 1433 {
			connString = fmt.Sprintf("Provider=SQLOLEDB;Data Source=%s;Initial Catalog=%s;user id=%s;password=%s;Connection Timeout=3600;Connect Timeout=3600;", server, database, user, password)
		}
	case "sqlserver_windows":
		qd = "adodb"
		connString = fmt.Sprintf("Provider=SQLOLEDB;Data Source=%s;integrated security=SSPI;Initial Catalog=%s;", server, database)
	case "sqlserver_mssql":
		qd = "mssql"
		connString = fmt.Sprintf("server=%s;port=%d;database=%s;user id=%s;password=%s", server, port, database, user, password)
	case "mysql":
		qd = "mysql"
		connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s", user, password, server, port, database, "utf8")
	case "oracle":
		qd = "oci8"
		connString = fmt.Sprintf("%s/%s@%s:%d/%s", user, password, server, port, database)
	default:
		qd = "adodb"
		connString = fmt.Sprintf("Provider=SQLOLEDB;Data Source=%s,%d;Initial Catalog=%s;user id=%s;password=%s;Connection Timeout=3600;Connect Timeout=3600;", server, port, database, user, password)
	}
	//fmt.Println(connString)
	return connString, qd, version
}
func getDb(cstr, qd string) *sql.DB {
	db, err := sql.Open(qd, cstr) //db, err := sql.Open("mssql", connString)
	if err != nil {
		//log.Fatal(fmt.Sprintf("打开数据库失败:%s", err.Error()))
		log.Warn(fmt.Sprintf("打开数据库失败:%s", err.Error()))
		return nil
	}
	/*db.SetMaxOpenConns(0) //用于设置最大打开的连接数,默认值为0表示不限制。
	db.SetMaxIdleConns(0) //用于设置闲置的连接数。
	db.SetConnMaxLifetime(0)*/
	err = db.Ping()
	if err != nil {
		fmt.Println(err)
	}
	return db
}

//QueryOne 只能查询一条数据的一个值
func QueryOne(sqlStr, connString, qd string, data interface{}) error {
	var dataCopy interface{}
	db := getDb(connString, qd)
	defer db.Close()
	row := db.QueryRow(sqlStr)
	err := row.Scan(&dataCopy)
	if err != nil && err != sql.ErrNoRows {
		return err
	}
	if dataCopy != nil {
		err = db.QueryRow(sqlStr).Scan(data)
		if err != nil && err != sql.ErrNoRows {
			return err
		}
	}
	return nil
}

//QueryScan 只能查询一条数据的一个值
func QueryScan(sqlStr, connString, qd string) interface{} {
	var data interface{}
	db := getDb(connString, qd)
	defer db.Close()
	err := db.QueryRow(sqlStr).Scan(&data)
	fmt.Println(sqlStr)
	if err != nil && err == sql.ErrNoRows { //返回空
		fmt.Println("读取不到数据")
		return data
	}
	if err != nil {
		//log.Fatal(fmt.Sprintf("读取失败:%s,%s,%s", connString, sqlStr, err.Error()))
		log.Warn(fmt.Sprintf("读取失败:%s,%s,%s", connString, sqlStr, err.Error()))
	}
	return data
}

//Exec 修改语句
func Exec(sqlStr, connString, qd string) error {
	fmt.Println(sqlStr)
	db := getDb(connString, qd)
	defer db.Close()
	result, err := db.Exec(sqlStr)
	if err != nil {
		//log.Fatal("修改错误:", err.Error())
		return err
	}
	rowsaff, err := result.RowsAffected() //修改行数
	if rowsaff == 0 {                     //表示没有修改数据
		return nil
	}
	return nil
}

//ExecReturnID 修改返回最大值和错误
func ExecReturnID(sqlStr, connString, qd string) (int64, error) {
	db := getDb(connString, qd)
	defer db.Close()
	result, err := db.Exec(sqlStr)
	if err != nil {
		//log.Fatal(fmt.Sprintf("修改错误:%s", err.Error()))
		log.Warn(fmt.Sprintf("修改错误:%s", err.Error()))
	}
	id, err := result.LastInsertId()      //最大Id
	rowsaff, err := result.RowsAffected() //修改行数
	if err != nil {
		return id, err
	}
	if rowsaff == 0 { //表示没有修改数据
		return id, err
	}
	return id, err
}

//ExecReturnRows 修改返回影响行数
func ExecReturnRows(sqlStr, connString, qd string) (int64, error) {
	db := getDb(connString, qd)
	defer db.Close()
	result, err := db.Exec(sqlStr)
	if err != nil {
		//log.Fatal(fmt.Sprintf("修改错误:%s", err.Error()))
		log.Warn(fmt.Sprintf("修改错误:%s", err.Error()))
	}
	//id, err := result.LastInsertId()      //最大Id
	rowsaff, err := result.RowsAffected() //修改行数
	if err != nil {
		return rowsaff, err
	}
	if rowsaff == 0 { //表示没有修改数据
		return rowsaff, err
	}
	return rowsaff, err
}

//TransactionSQLReturnRows 事务处理
func TransactionSQLReturnRows(sqlStrs []string, connString, qd string) (int64, error) {
	db := getDb(connString, qd)
	defer db.Close()
	tx, err := db.Begin()
	if err != nil {
		return 0, err
	}
	var i int64 = 0
	for _, sqlStr := range sqlStrs {
		fmt.Println(sqlStr)
		result, err := tx.Exec(sqlStr)
		if err != nil {
			tx.Rollback()
			return 0, err
		}
		rowsaff, err := result.RowsAffected() //修改行数
		fmt.Println("是否受到影响", rowsaff, err)
		if err != nil {
			tx.Rollback()
			return 0, err
		}
		if rowsaff == 0 { //表示没有修改数据
			//return rowsaff, err
		}
		i += rowsaff
	}
	err = tx.Commit()
	if err != nil {
		tx.Rollback()
		return 0, err
	}
	return i, nil
}

//TransactionSQL 事务处理
func TransactionSQL(sqlStrs []string, connString, qd string) error {
	db := getDb(connString, qd)
	defer db.Close()
	tx, err := db.Begin()
	if err != nil {
		return err
	}
	for _, sqlStr := range sqlStrs {
		result, err := tx.Exec(sqlStr)
		if err != nil {
			tx.Rollback()
			return err
		}
		rowsaff, err := result.RowsAffected() //修改行数
		if err != nil {
			tx.Rollback()
			return err
		}
		if rowsaff == 0 { //表示没有修改数据
			//return rowsaff, err
		}
	}
	err = tx.Commit()
	if err != nil {
		return err
	}
	return nil
}

//QueryJSONStr 读取数据返回jsonstring
func QueryJSONStr(sqlStr, connString, qd string) (string, error) {
	db := getDb(connString, qd)
	rows, err := db.Query(sqlStr)
	if err != nil {
		//log.Fatal(err.Error())
		return "", err
	}
	defer rows.Close()
	columns, err := rows.Columns()
	if err != nil {
		//log.Fatal(err.Error())
		return "", err
	}
	count := len(columns)
	tableData := make([]map[string]interface{}, 0)
	values := make([]interface{}, count)
	valuePtrs := make([]interface{}, count)
	//遍历每一行
	for rows.Next() {
		for i := 0; i < count; i++ {
			valuePtrs[i] = &values[i]
		}
		rows.Scan(valuePtrs...)
		entry := make(map[string]interface{})
		for i, col := range columns {
			var v interface{}
			val := values[i]
			b, ok := val.([]byte)
			if ok {
				v = string(b)
			} else {
				//TODO:编码需要确认
				v = val
				//先确定编码,如果编码是gb2312使用下面的代码
				/*c, ok := val.(string)
				if ok {
					//v, _ = iconv.ConvertString(c, "utf-8", "GB2312")
					v, _ = iconv.ConvertString(c, "GB2312", "utf-8")
				} else {
					v = val
				}*/
			}
			entry[col] = v
		}
		tableData = append(tableData, entry)
	}
	jsonData, err := json.Marshal(tableData)
	if err != nil {
		log.Fatal(fmt.Sprintf("Map转错误:%s", err.Error()))
	}
	return string(jsonData), nil
}

//QueryEntityNew 新的根据sql返回实体map转实体TODO:(https://www.cnblogs.com/akidongzi/p/12036096.html)类型转换有问题
func QueryEntityNew(sqlStr, connString, qd string, resp interface{}) error {
	db := getDb(connString, qd)
	defer db.Close()
	rows, err := db.Query(sqlStr)
	if err != nil {
		//log.Fatal(err.Error())
		return err
	}
	defer rows.Close()
	columns, err := rows.Columns()
	if err != nil {
		//log.Fatal(err.Error())
		return err
	}
	count := len(columns)
	tableData := make([]map[string]interface{}, 0)
	values := make([]interface{}, count)
	valuePtrs := make([]interface{}, count)
	//遍历每一行
	for rows.Next() {
		for i := 0; i < count; i++ {
			valuePtrs[i] = &values[i]
		}
		rows.Scan(valuePtrs...)
		entry := make(map[string]interface{})
		for i, col := range columns {
			var v interface{}
			val := values[i]
			b, ok := val.([]byte)
			if ok {
				v = string(b)
			} else {
				v = val
			}
			entry[col] = v
		}
		tableData = append(tableData, entry)
	}
	return mapstructure.Decode(tableData, resp)
}

//QueryEntity 根据sql返回实体
func QueryEntity(sqlStr, connString, qd string, data interface{}) error {
	log.Info(sqlStr)
	str, err := QueryJSONStr(sqlStr, connString, qd)
	if err != nil {
		//log.Fatal(err.Error())
		return err
	}
	return json.Unmarshal([]byte(str), data)
}

//打印一行记录,传入一个行的所有列信息
func printRow(colsdata []interface{}) {
	//enc := mahonia.NewEncoder("GBK")
	for _, val := range colsdata {
		t := reflect.ValueOf(val)
		fmt.Println(t.Type())
		fmt.Println(val)
		switch v := val.(type) {
		case nil:
			fmt.Print("NULL")
		case bool:
			if v {
				fmt.Print("True")
			} else {
				fmt.Print("False")
			}
		case []byte:
			//in := []byte("Hello World!")
			fmt.Print(v)
		case time.Time:
			fmt.Print(v.Format("2016-01-02 15:05:05.999"))
		case *interface{}:

		default:
			fmt.Print(v)
		}
		fmt.Print("\t")
	}
	fmt.Println()
}

func main() {
	var isdebug = true
	var server = "192.168.1.254"
	var port = 1521
	var user = "combrain"
	var password = "combrain"
	var database = "combrain"
	//连接字符串
	connString := fmt.Sprintf("%s/%s@%s:%d/%s", user, password, server, port, database)
	//connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s", user, password, server, port, database, "utf8")
	if isdebug {
		fmt.Println(connString)
	}
	//建立连接
	conn, err := sql.Open("oci8", connString)
	if err != nil {
		log.Fatal(err.Error())
	}
	defer conn.Close()
	//产生查询语句的Statement
	stmt, err := conn.Prepare(`SELECT DISTINCT bi.it_id AS goodsId, bi.itname AS productNm, bi.itname AS byname, bi.spec AS spec, bi.madein AS factory, bi.madein AS proArea, bi.madein AS productBrand, bi.unit AS goodsUnit, bi.bar_code AS barcode, NVL(bi.pass_no, ' ') AS approvalNumber, NVL (jg.direct_price * 1.03, 9999) AS marketPrice, NVL (jg.direct_price * 1.03, 9999) AS salePrice, NVL (info.num_1, 0) AS COST, 1 AS mpackTotal, bc.class_name AS businessRange, NVL (pb.batch_no, ' ') AS batchNum, NVL ( TO_CHAR (pb.avail_time, 'yyyy-mm-dd'), ' ') AS effectiveDate, NVL ( TO_CHAR (pb.made_time, 'yyyy-mm-dd'), ' ' ) AS scrq FROM BD_ITEMDOC bi INNER JOIN bd_invinfo info ON info.it_id = bi.it_id LEFT JOIN bd_class bc ON bi.item_class = bc.class_code LEFT JOIN sd_Saleprice jg ON bi.it_id = jg.it_id AND jg.corp_id = '00' AND jg.pricetype = '01' LEFT JOIN im_placebatch ip ON ip.it_id = bi.it_id LEFT JOIN pm_batchinfo pb ON ip.batch_id = pb.batch_id AND (ip.all_amt - ip.lock_amt) > 0 AND wp_price > 0.01 AND ip.all_amt > 0 AND ip.place_type = '00' AND pb.is_freeze = '00' AND pb.sorgid = 'YW01' AND ip.wh_id IN ('04') INNER JOIN ( SELECT info.it_id, MIN (avail_time) minTime FROM im_placebatch ip LEFT JOIN pm_batchinfo info on info.batch_id=ip.batch_id WHERE ip.all_amt - ip.lock_amt > 0 GROUP BY info.it_id ) t1 ON t1.it_id = ip.it_id AND pb.avail_time = t1.minTime`)
	if err != nil {
		log.Fatal(err.Error())
	}
	defer stmt.Close()

	//通过Statement执行查询
	rows, err := stmt.Query()
	if err != nil {
		log.Fatal(err.Error())
	}

	//建立一个列数组
	cols, err := rows.Columns()
	var colsdata = make([]interface{}, len(cols))
	for i := 0; i < len(cols); i++ {
		colsdata[i] = new(interface{})
		fmt.Print(cols[i])
		fmt.Print("\t")
	}
	//遍历每一行
	for rows.Next() {
		rows.Scan(colsdata...) //将查到的数据写入到这行中
		printRow(colsdata)     //打印此行
	}
	defer rows.Close()
}

 

 类似资料: