langchain[minor]: Migrate mlflow and databricks classes to deployments APIs. (#13699)
## Description
Related to https://github.com/mlflow/mlflow/pull/10420. MLflow AI
gateway will be deprecated and replaced by the `mlflow.deployments`
module. Happy to split this PR if it's too large.
```
pip install git+https://github.com/langchain-ai/langchain.git@refs/pull/13699/merge#subdirectory=libs/langchain
```
## Dependencies
Install mlflow from https://github.com/mlflow/mlflow/pull/10420:
```
pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10420/merge
```
## Testing plan
The following code works fine on local and databricks:
<details><summary>Click</summary>
<p>
```python
"""
Setup
-----
mlflow deployments start-server --config-path examples/gateway/openai/config.yaml
databricks secrets create-scope <scope>
databricks secrets put-secret <scope> openai-api-key --string-value $OPENAI_API_KEY
Run
---
python /path/to/this/file.py secrets/<scope>/openai-api-key
"""
from langchain.chat_models import ChatMlflow, ChatDatabricks
from langchain.embeddings import MlflowEmbeddings, DatabricksEmbeddings
from langchain.llms import Databricks, Mlflow
from langchain.schema.messages import HumanMessage
from langchain.chains.loading import load_chain
from mlflow.deployments import get_deploy_client
import uuid
import sys
import tempfile
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
###############################
# MLflow
###############################
chat = ChatMlflow(
target_uri="http://127.0.0.1:5000", endpoint="chat", params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))
embeddings = MlflowEmbeddings(target_uri="http://127.0.0.1:5000", endpoint="embeddings")
print(embeddings.embed_query("hello")[:3])
print(embeddings.embed_documents(["hello", "world"])[0][:3])
llm = Mlflow(
target_uri="http://127.0.0.1:5000",
endpoint="completions",
params={"temperature": 0.1},
)
print(llm("I am"))
llm_chain = LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=["adjective"],
template="Tell me a {adjective} joke",
),
)
print(llm_chain.run(adjective="funny"))
# serialization/deserialization
with tempfile.TemporaryDirectory() as tmpdir:
print(tmpdir)
path = f"{tmpdir}/llm.yaml"
llm_chain.save(path)
loaded_chain = load_chain(path)
print(loaded_chain("funny"))
###############################
# Databricks
###############################
secret = sys.argv[1]
client = get_deploy_client("databricks")
# External - chat
name = f"chat-{uuid.uuid4()}"
client.create_endpoint(
name=name,
config={
"served_entities": [
{
"name": "test",
"external_model": {
"name": "gpt-4",
"provider": "openai",
"task": "llm/v1/chat",
"openai_config": {
"openai_api_key": "{{" + secret + "}}",
},
},
}
],
},
)
try:
chat = ChatDatabricks(
target_uri="databricks", endpoint=name, params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))
finally:
client.delete_endpoint(endpoint=name)
# External - embeddings
name = f"embeddings-{uuid.uuid4()}"
client.create_endpoint(
name=name,
config={
"served_entities": [
{
"name": "test",
"external_model": {
"name": "text-embedding-ada-002",
"provider": "openai",
"task": "llm/v1/embeddings",
"openai_config": {
"openai_api_key": "{{" + secret + "}}",
},
},
}
],
},
)
try:
embeddings = DatabricksEmbeddings(target_uri="databricks", endpoint=name)
print(embeddings.embed_query("hello")[:3])
print(embeddings.embed_documents(["hello", "world"])[0][:3])
finally:
client.delete_endpoint(endpoint=name)
# External - completions
name = f"completions-{uuid.uuid4()}"
client.create_endpoint(
name=name,
config={
"served_entities": [
{
"name": "test",
"external_model": {
"name": "gpt-3.5-turbo-instruct",
"provider": "openai",
"task": "llm/v1/completions",
"openai_config": {
"openai_api_key": "{{" + secret + "}}",
},
},
}
],
},
)
try:
llm = Databricks(
endpoint_name=name,
model_kwargs={"temperature": 0.1},
)
print(llm("I am"))
finally:
client.delete_endpoint(endpoint=name)
# Foundation model - chat
chat = ChatDatabricks(
endpoint="databricks-llama-2-70b-chat", params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))
# Foundation model - embeddings
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
print(embeddings.embed_query("hello")[:3])
# Foundation model - completions
llm = Databricks(
endpoint_name="databricks-mpt-7b-instruct", model_kwargs={"temperature": 0.1}
)
print(llm("hello"))
llm_chain = LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=["adjective"],
template="Tell me a {adjective} joke",
),
)
print(llm_chain.run(adjective="funny"))
# serialization/deserialization
with tempfile.TemporaryDirectory() as tmpdir:
print(tmpdir)
path = f"{tmpdir}/llm.yaml"
llm_chain.save(path)
loaded_chain = load_chain(path)
print(loaded_chain("funny"))
```
Output:
```
content='Hello! How can I assist you today?'
[-0.025058426, -0.01938856, -0.027781019]
[-0.025058426, -0.01938856, -0.027781019]
sorry, but I cannot continue the sentence as it is incomplete. Can you please provide more information or context?
Sure, here's a classic one for you:
Why don't scientists trust atoms?
Because they make up everything!
/var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmpx_4no6ad
{'adjective': 'funny', 'text': "Sure, here's a classic one for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!"}
content='Hello! How can I assist you today?'
[-0.025058426, -0.01938856, -0.027781019]
[-0.025058426, -0.01938856, -0.027781019]
a 23 year old female and I am currently studying for my master's degree
content="\nHello! It's nice to meet you. Is there something I can help you with or would you like to chat for a bit?"
[0.051055908203125, 0.007221221923828125, 0.003879547119140625]
[0.051055908203125, 0.007221221923828125, 0.003879547119140625]
hello back
Well, I don't really know many jokes, but I do know this funny story...
/var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmp7_ds72ex
{'adjective': 'funny', 'text': " Well, I don't really know many jokes, but I do know this funny story..."}
```
</p>
</details>
The existing workflow doesn't break:
<details><summary>click</summary>
<p>
```python
import uuid
import mlflow
from mlflow.models import ModelSignature
from mlflow.types.schema import ColSpec, Schema
class MyModel(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input):
return str(uuid.uuid4())
with mlflow.start_run():
mlflow.pyfunc.log_model(
"model",
python_model=MyModel(),
pip_requirements=["mlflow==2.8.1", "cloudpickle<3"],
signature=ModelSignature(
inputs=Schema(
[
ColSpec("string", "prompt"),
ColSpec("string", "stop"),
]
),
outputs=Schema(
[
ColSpec(name=None, type="string"),
]
),
),
registered_model_name=f"lang-{uuid.uuid4()}",
)
# Manually create a serving endpoint with the registered model and run
from langchain.llms import Databricks
llm = Databricks(endpoint_name="<name>")
llm("hello") # 9d0b2491-3d13-487c-bc02-1287f06ecae7
```
</p>
</details>
## Follow-up tasks
(This PR is too large. I'll file a separate one for follow-up tasks.)
- Update `docs/docs/integrations/providers/mlflow_ai_gateway.mdx` and
`docs/docs/integrations/providers/databricks.md`.
---------
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>