db_connect.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. package tool
  2. import (
  3. "database/sql"
  4. "os"
  5. "path/filepath"
  6. "strings"
  7. "time"
  8. _ "github.com/go-sql-driver/mysql"
  9. _ "modernc.org/sqlite"
  10. )
  11. var db_set = map[string]string{}
  12. func Get_DB_set() map[string]string {
  13. new_db_set := map[string]string{}
  14. db_env, db_env_exist := os.LookupEnv("NAMU_DB")
  15. db_env_type, db_env_type_exist := os.LookupEnv("NAMU_DB_TYPE")
  16. if db_env_exist || db_env_type_exist {
  17. new_db_set["db_name"] = Choose(db_env, "data")
  18. new_db_set["db_type"] = Choose(db_env_type, "sqlite")
  19. return new_db_set
  20. }
  21. path_dir := filepath.Join("..", "data", "set.json")
  22. if File_exist_check(path_dir) {
  23. raw, err := os.ReadFile(path_dir)
  24. if err == nil {
  25. tmp := map[string]string{}
  26. if err := json.Unmarshal(raw, &tmp); err == nil {
  27. if v, ok := tmp["db"]; ok {
  28. new_db_set["db_name"] = v
  29. } else {
  30. new_db_set["db_name"] = "data"
  31. }
  32. if v, ok := tmp["db_type"]; ok {
  33. new_db_set["db_type"] = v
  34. } else {
  35. new_db_set["db_type"] = "sqlite"
  36. }
  37. return new_db_set
  38. }
  39. }
  40. }
  41. new_db_set["db_name"] = "data"
  42. new_db_set["db_type"] = "sqlite"
  43. return new_db_set
  44. }
  45. func Get_DB_set_MySQL(new_db_set map[string]string) map[string]string {
  46. path := filepath.Join("..", "data", "mysql.json")
  47. if !File_exist_check(path) {
  48. return new_db_set
  49. }
  50. raw, err := os.ReadFile(path)
  51. if err != nil {
  52. return new_db_set
  53. }
  54. tmp := new_db_set
  55. if err := json.Unmarshal(raw, &tmp); err != nil {
  56. return tmp
  57. }
  58. if host, ok := tmp["host"]; ok && host != "" {
  59. tmp["db_mysql_host"] = host
  60. } else {
  61. tmp["db_mysql_host"] = "127.0.0.1"
  62. }
  63. if port, ok := tmp["port"]; ok && port != "" {
  64. tmp["db_mysql_port"] = port
  65. } else {
  66. tmp["db_mysql_port"] = "3306"
  67. }
  68. if user, ok := tmp["user"]; ok {
  69. tmp["db_mysql_user"] = user
  70. }
  71. if pw, ok := tmp["password"]; ok {
  72. tmp["db_mysql_pw"] = pw
  73. }
  74. return tmp
  75. }
  76. func Exec_DB(db *sql.DB, query string, values ...any) {
  77. const retryDelay = 10 * time.Millisecond
  78. stmt, err := db.Prepare(DB_change(query))
  79. if err != nil {
  80. panic(err)
  81. }
  82. defer stmt.Close()
  83. for {
  84. _, err = stmt.Exec(values...)
  85. if err == nil {
  86. return
  87. }
  88. if strings.Contains(err.Error(), "database is locked") {
  89. time.Sleep(retryDelay)
  90. continue
  91. }
  92. panic(err)
  93. }
  94. }
  95. func Query_DB(db *sql.DB, query string, values ...any) *sql.Rows {
  96. const retryDelay = 10 * time.Millisecond
  97. stmt, err := db.Prepare(DB_change(query))
  98. if err != nil {
  99. panic(err)
  100. }
  101. defer stmt.Close()
  102. for {
  103. rows, err := stmt.Query(values...)
  104. if err == nil {
  105. return rows
  106. }
  107. if strings.Contains(err.Error(), "database is locked") {
  108. time.Sleep(retryDelay)
  109. continue
  110. }
  111. panic(err)
  112. }
  113. }
  114. // QueryRow_DB 이래서 포인터를 배우는구나...
  115. func QueryRow_DB(db *sql.DB, query string, var_list []any, values ...any) bool {
  116. const retryDelay = 10 * time.Millisecond
  117. stmt, err := db.Prepare(DB_change(query))
  118. if err != nil {
  119. panic(err)
  120. }
  121. defer stmt.Close()
  122. for {
  123. row := stmt.QueryRow(values...)
  124. err := row.Scan(var_list...)
  125. switch err {
  126. case nil:
  127. return true
  128. case sql.ErrNoRows:
  129. return false
  130. }
  131. if strings.Contains(err.Error(), "database is locked") {
  132. time.Sleep(retryDelay)
  133. continue
  134. }
  135. panic(err)
  136. }
  137. }
  138. func DB_boot() map[string]string {
  139. new_db_set := Get_DB_set()
  140. if new_db_set["db_type"] == "mysql" {
  141. new_db_set = Get_DB_set_MySQL(new_db_set)
  142. }
  143. db_set = new_db_set
  144. return new_db_set
  145. }
  146. func DB_connect_init() (*sql.DB, error) {
  147. if db_set["db_type"] == "sqlite" {
  148. db, err := sql.Open("sqlite", filepath.Join("..", db_set["db_name"] + ".db") + "?_journal_mode=WAL&_busy_timeout=5000")
  149. if err != nil {
  150. return nil, err
  151. }
  152. if err := db.Ping(); err != nil {
  153. db.Close()
  154. return nil, err
  155. }
  156. return db, nil
  157. } else {
  158. db, err := sql.Open("mysql", db_set["db_mysql_user"] + ":" + db_set["db_mysql_pw"] + "@tcp(" + db_set["db_mysql_host"] + ":" + db_set["db_mysql_port"] + ")")
  159. if err != nil {
  160. return nil, err
  161. }
  162. if err := db.Ping(); err != nil {
  163. db.Close()
  164. return nil, err
  165. }
  166. return db, nil
  167. }
  168. }
  169. func DB_connect() *sql.DB {
  170. // log.Default().Println("DB open")
  171. if db_set["db_type"] == "sqlite" {
  172. db, err := sql.Open("sqlite", filepath.Join("..", db_set["db_name"] + ".db") + "?_journal_mode=WAL&_busy_timeout=5000")
  173. if err != nil {
  174. panic(err)
  175. }
  176. return db
  177. } else {
  178. db, err := sql.Open("mysql", db_set["db_mysql_user"] + ":" + db_set["db_mysql_pw"] + "@tcp(" + db_set["db_mysql_host"] + ":" + db_set["db_mysql_port"] + ")/" + db_set["db_name"])
  179. if err != nil {
  180. panic(err)
  181. }
  182. return db
  183. }
  184. }
  185. func DB_close(db *sql.DB) {
  186. db.Close()
  187. // log.Default().Println("DB close")
  188. }
  189. func Get_DB_type() string {
  190. return db_set["db_type"]
  191. }
  192. func DB_change(data string) string {
  193. if Get_DB_type() == "mysql" {
  194. data = strings.Replace(data, "random()", "rand()", -1)
  195. data = strings.Replace(data, "collate nocase", "collate utf8mb4_general_ci", -1)
  196. }
  197. return data
  198. }