Skip to content

Commit c068855

Browse files
Added RAG prototype
1 parent eb58f7a commit c068855

File tree

1 file changed

+92
-45
lines changed

1 file changed

+92
-45
lines changed

prototype/main.jl

Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,29 @@ using PromptingTools
33
using PromptingTools.Experimental.RAGTools: FileChunker, build_index, SimpleIndexer, airag
44
using JSON3, Serialization
55
using Statistics: mean
6-
using FileIO
7-
using JLD2
6+
using LibPQ
7+
88
const PT = PromptingTools
99
const RT = PromptingTools.Experimental.RAGTools
1010

11-
# Register models
12-
PT.register_model!(name="deepseek-r1:1.5b", schema=PT.OllamaSchema())
13-
PT.register_model!(name="nomic-embed-text:latest", schema=PT.OllamaSchema())
11+
# User-defined models
12+
MODEL_CHAT = "deepseek-r1:1.5b" # Default chat model
13+
MODEL_EMBEDDING = "nomic-embed-text:latest" # Default embedding model
14+
15+
# Overwrite package globals with user-defined models
16+
PromptingTools.MODEL_CHAT = MODEL_CHAT
17+
PromptingTools.MODEL_EMBEDDING = MODEL_EMBEDDING
18+
19+
# Register models with Ollama schema
20+
PT.register_model!(name=MODEL_CHAT, schema=PT.OllamaSchema())
21+
PT.register_model!(name=MODEL_EMBEDDING, schema=PT.OllamaSchema())
1422

