journal/Journal.AI/LlamaSharpAiService.cs
Jacob Schmidt 27cc379eb8 feat: add Vulkan GPU backend and fix GpuLayerCount config
- Downgrade LLamaSharp packages to 0.25.0 to match Vulkan backend availability
- Add LLamaSharp.Backend.Vulkan for AMD/Intel/NVIDIA GPU acceleration
- Fix _gpuLayers bug: was reading LlamaCppTimeout instead of a dedicated field
- Add GpuLayerCount to JournalConfig, sourced from JOURNAL_GPU_LAYERS env var
- Document AI/LLM notes in README (version pinning, known vulkaninfo issue)

Co-Authored-By: Oz <oz-agent@warp.dev>
2026-03-02 22:44:01 -06:00

330 lines
12 KiB
C#

using System.Text;
using System.Text.RegularExpressions;
using Journal.Core.Dtos;
using Journal.Core.Models;
using Journal.Core.Services.Ai;
using LLama;
using LLama.Common;
using LLama.Sampling;
namespace Journal.AI;
public sealed partial class LlamaSharpAiService(JournalConfig config) : IAiService, IDisposable
{
private const string DefaultModelUrl =
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf";
private const string DefaultModelFileName = "Phi-3-mini-4k-instruct-q4.gguf";
private const string ModelSubDirectory = "ai-models";
private readonly string _configuredModelPath = config.GgufModelPath;
private readonly uint _contextSize = (uint)Math.Clamp(config.ModelContextTokens, 512, 4096);
private readonly int _gpuLayers = config.GpuLayerCount;
private readonly Lock _sync = new();
private string? _resolvedModelPath;
private LLamaWeights? _weights;
private bool _disposed;
public Task<AiHealthDto> HealthAsync(CancellationToken cancellationToken = default)
{
var resolved = _resolvedModelPath ?? _configuredModelPath;
var modelExists = File.Exists(resolved) || File.Exists(GetDefaultModelPath());
var loaded = _weights is not null;
return Task.FromResult(new AiHealthDto(
Provider: "llamasharp",
Enabled: true,
Healthy: modelExists || loaded,
Message: loaded
? "Model loaded."
: modelExists
? "Model found (will load on first use)."
: "Model not found locally. It will be downloaded on first use."));
}
private static string BuildChatSystemPrompt()
{
var dateStr = DateTime.Now.ToString("MMMM d, yyyy");
return $"You are a supportive conversational coach inside a private journaling app. " +
$"Today's date is {dateStr}. " +
$"Reply in plain natural language only. Never output JSON, code blocks, or structured data. " +
$"Be warm, practical, and concise. Do not repeat yourself.";
}
public async Task<string> ChatAsync(string prompt, CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(prompt))
throw new ArgumentException("Prompt is required.", nameof(prompt));
var raw = await RunSessionAsync(prompt, BuildChatSystemPrompt(),
maxTokens: 512, cancellationToken: cancellationToken);
return CleanChatResponse(raw);
}
public async Task<string> ChatWithHistoryAsync(IReadOnlyList<(string Role, string Text)> history,
string prompt, CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(prompt))
throw new ArgumentException("Prompt is required.", nameof(prompt));
var modelPath = await EnsureModelAsync(cancellationToken);
EnsureWeights(modelPath);
using var context = _weights!.CreateContext(new ModelParams(modelPath)
{
ContextSize = _contextSize,
GpuLayerCount = _gpuLayers
});
var executor = new StatelessExecutor(_weights!, context.Params);
var sb = new StringBuilder();
sb.Append($"<|system|>\n{BuildChatSystemPrompt()}<|end|>\n");
foreach (var (role, text) in history)
{
var tag = string.Equals(role, "user", StringComparison.OrdinalIgnoreCase) ? "user" : "assistant";
sb.Append($"<|{tag}|>\n{text}<|end|>\n");
}
sb.Append($"<|user|>\n{prompt}<|end|>\n");
sb.Append("<|assistant|>\n");
var inferenceParams = new InferenceParams
{
MaxTokens = 512,
AntiPrompts = ["<|user|>", "<|system|>", "<|end|>", "<|endoftext|>"],
SamplingPipeline = new DefaultSamplingPipeline { Temperature = 0.7f }
};
var result = new StringBuilder();
await foreach (var token in executor.InferAsync(sb.ToString(), inferenceParams, cancellationToken))
{
result.Append(token);
}
return CleanChatResponse(StripSpecialTokens(result.ToString()));
}
internal async Task<string> ChatJsonAsync(string prompt, CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(prompt))
throw new ArgumentException("Prompt is required.", nameof(prompt));
var dateStr = DateTime.Now.ToString("MMMM d, yyyy");
return await RunSessionAsync(prompt,
$"You are a coaching assistant inside a private journaling app. " +
$"Today's date is {dateStr}. " +
$"You MUST respond with ONLY a single valid JSON object. " +
$"Do NOT write any text, explanation, or commentary before or after the JSON. " +
$"Output MUST start with {{ and end with }}.",
maxTokens: 2048, temperature: 0.2f, cancellationToken: cancellationToken);
}
public async Task<string> SummarizeEntryAsync(string content, string? fileStem = null,
CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(content))
throw new ArgumentException("Entry content is required.", nameof(content));
var prompt = fileStem is not null
? $"Summarize this journal entry ({fileStem}) concisely:\n\n{content}"
: $"Summarize this journal entry concisely:\n\n{content}";
return await ChatAsync(prompt, cancellationToken);
}
public async Task<string> SummarizeAllAsync(IReadOnlyList<string> entries,
CancellationToken cancellationToken = default)
{
if (entries is null || entries.Count == 0)
return "No entries to summarize.";
var combined = string.Join("\n\n---\n\n", entries);
var prompt = $"Summarize the following {entries.Count} journal entries into a concise overview:\n\n{combined}";
return await ChatAsync(prompt, cancellationToken);
}
public async Task<IReadOnlyList<double>> EmbedAsync(string content,
CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(content))
throw new ArgumentException("Content is required.", nameof(content));
var modelPath = await EnsureModelAsync(cancellationToken);
try
{
EnsureWeights(modelPath);
var embedder = new LLamaEmbedder(_weights!, new ModelParams(modelPath)
{
Embeddings = true,
ContextSize = _contextSize,
GpuLayerCount = _gpuLayers
});
var embeddingArrays = await embedder.GetEmbeddings(content, cancellationToken);
var result = new List<double>();
foreach (var arr in embeddingArrays)
{
foreach (var val in arr)
{
result.Add(val);
}
}
return result;
}
catch
{
return [];
}
}
// ── Model download (mirrors LocalWhisperS2TService.EnsureModelAsync) ───
private static string GetDefaultModelPath()
{
var modelDirectory = Path.Combine(
Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData),
"ProjectJournal",
ModelSubDirectory);
return Path.Combine(modelDirectory, DefaultModelFileName);
}
private async Task<string> EnsureModelAsync(CancellationToken cancellationToken = default)
{
if (File.Exists(_configuredModelPath))
return _configuredModelPath;
var defaultPath = GetDefaultModelPath();
if (File.Exists(defaultPath))
return defaultPath;
var modelDirectory = Path.GetDirectoryName(defaultPath)!;
Directory.CreateDirectory(modelDirectory);
var tempPath = defaultPath + ".download";
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(TimeSpan.FromMinutes(30));
using var httpClient = new HttpClient();
httpClient.Timeout = TimeSpan.FromMinutes(30);
using var response = await httpClient.GetAsync(DefaultModelUrl,
HttpCompletionOption.ResponseHeadersRead, cts.Token);
response.EnsureSuccessStatusCode();
await using var contentStream = await response.Content.ReadAsStreamAsync(cts.Token);
await using var fileStream = new FileStream(tempPath, FileMode.Create, FileAccess.Write, FileShare.None);
await contentStream.CopyToAsync(fileStream, cts.Token);
await fileStream.FlushAsync(cts.Token);
fileStream.Close();
File.Move(tempPath, defaultPath, overwrite: true);
return defaultPath;
}
// ── Session / weights lifecycle ────────────────────────────────────────
private Task<string> RunSessionAsync(string prompt, string systemPrompt,
int maxTokens, CancellationToken cancellationToken)
=> RunSessionAsync(prompt, systemPrompt, maxTokens, temperature: 0.7f, cancellationToken);
private async Task<string> RunSessionAsync(string prompt, string systemPrompt,
int maxTokens, float temperature, CancellationToken cancellationToken)
{
var modelPath = await EnsureModelAsync(cancellationToken);
EnsureWeights(modelPath);
using var context = _weights!.CreateContext(new ModelParams(modelPath)
{
ContextSize = _contextSize,
GpuLayerCount = _gpuLayers
});
var executor = new StatelessExecutor(_weights!, context.Params);
var fullPrompt = $"<|system|>\n{systemPrompt}<|end|>\n" +
$"<|user|>\n{prompt}<|end|>\n" +
$"<|assistant|>\n";
var inferenceParams = new InferenceParams
{
MaxTokens = maxTokens,
AntiPrompts = [
"<|user|>",
"<|system|>",
"<|end|>",
"<|endoftext|>",
],
SamplingPipeline = new DefaultSamplingPipeline
{
Temperature = temperature
}
};
var sb = new StringBuilder();
await foreach (var token in executor.InferAsync(fullPrompt, inferenceParams, cancellationToken))
{
sb.Append(token);
}
return StripSpecialTokens(sb.ToString());
}
private static string StripSpecialTokens(string raw)
{
var text = raw;
foreach (var marker in new[] { "<|assistant|>", "<|user|>", "<|system|>", "<|end|>", "<|endoftext|>" })
text = text.Replace(marker, "");
return text.Trim();
}
private static readonly Regex RoleMarkerRegex = MyRegex();
private static string CleanChatResponse(string raw)
{
var text = StripSpecialTokens(raw);
text = RoleMarkerRegex.Replace(text, "");
text = text.Replace("**", "");
text = MyRegex2().Replace(text, "\n\n");
return text.Trim();
}
private void EnsureWeights(string modelPath)
{
if (_weights is not null) return;
lock (_sync)
{
if (_weights is not null) return;
_resolvedModelPath = modelPath;
_weights = LLamaWeights.LoadFromFile(new ModelParams(modelPath)
{
ContextSize = _contextSize,
GpuLayerCount = _gpuLayers
});
}
}
public void Dispose()
{
if (_disposed) return;
_disposed = true;
_weights?.Dispose();
}
[GeneratedRegex(@"\*{0,2}(System|Assistant|User):\*{0,2}", RegexOptions.IgnoreCase | RegexOptions.Compiled, "en-US")]
private static partial Regex MyRegex();
[GeneratedRegex(@"\n{3,}")]
private static partial Regex MyRegex1();
[GeneratedRegex(@"\n{3,}")]
private static partial Regex MyRegex2();
}