15-
# File collection and combination functions (unchanged)
23+
# Pgvector conversion module
24+
module Pgvector
25+
convert(v::AbstractVector{T}) where T<:Real = string("[", join(v, ","), "]")
26+
end
27+
28+
# Function to collect files with specified extensions
1629
function collect_files_with_extensions(directory::String, extensions::Vector{String})
1730
files = String[]
1831
for (root, _, file_names) in walkdir(directory)
@@ -26,6 +39,7 @@ function collect_files_with_extensions(directory::String, extensions::Vector{Str
2639
return files
2740
end
2841

42+
# Function to combine files into a single output file
2943
function write_combined_file(files::Vector{String}, output_file::String)
3044
open(output_file, "w") do io
3145
for file in files
@@ -35,58 +49,91 @@ function write_combined_file(files::Vector{String}, output_file::String)
3549
println(io, line)
3650
end
3751
end
38-
println(io, "\n")
52+
println(io, "\n") # Add a separator between files
3953
end
4054
end
4155
end
4256

57+
# Prompt template for FunSQL.jl query generation
58+
const FUNSQL_PROMPT_TEMPLATE = """
59+
You are an expert in FunSQL.jl, a Julia library for compositional construction of SQL queries. Your task is to translate the given natural language question into a corresponding FunSQL.jl query. Provide only the FunSQL.jl query as output, without any additional text or explanation.
60+
61+
For example:
62+
Question: "Find all male patients in the database."
63+
FunSQL.jl Query: From(:person) |> Where(Get.gender .== "M") |> Select(Get.person_id)
64+
65+
Now, for the following question, generate only the FunSQL.jl query:
66+
67+
Question: {input_query}
68+
69+
FunSQL.jl Query:
70+
"""
71+
4372
# Define directory, extensions, and output file
4473
directory = "."
4574
extensions = [".jl", ".md", ".Rmd"]
4675
output_file = "combined_output.txt"
4776

48-
# Collect files and create combined output
77+
# Collect and combine files
4978
files = collect_files_with_extensions(directory, extensions)
5079
write_combined_file(files, output_file)
5180

52-
# Build index
81+
# Database connection
82+
conn = LibPQ.Connection("host=localhost port=5432 dbname=pgvectordemo user=postgres password=param")
83+
84+
# Build index with user-defined embedding model
5385
cfg = SimpleIndexer(chunker=FileChunker())
54-
index = build_index(cfg, [output_file];
55-
embedder_kwargs = (
56-
schema = PT.OllamaSchema(),
57-
model = "nomic-embed-text:latest"
58-
)
59-
)
86+
index = build_index(cfg, [output_file]; embedder_kwargs=(schema=PT.OllamaSchema(), model=MODEL_EMBEDDING))
6087
println("Index built with $(length(index)) chunks.")
6188

62-
# Save the index
63-
index_file = "index.jld2"
64-
println("Saving index to $index_file...")
65-
@save index_file index
66-
println("Index saved successfully.")
67-
68-
# Perform RAG query
69-
answer = airag(index;
70-
question = "Write a FunSQL.jl query to find all male patients in the database?",
71-
verbose = 2,
72-
retriever_kwargs = (
73-
model = "deepseek-r1:1.5b",
74-
schema = PT.OllamaSchema(),
75-
embedder_kwargs = (
76-
schema = PT.OllamaSchema(),
77-
model = "nomic-embed-text:latest",
78-
api_key = "ollama-dummy-key"
89+
# Function to store embeddings in a PostgreSQL database using pgvector
90+
function store_embeddings_in_pgvector(conn::LibPQ.Connection, embeddings::AbstractMatrix, chunks::AbstractVector, embedding_dimension::Int)
91+
# Ensure the table exists
92+
LibPQ.execute(conn, """
93+
CREATE TABLE IF NOT EXISTS embeddings (
94+
id SERIAL PRIMARY KEY,
95+
chunk TEXT NOT NULL,
96+
embedding VECTOR($embedding_dimension)
7997
)
80-
),
81-
generator_kwargs = (
82-
model = "deepseek-r1:1.5b",
83-
schema = PT.OllamaSchema(),
84-
embedder_kwargs = (
85-
schema = PT.OllamaSchema(),
86-
model = "nomic-embed-text:latest",
87-
api_key = "ollama-dummy-key"
88-
)
89-
),
90-
api_kwargs = (api_key = "ollama-dummy-key",)
91-
)
92-
println(answer)
98+
""")
99+
100+
# Convert embeddings to Float64 if necessary
101+
embeddings = convert(Matrix{Float64}, embeddings)
102+
103+
# Convert chunks to Vector{String} if necessary
104+
chunks = String.(chunks)
105+
106+
# Insert embeddings and their corresponding chunks
107+
for i in 1:size(embeddings, 2)
108+
chunk = chunks[i]
109+
embedding = Pgvector.convert(embeddings[:, i])
110+
LibPQ.execute(conn, """
111+
INSERT INTO embeddings (chunk, embedding)
112+
VALUES (\$1, \$2)
113+
""", (chunk, embedding)) # Pass parameters as a tuple
114+
end
115+
116+
println("Embeddings successfully stored in the database.")
117+
end
118+
119+
# Assuming the embedding model produces 768-dimensional embeddings
120+
embedding_dimension = 768 # Adjust based on actual model output
121+
store_embeddings_in_pgvector(conn, index.embeddings, index.chunks, embedding_dimension)
122+
123+
# Function to generate FunSQL.jl query using the prompt template
124+
function generate_funsql_query(index, question::String)
125+
# Fill the prompt template with the user's question
126+
prompt = replace(FUNSQL_PROMPT_TEMPLATE, "{input_query}" => question)
127+
128+
# Perform RAG query with the filled prompt
129+
answer = airag(index;
130+
question=prompt,
131+
retriever_kwargs=(model=MODEL_EMBEDDING, schema=PT.OllamaSchema(), embedder_kwargs=(schema=PT.OllamaSchema(), model=MODEL_EMBEDDING)),
132+
generator_kwargs=(model=MODEL_CHAT, schema=PT.OllamaSchema())
133+
)
134+
return answer
135+
end
136+
137+
question = "Write a FunSQL.jl query to find all male patients in the database?"
138+
answer = generate_funsql_query(index, question)
139+
println("Generated FunSQL.jl Query: $answer")

0 commit comments

Comments
 (0)